package com.simiacryptus.mindseye.layers.tensorflow;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.protobuf.InvalidProtocolBufferException;
import com.simiacryptus.mindseye.lang.DataSerializer;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefHashMap;
import com.simiacryptus.ref.wrappers.RefIntStream;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.util.JsonUtil;
import com.simiacryptus.util.Util;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.function.DoubleSupplier;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.tensorflow.Graph;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.linalg.MatMul;

/* loaded from: input_file:com/simiacryptus/mindseye/layers/tensorflow/MatMulLayer.class */
public class MatMulLayer extends TFLayerBase {
    private final int[] intputDims;
    private final int[] outputDims;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MatMulLayer(int[] iArr, int[] iArr2) {
        super(defaultStates(iArr, iArr2));
        this.intputDims = iArr;
        this.outputDims = iArr2;
    }

    public MatMulLayer(@Nonnull JsonObject jsonObject, Map<CharSequence, byte[]> map) {
        super(jsonObject, map);
        this.intputDims = JsonUtil.toIntArray(jsonObject.get("inputDims").getAsJsonArray());
        this.outputDims = JsonUtil.toIntArray(jsonObject.get("outputDims").getAsJsonArray());
    }

    @Override // com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase
    public GraphDef getGraphDef() {
        try {
            Graph graph = new Graph();
            Throwable th = null;
            try {
                try {
                    Ops create = Ops.create(graph);
                    create.withName(getOutputNode()).reshape(create.linalg.transpose(create.linalg.matMul(create.withName("weights").placeholder(Double.class, new Placeholder.Options[0]), create.reshape(create.withName(getInputNodes().get(0)).placeholder(Double.class, new Placeholder.Options[0]), create.constant(new long[]{-1, Tensor.length(getIntputDims())})), new MatMul.Options[]{MatMul.transposeB(true)}), create.constant(new int[]{1, 0})), create.constant(RefIntStream.concat(RefIntStream.of(-1), RefArrays.stream(getOutputDims())).toArray()));
                    GraphDef parseFrom = GraphDef.parseFrom(graph.toGraphDef());
                    if (graph != null) {
                        if (0 != 0) {
                            try {
                                graph.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            graph.close();
                        }
                    }
                    return parseFrom;
                } finally {
                }
            } finally {
            }
        } catch (InvalidProtocolBufferException e) {
            throw Util.throwException(e);
        }
    }

    @Override // com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase
    @Nonnull
    public List<String> getInputNodes() {
        return Arrays.asList("input");
    }

    public int[] getIntputDims() {
        return this.intputDims;
    }

    public int[] getOutputDims() {
        return this.outputDims;
    }

    @Override // com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase
    @Nonnull
    public String getOutputNode() {
        return "output";
    }

    @Override // com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase
    @Nullable
    public String getSummaryOut() {
        return null;
    }

    public boolean isSingleBatch() {
        return false;
    }

    @Nonnull
    public static MatMulLayer fromJson(@Nonnull JsonObject jsonObject, Map<CharSequence, byte[]> map) {
        return new MatMulLayer(jsonObject, map);
    }

    @Nonnull
    private static RefMap<String, Tensor> defaultStates(int[] iArr, int[] iArr2) {
        RefHashMap refHashMap = new RefHashMap();
        int length = Tensor.length(iArr2);
        int length2 = Tensor.length(iArr);
        Tensor tensor = new Tensor(new int[]{length, length2});
        tensor.setByCoord(coordinate -> {
            return (1.0d - (2.0d * ((Random) Util.R.get()).nextDouble())) * Math.sqrt(6.0d / ((length2 + length) + 1));
        });
        RefUtil.freeRef(refHashMap.put("weights", tensor));
        return refHashMap;
    }

    public void set(@Nonnull DoubleSupplier doubleSupplier) {
        RefMap<String, Tensor> weights = getWeights();
        if (!$assertionsDisabled && weights == null) {
            throw new AssertionError();
        }
        Tensor tensor = (Tensor) weights.get("weights");
        if (!$assertionsDisabled && tensor == null) {
            throw new AssertionError();
        }
        tensor.set(i -> {
            return doubleSupplier.getAsDouble();
        });
        tensor.freeRef();
        weights.freeRef();
    }

    @Override // com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase
    public JsonObject getJson(Map<CharSequence, byte[]> map, @Nonnull DataSerializer dataSerializer) {
        JsonObject json = super.getJson(map, dataSerializer);
        if (!$assertionsDisabled && json == null) {
            throw new AssertionError();
        }
        json.add("inputDims", JsonUtil.toIntArray(getIntputDims()));
        json.add("outputDims", JsonUtil.toIntArray(getOutputDims()));
        return json;
    }

    @Override // com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase
    public void _free() {
        super._free();
    }

    @Override // com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase
    @Nonnull
    /* renamed from: addRef */
    public MatMulLayer mo2addRef() {
        return (MatMulLayer) super.mo2addRef();
    }

    @Override // com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase
    @Nonnull
    protected Set<String> getDataKeys(JsonObject jsonObject) {
        HashSet hashSet = new HashSet();
        hashSet.add("weights");
        return hashSet;
    }

    @Override // com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase
    /* renamed from: getJson */
    public /* bridge */ /* synthetic */ JsonElement mo4getJson(Map map, @Nonnull DataSerializer dataSerializer) {
        return getJson((Map<CharSequence, byte[]>) map, dataSerializer);
    }

    static {
        $assertionsDisabled = !MatMulLayer.class.desiredAssertionStatus();
    }
}
