package com.simiacryptus.mindseye.layers.tensorflow;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.protobuf.InvalidProtocolBufferException;
import com.simiacryptus.mindseye.lang.DataSerializer;
import com.simiacryptus.mindseye.lang.Delta;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.LayerBase;
import com.simiacryptus.mindseye.lang.Result;
import com.simiacryptus.mindseye.lang.Singleton;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.TensorArray;
import com.simiacryptus.mindseye.lang.TensorList;
import com.simiacryptus.mindseye.lang.tensorflow.TFIO;
import com.simiacryptus.mindseye.lang.tensorflow.TFUtil;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.RefArrayList;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefCollection;
import com.simiacryptus.ref.wrappers.RefCollectors;
import com.simiacryptus.ref.wrappers.RefHashMap;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.ref.wrappers.RefSet;
import com.simiacryptus.ref.wrappers.RefStream;
import com.simiacryptus.tensorflow.TensorboardEventWriter;
import com.simiacryptus.tensorflow.TensorflowUtil;
import com.simiacryptus.util.Util;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Shape;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.Summary;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;

/* loaded from: input_file:com/simiacryptus/mindseye/layers/tensorflow/TFLayerBase.class */
public abstract class TFLayerBase extends LayerBase {
    private static final Logger log;

    @Nullable
    public static TensorboardEventWriter eventWriter;
    private final RefMap<String, Tensor> weights;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/simiacryptus/mindseye/layers/tensorflow/TFLayerBase$Accumulator.class */
    public static class Accumulator extends Result.Accumulator {
        private final Session.Runner runner;
        private final int fwdFetches;
        private final RefList<String> stateNames;
        private final Result[] inputs;
        private final TFSession tfsession;
        private RefMap<String, Tensor> weights;
        private String outputNode;
        private UUID id;
        private boolean invertRanks;
        private List<String> inputNodes;
        private boolean floatInputs;
        private Output<?>[] gradients;
        static final /* synthetic */ boolean $assertionsDisabled;

        public Accumulator(Session.Runner runner, int i, RefList<String> refList, UUID uuid, RefMap<String, Tensor> refMap, String str, boolean z, List<String> list, boolean z2, Output<?>[] outputArr, TFSession tFSession, Result... resultArr) {
            this.runner = runner;
            this.fwdFetches = i;
            this.stateNames = refList;
            this.inputs = resultArr;
            this.weights = refMap;
            this.outputNode = str;
            this.id = uuid;
            this.invertRanks = z;
            this.inputNodes = list;
            this.floatInputs = z2;
            this.gradients = outputArr;
            this.tfsession = tFSession;
        }

