package com.simiacryptus.mindseye.util;

import com.google.common.collect.Streams;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.layers.cudnn.ActivationLayer;
import com.simiacryptus.mindseye.layers.cudnn.ImgBandBiasLayer;
import com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer;
import com.simiacryptus.mindseye.layers.cudnn.LRNLayer;
import com.simiacryptus.mindseye.layers.cudnn.PoolingLayer;
import com.simiacryptus.mindseye.layers.cudnn.conv.SimpleConvolutionLayer;
import com.simiacryptus.mindseye.layers.java.FullyConnectedLayer;
import com.simiacryptus.mindseye.layers.tensorflow.MatMulLayer;
import com.simiacryptus.mindseye.layers.tensorflow.TFLayer;
import com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.network.InnerNode;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefCollectors;
import com.simiacryptus.ref.wrappers.RefConcurrentHashMap;
import com.simiacryptus.ref.wrappers.RefHashMap;
import com.simiacryptus.ref.wrappers.RefIntStream;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.tensorflow.GraphModel;
import com.simiacryptus.tensorflow.ImageNetworkPipeline;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.IntStream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.jetbrains.annotations.NotNull;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;

/* loaded from: input_file:com/simiacryptus/mindseye/util/TFConverter.class */
public class TFConverter {
    static final /* synthetic */ boolean $assertionsDisabled;

    public static RefList<TFLayer> getLayers(@Nonnull ImageNetworkPipeline imageNetworkPipeline) {
        return (RefList) RefIntStream.range(0, imageNetworkPipeline.graphDefs.size()).mapToObj(i -> {
            return getLayer(imageNetworkPipeline, i);
        }).collect(RefCollectors.toList());
    }

    @Nonnull
    public static TFLayer getLayer(@Nonnull ImageNetworkPipeline imageNetworkPipeline, int i) {
        TFLayer tFLayer = new TFLayer(((GraphDef) imageNetworkPipeline.graphDefs.get(i)).toByteArray(), new RefHashMap(), (String) imageNetworkPipeline.nodeIds().get(i), i == 0 ? "input" : (String) imageNetworkPipeline.nodeIds().get(i - 1));
        tFLayer.setFloat(true);
        TFLayer mo2addRef = tFLayer.mo2addRef();
        tFLayer.freeRef();
        return mo2addRef;
    }

    @Nonnull
    public FullyConnectedLayer getFCLayer(@Nonnull MatMulLayer matMulLayer) {
        RefMap<String, Tensor> weights = matMulLayer.getWeights();
        if (!$assertionsDisabled && weights == null) {
            throw new AssertionError();
        }
        Tensor tensor = (Tensor) weights.get("weights");
        weights.freeRef();
        int[] intputDims = matMulLayer.getIntputDims();
        int[] outputDims = matMulLayer.getOutputDims();
        matMulLayer.freeRef();
        int[] array = Streams.concat(new IntStream[]{RefArrays.stream(outputDims), RefIntStream.range(0, intputDims.length).map(i -> {
            return (intputDims.length - 1) - i;
        }).map(i2 -> {
            return intputDims[i2];
        })}).toArray();
        int[] array2 = Streams.concat(new IntStream[]{RefIntStream.range(0, intputDims.length).map(i3 -> {
            return ((outputDims.length + intputDims.length) - 1) - i3;
        }), RefIntStream.range(0, outputDims.length)}).toArray();
        if (!$assertionsDisabled && tensor == null) {
            throw new AssertionError();
        }
        Tensor reshapeCast = tensor.reshapeCast(array);
        Tensor permuteDimensions = reshapeCast.permuteDimensions(array2);
        reshapeCast.freeRef();
        tensor.freeRef();
        FullyConnectedLayer fullyConnectedLayer = new FullyConnectedLayer(intputDims, outputDims);
        Tensor weights2 = fullyConnectedLayer.getWeights();
        if (!$assertionsDisabled && weights2 == null) {
            throw new AssertionError();
        }
        weights2.set(permuteDimensions.addRef());
        weights2.freeRef();
        permuteDimensions.freeRef();
        return fullyConnectedLayer;
    }

    @Nonnull
    public PipelineNetwork convert(@Nonnull TFLayerBase tFLayerBase) {
        PipelineNetwork pipelineNetwork = new PipelineNetwork(1, new Layer[0]);
        RefConcurrentHashMap refConcurrentHashMap = new RefConcurrentHashMap();
        RefUtil.freeRef(getNode(tFLayerBase.getOutputNode(), pipelineNetwork.addRef(), new GraphModel(tFLayerBase.constGraph().toByteArray()), (RefMap) RefUtil.addRef(refConcurrentHashMap)));
        tFLayerBase.freeRef();
        refConcurrentHashMap.freeRef();
        return pipelineNetwork;
    }

