package org.tribuo.common.xgboost;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.logging.Logger;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.WeightedExamples;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.SkeletalTrainerProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/common/xgboost/XGBoostTrainer.class */
public abstract class XGBoostTrainer<T extends Output<T>> implements Trainer<T>, WeightedExamples {
    private static final Logger logger = Logger.getLogger(XGBoostTrainer.class.getName());
    protected final Map<String, Object> parameters;

    @Config(description = "Override for parameters, if used must contain all the relevant parameters, including the objective")
    protected Map<String, String> overrideParameters;

    @Config(mandatory = true, description = "The number of trees to build.")
    protected int numTrees;

    @Config(description = "The learning rate, shrinks the new tree output to prevent overfitting.")
    private double eta;

    @Config(description = "Minimum loss reduction needed to split a tree node.")
    private double gamma;

    @Config(description = "The maximum depth of any tree.")
    private int maxDepth;

    @Config(description = "The minimum weight in each child node before a split is valid.")
    private double minChildWeight;

    @Config(description = "Independently subsample the examples for each tree.")
    private double subsample;

    @Config(description = "Independently subsample the features available for each node of each tree.")
    private double featureSubsample;

    @Config(description = "l2 regularisation term on the weights.")
    private double lambda;

    @Config(description = "l1 regularisation term on the weights.")
    private double alpha;

    @Config(description = "The number of threads to use at training time.")
    private int nThread;

    @Config(description = "Quiesce all the logging output from the XGBoost C library. Deprecated in favour of 'verbosity'.")
    @Deprecated
    private int silent;

    @Config(description = "Logging verbosity, 0 is silent, 3 is debug.")
    private LoggingVerbosity verbosity;

    @Config(description = "Type of the weak learner.")
    private BoosterType booster;

    @Config(description = "The tree building algorithm to use.")
    private TreeMethod treeMethod;

    @Config(description = "The RNG seed.")
    private long seed;
    protected int trainInvocationCounter;

    /* loaded from: input_file:org/tribuo/common/xgboost/XGBoostTrainer$BoosterType.class */
    public enum BoosterType {
        LINEAR("gblinear"),
        GBTREE("gbtree"),
        DART("dart");

        public final String paramName;