        public void accept(@Nullable DeltaSet<UUID> deltaSet, @Nullable TensorList tensorList) {
            RefArrayList refArrayList = new RefArrayList();
            Output<?>[] outputArr = this.gradients;
            if (this.floatInputs) {
                org.tensorflow.Tensor<Float> floatTensor = TFIO.getFloatTensor(tensorList == null ? null : tensorList.addRef());
                this.runner.feed(this.outputNode + "_delta", floatTensor);
                refArrayList.add(floatTensor);
            } else {
                org.tensorflow.Tensor<Double> doubleTensor = TFIO.getDoubleTensor(tensorList == null ? null : tensorList.addRef());
                this.runner.feed(this.outputNode + "_delta", doubleTensor);
                refArrayList.add(doubleTensor);
            }
            if (null != tensorList) {
                tensorList.freeRef();
            }
            RefArrays.stream(outputArr).forEach(output -> {
                this.runner.fetch(output);
            });
            Session.Run runAndFetchMetadata = this.runner.runAndFetchMetadata();
            for (int i = 0; i < this.inputs.length; i++) {
                org.tensorflow.Tensor tensor = (org.tensorflow.Tensor) runAndFetchMetadata.outputs.get(this.fwdFetches + i);
                Result.Accumulator accumulator = this.inputs[i].getAccumulator();
                if (!$assertionsDisabled && accumulator == null) {
                    throw new AssertionError();
                }
                accumulator.accept(deltaSet == null ? null : deltaSet.addRef(), TFIO.getTensorList(tensor));
                accumulator.freeRef();
                refArrayList.add(tensor);
            }
            for (int i2 = 0; i2 < this.stateNames.size(); i2++) {
                String str = (String) this.stateNames.get(i2);
                if (!$assertionsDisabled && deltaSet == null) {
                    throw new AssertionError();
                }
                Delta delta = deltaSet.get(UUID.nameUUIDFromBytes((this.id + "_" + str).getBytes()), (Tensor) this.weights.get(str));
                org.tensorflow.Tensor tensor2 = (org.tensorflow.Tensor) runAndFetchMetadata.outputs.get(i2 + this.fwdFetches + this.inputNodes.size());
                Tensor tensor3 = tensor2.dataType() == DataType.FLOAT ? TFIO.getTensor(tensor2.expect(Float.class), this.invertRanks) : TFIO.getTensor(tensor2.expect(Double.class), this.invertRanks);
                if (!$assertionsDisabled && delta == null) {
                    throw new AssertionError();
                }
                delta.addInPlace(tensor3);
                delta.freeRef();
            }
            if (null != deltaSet) {
                deltaSet.freeRef();
            }
            refArrayList.stream().forEach(tensor4 -> {
                tensor4.close();
            });
            refArrayList.freeRef();
        }

        public void _free() {
            super._free();
            this.weights.freeRef();
            RefUtil.freeRef(this.inputs);
            this.tfsession.freeRef();
            if (null != this.stateNames) {
                this.stateNames.freeRef();
            }
        }

