package org.deeplearning4j.arbiter.data;

import java.io.IOException;
import java.util.Map;
import java.util.Random;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultipleEpochsIterator;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/deeplearning4j/arbiter/data/MnistDataProvider.class */
public class MnistDataProvider implements DataProvider {
    private int numEpochs;
    private int batchSize;
    private int rngSeed;

    public MnistDataProvider(int i, int i2) {
        this(i, i2, new Random().nextInt());
    }

    public MnistDataProvider(@JsonProperty("numEpochs") int i, @JsonProperty("batchSize") int i2, @JsonProperty("rngSeed") int i3) {
        this.numEpochs = i;
        this.batchSize = i2;
        this.rngSeed = i3;
    }

    public Object trainData(Map<String, Object> map) {
        try {
            return new MultipleEpochsIterator(this.numEpochs, new MnistDataSetIterator(this.batchSize, true, this.rngSeed));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public Object testData(Map<String, Object> map) {
        try {
            return new MnistDataSetIterator(this.batchSize, false, 12345);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public Class<?> getDataType() {
        return DataSetIterator.class;
    }

    public int getNumEpochs() {
        return this.numEpochs;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public int getRngSeed() {
        return this.rngSeed;
    }

    public void setNumEpochs(int i) {
        this.numEpochs = i;
    }

    public void setBatchSize(int i) {
        this.batchSize = i;
    }

    public void setRngSeed(int i) {
        this.rngSeed = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MnistDataProvider)) {
            return false;
        }
        MnistDataProvider mnistDataProvider = (MnistDataProvider) obj;
        return mnistDataProvider.canEqual(this) && getNumEpochs() == mnistDataProvider.getNumEpochs() && getBatchSize() == mnistDataProvider.getBatchSize() && getRngSeed() == mnistDataProvider.getRngSeed();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof MnistDataProvider;
    }

    public int hashCode() {
        return (((((1 * 59) + getNumEpochs()) * 59) + getBatchSize()) * 59) + getRngSeed();
    }

    public String toString() {
        return "MnistDataProvider(numEpochs=" + getNumEpochs() + ", batchSize=" + getBatchSize() + ", rngSeed=" + getRngSeed() + ")";
    }

    public MnistDataProvider() {
    }
}
