package com.linkedin.lift.lib;

import com.linkedin.lift.lib.StatsUtils;
import com.linkedin.lift.types.CustomMetric;
import com.linkedin.lift.types.ModelPrediction;
import com.linkedin.lift.types.ModelPrediction$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Iterable;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.StringOps;
import scala.math.Numeric$DoubleIsFractional$;
import scala.math.Ordering$Double$;
import scala.math.package$;
import scala.runtime.BoxesRunTime;

/* compiled from: StatsUtils.scala */
/* loaded from: input_file:com/linkedin/lift/lib/StatsUtils$.class */
public final class StatsUtils$ {
    public static final StatsUtils$ MODULE$ = null;

    static {
        new StatsUtils$();
    }

    public double roundDouble(double d, int i) {
        return package$.MODULE$.round(d * r0) / package$.MODULE$.pow(10.0d, i);
    }

    public int roundDouble$default$2() {
        return 5;
    }

    public Tuple2<Object, Object> computePosNegSamplePercentages(Dataset<Row> dataset, Dataset<Row> dataset2, long j, double d) {
        double count = dataset.count();
        double count2 = dataset2.count();
        double d2 = count2 + count;
        Tuple2.mcDD.sp spVar = (d < 0.0d || d > 1.0d) ? new Tuple2.mcDD.sp(j / d2, j / d2) : new Tuple2.mcDD.sp((j * (1.0d - d)) / count, (j * d) / count2);
        if (spVar == null) {
            throw new MatchError(spVar);
        }
        Tuple2.mcDD.sp spVar2 = new Tuple2.mcDD.sp(spVar._1$mcD$sp(), spVar._2$mcD$sp());
        double _1$mcD$sp = spVar2._1$mcD$sp();
        double _2$mcD$sp = spVar2._2$mcD$sp();
        return new Tuple2.mcDD.sp(_1$mcD$sp > 1.0d ? 1.0d : _1$mcD$sp, _2$mcD$sp > 1.0d ? 1.0d : _2$mcD$sp);
    }

    public double computePosNegSamplePercentages$default$4() {
        return -1.0d;
    }

    public Dataset<Row> sampleDataFrame(Dataset<Row> dataset, String str, long j, double d, long j2) {
        Dataset<Row> filter = dataset.filter(functions$.MODULE$.col(str).$eq$eq$eq(BoxesRunTime.boxToDouble(1.0d)));
        Dataset<Row> filter2 = dataset.filter(functions$.MODULE$.col(str).$eq$eq$eq(BoxesRunTime.boxToDouble(0.0d)));
        Tuple2<Object, Object> computePosNegSamplePercentages = computePosNegSamplePercentages(filter, filter2, j, d);
        if (computePosNegSamplePercentages == null) {
            throw new MatchError(computePosNegSamplePercentages);
        }
        Tuple2.mcDD.sp spVar = new Tuple2.mcDD.sp(computePosNegSamplePercentages._1$mcD$sp(), computePosNegSamplePercentages._2$mcD$sp());
        double _1$mcD$sp = spVar._1$mcD$sp();
        double _2$mcD$sp = spVar._2$mcD$sp();
        Tuple2 tuple2 = j2 == 0 ? new Tuple2(filter.sample(_1$mcD$sp), filter2.sample(_2$mcD$sp)) : new Tuple2(filter.sample(_1$mcD$sp, j2), filter2.sample(_2$mcD$sp, j2));
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((Dataset) tuple2._1(), (Dataset) tuple2._2());
        return ((Dataset) tuple22._1()).union((Dataset) tuple22._2());
    }

    public double sampleDataFrame$default$4() {
        return -1.0d;
    }

    public long sampleDataFrame$default$5() {
        return 0L;
    }

