package com.linkedin.lift.eval;

import com.linkedin.lift.lib.DivergenceUtils$;
import com.linkedin.lift.lib.StatsUtils$;
import com.linkedin.lift.types.Distribution;
import com.linkedin.lift.types.Distribution$;
import com.linkedin.lift.types.FairnessResult;
import com.linkedin.lift.types.FairnessResult$;
import com.linkedin.lift.types.ModelPrediction;
import com.linkedin.lift.types.ModelPrediction$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrameReader;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions$;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Set;
import scala.runtime.BoxesRunTime;

/* compiled from: FairnessMetricsUtils.scala */
/* loaded from: input_file:com/linkedin/lift/eval/FairnessMetricsUtils$.class */
public final class FairnessMetricsUtils$ {
    public static final FairnessMetricsUtils$ MODULE$ = null;

    static {
        new FairnessMetricsUtils$();
    }

    public Dataset<Row> projectIdLabelsAndScores(Dataset<Row> dataset, String str, String str2, String str3, String str4) {
        return str4.isEmpty() ? dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str), functions$.MODULE$.col(str2), functions$.MODULE$.col(str3)})) : Predef$.MODULE$.refArrayOps(dataset.schema().fieldNames()).contains(str4) ? dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str), functions$.MODULE$.col(str2), functions$.MODULE$.col(str3), functions$.MODULE$.col(str4)})) : dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str), functions$.MODULE$.col(str2), functions$.MODULE$.col(str3)}));
    }

    public Seq<FairnessResult> computePermutationTestMetrics(Seq<ModelPrediction> seq, String str, Seq<String> seq2, int i, long j) {
        return (Seq) seq2.flatMap(new FairnessMetricsUtils$$anonfun$computePermutationTestMetrics$1(seq, str, i, j, ((TraversableOnce) seq.map(new FairnessMetricsUtils$$anonfun$1(), Seq$.MODULE$.canBuildFrom())).toSet().subsets(2).toList()), Seq$.MODULE$.canBuildFrom());
    }

    public Seq<FairnessResult> computeModelPerformanceMetrics(Dataset<Row> dataset, MeasureModelFairnessMetricsCmdLineArgs measureModelFairnessMetricsCmdLineArgs) {
        Seq<ModelPrediction> compute = ModelPrediction$.MODULE$.compute(dataset, measureModelFairnessMetricsCmdLineArgs.labelField(), measureModelFairnessMetricsCmdLineArgs.scoreField(), measureModelFairnessMetricsCmdLineArgs.groupIdField(), measureModelFairnessMetricsCmdLineArgs.protectedAttributeField());
        return (Seq) computePermutationTestMetrics(compute, measureModelFairnessMetricsCmdLineArgs.protectedAttributeField(), measureModelFairnessMetricsCmdLineArgs.permutationMetrics(), measureModelFairnessMetricsCmdLineArgs.numTrials(), measureModelFairnessMetricsCmdLineArgs.seed()).$plus$plus((Seq) ((Seq) measureModelFairnessMetricsCmdLineArgs.performanceBenefitMetrics().map(new FairnessMetricsUtils$$anonfun$2(measureModelFairnessMetricsCmdLineArgs, compute), Seq$.MODULE$.canBuildFrom())).flatMap(new FairnessMetricsUtils$$anonfun$3(measureModelFairnessMetricsCmdLineArgs), Seq$.MODULE$.canBuildFrom()), Seq$.MODULE$.canBuildFrom());
    }

    public Dataset<Row> computeJoinedDF(DataFrameReader dataFrameReader, Dataset<Row> dataset, String str, String str2, String str3, String str4) {
        return dataFrameReader.load(str2).select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str3).as(str), functions$.MODULE$.col(str4)})).join(dataset, str);
    }

    public Option<Distribution> computeReferenceDistributionOpt(Distribution distribution, String str) {
        if (str != null ? !str.equals("UNIFORM") : "UNIFORM" != 0) {
            return None$.MODULE$;
        }
        return new Some(new Distribution((Map) distribution.entries().map(new FairnessMetricsUtils$$anonfun$4(1.0d / distribution.entries().size()), Map$.MODULE$.canBuildFrom())));
    }

    public Seq<FairnessResult> computeDatasetMetrics(Distribution distribution, Option<Distribution> option, MeasureDatasetFairnessMetricsCmdLineArgs measureDatasetFairnessMetricsCmdLineArgs, String str) {
        return (Seq) DivergenceUtils$.MODULE$.computeDistanceMetrics(measureDatasetFairnessMetricsCmdLineArgs.distanceMetrics(), distribution, option, measureDatasetFairnessMetricsCmdLineArgs.labelField(), str, measureDatasetFairnessMetricsCmdLineArgs.protectedAttributeField()).$plus$plus((Seq) ((Seq) DivergenceUtils$.MODULE$.computeDistanceMetrics(measureDatasetFairnessMetricsCmdLineArgs.benefitMetrics(), distribution, option, measureDatasetFairnessMetricsCmdLineArgs.labelField(), str, measureDatasetFairnessMetricsCmdLineArgs.protectedAttributeField()).map(new FairnessMetricsUtils$$anonfun$5(), Seq$.MODULE$.canBuildFrom())).flatMap(new FairnessMetricsUtils$$anonfun$6(measureDatasetFairnessMetricsCmdLineArgs), Seq$.MODULE$.canBuildFrom()), Seq$.MODULE$.canBuildFrom());
    }

    public String computeDatasetMetrics$default$4() {
        return "";
    }

    public Dataset<Row> computeProbabilityDF(Dataset<Row> dataset, Option<Object> option, String str, String str2, String str3, String str4) {
        Dataset select = str4.equals("RAW") ? dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str), functions$.MODULE$.lit(BoxesRunTime.boxToDouble(1.0d)).$div(functions$.MODULE$.lit(BoxesRunTime.boxToDouble(1.0d)).$plus(functions$.MODULE$.exp(functions$.MODULE$.col(str2).unary_$minus()))).as(str2), functions$.MODULE$.col(str3)})) : dataset.select(str, Predef$.MODULE$.wrapRefArray(new String[]{str2, str3}));
        return (Dataset) option.map(new FairnessMetricsUtils$$anonfun$computeProbabilityDF$1(str, str2, str3, select)).getOrElse(new FairnessMetricsUtils$$anonfun$computeProbabilityDF$2(select));
    }

    public Seq<FairnessResult> computeModelMetrics(Dataset<Row> dataset, Option<Distribution> option, MeasureModelFairnessMetricsCmdLineArgs measureModelFairnessMetricsCmdLineArgs) {
        Dataset<Row> computeProbabilityDF = computeProbabilityDF(dataset, measureModelFairnessMetricsCmdLineArgs.thresholdOpt(), measureModelFairnessMetricsCmdLineArgs.labelField(), measureModelFairnessMetricsCmdLineArgs.scoreField(), measureModelFairnessMetricsCmdLineArgs.protectedAttributeField(), measureModelFairnessMetricsCmdLineArgs.scoreType());
        return (Seq) computeDatasetMetrics(DivergenceUtils$.MODULE$.computeGeneralizedPredictionCountDistribution(computeProbabilityDF, measureModelFairnessMetricsCmdLineArgs.labelField(), measureModelFairnessMetricsCmdLineArgs.scoreField(), measureModelFairnessMetricsCmdLineArgs.protectedAttributeField()), option, new MeasureDatasetFairnessMetricsCmdLineArgs(measureModelFairnessMetricsCmdLineArgs.datasetPath(), measureModelFairnessMetricsCmdLineArgs.protectedDatasetPath(), MeasureDatasetFairnessMetricsCmdLineArgs$.MODULE$.apply$default$3(), MeasureDatasetFairnessMetricsCmdLineArgs$.MODULE$.apply$default$4(), measureModelFairnessMetricsCmdLineArgs.uidField(), measureModelFairnessMetricsCmdLineArgs.labelField(), measureModelFairnessMetricsCmdLineArgs.protectedAttributeField(), MeasureDatasetFairnessMetricsCmdLineArgs$.MODULE$.apply$default$8(), measureModelFairnessMetricsCmdLineArgs.outputPath(), measureModelFairnessMetricsCmdLineArgs.referenceDistribution(), measureModelFairnessMetricsCmdLineArgs.distanceMetrics(), measureModelFairnessMetricsCmdLineArgs.overallMetrics(), measureModelFairnessMetricsCmdLineArgs.distanceBenefitMetrics()), measureModelFairnessMetricsCmdLineArgs.scoreField()).$plus$plus(computeModelPerformanceMetrics(measureModelFairnessMetricsCmdLineArgs.groupIdField().isEmpty() ? StatsUtils$.MODULE$.sampleDataFrame(computeProbabilityDF, measureModelFairnessMetricsCmdLineArgs.labelField(), measureModelFairnessMetricsCmdLineArgs.approxRows(), measureModelFairnessMetricsCmdLineArgs.labelZeroPercentage(), measureModelFairnessMetricsCmdLineArgs.seed()) : StatsUtils$.MODULE$.sampleDataFrameByGroupId(dataset, measureModelFairnessMetricsCmdLineArgs.labelField(), measureModelFairnessMetricsCmdLineArgs.scoreField(), measureModelFairnessMetricsCmdLineArgs.groupIdField(), measureModelFairnessMetricsCmdLineArgs.protectedAttributeField(), measureModelFairnessMetricsCmdLineArgs.approxRows(), measureModelFairnessMetricsCmdLineArgs.seed()), measureModelFairnessMetricsCmdLineArgs), Seq$.MODULE$.canBuildFrom());
    }

    public void writeFairnessResults(SparkSession sparkSession, String str, Map<String, String> map, String str2, Seq<FairnessResult> seq) {
        FairnessResult$.MODULE$.toDF(sparkSession, seq).repartition(1).write().mode(SaveMode.Overwrite).format(str).options(map).save(str2);
    }

    public void computeAndWriteDatasetMetrics(Dataset<Row> dataset, Option<Distribution> option, MeasureDatasetFairnessMetricsCmdLineArgs measureDatasetFairnessMetricsCmdLineArgs) {
        writeFairnessResults(dataset.sparkSession(), measureDatasetFairnessMetricsCmdLineArgs.dataFormat(), measureDatasetFairnessMetricsCmdLineArgs.dataOptions(), measureDatasetFairnessMetricsCmdLineArgs.outputPath(), computeDatasetMetrics(Distribution$.MODULE$.compute(dataset, (Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapRefArray(new String[]{measureDatasetFairnessMetricsCmdLineArgs.labelField(), measureDatasetFairnessMetricsCmdLineArgs.protectedAttributeField()}))), option, measureDatasetFairnessMetricsCmdLineArgs, computeDatasetMetrics$default$4()));
    }

    public void computeAndWriteModelMetrics(Dataset<Row> dataset, Option<Distribution> option, MeasureModelFairnessMetricsCmdLineArgs measureModelFairnessMetricsCmdLineArgs) {
        writeFairnessResults(dataset.sparkSession(), measureModelFairnessMetricsCmdLineArgs.dataFormat(), measureModelFairnessMetricsCmdLineArgs.dataOptions(), measureModelFairnessMetricsCmdLineArgs.outputPath(), computeModelMetrics(dataset, option, measureModelFairnessMetricsCmdLineArgs));
    }

    private FairnessMetricsUtils$() {
        MODULE$ = this;
    }
}
