package cc.mallet.cluster.neighbor_evaluator;

import cc.mallet.classify.Classifier;
import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.util.PairwiseMatrix;
import cc.mallet.types.MatrixOps;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;

/* loaded from: input_file:cc/mallet/cluster/neighbor_evaluator/PairwiseEvaluator.class */
public class PairwiseEvaluator extends ClassifyingNeighborEvaluator {
    private static final long serialVersionUID = 1;
    CombiningStrategy combiningStrategy;
    boolean mergeFirst;
    PairwiseMatrix scoreCache;

    /* loaded from: input_file:cc/mallet/cluster/neighbor_evaluator/PairwiseEvaluator$Average.class */
    public static class Average implements CombiningStrategy, Serializable {
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;

        @Override // cc.mallet.cluster.neighbor_evaluator.PairwiseEvaluator.CombiningStrategy
        public double combine(double[] dArr) {
            return MatrixOps.mean(dArr);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.defaultWriteObject();
            objectOutputStream.writeInt(1);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.defaultReadObject();
            objectInputStream.readInt();
        }
    }

    /* loaded from: input_file:cc/mallet/cluster/neighbor_evaluator/PairwiseEvaluator$CombiningStrategy.class */
    public interface CombiningStrategy {
        double combine(double[] dArr);
    }

    /* loaded from: input_file:cc/mallet/cluster/neighbor_evaluator/PairwiseEvaluator$Maximum.class */
    public static class Maximum implements CombiningStrategy, Serializable {
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;

        @Override // cc.mallet.cluster.neighbor_evaluator.PairwiseEvaluator.CombiningStrategy
        public double combine(double[] dArr) {
            return MatrixOps.max(dArr);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.defaultWriteObject();
            objectOutputStream.writeInt(1);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.defaultReadObject();
            objectInputStream.readInt();
        }
    }

    /* loaded from: input_file:cc/mallet/cluster/neighbor_evaluator/PairwiseEvaluator$Minimum.class */
    public static class Minimum implements CombiningStrategy, Serializable {
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;

        @Override // cc.mallet.cluster.neighbor_evaluator.PairwiseEvaluator.CombiningStrategy
        public double combine(double[] dArr) {
            return MatrixOps.min(dArr);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.defaultWriteObject();
            objectOutputStream.writeInt(1);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.defaultReadObject();
            objectInputStream.readInt();
        }
    }

    public PairwiseEvaluator(Classifier classifier, String str, CombiningStrategy combiningStrategy, boolean z) {
        super(classifier, str);
        this.combiningStrategy = combiningStrategy;
        this.mergeFirst = z;
    }

    @Override // cc.mallet.cluster.neighbor_evaluator.ClassifyingNeighborEvaluator, cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator
    public double[] evaluate(Neighbor[] neighborArr) {
        double[] dArr = new double[neighborArr.length];
        for (int i = 0; i < neighborArr.length; i++) {
            dArr[i] = evaluate(neighborArr[i]);
        }
        return dArr;
    }

    @Override // cc.mallet.cluster.neighbor_evaluator.ClassifyingNeighborEvaluator, cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator
    public double evaluate(Neighbor neighbor) {
        if (!(neighbor instanceof AgglomerativeNeighbor)) {
            throw new IllegalArgumentException("Expect AgglomerativeNeighbor not " + neighbor.getClass().getName());
        }
        AgglomerativeNeighbor agglomerativeNeighbor = (AgglomerativeNeighbor) neighbor;
        Clustering original = neighbor.getOriginal();
        int[] iArr = agglomerativeNeighbor.getOldClusters()[0];
        int[] iArr2 = agglomerativeNeighbor.getOldClusters()[1];
        ArrayList arrayList = new ArrayList();
        for (int i : iArr) {
            for (int i2 : iArr2) {
                arrayList.add(new Double(getScore(new AgglomerativeNeighbor(original, original, i, i2))));
            }
        }
        if (this.mergeFirst) {
            for (int i3 = 0; i3 < iArr.length; i3++) {
                for (int i4 = i3 + 1; i4 < iArr.length; i4++) {
                    arrayList.add(new Double(getScore(new AgglomerativeNeighbor(original, original, iArr[i3], iArr[i4]))));
                }
            }
            for (int i5 = 0; i5 < iArr2.length; i5++) {
                for (int i6 = i5 + 1; i6 < iArr2.length; i6++) {
                    arrayList.add(new Double(getScore(new AgglomerativeNeighbor(original, original, iArr2[i5], iArr2[i6]))));
                }
            }
        }
        if (arrayList.size() < 1) {
            throw new IllegalStateException("No pairs of Instances were scored.");
        }
        double[] dArr = new double[arrayList.size()];
        for (int i7 = 0; i7 < dArr.length; i7++) {
            dArr[i7] = ((Double) arrayList.get(i7)).doubleValue();
        }
        return this.combiningStrategy.combine(dArr);
    }

    @Override // cc.mallet.cluster.neighbor_evaluator.ClassifyingNeighborEvaluator, cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator
    public void reset() {
        this.scoreCache = null;
    }

    @Override // cc.mallet.cluster.neighbor_evaluator.ClassifyingNeighborEvaluator
    public String toString() {
        return "class=" + getClass().getName() + " classifier=" + this.classifier.getClass().getName();
    }

    private double getScore(AgglomerativeNeighbor agglomerativeNeighbor) {
        if (this.scoreCache == null) {
            this.scoreCache = new PairwiseMatrix(agglomerativeNeighbor.getOriginal().getNumInstances());
        }
        int[] newCluster = agglomerativeNeighbor.getNewCluster();
        if (this.scoreCache.get(newCluster[0], newCluster[1]) == 0.0d) {
            this.scoreCache.set(newCluster[0], newCluster[1], this.classifier.classify(agglomerativeNeighbor).getLabelVector().value(this.scoringLabel));
        }
        return this.scoreCache.get(newCluster[0], newCluster[1]);
    }
}
