package com.simiacryptus.mindseye.lang.tensorflow;

import com.google.common.primitives.Floats;
import com.simiacryptus.mindseye.lang.TensorArray;
import com.simiacryptus.mindseye.lang.TensorList;
import com.simiacryptus.ref.lang.RecycleBin;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefDoubleStream;
import com.simiacryptus.ref.wrappers.RefIntStream;
import com.simiacryptus.ref.wrappers.RefLongStream;
import com.simiacryptus.ref.wrappers.RefSystem;
import com.simiacryptus.util.Util;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.jetbrains.annotations.NotNull;
import org.tensorflow.DataType;
import org.tensorflow.Tensor;

/* loaded from: input_file:com/simiacryptus/mindseye/lang/tensorflow/TFIO.class */
public class TFIO {
    @NotNull
    public static TensorArray getTensorList(Tensor<?> tensor) {
        return getTensorList(tensor, true);
    }

    @NotNull
    public static TensorArray getTensorList(Tensor<?> tensor, boolean z) {
        if (tensor.dataType() == DataType.DOUBLE) {
            return getTensorArray_Double(tensor.expect(Double.class), tensor.shape(), z);
        }
        if (tensor.dataType() == DataType.FLOAT) {
            return getTensorArray_Float(tensor.expect(Float.class), tensor.shape(), z);
        }
        throw new IllegalArgumentException(tensor.dataType().toString());
    }

    @NotNull
    public static com.simiacryptus.mindseye.lang.Tensor getTensor(Tensor<?> tensor) {
        return getTensor(tensor, true);
    }

    @NotNull
    public static com.simiacryptus.mindseye.lang.Tensor getTensor(Tensor<?> tensor, boolean z) {
        if (tensor.dataType() == DataType.DOUBLE) {
            return getTensor_Double(tensor.expect(Double.class), tensor.shape(), z);
        }
        if (tensor.dataType() == DataType.FLOAT) {
            return getTensor_Float(tensor.expect(Float.class), tensor.shape(), z);
        }
        throw new IllegalArgumentException(tensor.dataType().toString());
    }

    @NotNull
    public static Tensor<Float> getFloatTensor(@Nullable com.simiacryptus.mindseye.lang.Tensor tensor) {
        return getFloatTensor(tensor, true);
    }

    @NotNull
    public static Tensor<Float> getFloatTensor(@NotNull com.simiacryptus.mindseye.lang.Tensor tensor, boolean z) {
        com.simiacryptus.mindseye.lang.Tensor tensor2;
        double[] data;
        if (z) {
            tensor2 = tensor.invertDimensions();
            data = tensor2.getData();
        } else {
            tensor2 = null;
            data = tensor.getData();
        }
        Tensor<Float> create = Tensor.create(Util.toLong(tensor.getDimensions()), FloatBuffer.wrap(Util.getFloats(data)));
        if (null != tensor2) {
            tensor2.freeRef();
        }
        tensor.freeRef();
        return create;
    }

    @NotNull
    public static Tensor<Float> getFloatTensor(@Nullable TensorList tensorList) {
        return getFloatTensor(tensorList, true);
    }

    @NotNull
    public static Tensor<Float> getFloatTensor(@Nullable TensorList tensorList, boolean z) {
        long[] array = RefLongStream.concat(RefLongStream.of(tensorList.length()), RefArrays.stream(tensorList.getDimensions()).mapToLong(i -> {
            return i;
        })).toArray();
        double[] doubles = getDoubles(tensorList, z);
        Tensor<Float> create = Tensor.create(array, FloatBuffer.wrap(Util.getFloats(doubles)));
        RecycleBin.DOUBLES.recycle(doubles, doubles.length);
        return create;
    }

    @NotNull
    public static Tensor<Double> getDoubleTensor(@Nullable com.simiacryptus.mindseye.lang.Tensor tensor) {
        Tensor<Double> doubleTensor = getDoubleTensor(tensor == null ? null : tensor.addRef(), true);
        if (null != tensor) {
            tensor.freeRef();
        }
        return doubleTensor;
    }

    @NotNull
    public static Tensor<Double> getDoubleTensor(@NotNull com.simiacryptus.mindseye.lang.Tensor tensor, boolean z) {
        com.simiacryptus.mindseye.lang.Tensor tensor2;
        double[] data;
        if (z) {
            tensor2 = tensor.invertDimensions();
            data = tensor2.getData();
        } else {
            tensor2 = null;
            data = tensor.getData();
        }
        Tensor<Double> create = Tensor.create(Util.toLong(tensor.getDimensions()), DoubleBuffer.wrap(data));
        tensor.freeRef();
        if (null != tensor2) {
            tensor2.freeRef();
        }
        return create;
    }

