package org.deeplearning4j.arbiter;

import java.beans.ConstructorProperties;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.arbiter.BaseNetworkSpace;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper;
import org.deeplearning4j.arbiter.util.CollectionUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

/* loaded from: input_file:org/deeplearning4j/arbiter/MultiLayerSpace.class */
public class MultiLayerSpace extends BaseNetworkSpace<DL4JConfiguration> {

    @Deprecated
    private ParameterSpace<int[]> cnnInputSize;
    private List<LayerConf> layerSpaces;
    private ParameterSpace<InputType> inputType;
    private EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration;
    private int numParameters;

    /* loaded from: input_file:org/deeplearning4j/arbiter/MultiLayerSpace$Builder.class */
    public static class Builder extends BaseNetworkSpace.Builder<Builder> {

        @Deprecated
        private ParameterSpace<int[]> cnnInputSize;
        private List<LayerConf> layerSpaces = new ArrayList();
        private ParameterSpace<InputType> inputType;
        private EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration;

        @Deprecated
        public Builder cnnInputSize(int i, int i2, int i3) {
            return cnnInputSize(new FixedValue(new int[]{i, i2, i3}));
        }

        @Deprecated
        public Builder cnnInputSize(ParameterSpace<int[]> parameterSpace) {
            this.cnnInputSize = parameterSpace;
            return this;
        }

        public Builder setInputType(InputType inputType) {
            return setInputType((ParameterSpace<InputType>) new FixedValue(inputType));
        }

        public Builder setInputType(ParameterSpace<InputType> parameterSpace) {
            this.inputType = parameterSpace;
            return this;
        }

        public Builder addLayer(LayerSpace<?> layerSpace) {
            return addLayer(layerSpace, new FixedValue(1), true);
        }

        public Builder addLayer(LayerSpace<? extends Layer> layerSpace, ParameterSpace<Integer> parameterSpace, boolean z) {
            this.layerSpaces.add(new LayerConf(layerSpace, parameterSpace, z));
            return this;
        }

        public Builder earlyStoppingConfiguration(EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration) {
            this.earlyStoppingConfiguration = earlyStoppingConfiguration;
            return this;
        }

        @Override // org.deeplearning4j.arbiter.BaseNetworkSpace.Builder
        public MultiLayerSpace build() {
            return new MultiLayerSpace(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/arbiter/MultiLayerSpace$LayerConf.class */
    public static class LayerConf {
        private LayerSpace<?> layerSpace;
        private ParameterSpace<Integer> numLayers;
        private boolean duplicateConfig;

        @ConstructorProperties({"layerSpace", "numLayers", "duplicateConfig"})
        public LayerConf(LayerSpace<?> layerSpace, ParameterSpace<Integer> parameterSpace, boolean z) {
            this.layerSpace = layerSpace;
            this.numLayers = parameterSpace;
            this.duplicateConfig = z;
        }

        public LayerConf() {
        }

        public LayerSpace<?> getLayerSpace() {
            return this.layerSpace;
        }

        public ParameterSpace<Integer> getNumLayers() {
            return this.numLayers;
        }

        public boolean isDuplicateConfig() {
            return this.duplicateConfig;
        }

        public void setLayerSpace(LayerSpace<?> layerSpace) {
            this.layerSpace = layerSpace;
        }

        public void setNumLayers(ParameterSpace<Integer> parameterSpace) {
            this.numLayers = parameterSpace;
        }

        public void setDuplicateConfig(boolean z) {
            this.duplicateConfig = z;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof LayerConf)) {
                return false;
            }
            LayerConf layerConf = (LayerConf) obj;
            if (!layerConf.canEqual(this)) {
                return false;
            }
            LayerSpace<?> layerSpace = getLayerSpace();
            LayerSpace<?> layerSpace2 = layerConf.getLayerSpace();
            if (layerSpace == null) {
                if (layerSpace2 != null) {
                    return false;
                }
            } else if (!layerSpace.equals(layerSpace2)) {
                return false;
            }
            ParameterSpace<Integer> numLayers = getNumLayers();
            ParameterSpace<Integer> numLayers2 = layerConf.getNumLayers();
            if (numLayers == null) {
                if (numLayers2 != null) {
                    return false;
                }
            } else if (!numLayers.equals(numLayers2)) {
                return false;
            }
            return isDuplicateConfig() == layerConf.isDuplicateConfig();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof LayerConf;
        }

        public int hashCode() {
            LayerSpace<?> layerSpace = getLayerSpace();
            int hashCode = (1 * 59) + (layerSpace == null ? 43 : layerSpace.hashCode());
            ParameterSpace<Integer> numLayers = getNumLayers();
            return (((hashCode * 59) + (numLayers == null ? 43 : numLayers.hashCode())) * 59) + (isDuplicateConfig() ? 79 : 97);
        }

        public String toString() {
            return "MultiLayerSpace.LayerConf(layerSpace=" + getLayerSpace() + ", numLayers=" + getNumLayers() + ", duplicateConfig=" + isDuplicateConfig() + ")";
        }
    }