    public Dataset<Row> sampleDataFrameByGroupId(Dataset<Row> dataset, String str, String str2, String str3, String str4, long j, long j2) {
        Dataset df = ModelPrediction$.MODULE$.getModelPredictionDS(dataset, str, str2, str3, str4).toDF();
        Dataset distinct = df.select("groupId", Predef$.MODULE$.wrapRefArray(new String[0])).distinct();
        double min = package$.MODULE$.min(1.0d, j / distinct.count());
        return df.join(j2 == 0 ? distinct.sample(min) : distinct.sample(min, j2), "groupId").select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col("label").as(str), functions$.MODULE$.col("prediction").as(str2), functions$.MODULE$.col("groupId").as(str3), functions$.MODULE$.col("dimensionValue").as(str4)}));
    }

    public long sampleDataFrameByGroupId$default$7() {
        return 0L;
    }

    public double computePrecisionAtK(double d, int i, Seq<ModelPrediction> seq) {
        return BoxesRunTime.unboxToDouble(((Iterable) seq.groupBy(new StatsUtils$$anonfun$10()).map(new StatsUtils$$anonfun$11(d, i), Iterable$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$)) / r0.size();
    }

    public Function1<Seq<ModelPrediction>, Object> getMetricFn(String str) {
        if (str.equals("AUC")) {
            return new StatsUtils$$anonfun$getMetricFn$1();
        }
        if (str.equals("FNR")) {
            return new StatsUtils$$anonfun$getMetricFn$2();
        }
        if (str.equals("FPR")) {
            return new StatsUtils$$anonfun$getMetricFn$3();
        }
        if (str.equals("TNR")) {
            return new StatsUtils$$anonfun$getMetricFn$4();
        }
        if (str.equals("PRECISION")) {
            return new StatsUtils$$anonfun$getMetricFn$5();
        }
        if (str.equals("RECALL")) {
            return new StatsUtils$$anonfun$getMetricFn$6();
        }
        if (!str.matches("PRECISION/\\d*\\.*\\d+@\\d+")) {
            return new StatsUtils$$anonfun$getMetricFn$8((CustomMetric) Class.forName(str).newInstance());
        }
        String str2 = (String) Predef$.MODULE$.refArrayOps(str.split("/")).last();
        return new StatsUtils$$anonfun$getMetricFn$7(new StringOps(Predef$.MODULE$.augmentString((String) Predef$.MODULE$.refArrayOps(str2.split("@")).head())).toDouble(), new StringOps(Predef$.MODULE$.augmentString((String) Predef$.MODULE$.refArrayOps(str2.split("@")).last())).toInt());
    }

    public double computeStdDev(Seq<Object> seq) {
        double unboxToDouble;
        Seq seq2 = (Seq) seq.filterNot(new StatsUtils$$anonfun$1());
        if (seq2.length() < 2) {
            unboxToDouble = 0.0d;
        } else {
            unboxToDouble = BoxesRunTime.unboxToDouble(((TraversableOnce) seq2.map(new StatsUtils$$anonfun$2(BoxesRunTime.unboxToDouble(seq2.sum(Numeric$DoubleIsFractional$.MODULE$)) / seq2.length()), Seq$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$)) / (r0 - 1);
        }
        return package$.MODULE$.sqrt(unboxToDouble);
    }

    public StatsUtils.ConfusionMatrix computeGeneralizedConfusionMatrix(Seq<ModelPrediction> seq) {
        return seq.isEmpty() ? new StatsUtils.ConfusionMatrix(0.0d, 0.0d, 0.0d, 0.0d) : (StatsUtils.ConfusionMatrix) ((TraversableOnce) seq.map(new StatsUtils$$anonfun$computeGeneralizedConfusionMatrix$1(), Seq$.MODULE$.canBuildFrom())).reduce(new StatsUtils$$anonfun$computeGeneralizedConfusionMatrix$2());
    }

    public double computePrecision(Seq<ModelPrediction> seq) {
        StatsUtils.ConfusionMatrix computeGeneralizedConfusionMatrix = computeGeneralizedConfusionMatrix(seq);
        if (computeGeneralizedConfusionMatrix.truePositive() == 0) {
            return 0.0d;
        }
        return computeGeneralizedConfusionMatrix.truePositive() / (computeGeneralizedConfusionMatrix.truePositive() + computeGeneralizedConfusionMatrix.falsePositive());
    }

    public double computeFalsePositiveRate(Seq<ModelPrediction> seq) {
        StatsUtils.ConfusionMatrix computeGeneralizedConfusionMatrix = computeGeneralizedConfusionMatrix(seq);
        if (computeGeneralizedConfusionMatrix.falsePositive() == 0) {
            return 0.0d;
        }
        return computeGeneralizedConfusionMatrix.falsePositive() / (computeGeneralizedConfusionMatrix.falsePositive() + computeGeneralizedConfusionMatrix.trueNegative());
    }

    public double computeFalseNegativeRate(Seq<ModelPrediction> seq) {
        StatsUtils.ConfusionMatrix computeGeneralizedConfusionMatrix = computeGeneralizedConfusionMatrix(seq);
        if (computeGeneralizedConfusionMatrix.falseNegative() == 0) {
            return 0.0d;
        }
        return computeGeneralizedConfusionMatrix.falseNegative() / (computeGeneralizedConfusionMatrix.truePositive() + computeGeneralizedConfusionMatrix.falseNegative());
    }

    public double computeRecall(Seq<ModelPrediction> seq) {
        StatsUtils.ConfusionMatrix computeGeneralizedConfusionMatrix = computeGeneralizedConfusionMatrix(seq);
        if (computeGeneralizedConfusionMatrix.truePositive() == 0) {
            return 0.0d;
        }
        return computeGeneralizedConfusionMatrix.truePositive() / (computeGeneralizedConfusionMatrix.truePositive() + computeGeneralizedConfusionMatrix.falseNegative());
    }

    public double computeTrueNegativeRate(Seq<ModelPrediction> seq) {
        StatsUtils.ConfusionMatrix computeGeneralizedConfusionMatrix = computeGeneralizedConfusionMatrix(seq);
        if (computeGeneralizedConfusionMatrix.trueNegative() == 0) {
            return 0.0d;
        }
        return computeGeneralizedConfusionMatrix.trueNegative() / (computeGeneralizedConfusionMatrix.trueNegative() + computeGeneralizedConfusionMatrix.falsePositive());
    }

    public Tuple2<Seq<Object>, Seq<Object>> computeROCCurve(Seq<ModelPrediction> seq) {
        Seq seq2 = (Seq) seq.sortBy(new StatsUtils$$anonfun$12(), Ordering$Double$.MODULE$);
        List list = (List) seq2.sliding(2).zipWithIndex().collect(new StatsUtils$$anonfun$7()).toList().$colon$plus(BoxesRunTime.boxToInteger(seq2.length() - 1), List$.MODULE$.canBuildFrom());
        List list2 = (List) list.collect((Seq) ((TraversableLike) seq2.scanLeft(BoxesRunTime.boxToDouble(0.0d), new StatsUtils$$anonfun$13(), Seq$.MODULE$.canBuildFrom())).tail(), List$.MODULE$.canBuildFrom());
        List list3 = (List) ((List) list.zip(list2, List$.MODULE$.canBuildFrom())).map(new StatsUtils$$anonfun$14(), List$.MODULE$.canBuildFrom());
        return new Tuple2<>((List) list3.map(new StatsUtils$$anonfun$5(BoxesRunTime.unboxToDouble(list3.lastOption().getOrElse(new StatsUtils$$anonfun$4()))), List$.MODULE$.canBuildFrom()), (List) list2.map(new StatsUtils$$anonfun$6(BoxesRunTime.unboxToDouble(list2.lastOption().getOrElse(new StatsUtils$$anonfun$3()))), List$.MODULE$.canBuildFrom()));
    }

    public double computeAUC(Seq<ModelPrediction> seq) {
        Tuple2<Seq<Object>, Seq<Object>> computeROCCurve = computeROCCurve(seq);
        if (computeROCCurve == null) {
            throw new MatchError(computeROCCurve);
        }
        Tuple2 tuple2 = new Tuple2((Seq) computeROCCurve._1(), (Seq) computeROCCurve._2());
        Seq seq2 = (Seq) tuple2._1();
        Seq seq3 = (Seq) tuple2._2();
        if (seq2.length() == 1 && seq3.length() == 1) {
            return 0.0d;
        }
        return BoxesRunTime.unboxToDouble(((IterableLike) seq2.zip(seq3, Seq$.MODULE$.canBuildFrom())).sliding(2).foldLeft(BoxesRunTime.boxToDouble(0.0d), new StatsUtils$$anonfun$computeAUC$1()));
    }

    public final double com$linkedin$lift$lib$StatsUtils$$singleQueryPrecisionAtK$1(Seq seq, double d, int i) {
        if (((Seq) seq.filter(new StatsUtils$$anonfun$8(i))).isEmpty()) {
            return 0.0d;
        }
        return r0.count(new StatsUtils$$anonfun$9(d)) / r0.length();
    }

    private StatsUtils$() {
        MODULE$ = this;
    }
}