        BoosterType(String str) {
            this.paramName = str;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/tribuo/common/xgboost/XGBoostTrainer$DMatrixTuple.class */
    public static class DMatrixTuple<T extends Output<T>> {
        public final DMatrix data;
        public final int[] numValidFeatures;
        public final Example<T>[] examples;

        protected DMatrixTuple(DMatrix dMatrix, int[] iArr, Example<T>[] exampleArr) {
            this.data = dMatrix;
            this.numValidFeatures = iArr;
            this.examples = exampleArr;
        }
    }

    /* loaded from: input_file:org/tribuo/common/xgboost/XGBoostTrainer$LoggingVerbosity.class */
    public enum LoggingVerbosity {
        SILENT(0),
        WARNING(1),
        INFO(2),
        DEBUG(3);

        public final int value;

        LoggingVerbosity(int i) {
            this.value = i;
        }
    }

    /* loaded from: input_file:org/tribuo/common/xgboost/XGBoostTrainer$TreeMethod.class */
    public enum TreeMethod {
        AUTO("auto"),
        EXACT("exact"),
        APPROX("approx"),
        HIST("hist"),
        GPU_HIST("gpu_hist");

        public final String paramName;

        TreeMethod(String str) {
            this.paramName = str;
        }
    }

    @Deprecated
    /* loaded from: input_file:org/tribuo/common/xgboost/XGBoostTrainer$XGBoostTrainerProvenance.class */
    protected static class XGBoostTrainerProvenance extends SkeletalTrainerProvenance {
        private static final long serialVersionUID = 1;

        protected <T extends Output<T>> XGBoostTrainerProvenance(XGBoostTrainer<T> xGBoostTrainer) {
            super(xGBoostTrainer);
        }

        protected XGBoostTrainerProvenance(Map<String, Provenance> map) {
            super(map);
        }
    }

    protected XGBoostTrainer(int i) {
        this(i, 0.3d, 0.0d, 6, 1.0d, 1.0d, 1.0d, 1.0d, 0.0d, 4, true, 12345L);
    }

    protected XGBoostTrainer(int i, int i2, boolean z) {
        this(i, 0.3d, 0.0d, 6, 1.0d, 1.0d, 1.0d, 1.0d, 0.0d, i2, z, 12345L);
    }

    protected XGBoostTrainer(int i, double d, double d2, int i2, double d3, double d4, double d5, double d6, double d7, int i3, boolean z, long j) {
        this(BoosterType.GBTREE, TreeMethod.AUTO, i, d, d2, i2, d3, d4, d5, d6, d7, i3, z ? LoggingVerbosity.SILENT : LoggingVerbosity.INFO, j);
    }

    protected XGBoostTrainer(BoosterType boosterType, TreeMethod treeMethod, int i, double d, double d2, int i2, double d3, double d4, double d5, double d6, double d7, int i3, LoggingVerbosity loggingVerbosity, long j) {
        this.parameters = new HashMap();
        this.overrideParameters = new HashMap();
        this.eta = 0.3d;
        this.gamma = 0.0d;
        this.maxDepth = 6;
        this.minChildWeight = 1.0d;
        this.subsample = 1.0d;
        this.featureSubsample = 1.0d;
        this.lambda = 1.0d;
        this.alpha = 1.0d;
        this.nThread = 4;
        this.silent = 1;
        this.verbosity = LoggingVerbosity.SILENT;
        this.booster = BoosterType.GBTREE;
        this.treeMethod = TreeMethod.AUTO;
        this.seed = 12345L;
        this.trainInvocationCounter = 0;
        if (i < 1) {
            throw new IllegalArgumentException("Must supply a positive number of trees. Received " + i);
        }
        this.booster = boosterType;
        this.treeMethod = treeMethod;
        this.numTrees = i;
        this.eta = d;
        this.gamma = d2;
        this.maxDepth = i2;
        this.minChildWeight = d3;
        this.subsample = d4;
        this.featureSubsample = d5;
        this.lambda = d6;
        this.alpha = d7;
        this.nThread = i3;
        this.verbosity = loggingVerbosity;
        this.silent = 0;
        this.seed = j;
    }

    protected XGBoostTrainer(int i, Map<String, Object> map) {
        this.parameters = new HashMap();
        this.overrideParameters = new HashMap();
        this.eta = 0.3d;
        this.gamma = 0.0d;
        this.maxDepth = 6;
        this.minChildWeight = 1.0d;
        this.subsample = 1.0d;
        this.featureSubsample = 1.0d;
        this.lambda = 1.0d;
        this.alpha = 1.0d;
        this.nThread = 4;
        this.silent = 1;
        this.verbosity = LoggingVerbosity.SILENT;
        this.booster = BoosterType.GBTREE;
        this.treeMethod = TreeMethod.AUTO;
        this.seed = 12345L;
        this.trainInvocationCounter = 0;
        if (i < 1) {
            throw new IllegalArgumentException("Must supply a positive number of trees. Received " + i);
        }
        this.numTrees = i;
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            this.overrideParameters.put(entry.getKey(), entry.getValue().toString());
        }
    }

    protected XGBoostTrainer() {
        this.parameters = new HashMap();
        this.overrideParameters = new HashMap();
        this.eta = 0.3d;
        this.gamma = 0.0d;
        this.maxDepth = 6;
        this.minChildWeight = 1.0d;
        this.subsample = 1.0d;
        this.featureSubsample = 1.0d;
        this.lambda = 1.0d;
        this.alpha = 1.0d;
        this.nThread = 4;
        this.silent = 1;
        this.verbosity = LoggingVerbosity.SILENT;
        this.booster = BoosterType.GBTREE;
        this.treeMethod = TreeMethod.AUTO;
        this.seed = 12345L;
        this.trainInvocationCounter = 0;
    }