    private MultiLayerSpace(Builder builder) {
        super(builder);
        this.layerSpaces = new ArrayList();
        this.cnnInputSize = builder.cnnInputSize;
        this.inputType = builder.inputType;
        this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;
        this.layerSpaces = builder.layerSpaces;
        Iterator it = CollectionUtils.getUnique(collectLeaves()).iterator();
        while (it.hasNext()) {
            this.numParameters += ((ParameterSpace) it.next()).numParameters();
        }
    }

    private MultiLayerSpace() {
        this.layerSpaces = new ArrayList();
    }

    /* renamed from: getValue, reason: merged with bridge method [inline-methods] */
    public DL4JConfiguration m1getValue(double[] dArr) {
        ArrayList arrayList = new ArrayList();
        for (LayerConf layerConf : this.layerSpaces) {
            int intValue = ((Integer) layerConf.numLayers.getValue(dArr)).intValue();
            if (!layerConf.duplicateConfig) {
                throw new UnsupportedOperationException("Not yet implemented");
            }
            Layer layer = (Layer) layerConf.layerSpace.getValue(dArr);
            for (int i = 0; i < intValue; i++) {
                arrayList.add(layer.clone());
            }
        }
        NeuralNetConfiguration.Builder randomGlobalConf = randomGlobalConf(dArr);
        int nOut = ((FeedForwardLayer) arrayList.get(0)).getNOut();
        for (int i2 = 1; i2 < arrayList.size(); i2++) {
            if (arrayList.get(i2) instanceof FeedForwardLayer) {
                FeedForwardLayer feedForwardLayer = (FeedForwardLayer) arrayList.get(i2);
                feedForwardLayer.setNIn(nOut);
                nOut = feedForwardLayer.getNOut();
            }
        }
        NeuralNetConfiguration.ListBuilder list = randomGlobalConf.list();
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            list.layer(i3, (Layer) arrayList.get(i3));
        }
        if (this.backprop != null) {
            list.backprop(((Boolean) this.backprop.getValue(dArr)).booleanValue());
        }
        if (this.pretrain != null) {
            list.pretrain(((Boolean) this.pretrain.getValue(dArr)).booleanValue());
        }
        if (this.backpropType != null) {
            list.backpropType((BackpropType) this.backpropType.getValue(dArr));
        }
        if (this.tbpttFwdLength != null) {
            list.tBPTTForwardLength(((Integer) this.tbpttFwdLength.getValue(dArr)).intValue());
        }
        if (this.tbpttBwdLength != null) {
            list.tBPTTBackwardLength(((Integer) this.tbpttBwdLength.getValue(dArr)).intValue());
        }
        if (this.cnnInputSize != null) {
            list.cnnInputSize((int[]) this.cnnInputSize.getValue(dArr));
        }
        if (this.inputType != null) {
            list.setInputType((InputType) this.inputType.getValue(dArr));
        }
        return new DL4JConfiguration(list.build(), this.earlyStoppingConfiguration, Integer.valueOf(this.numEpochs));
    }