    @Nullable
    protected DAGNode getNode(@Nonnull String str, @Nonnull PipelineNetwork pipelineNetwork, @Nonnull GraphModel graphModel, @Nonnull RefMap<String, DAGNode> refMap) {
        try {
            if (refMap.containsKey(str)) {
                pipelineNetwork.freeRef();
            } else {
                DAGNode dagNode = getDagNode(str, pipelineNetwork, graphModel, refMap.addRef());
                if (refMap.containsKey(str)) {
                    dagNode.freeRef();
                } else {
                    RefUtil.freeRef(refMap.put(str, dagNode));
                }
            }
            DAGNode dAGNode = (DAGNode) refMap.get(str);
            refMap.freeRef();
            return dAGNode;
        } catch (Throwable th) {
            throw new RuntimeException("Error converting " + str, th);
        }
    }

    @Nonnull
    protected PoolingLayer getPoolingLayer(@Nonnull GraphModel.GraphNode graphNode) {
        PoolingLayer poolingLayer = new PoolingLayer();
        poolingLayer.setMode(PoolingLayer.PoolingMode.Max);
        PoolingLayer addRef = poolingLayer.addRef();
        poolingLayer.freeRef();
        Map attrMap = graphNode.getNodeDef().getAttrMap();
        if (!$assertionsDisabled && !"SAME".equals(((AttrValue) attrMap.get("padding")).getS().toStringUtf8())) {
            throw new AssertionError();
        }
        AttrValue attrValue = (AttrValue) attrMap.get("ksize");
        if (null != attrValue) {
            List iList = attrValue.getList().getIList();
            addRef.setWindowX(Math.toIntExact(((Long) iList.get(1)).longValue()));
            addRef.setWindowY(Math.toIntExact(((Long) iList.get(2)).longValue()));
        }
        AttrValue attrValue2 = (AttrValue) attrMap.get("strides");
        if (null != attrValue2) {
            List iList2 = attrValue2.getList().getIList();
            addRef.setStrideX(Math.toIntExact(((Long) iList2.get(1)).longValue()));
            addRef.setStrideY(Math.toIntExact(((Long) iList2.get(2)).longValue()));
        }
        return addRef;
    }

    @Nonnull
    protected ImgBandBiasLayer getBiasAdd(@Nonnull GraphModel.GraphNode graphNode) {
        GraphModel.GraphNode graphNode2 = (GraphModel.GraphNode) graphNode.getInputs().get(1);
        if (!$assertionsDisabled && !graphNode2.getOp().equals("Const")) {
            throw new AssertionError();
        }
        double[] data = graphNode2.getData();
        if (!$assertionsDisabled && data == null) {
            throw new AssertionError();
        }
        Tensor tensor = new Tensor(data, new int[]{data.length});
        ImgBandBiasLayer imgBandBiasLayer = new ImgBandBiasLayer(data.length);
        imgBandBiasLayer.set(tensor.addRef());
        ImgBandBiasLayer addRef = imgBandBiasLayer.addRef();
        imgBandBiasLayer.freeRef();
        tensor.freeRef();
        return addRef;
    }

    @Nonnull
    protected Layer getConv2D(@Nonnull GraphModel.GraphNode graphNode) {
        GraphModel.GraphNode graphNode2 = (GraphModel.GraphNode) graphNode.getInputs().get(1);
        if (!$assertionsDisabled && !graphNode2.getOp().equals("Const")) {
            throw new AssertionError();
        }
        int[] array = RefArrays.stream(graphNode2.getShape()).mapToInt(j -> {
            return (int) j;
        }).toArray();
        double[] data = graphNode2.getData();
        if (array.length == 0) {
            if (!$assertionsDisabled && data == null) {
                throw new AssertionError();
            }
            array = new int[]{data.length};
        }
        Tensor tensor = new Tensor(data, new int[]{array[3], array[2], array[1], array[0]});
        Tensor invertDimensions = tensor.invertDimensions();
        tensor.freeRef();
        int[] dimensions = invertDimensions.getDimensions();
        SimpleConvolutionLayer simpleConvolutionLayer = new SimpleConvolutionLayer(dimensions[0], dimensions[1], dimensions[2] * dimensions[3]);
        Tensor tensor2 = new Tensor(new int[]{dimensions[0], dimensions[1], dimensions[2], dimensions[3]});
        invertDimensions.coordStream(false).forEach((Consumer) RefUtil.wrapInterface(coordinate -> {
            int[] coords = coordinate.getCoords();
            tensor2.set((dimensions[0] - 1) - coords[0], (dimensions[1] - 1) - coords[1], coords[2], coords[3], invertDimensions.get(coordinate));
        }, new Object[]{tensor2.addRef(), invertDimensions.addRef()}));
        invertDimensions.freeRef();
        simpleConvolutionLayer.set(tensor2.addRef());
        tensor2.freeRef();
        AttrValue attrValue = (AttrValue) graphNode.getNodeDef().getAttrMap().get("strides");
        if (null == attrValue) {
            return simpleConvolutionLayer;
        }
        int[] array2 = attrValue.getList().getIList().stream().mapToInt(l -> {
            return Math.toIntExact(l.longValue());
        }).toArray();
        int i = array2[1];
        int i2 = array2[2];
        if (i <= 1 && i2 <= 1) {
            return simpleConvolutionLayer;
        }
        simpleConvolutionLayer.setStrideX(i);
        simpleConvolutionLayer.setStrideY(i2);
        return simpleConvolutionLayer;
    }