    public void postConfig() {
        this.parameters.put("eta", Double.valueOf(this.eta));
        this.parameters.put("gamma", Double.valueOf(this.gamma));
        this.parameters.put("max_depth", Integer.valueOf(this.maxDepth));
        this.parameters.put("min_child_weight", Double.valueOf(this.minChildWeight));
        this.parameters.put("subsample", Double.valueOf(this.subsample));
        this.parameters.put("colsample_bytree", Double.valueOf(this.featureSubsample));
        this.parameters.put("lambda", Double.valueOf(this.lambda));
        this.parameters.put("alpha", Double.valueOf(this.alpha));
        this.parameters.put("nthread", Integer.valueOf(this.nThread));
        this.parameters.put("seed", Long.valueOf(this.seed));
        if (this.silent == 1) {
            this.parameters.put("verbosity", 0);
        } else {
            this.parameters.put("verbosity", Integer.valueOf(this.verbosity.value));
        }
        this.parameters.put("booster", this.booster.paramName);
        this.parameters.put("tree_method", this.treeMethod.paramName);
        if (!this.overrideParameters.isEmpty() && !this.overrideParameters.containsKey("objective")) {
            throw new PropertyException("", "overrideParameters", "When using the override parameters must supply an objective");
        }
    }

    public String toString() {
        return "XGBoostTrainer(numTrees=" + this.numTrees + ",parameters" + this.parameters.toString() + ")";
    }

    protected XGBoostModel<T> createModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, List<Booster> list, XGBoostOutputConverter<T> xGBoostOutputConverter) {
        return new XGBoostModel<>(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, list, xGBoostOutputConverter);
    }

    protected Map<String, Object> copyParams(Map<String, ?> map) {
        return new HashMap(map);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.trainInvocationCounter = i;
    }

    protected static <T extends Output<T>> DMatrixTuple<T> convertDataset(Dataset<T> dataset, Function<T, Float> function) throws XGBoostError {
        return convertExamples(dataset.getData(), dataset.getFeatureIDMap(), function);
    }