    @NotNull
    public static Tensor<Double> getDoubleTensor(@Nullable TensorList tensorList) {
        Tensor<Double> doubleTensor = getDoubleTensor(tensorList == null ? null : tensorList.addRef(), true);
        if (null != tensorList) {
            tensorList.freeRef();
        }
        return doubleTensor;
    }

    @NotNull
    public static Tensor<Double> getDoubleTensor(@Nullable TensorList tensorList, boolean z) {
        long[] array = RefLongStream.concat(RefLongStream.of(tensorList.length()), RefArrays.stream(tensorList.getDimensions()).mapToLong(i -> {
            return i;
        })).toArray();
        double[] doubles = getDoubles(tensorList, z);
        Tensor<Double> create = Tensor.create(array, DoubleBuffer.wrap(doubles));
        RecycleBin.DOUBLES.recycle(doubles, doubles.length);
        return create;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void free(Object obj) {
        if (obj instanceof double[]) {
            RecycleBin.DOUBLES.recycle((double[]) obj, r0.length);
        } else if (!(obj instanceof float[])) {
            RefArrays.stream((Object[]) obj).forEach(obj2 -> {
                free(obj2);
            });
        } else {
            RecycleBin.FLOATS.recycle((float[]) obj, r0.length);
        }
    }

    private static Object createFloatArray(@NotNull long[] jArr) {
        if (jArr.length == 1) {
            return RecycleBin.FLOATS.obtain(jArr[0]);
        }
        if (jArr.length == 2) {
            return RefIntStream.range(0, (int) jArr[0]).mapToObj(i -> {
                return new float[(int) jArr[1]];
            }).toArray(i2 -> {
                return new float[i2];
            });
        }
        if (jArr.length == 3) {
            return RefIntStream.range(0, (int) jArr[0]).mapToObj(i3 -> {
                return (float[][]) RefIntStream.range(0, (int) jArr[1]).mapToObj(i3 -> {
                    return new float[(int) jArr[2]];
                }).toArray(i4 -> {
                    return new float[i4];
                });
            }).toArray(i4 -> {
                return new float[i4];
            });
        }
        if (jArr.length == 4) {
            return RefIntStream.range(0, (int) jArr[0]).mapToObj(i5 -> {
                return (float[][][]) RefIntStream.range(0, (int) jArr[1]).mapToObj(i5 -> {
                    return (float[][]) RefIntStream.range(0, (int) jArr[2]).mapToObj(i5 -> {
                        return new float[(int) jArr[3]];
                    }).toArray(i6 -> {
                        return new float[i6];
                    });
                }).toArray(i6 -> {
                    return new float[i6];
                });
            }).toArray(i6 -> {
                return new float[i6][];
            });
        }
        if (jArr.length == 5) {
            return RefIntStream.range(0, (int) jArr[0]).mapToObj(i7 -> {
                return (float[][][][]) RefIntStream.range(0, (int) jArr[1]).mapToObj(i7 -> {
                    return (float[][][]) RefIntStream.range(0, (int) jArr[2]).mapToObj(i7 -> {
                        return (float[][]) RefIntStream.range(0, (int) jArr[3]).mapToObj(i7 -> {
                            return new float[(int) jArr[4]];
                        }).toArray(i8 -> {
                            return new float[i8];
                        });
                    }).toArray(i8 -> {
                        return new float[i8];
                    });
                }).toArray(i8 -> {
                    return new float[i8][];
                });
            }).toArray(i8 -> {
                return new float[i8][][];
            });
        }
        if (jArr.length == 6) {
            return RefIntStream.range(0, (int) jArr[0]).mapToObj(i9 -> {
                return (float[][][][][]) RefIntStream.range(0, (int) jArr[1]).mapToObj(i9 -> {
                    return (float[][][][]) RefIntStream.range(0, (int) jArr[2]).mapToObj(i9 -> {
                        return (float[][][]) RefIntStream.range(0, (int) jArr[3]).mapToObj(i9 -> {
                            return (float[][]) RefIntStream.range(0, (int) jArr[4]).mapToObj(i9 -> {
                                return new float[(int) jArr[5]];
                            }).toArray(i10 -> {
                                return new float[i10];
                            });
                        }).toArray(i10 -> {
                            return new float[i10];
                        });
                    }).toArray(i10 -> {
                        return new float[i10][];
                    });
                }).toArray(i10 -> {
                    return new float[i10][][];
                });
            }).toArray(i10 -> {
                return new float[i10][][][];
            });
        }
        throw new RuntimeException("Rank " + jArr.length);
    }

    private static Object createDoubleArray(@NotNull long[] jArr) {
        if (jArr.length == 1) {
            return RecycleBin.DOUBLES.obtain(jArr[0]);
        }
        if (jArr.length == 2) {
            return RefIntStream.range(0, (int) jArr[0]).mapToObj(i -> {
                return new double[(int) jArr[1]];
            }).toArray(i2 -> {
                return new double[i2];
            });
        }
        if (jArr.length == 3) {
            return RefIntStream.range(0, (int) jArr[0]).mapToObj(i3 -> {
                return (double[][]) RefIntStream.range(0, (int) jArr[1]).mapToObj(i3 -> {
                    return new double[(int) jArr[2]];
                }).toArray(i4 -> {
                    return new double[i4];
                });
            }).toArray(i4 -> {
                return new double[i4];
            });
        }
        if (jArr.length == 4) {
            return RefIntStream.range(0, (int) jArr[0]).mapToObj(i5 -> {
                return (double[][][]) RefIntStream.range(0, (int) jArr[1]).mapToObj(i5 -> {
                    return (double[][]) RefIntStream.range(0, (int) jArr[2]).mapToObj(i5 -> {
                        return new double[(int) jArr[3]];
                    }).toArray(i6 -> {
                        return new double[i6];
                    });
                }).toArray(i6 -> {
                    return new double[i6];
                });
            }).toArray(i6 -> {
                return new double[i6][];
            });
        }
        throw new RuntimeException("Rank " + jArr.length);
    }

    /* JADX INFO: Access modifiers changed from: private */
    @NotNull
    public static RefDoubleStream flattenDoubles(Object obj) {
        return obj instanceof double[] ? RefArrays.stream((double[]) obj) : obj instanceof Double ? RefDoubleStream.of(((Double) obj).doubleValue()) : RefArrays.stream((Object[]) obj).flatMapToDouble(obj2 -> {
            return flattenDoubles(obj2);
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Stream<Float> flattenFloats(Object obj) {
        return obj instanceof float[] ? Floats.asList((float[]) obj).stream() : RefArrays.stream((Object[]) obj).flatMap(obj2 -> {
            return flattenFloats(obj2);
        });
    }

    private static double[] getDoubles(@NotNull TensorList tensorList, boolean z) {
        double[] dArr = (double[]) RecycleBin.DOUBLES.obtain(tensorList.length() * com.simiacryptus.mindseye.lang.Tensor.length(tensorList.getDimensions()));
        DoubleBuffer wrap = DoubleBuffer.wrap(dArr);
        if (z) {
            tensorList.stream().map(tensor -> {
                com.simiacryptus.mindseye.lang.Tensor invertDimensions = tensor.invertDimensions();
                tensor.freeRef();
                return invertDimensions;
            }).forEach(tensor2 -> {
                wrap.put(tensor2.getData());
                tensor2.freeRef();
            });
        } else {
            tensorList.stream().forEach(tensor3 -> {
                wrap.put(tensor3.getData());
                tensor3.freeRef();
            });
        }
        tensorList.freeRef();
        return dArr;
    }

    @NotNull
    private static TensorArray getTensorArray_Float(Tensor<Float> tensor, @NotNull long[] jArr, boolean z) {
        float[] floats = getFloats(tensor);
        int[] array = RefArrays.stream(jArr).skip(1L).mapToInt(j -> {
            return (int) j;
        }).toArray();
        TensorArray tensorArray = new TensorArray((com.simiacryptus.mindseye.lang.Tensor[]) RefIntStream.range(0, (int) jArr[0]).mapToObj(i -> {
            int length = i * com.simiacryptus.mindseye.lang.Tensor.length(array);
            if (!z) {
                com.simiacryptus.mindseye.lang.Tensor tensor2 = new com.simiacryptus.mindseye.lang.Tensor(array);
                tensor2.set(i -> {
                    return floats[i + length];
                });
                return tensor2;
            }
            com.simiacryptus.mindseye.lang.Tensor tensor3 = new com.simiacryptus.mindseye.lang.Tensor(com.simiacryptus.mindseye.lang.Tensor.reverse(array));
            tensor3.set(i2 -> {
                return floats[i2 + length];
            });
            com.simiacryptus.mindseye.lang.Tensor invertDimensions = tensor3.invertDimensions();
            tensor3.freeRef();
            return invertDimensions;
        }).toArray(i2 -> {
            return new com.simiacryptus.mindseye.lang.Tensor[i2];
        }));
        RecycleBin.FLOATS.recycle(floats, floats.length);
        return tensorArray;
    }

    @NotNull
    private static com.simiacryptus.mindseye.lang.Tensor getTensor_Float(Tensor<Float> tensor, @NotNull long[] jArr, boolean z) {
        if (0 == tensor.numElements()) {
            return new com.simiacryptus.mindseye.lang.Tensor(RefArrays.stream(jArr).mapToInt(j -> {
                return (int) j;
            }).toArray());
        }
        float[] floats = getFloats(tensor);
        int[] array = RefArrays.stream(jArr).mapToInt(j2 -> {
            return (int) j2;
        }).toArray();
        if (!z) {
            com.simiacryptus.mindseye.lang.Tensor tensor2 = new com.simiacryptus.mindseye.lang.Tensor(array);
            tensor2.set(i -> {
                return floats[i];
            });
            RecycleBin.FLOATS.recycle(floats, floats.length);
            return tensor2;
        }
        com.simiacryptus.mindseye.lang.Tensor tensor3 = new com.simiacryptus.mindseye.lang.Tensor(com.simiacryptus.mindseye.lang.Tensor.reverse(array));
        tensor3.set(i2 -> {
            return floats[i2];
        });
        RecycleBin.FLOATS.recycle(floats, floats.length);
        com.simiacryptus.mindseye.lang.Tensor invertDimensions = tensor3.invertDimensions();
        tensor3.freeRef();
        return invertDimensions;
    }

    @NotNull
    private static TensorArray getTensorArray_Double(Tensor<Double> tensor, @NotNull long[] jArr, boolean z) {
        double[] doubles = getDoubles(tensor);
        int[] array = RefArrays.stream(jArr).skip(1L).mapToInt(j -> {
            return (int) j;
        }).toArray();
        TensorArray tensorArray = new TensorArray((com.simiacryptus.mindseye.lang.Tensor[]) RefIntStream.range(0, (int) jArr[0]).mapToObj(i -> {
            if (!z) {
                com.simiacryptus.mindseye.lang.Tensor tensor2 = new com.simiacryptus.mindseye.lang.Tensor(array);
                RefSystem.arraycopy(doubles, i * tensor2.length(), tensor2.getData(), 0, tensor2.length());
                return tensor2;
            }
            com.simiacryptus.mindseye.lang.Tensor tensor3 = new com.simiacryptus.mindseye.lang.Tensor(com.simiacryptus.mindseye.lang.Tensor.reverse(array));
            RefSystem.arraycopy(doubles, i * tensor3.length(), tensor3.getData(), 0, tensor3.length());
            com.simiacryptus.mindseye.lang.Tensor invertDimensions = tensor3.invertDimensions();
            tensor3.freeRef();
            return invertDimensions;
        }).toArray(i2 -> {
            return new com.simiacryptus.mindseye.lang.Tensor[i2];
        }));
        RecycleBin.DOUBLES.recycle(doubles, doubles.length);
        return tensorArray;
    }

    @NotNull
    private static com.simiacryptus.mindseye.lang.Tensor getTensor_Double(Tensor<Double> tensor, @NotNull long[] jArr, boolean z) {
        double[] doubles = getDoubles(tensor);
        int[] array = RefArrays.stream(jArr).mapToInt(j -> {
            return (int) j;
        }).toArray();
        if (!z) {
            com.simiacryptus.mindseye.lang.Tensor tensor2 = new com.simiacryptus.mindseye.lang.Tensor(array);
            RefSystem.arraycopy(doubles, 0, tensor2.getData(), 0, tensor2.length());
            RecycleBin.DOUBLES.recycle(doubles, doubles.length);
            return tensor2;
        }
        com.simiacryptus.mindseye.lang.Tensor tensor3 = new com.simiacryptus.mindseye.lang.Tensor(com.simiacryptus.mindseye.lang.Tensor.reverse(array));
        RefSystem.arraycopy(doubles, 0, tensor3.getData(), 0, tensor3.length());
        RecycleBin.DOUBLES.recycle(doubles, doubles.length);
        com.simiacryptus.mindseye.lang.Tensor invertDimensions = tensor3.invertDimensions();
        tensor3.freeRef();
        return invertDimensions;
    }

    private static double[] getDoubles(Tensor<Double> tensor) {
        Object copyTo = tensor.copyTo(createDoubleArray(tensor.shape()));
        double[] array = flattenDoubles(copyTo).toArray();
        free(copyTo);
        return array;
    }

    @NotNull
    private static float[] getFloats(Tensor<Float> tensor) {
        if (0 == tensor.numElements()) {
            return new float[0];
        }
        Object copyTo = tensor.copyTo(createFloatArray(tensor.shape()));
        double[] array = flattenFloats(copyTo).mapToDouble(f -> {
            return f.floatValue();
        }).toArray();
        free(copyTo);
        return Util.getFloats(array);
    }
}
