package com.linkedin.lift.lib;

import com.linkedin.lift.types.Distribution;
import com.linkedin.lift.types.FairnessResult;
import com.linkedin.lift.types.FairnessResult$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Set;
import scala.math.Numeric$DoubleIsFractional$;
import scala.math.Ordering$Double$;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

/* compiled from: DivergenceUtils.scala */
/* loaded from: input_file:com/linkedin/lift/lib/DivergenceUtils$.class */
public final class DivergenceUtils$ {
    public static final DivergenceUtils$ MODULE$ = null;

    static {
        new DivergenceUtils$();
    }

    public double computeKullbackLeiblerDivergence(Distribution distribution, Distribution distribution2, double d) {
        Seq seq = (Seq) distribution.zip(distribution2).map(new DivergenceUtils$$anonfun$1(d), Seq$.MODULE$.canBuildFrom());
        double sum = distribution.sum();
        return (1.0d / package$.MODULE$.log(2.0d)) * ((BoxesRunTime.unboxToDouble(seq.sum(Numeric$DoubleIsFractional$.MODULE$)) / sum) + package$.MODULE$.log((distribution2.sum() + (d * seq.size())) / sum));
    }

    public double computeKullbackLeiblerDivergence$default$3() {
        return 1.0d;
    }

    public double computeJensenShannonDivergence(Distribution distribution, Distribution distribution2) {
        Distribution distribution3 = new Distribution(((TraversableOnce) distribution.zip(distribution2).map(new DivergenceUtils$$anonfun$2(distribution.sum(), distribution2.sum()), Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms()));
        return 0.5d * (computeKullbackLeiblerDivergence(distribution, distribution3, 0.0d) + computeKullbackLeiblerDivergence(distribution2, distribution3, 0.0d));
    }

    public double computeTotalVariationDistance(Distribution distribution, Distribution distribution2) {
        return 0.5d * BoxesRunTime.unboxToDouble(((TraversableOnce) distribution.zip(distribution2).map(new DivergenceUtils$$anonfun$3(distribution.sum(), distribution2.sum()), Seq$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$));
    }

    public double computeInfinityNormDistance(Distribution distribution, Distribution distribution2) {
        return BoxesRunTime.unboxToDouble(((TraversableOnce) distribution.zip(distribution2).map(new DivergenceUtils$$anonfun$4(distribution.sum(), distribution2.sum()), Seq$.MODULE$.canBuildFrom())).max(Ordering$Double$.MODULE$));
    }

    public double computeSkew(Distribution distribution, Distribution distribution2, Map<String, String> map, double d) {
        int size = distribution.zip(distribution2).size();
        return ((package$.MODULE$.log(distribution.getValue(map) + d) - package$.MODULE$.log(distribution.sum() + (d * size))) + package$.MODULE$.log(distribution2.sum() + (d * size))) - package$.MODULE$.log(distribution2.getValue(map) + d);
    }

    public double computeSkew$default$4() {
        return 1.0d;
    }

    public Tuple2<Map<String, String>, Object> computeMinSkew(Distribution distribution, Distribution distribution2, double d) {
        Seq seq = (Seq) distribution.zip(distribution2).map(new DivergenceUtils$$anonfun$5(d), Seq$.MODULE$.canBuildFrom());
        double sum = distribution.sum() + (d * seq.size());
        double sum2 = distribution2.sum() + (d * seq.size());
        Tuple2 tuple2 = (Tuple2) seq.minBy(new DivergenceUtils$$anonfun$6(), Ordering$Double$.MODULE$);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((Map) tuple2._1(), BoxesRunTime.boxToDouble(tuple2._2$mcD$sp()));
        return new Tuple2<>((Map) tuple22._1(), BoxesRunTime.boxToDouble(package$.MODULE$.log(tuple22._2$mcD$sp()) + package$.MODULE$.log(sum2 / sum)));
    }

    public double computeMinSkew$default$3() {
        return 1.0d;
    }

    public Tuple2<Map<String, String>, Object> computeMaxSkew(Distribution distribution, Distribution distribution2, double d) {
        Seq seq = (Seq) distribution.zip(distribution2).map(new DivergenceUtils$$anonfun$7(d), Seq$.MODULE$.canBuildFrom());
        double sum = distribution.sum() + (d * seq.size());
        double sum2 = distribution2.sum() + (d * seq.size());
        Tuple2 tuple2 = (Tuple2) seq.maxBy(new DivergenceUtils$$anonfun$8(), Ordering$Double$.MODULE$);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((Map) tuple2._1(), BoxesRunTime.boxToDouble(tuple2._2$mcD$sp()));
        return new Tuple2<>((Map) tuple22._1(), BoxesRunTime.boxToDouble(package$.MODULE$.log(tuple22._2$mcD$sp()) + package$.MODULE$.log(sum2 / sum)));
    }

    public double computeMaxSkew$default$3() {
        return 1.0d;
    }

    public Map<Map<String, String>, Object> computeAllSkews(Distribution distribution, Distribution distribution2, double d) {
        Seq<Tuple3<Map<String, String>, Object, Object>> zip = distribution.zip(distribution2);
        int size = zip.size();
        return ((TraversableOnce) zip.map(new DivergenceUtils$$anonfun$computeAllSkews$1(d, package$.MODULE$.log(distribution.sum() + (d * size)) - package$.MODULE$.log(distribution2.sum() + (d * size))), Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
    }

    public double computeAllSkews$default$3() {
        return 1.0d;
    }

    public Distribution computeGeneralizedPredictionCountDistribution(Dataset<Row> dataset, String str, String str2, String str3) {
        return new Distribution(Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps((Object[]) dataset.select(str3, Predef$.MODULE$.wrapRefArray(new String[]{str, str2})).groupBy(str3, Predef$.MODULE$.wrapRefArray(new String[]{str})).agg(functions$.MODULE$.sum(functions$.MODULE$.col(str2)), Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.sum(functions$.MODULE$.lit(BoxesRunTime.boxToDouble(1.0d)).$minus(functions$.MODULE$.col(str2)))})).collect()).flatMap(new DivergenceUtils$$anonfun$9(str, str2, str3), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).toMap(Predef$.MODULE$.$conforms()));
    }

    public FairnessResult computeDemographicParity(Distribution distribution, String str, String str2) {
        Map map = (Map) distribution.computeMarginal((Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapRefArray(new String[]{str2}))).entries().map(new DivergenceUtils$$anonfun$12(distribution, str, ((TraversableOnce) distribution.entries().map(new DivergenceUtils$$anonfun$11(str), Iterable$.MODULE$.canBuildFrom())).toSet().contains("1") ? "1" : "1.0"), Map$.MODULE$.canBuildFrom());
        return new FairnessResult("DEMOGRAPHIC_PARITY", FairnessResult$.MODULE$.apply$default$2(), None$.MODULE$, map.keys().toSet().subsets(2).map(new DivergenceUtils$$anonfun$13(str2, map)).toMap(Predef$.MODULE$.$conforms()), map);
    }

    public FairnessResult computeEqualizedOdds(Distribution distribution, String str, String str2, String str3) {
        Map map = (Map) distribution.computeMarginal((Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapRefArray(new String[]{str, str3}))).entries().map(new DivergenceUtils$$anonfun$15(distribution, str2, ((TraversableOnce) distribution.entries().map(new DivergenceUtils$$anonfun$14(str2), Iterable$.MODULE$.canBuildFrom())).toSet().contains("1") ? "1" : "1.0"), Map$.MODULE$.canBuildFrom());
        return new FairnessResult("EQUALIZED_ODDS", FairnessResult$.MODULE$.apply$default$2(), None$.MODULE$, (Map) map.groupBy(new DivergenceUtils$$anonfun$16(str)).flatMap(new DivergenceUtils$$anonfun$17(str, str3), Map$.MODULE$.canBuildFrom()), (Map) map.map(new DivergenceUtils$$anonfun$18(), Map$.MODULE$.canBuildFrom()));
    }

    public Seq<FairnessResult> computeDatasetDistanceMetrics(Seq<String> seq, Distribution distribution, Option<Distribution> option, String str, String str2) {
        return (Seq) seq.flatMap(new DivergenceUtils$$anonfun$computeDatasetDistanceMetrics$1(distribution, option, str, str2), Seq$.MODULE$.canBuildFrom());
    }

    public Seq<FairnessResult> computeDistanceMetrics(Seq<String> seq, Distribution distribution, Option<Distribution> option, String str, String str2, String str3) {
        return str2.isEmpty() ? computeDatasetDistanceMetrics(seq, distribution, option, str, str3) : (Seq) computeDatasetDistanceMetrics(seq, distribution.computeMarginal((Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapRefArray(new String[]{str2, str3}))), option, str2, str3).$plus$plus((Seq) seq.flatMap(new DivergenceUtils$$anonfun$19(distribution, str, str2, str3), Seq$.MODULE$.canBuildFrom()), Seq$.MODULE$.canBuildFrom());
    }

    private DivergenceUtils$() {
        MODULE$ = this;
    }
}