    @NotNull
    private DAGNode getDagNode(@Nonnull String str, @Nonnull PipelineNetwork pipelineNetwork, @Nonnull GraphModel graphModel, @Nonnull RefMap<String, DAGNode> refMap) {
        GraphModel.GraphNode child = graphModel.getChild(str);
        if (!$assertionsDisabled && null == child) {
            throw new AssertionError();
        }
        try {
            if (child.getOp().equals("Conv2D")) {
                InnerNode add = pipelineNetwork.add(getConv2D(child), new DAGNode[]{getNode((String) child.getInputKeys().get(0), pipelineNetwork.addRef(), graphModel, refMap)});
                pipelineNetwork.freeRef();
                return add;
            }
            if (child.getOp().equals("BiasAdd")) {
                InnerNode add2 = pipelineNetwork.add(getBiasAdd(child), new DAGNode[]{getNode((String) child.getInputKeys().get(0), pipelineNetwork.addRef(), graphModel, refMap)});
                pipelineNetwork.freeRef();
                return add2;
            }
            if (child.getOp().equals("Relu")) {
                InnerNode add3 = pipelineNetwork.add(new ActivationLayer(ActivationLayer.Mode.RELU), new DAGNode[]{getNode((String) child.getInputKeys().get(0), pipelineNetwork.addRef(), graphModel, refMap)});
                pipelineNetwork.freeRef();
                return add3;
            }
            if (child.getOp().equals("LRN")) {
                InnerNode add4 = pipelineNetwork.add(getLRNLayer(child), new DAGNode[]{getNode((String) child.getInputKeys().get(0), pipelineNetwork.addRef(), graphModel, refMap)});
                pipelineNetwork.freeRef();
                return add4;
            }
            if (child.getOp().equals("MaxPool")) {
                InnerNode add5 = pipelineNetwork.add(getPoolingLayer(child), new DAGNode[]{getNode((String) child.getInputKeys().get(0), pipelineNetwork.addRef(), graphModel, refMap)});
                pipelineNetwork.freeRef();
                return add5;
            }
            if (child.getOp().equals("Concat")) {
                InnerNode add6 = pipelineNetwork.add(new ImgConcatLayer(), (DAGNode[]) child.getInputKeys().stream().skip(1L).map(str2 -> {
                    return getNode(str2, pipelineNetwork.addRef(), graphModel, (RefMap) RefUtil.addRef(refMap));
                }).toArray(i -> {
                    return new DAGNode[i];
                }));
                refMap.freeRef();
                pipelineNetwork.freeRef();
                return add6;
            }
            if (!child.getOp().equals("Placeholder")) {
                refMap.freeRef();
                throw new IllegalArgumentException(child.getOp());
            }
            refMap.freeRef();
            DAGNode input = pipelineNetwork.getInput(0);
            pipelineNetwork.freeRef();
            return input;
        } catch (Throwable th) {
            pipelineNetwork.freeRef();
            throw th;
        }
    }

    @Nonnull
    private LRNLayer getLRNLayer(@Nonnull GraphModel.GraphNode graphNode) {
        Map attrMap = graphNode.getNodeDef().getAttrMap();
        long i = ((AttrValue) attrMap.get("depth_radius")).getI();
        float f = ((AttrValue) attrMap.get("alpha")).getF();
        float f2 = ((AttrValue) attrMap.get("bias")).getF();
        float f3 = ((AttrValue) attrMap.get("beta")).getF();
        LRNLayer lRNLayer = new LRNLayer((int) ((i * 2) + 1));
        lRNLayer.setAlpha(f * ((float) r0));
        LRNLayer addRef = lRNLayer.addRef();
        addRef.setBeta(f3);
        LRNLayer addRef2 = addRef.addRef();
        addRef2.setK(f2);
        LRNLayer addRef3 = addRef2.addRef();
        addRef2.freeRef();
        addRef.freeRef();
        lRNLayer.freeRef();
        return addRef3;
    }

    static {
        $assertionsDisabled = !TFConverter.class.desiredAssertionStatus();
    }
}
