package org.deeplearning4j.arbiter.scoring.graph;

import java.util.Map;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/deeplearning4j/arbiter/scoring/graph/BaseGraphTestSetEvaluationScoreFunction.class */
public abstract class BaseGraphTestSetEvaluationScoreFunction implements ScoreFunction<ComputationGraph, MultiDataSetIterator> {
    /* JADX INFO: Access modifiers changed from: protected */
    public Evaluation getEvaluation(ComputationGraph computationGraph, DataProvider<MultiDataSetIterator> dataProvider, Map<String, Object> map) {
        if (computationGraph.getNumOutputArrays() != 1) {
            throw new IllegalStateException("GraphSetSetAccuracyScoreFunction cannot be applied to ComputationGraphs with more than one output. NumOutputs = " + computationGraph.getNumOutputArrays());
        }
        MultiDataSetIterator multiDataSetIterator = (MultiDataSetIterator) dataProvider.testData(map);
        Evaluation evaluation = new Evaluation();
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator.next();
            if (multiDataSet.hasMaskArrays()) {
                INDArray[] featuresMaskArrays = multiDataSet.getFeaturesMaskArrays();
                INDArray[] labelsMaskArrays = multiDataSet.getLabelsMaskArrays();
                computationGraph.setLayerMaskArrays(featuresMaskArrays, labelsMaskArrays);
                INDArray iNDArray = computationGraph.output(multiDataSet.getFeatures())[0];
                if (labelsMaskArrays != null) {
                    evaluation.evalTimeSeries(multiDataSet.getLabels(0), iNDArray, labelsMaskArrays[0]);
                } else {
                    evaluation.evalTimeSeries(multiDataSet.getLabels(0), iNDArray);
                }
                computationGraph.clearLayerMaskArrays();
            } else {
                INDArray iNDArray2 = computationGraph.output(false, multiDataSet.getFeatures())[0];
                if (multiDataSet.getLabels(0).rank() == 3) {
                    evaluation.evalTimeSeries(multiDataSet.getLabels(0), iNDArray2);
                } else {
                    evaluation.eval(multiDataSet.getLabels(0), iNDArray2);
                }
            }
        }
        return evaluation;
    }

    public boolean minimize() {
        return false;
    }

    public abstract String toString();
}