    public int numParameters() {
        return this.numParameters;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public List<ParameterSpace> collectLeaves() {
        List<ParameterSpace> collectLeaves = super.collectLeaves();
        for (LayerConf layerConf : this.layerSpaces) {
            collectLeaves.addAll(layerConf.numLayers.collectLeaves());
            collectLeaves.addAll(layerConf.layerSpace.collectLeaves());
        }
        if (this.cnnInputSize != null) {
            collectLeaves.addAll(this.cnnInputSize.collectLeaves());
        }
        if (this.inputType != null) {
            collectLeaves.addAll(this.inputType.collectLeaves());
        }
        return collectLeaves;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public String toString() {
        StringBuilder sb = new StringBuilder(super.toString());
        int i = 0;
        for (LayerConf layerConf : this.layerSpaces) {
            int i2 = i;
            i++;
            sb.append("Layer config ").append(i2).append(": (Number layers:").append(layerConf.numLayers).append(", duplicate: ").append(layerConf.duplicateConfig).append("), ").append(layerConf.layerSpace.toString()).append("\n");
        }
        if (this.cnnInputSize != null) {
            sb.append("cnnInputSize: ").append(this.cnnInputSize).append("\n");
        }
        if (this.inputType != null) {
            sb.append("inputType: ").append(this.inputType).append("\n");
        }
        if (this.earlyStoppingConfiguration != null) {
            sb.append("Early stopping configuration:").append(this.earlyStoppingConfiguration.toString()).append("\n");
        } else {
            sb.append("Training # epochs:").append(this.numEpochs).append("\n");
        }
        return sb.toString();
    }

    public LayerSpace<?> getLayerSpace(int i) {
        return this.layerSpaces.get(i).getLayerSpace();
    }

    public static MultiLayerSpace fromJson(String str) {
        try {
            return (MultiLayerSpace) JsonMapper.getMapper().readValue(str, MultiLayerSpace.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static MultiLayerSpace fromYaml(String str) {
        try {
            return (MultiLayerSpace) YamlMapper.getMapper().readValue(str, MultiLayerSpace.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MultiLayerSpace)) {
            return false;
        }
        MultiLayerSpace multiLayerSpace = (MultiLayerSpace) obj;
        if (!multiLayerSpace.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        ParameterSpace<int[]> parameterSpace = this.cnnInputSize;
        ParameterSpace<int[]> parameterSpace2 = multiLayerSpace.cnnInputSize;
        if (parameterSpace == null) {
            if (parameterSpace2 != null) {
                return false;
            }
        } else if (!parameterSpace.equals(parameterSpace2)) {
            return false;
        }
        List<LayerConf> list = this.layerSpaces;
        List<LayerConf> list2 = multiLayerSpace.layerSpaces;
        if (list == null) {
            if (list2 != null) {
                return false;
            }
        } else if (!list.equals(list2)) {
            return false;
        }
        ParameterSpace<InputType> parameterSpace3 = this.inputType;
        ParameterSpace<InputType> parameterSpace4 = multiLayerSpace.inputType;
        if (parameterSpace3 == null) {
            if (parameterSpace4 != null) {
                return false;
            }
        } else if (!parameterSpace3.equals(parameterSpace4)) {
            return false;
        }
        EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration = this.earlyStoppingConfiguration;
        EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration2 = multiLayerSpace.earlyStoppingConfiguration;
        if (earlyStoppingConfiguration == null) {
            if (earlyStoppingConfiguration2 != null) {
                return false;
            }
        } else if (!earlyStoppingConfiguration.equals(earlyStoppingConfiguration2)) {
            return false;
        }
        return this.numParameters == multiLayerSpace.numParameters;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    protected boolean canEqual(Object obj) {
        return obj instanceof MultiLayerSpace;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public int hashCode() {
        int hashCode = (1 * 59) + super.hashCode();
        ParameterSpace<int[]> parameterSpace = this.cnnInputSize;
        int hashCode2 = (hashCode * 59) + (parameterSpace == null ? 43 : parameterSpace.hashCode());
        List<LayerConf> list = this.layerSpaces;
        int hashCode3 = (hashCode2 * 59) + (list == null ? 43 : list.hashCode());
        ParameterSpace<InputType> parameterSpace2 = this.inputType;
        int hashCode4 = (hashCode3 * 59) + (parameterSpace2 == null ? 43 : parameterSpace2.hashCode());
        EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration = this.earlyStoppingConfiguration;
        return (((hashCode4 * 59) + (earlyStoppingConfiguration == null ? 43 : earlyStoppingConfiguration.hashCode())) * 59) + this.numParameters;
    }
}
