package edu.umass.cs.mallet.base.classify;

import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.util.MalletLogger;
import edu.umass.cs.mallet.base.util.Maths;
import java.util.Random;
import java.util.logging.Logger;

/* loaded from: input_file:edu/umass/cs/mallet/base/classify/AdaBoostTrainer.class */
public class AdaBoostTrainer extends ClassifierTrainer {
    private static Logger logger = MalletLogger.getLogger(AdaBoostTrainer.class.getName());
    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    ClassifierTrainer weakLearner;
    int numRounds;

    public AdaBoostTrainer(ClassifierTrainer classifierTrainer, int i) {
        if (!(classifierTrainer instanceof Boostable)) {
            throw new IllegalArgumentException("weak learner not boostable");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("number of rounds must be positive");
        }
        this.weakLearner = classifierTrainer;
        this.numRounds = i;
    }

    public AdaBoostTrainer(ClassifierTrainer classifierTrainer) {
        this(classifierTrainer, 100);
    }

    @Override // edu.umass.cs.mallet.base.classify.ClassifierTrainer
    public Classifier train(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ClassifierEvaluating classifierEvaluating, Classifier classifier) {
        double d;
        if (instanceList.getFeatureSelection() != null) {
            throw new UnsupportedOperationException("FeatureSelection not yet implemented.");
        }
        Random random = new Random();
        double size = 1.0d / instanceList.size();
        InstanceList instanceList4 = new InstanceList();
        for (int i = 0; i < instanceList.size(); i++) {
            instanceList4.add(instanceList.getInstance(i), size);
        }
        boolean[] zArr = new boolean[instanceList4.size()];
        int size2 = instanceList4.getTargetAlphabet().size();
        if (size2 != 2) {
            logger.info("AdaBoostTrainer.train: WARNING: more than two classes");
        }
        Classifier[] classifierArr = new Classifier[this.numRounds];
        double[] dArr = new double[this.numRounds];
        InstanceList instanceList5 = new InstanceList();
        int i2 = 0;
        while (i2 < this.numRounds) {
            logger.info("===========  AdaBoostTrainer round " + (i2 + 1) + " begin");
            int i3 = 0;
            do {
                d = 0.0d;
                instanceList5 = instanceList4.sampleWithInstanceWeights(random);
                classifierArr[i2] = this.weakLearner.train(instanceList5, instanceList2);
                for (int i4 = 0; i4 < instanceList4.size(); i4++) {
                    if (classifierArr[i2].classify(instanceList4.getInstance(i4)).bestLabelIsCorrect()) {
                        zArr[i4] = true;
                    } else {
                        zArr[i4] = false;
                        d += instanceList4.getInstanceWeight(i4);
                    }
                }
                i3++;
                if (!Maths.almostEquals(d, Transducer.ZERO_COST)) {
                    break;
                }
            } while (i3 < MAX_NUM_RESAMPLING_ITERATIONS);
            if (Maths.almostEquals(d, Transducer.ZERO_COST) || d > 0.5d) {
                logger.info("AdaBoostTrainer stopped at " + (i2 + 1) + " / " + this.numRounds + " rounds: numClasses=" + size2 + " error=" + d);
                int i5 = i2 == 0 ? 1 : i2;
                if (i2 == 0) {
                    dArr[0] = 1.0d;
                }
                double[] dArr2 = new double[i5];
                Classifier[] classifierArr2 = new Classifier[i5];
                System.arraycopy(dArr, 0, dArr2, 0, i5);
                System.arraycopy(classifierArr, 0, classifierArr2, 0, i5);
                for (int i6 = 0; i6 < dArr2.length; i6++) {
                    logger.info("AdaBoostTrainer weight[weakLearner[" + i6 + "]]=" + dArr2[i6]);
                }
                return new AdaBoost(instanceList5.getPipe(), classifierArr2, dArr2);
            }
            dArr[i2] = Math.log((1.0d - d) / d);
            double d2 = d / (1.0d - d);
            double d3 = 0.0d;
            for (int i7 = 0; i7 < instanceList4.size(); i7++) {
                double instanceWeight = instanceList4.getInstanceWeight(i7);
                if (zArr[i7]) {
                    instanceWeight *= d2;
                }
                instanceList4.setInstanceWeight(i7, instanceWeight);
                d3 += instanceWeight;
            }
            for (int i8 = 0; i8 < instanceList4.size(); i8++) {
                instanceList4.setInstanceWeight(i8, instanceList4.getInstanceWeight(i8) / d3);
            }
            logger.info("===========  AdaBoostTrainer round " + (i2 + 1) + " finished, weak classifier training error = " + d);
            i2++;
        }
        for (int i9 = 0; i9 < dArr.length; i9++) {
            logger.info("AdaBoostTrainer weight[weakLearner[" + i9 + "]]=" + dArr[i9]);
        }
        return new AdaBoost(instanceList5.getPipe(), classifierArr, dArr);
    }
}