    protected static <T extends Output<T>> DMatrixTuple<T> convertDataset(Dataset<T> dataset) throws XGBoostError {
        return convertExamples(dataset.getData(), dataset.getFeatureIDMap(), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static <T extends Output<T>> DMatrixTuple<T> convertExamples(Iterable<Example<T>> iterable, ImmutableFeatureMap immutableFeatureMap) throws XGBoostError {
        return convertExamples(iterable, immutableFeatureMap, null);
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected static <T extends Output<T>> DMatrixTuple<T> convertExamples(Iterable<Example<T>> iterable, ImmutableFeatureMap immutableFeatureMap, Function<T, Float> function) throws XGBoostError {
        boolean z = function != 0;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        ArrayList arrayList6 = new ArrayList();
        ArrayList arrayList7 = new ArrayList();
        long j = 0;
        arrayList3.add(0L);
        for (Example<T> example : iterable) {
            if (z) {
                arrayList.add((Float) function.apply(example.getOutput()));
                arrayList5.add(Float.valueOf(example.getWeight()));
            }
            arrayList7.add(example);
            long convertSingleExample = convertSingleExample(example, immutableFeatureMap, arrayList2, arrayList4, arrayList3, j);
            arrayList6.add(Integer.valueOf((int) (convertSingleExample - j)));
            j = convertSingleExample;
        }
        DMatrix dMatrix = new DMatrix(Util.toPrimitiveLong(arrayList3), Util.toPrimitiveInt(arrayList4), Util.toPrimitiveFloat(arrayList2), DMatrix.SparseType.CSR, immutableFeatureMap.size());
        if (z) {
            dMatrix.setLabel(Util.toPrimitiveFloat(arrayList));
            dMatrix.setWeight(Util.toPrimitiveFloat(arrayList5));
        }
        return new DMatrixTuple<>(dMatrix, Util.toPrimitiveInt(arrayList6), (Example[]) arrayList7.toArray(new Example[0]));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static <T extends Output<T>> DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap immutableFeatureMap) throws XGBoostError {
        return convertExample(example, immutableFeatureMap, null);
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected static <T extends Output<T>> DMatrixTuple<T> convertExample(Example<T> example, ImmutableFeatureMap immutableFeatureMap, Function<T, Float> function) throws XGBoostError {
        boolean z = function != 0;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(0L);
        long convertSingleExample = convertSingleExample(example, immutableFeatureMap, arrayList, arrayList2, arrayList3, 0L);
        DMatrix dMatrix = new DMatrix(Util.toPrimitiveLong(arrayList3), Util.toPrimitiveInt(arrayList2), Util.toPrimitiveFloat(arrayList), DMatrix.SparseType.CSR, immutableFeatureMap.size());
        if (z) {
            dMatrix.setLabel(new float[]{((Float) function.apply(example.getOutput())).floatValue()});
            dMatrix.setWeight(new float[]{example.getWeight()});
        }
        return new DMatrixTuple<>(dMatrix, new int[]{(int) convertSingleExample}, new Example[]{example});
    }

    protected static <T extends Output<T>> long convertSingleExample(Example<T> example, ImmutableFeatureMap immutableFeatureMap, ArrayList<Float> arrayList, ArrayList<Integer> arrayList2, ArrayList<Long> arrayList3, long j) {
        int i = 0;
        int i2 = -1;
        int size = arrayList2.size();
        Iterator it = example.iterator();
        while (it.hasNext()) {
            Feature feature = (Feature) it.next();
            int id = immutableFeatureMap.getID(feature.getName());
            if (id > i2) {
                i2 = id;
                arrayList.add(Float.valueOf((float) feature.getValue()));
                arrayList2.add(Integer.valueOf(id));
                i++;
            } else if (id > -1) {
                int binarySearch = Util.binarySearch(arrayList2, Integer.valueOf(id), size, i + size);
                if (binarySearch < 0) {
                    int i3 = -(binarySearch + 1);
                    arrayList2.add(i3, Integer.valueOf(id));
                    arrayList.add(i3, Float.valueOf((float) feature.getValue()));
                    i++;
                } else {
                    arrayList.set(binarySearch, Float.valueOf(arrayList.get(binarySearch).floatValue() + ((float) feature.getValue())));
                }
            }
        }
        if (i == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        long j2 = j + i;
        arrayList3.add(Long.valueOf(j2));
        return j2;
    }

    static long convertSingleExample(SparseVector sparseVector, ArrayList<Float> arrayList, ArrayList<Integer> arrayList2, ArrayList<Long> arrayList3, long j) {
        int i = 0;
        VectorIterator it = sparseVector.iterator();
        while (it.hasNext()) {
            VectorTuple vectorTuple = (VectorTuple) it.next();
            arrayList.add(Float.valueOf((float) vectorTuple.value));
            arrayList2.add(Integer.valueOf(vectorTuple.index));
            i++;
        }
        long j2 = j + i;
        arrayList3.add(Long.valueOf(j2));
        return j2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static DMatrix convertSparseVector(SparseVector sparseVector) throws XGBoostError {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        arrayList2.add(0L);
        convertSingleExample(sparseVector, arrayList, arrayList3, arrayList2, 0L);
        float[] primitiveFloat = Util.toPrimitiveFloat(arrayList);
        return new DMatrix(Util.toPrimitiveLong(arrayList2), Util.toPrimitiveInt(arrayList3), primitiveFloat, DMatrix.SparseType.CSR, sparseVector.size());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static DMatrix convertSparseVectors(List<SparseVector> list) throws XGBoostError {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        int i = 0;
        long j = 0;
        arrayList2.add(0L);
        for (SparseVector sparseVector : list) {
            j = convertSingleExample(sparseVector, arrayList, arrayList3, arrayList2, j);
            i = sparseVector.size();
        }
        return new DMatrix(Util.toPrimitiveLong(arrayList2), Util.toPrimitiveInt(arrayList3), Util.toPrimitiveFloat(arrayList), DMatrix.SparseType.CSR, i);
    }
}