        static {
            $assertionsDisabled = !TFLayerBase.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/simiacryptus/mindseye/layers/tensorflow/TFLayerBase$TFSession.class */
    public static class TFSession extends ReferenceCountingBase {

        @Nonnull
        public final Session session;

        @Nullable
        private final TFLayerBase parent;
        static final /* synthetic */ boolean $assertionsDisabled;
        public final Singleton<Output<?>[]> outputSingleton = new Singleton<>();

        @Nonnull
        public final Graph graph = new Graph();

        public TFSession(@Nullable TFLayerBase tFLayerBase) {
            this.parent = tFLayerBase;
            GraphDef graphDef = this.parent.getGraphDef();
            TensorflowUtil.validate(graphDef);
            this.graph.importGraphDef(graphDef.toByteArray());
            this.session = new Session(this.graph);
        }

        @Nonnull
        public Output<?>[] getGradients() {
            return (Output[]) this.outputSingleton.getOrInit(() -> {
                if (!$assertionsDisabled && this.parent == null) {
                    throw new AssertionError();
                }
                RefMap<String, Tensor> weights = this.parent.getWeights();
                if (!$assertionsDisabled && weights == null) {
                    throw new AssertionError();
                }
                RefSet keySet = weights.keySet();
                RefList refList = (RefList) keySet.stream().collect(RefCollectors.toList());
                keySet.freeRef();
                weights.freeRef();
                Ops create = Ops.create(this.graph);
                String str = this.parent.getOutputNode() + "_delta";
                create.withName(str).placeholder(this.parent.floatInputs() ? Float.class : Double.class, new Placeholder.Options[]{Placeholder.shape(Shape.unknown())});
                Output[] addGradients = this.graph.addGradients("gradient", new Output[]{TensorflowUtil.find(this.graph, this.parent.getOutputNode()).output(0)}, (Output[]) RefStream.concat(this.parent.getInputNodes().stream(), refList.stream()).map(str2 -> {
                    return TensorflowUtil.find(this.graph, str2).output(0);
                }).toArray(i -> {
                    return new Output[i];
                }), new Output[]{TensorflowUtil.find(this.graph, str).output(0)});
                refList.freeRef();
                return addGradients;
            });
        }

        public void _free() {
            if (null != this.parent) {
                this.parent.freeRef();
            }
            new Thread(() -> {
                this.session.close();
                this.graph.close();
            }).start();
            this.outputSingleton.freeRef();
            super._free();
        }

        @Nonnull
        /* renamed from: addRef, reason: merged with bridge method [inline-methods] */
        public TFSession m13addRef() {
            return super.addRef();
        }

        static {
            $assertionsDisabled = !TFLayerBase.class.desiredAssertionStatus();
        }
    }

    public TFLayerBase(@Nonnull JsonObject jsonObject, Map<CharSequence, byte[]> map) {
        super(jsonObject);
        this.weights = new RefHashMap();
        for (String str : getDataKeys(jsonObject)) {
            RefMap<String, Tensor> weights = getWeights();
            if (!$assertionsDisabled && weights == null) {
                throw new AssertionError();
            }
            RefUtil.freeRef(weights.put(str, Tensor.fromJson(jsonObject.get(str), map)));
            weights.freeRef();
        }
    }

    public TFLayerBase(@Nullable RefMap<String, Tensor> refMap) {
        this.weights = new RefHashMap();
        RefMap<String, Tensor> weights = getWeights();
        if (!$assertionsDisabled && weights == null) {
            throw new AssertionError();
        }
        weights.putAll(refMap);
        weights.freeRef();
    }

    public abstract GraphDef getGraphDef();

    @Nullable
    public abstract List<String> getInputNodes();

    public abstract String getOutputNode();

    @Nullable
    public abstract String getSummaryOut();

    @Nullable
    public RefMap<String, Tensor> getWeights() {
        if (this.weights == null) {
            return null;
        }
        return this.weights.addRef();
    }

    @Nonnull
    public TFLayer asConstLayer() {
        return new TFLayer(constGraph().toByteArray(), new RefHashMap(), getOutputNode(), (String[]) getInputNodes().toArray(new String[0]));
    }

    @Nonnull
    public GraphDef constGraph() {
        return TFUtil.implantConstants(getGraphDef(), getWeights());
    }

    public JsonObject getJson(Map<CharSequence, byte[]> map, @Nonnull DataSerializer dataSerializer) {
        JsonObject jsonStub = getJsonStub();
        RefMap<String, Tensor> weights = getWeights();
        if (!$assertionsDisabled && weights == null) {
            throw new AssertionError();
        }
        weights.forEach((str, tensor) -> {
            JsonElement json = tensor.getJson(map, dataSerializer);
            tensor.freeRef();
            jsonStub.add(str, json);
        });
        weights.freeRef();
        return jsonStub;
    }

    @Nullable
    public RefList<double[]> state() {
        RefCollection values = this.weights.values();
        RefList<double[]> refList = (RefList) values.stream().map(tensor -> {
            try {
                return tensor.getData();
            } finally {
                tensor.freeRef();
            }
        }).collect(RefCollectors.toList());
        values.freeRef();
        return refList;
    }

    @Nullable
    public Result eval(@Nullable Result... resultArr) {
        return eval(new TFSession(mo2addRef()), resultArr);
    }

    public void close() {
    }

    public boolean invertWeights() {
        return true;
    }

    @Nonnull
    public GraphDef getConstGraph(GraphDef graphDef) {
        return TFUtil.implantConstants(graphDef, getWeights());
    }

    public void _free() {
        if (null != this.weights) {
            this.weights.freeRef();
        }
        close();
        super._free();
    }

    @Override // 
    @Nonnull
    /* renamed from: addRef, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public TFLayerBase mo2addRef() {
        return (TFLayerBase) super.addRef();
    }

    @Nonnull
    Result eval(@Nonnull TFSession tFSession, @Nonnull Result... resultArr) {
        RefMap<String, Tensor> weights = getWeights();
        if (!$assertionsDisabled && weights == null) {
            throw new AssertionError();
        }
        RefSet keySet = weights.keySet();
        RefList refList = (RefList) keySet.stream().collect(RefCollectors.toList());
        keySet.freeRef();
        Session.Runner runner = tFSession.session.runner();
        RefArrayList<org.tensorflow.Tensor<?>> tensors = setTensors(runner, weights, (Result[]) RefUtil.addRef(resultArr));
        boolean run = run(runner);
        return new Result(getOutput(runner, tensors, run), new Accumulator(runner, run ? 2 : 1, refList, getId(), getWeights(), getOutputNode(), invertWeights(), getInputNodes(), floatInputs(), tFSession.getGradients(), tFSession, resultArr));
    }

    @Nonnull
    protected abstract Set<String> getDataKeys(JsonObject jsonObject);

    protected boolean floatInputs() {
        return false;
    }

    private boolean run(Session.Runner runner) {
        runner.fetch(getOutputNode());
        boolean z = (null == eventWriter || null == getSummaryOut() || getSummaryOut().isEmpty()) ? false : true;
        if (z) {
            runner.fetch(getSummaryOut());
        }
        return z;
    }

    @NotNull
    private RefArrayList<org.tensorflow.Tensor<?>> setTensors(Session.Runner runner, RefMap<String, Tensor> refMap, @Nonnull Result[] resultArr) {
        RefArrayList<org.tensorflow.Tensor<?>> refArrayList = new RefArrayList<>();
        refMap.forEach((str, tensor) -> {
            boolean invertWeights = invertWeights();
            org.tensorflow.Tensor floatTensor = floatInputs() ? TFIO.getFloatTensor(tensor, invertWeights) : TFIO.getDoubleTensor(tensor, invertWeights);
            runner.feed(str, floatTensor);
            refArrayList.add(floatTensor);
        });
        refMap.freeRef();
        List<String> inputNodes = getInputNodes();
        if (!$assertionsDisabled && inputNodes == null) {
            throw new AssertionError();
        }
        for (int i = 0; i < inputNodes.size(); i++) {
            String str2 = inputNodes.get(i);
            TensorList data = resultArr[i].getData();
            org.tensorflow.Tensor floatTensor = floatInputs() ? TFIO.getFloatTensor(data, true) : TFIO.getDoubleTensor(data, true);
            runner.feed(str2, floatTensor);
            refArrayList.add(floatTensor);
        }
        RefUtil.freeRef(resultArr);
        return refArrayList;
    }

    @NotNull
    private TensorArray getOutput(Session.Runner runner, RefArrayList<org.tensorflow.Tensor<?>> refArrayList, boolean z) {
        try {
            Session.Run runAndFetchMetadata = runner.runAndFetchMetadata();
            org.tensorflow.Tensor tensor = (org.tensorflow.Tensor) runAndFetchMetadata.outputs.get(0);
            TensorArray tensorList = TFIO.getTensorList(tensor);
            refArrayList.add(tensor);
            refArrayList.freeRef();
            if (z) {
                try {
                    Summary parseFrom = Summary.parseFrom(((org.tensorflow.Tensor) runAndFetchMetadata.outputs.get(1)).expect(String.class).bytesValue());
                    try {
                        if (null != eventWriter) {
                            eventWriter.write(parseFrom);
                        }
                    } catch (IOException e) {
                        throw Util.throwException(e);
                    }
                } catch (InvalidProtocolBufferException e2) {
                    throw Util.throwException(e2);
                }
            }
            return tensorList;
        } catch (IllegalArgumentException e3) {
            throw e3;
        }
    }

    /* renamed from: getJson */
    public /* bridge */ /* synthetic */ JsonElement mo4getJson(Map map, @Nonnull DataSerializer dataSerializer) {
        return getJson((Map<CharSequence, byte[]>) map, dataSerializer);
    }

    static {
        $assertionsDisabled = !TFLayerBase.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(TFLayer.class);
        eventWriter = null;
    }
}
