package org.apache.flink.ml.stats.fvaluetest;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.math3.distribution.FDistribution;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/stats/fvaluetest/FValueTest.class */
public class FValueTest implements AlgoOperator<FValueTest>, FValueTestParams<FValueTest> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/stats/fvaluetest/FValueTest$CalCovarianceOperator.class */
    public static class CalCovarianceOperator extends RichMapPartitionFunction<Tuple2<Vector, Double>, DenseVector> {
        private final String broadcastKey;

        private CalCovarianceOperator(String str) {
            this.broadcastKey = str;
        }

        public void mapPartition(Iterable<Tuple2<Vector, Double>> iterable, Collector<DenseVector> collector) {
            Tuple5 tuple5 = (Tuple5) getRuntimeContext().getBroadcastVariable(this.broadcastKey).get(0);
            int size = ((DenseVector) tuple5.f3).size();
            DenseVector denseVector = new DenseVector(size);
            for (Tuple2<Vector, Double> tuple2 : iterable) {
                Preconditions.checkArgument(((Vector) tuple2.f0).size() == size, "Input %s features, but FValueTest is expecting %s features.", new Object[]{Integer.valueOf(((Vector) tuple2.f0).size()), Integer.valueOf(size)});
                double doubleValue = ((Double) tuple2.f1).doubleValue() - ((Double) tuple5.f1).doubleValue();
                if (doubleValue != 0.0d) {
                    for (int i = 0; i < size; i++) {
                        double[] dArr = denseVector.values;
                        int i2 = i;
                        dArr[i2] = dArr[i2] + (doubleValue * (((Vector) tuple2.f0).get(i) - ((DenseVector) tuple5.f3).get(i)));
                    }
                }
            }
            BLAS.scal(1.0d / (((Long) tuple5.f0).longValue() - 1), denseVector);
            collector.collect(denseVector);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/stats/fvaluetest/FValueTest$CalFValueOperator.class */
    public static class CalFValueOperator extends RichMapPartitionFunction<DenseVector, Tuple4<Integer, Double, Long, Double>> {
        private final String broadcastKey;
        private DenseVector sumVector;

        private CalFValueOperator(String str) {
            this.broadcastKey = str;
        }

        public void mapPartition(Iterable<DenseVector> iterable, Collector<Tuple4<Integer, Double, Long, Double>> collector) {
            Tuple5 tuple5 = (Tuple5) getRuntimeContext().getBroadcastVariable(this.broadcastKey).get(0);
            int size = ((DenseVector) tuple5.f4).size();
            if (iterable.iterator().hasNext()) {
                this.sumVector = iterable.iterator().next();
            }
            Preconditions.checkArgument(this.sumVector.size() == size, "Input %s features, but FValueTest is expecting %s features.", new Object[]{Integer.valueOf(this.sumVector.size()), Integer.valueOf(size)});
            long longValue = ((Long) tuple5.f0).longValue() - 2;
            FDistribution fDistribution = new FDistribution(1.0d, longValue);
            for (int i = 0; i < size; i++) {
                double doubleValue = this.sumVector.get(i) / (((Double) tuple5.f2).doubleValue() * ((DenseVector) tuple5.f4).get(i));
                double d = ((doubleValue * doubleValue) / (1.0d - (doubleValue * doubleValue))) * longValue;
                collector.collect(Tuple4.of(Integer.valueOf(i), Double.valueOf(1.0d - fDistribution.cumulativeProbability(d)), Long.valueOf(longValue), Double.valueOf(d)));
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/stats/fvaluetest/FValueTest$SummaryAggregator.class */
    private static class SummaryAggregator implements AggregateFunction<Tuple2<Vector, Double>, Tuple5<Long, Double, Double, DenseVector, DenseVector>, Tuple5<Long, Double, Double, DenseVector, DenseVector>> {
        private SummaryAggregator() {
        }

        /* renamed from: createAccumulator, reason: merged with bridge method [inline-methods] */
        public Tuple5<Long, Double, Double, DenseVector, DenseVector> m136createAccumulator() {
            return Tuple5.of(0L, Double.valueOf(0.0d), Double.valueOf(0.0d), new DenseVector(new double[0]), new DenseVector(new double[0]));
        }

        public Tuple5<Long, Double, Double, DenseVector, DenseVector> add(Tuple2<Vector, Double> tuple2, Tuple5<Long, Double, Double, DenseVector, DenseVector> tuple5) {
            Vector vector = (Vector) tuple2.f0;
            double doubleValue = ((Double) tuple2.f1).doubleValue();
            if (((Long) tuple5.f0).longValue() == 0) {
                tuple5.f3 = new DenseVector(vector.size());
                tuple5.f4 = new DenseVector(vector.size());
            }
            tuple5.f0 = Long.valueOf(((Long) tuple5.f0).longValue() + 1);
            tuple5.f1 = Double.valueOf(((Double) tuple5.f1).doubleValue() + doubleValue);
            tuple5.f2 = Double.valueOf(((Double) tuple5.f2).doubleValue() + (doubleValue * doubleValue));
            BLAS.axpy(1.0d, vector, (DenseVector) tuple5.f3);
            for (int i = 0; i < vector.size(); i++) {
                double[] dArr = ((DenseVector) tuple5.f4).values;
                int i2 = i;
                dArr[i2] = dArr[i2] + (vector.get(i) * vector.get(i));
            }
            return tuple5;
        }

        public Tuple5<Long, Double, Double, DenseVector, DenseVector> getResult(Tuple5<Long, Double, Double, DenseVector, DenseVector> tuple5) {
            long longValue = ((Long) tuple5.f0).longValue();
            Preconditions.checkState(longValue > 0, "The training set is empty.");
            int size = ((DenseVector) tuple5.f3).size();
            double doubleValue = ((Double) tuple5.f1).doubleValue() / longValue;
            Tuple5<Long, Double, Double, DenseVector, DenseVector> of = Tuple5.of(Long.valueOf(longValue), Double.valueOf(doubleValue), Double.valueOf(Math.sqrt((((((Double) tuple5.f2).doubleValue() / longValue) - (doubleValue * doubleValue)) * longValue) / (longValue - 1))), new DenseVector(size), new DenseVector(size));
            for (int i = 0; i < ((DenseVector) tuple5.f3).size(); i++) {
                double d = ((DenseVector) tuple5.f3).get(i) / longValue;
                ((DenseVector) of.f3).values[i] = d;
                ((DenseVector) of.f4).values[i] = Math.sqrt((((((DenseVector) tuple5.f4).get(i) / longValue) - (d * d)) * longValue) / (longValue - 1));
            }
            return of;
        }

        public Tuple5<Long, Double, Double, DenseVector, DenseVector> merge(Tuple5<Long, Double, Double, DenseVector, DenseVector> tuple5, Tuple5<Long, Double, Double, DenseVector, DenseVector> tuple52) {
            if (((Long) tuple5.f0).longValue() == 0) {
                return tuple52;
            }
            if (((Long) tuple52.f0).longValue() == 0) {
                return tuple5;
            }
            tuple52.f0 = Long.valueOf(((Long) tuple52.f0).longValue() + ((Long) tuple5.f0).longValue());
            tuple52.f1 = Double.valueOf(((Double) tuple52.f1).doubleValue() + ((Double) tuple5.f1).doubleValue());
            tuple52.f2 = Double.valueOf(((Double) tuple52.f2).doubleValue() + ((Double) tuple5.f2).doubleValue());
            BLAS.axpy(1.0d, (Vector) tuple5.f3, (DenseVector) tuple52.f3);
            BLAS.axpy(1.0d, (Vector) tuple5.f4, (DenseVector) tuple52.f4);
            return tuple52;
        }
    }

    public FValueTest() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        String featuresCol = getFeaturesCol();
        String labelCol = getLabelCol();
        StreamTableEnvironment streamTableEnvironment = (StreamTableEnvironment) ((TableImpl) tableArr[0]).getTableEnvironment();
        SingleOutputStreamOperator returns = streamTableEnvironment.toDataStream(tableArr[0]).map(row -> {
            Number number = (Number) row.getField(labelCol);
            Preconditions.checkNotNull(number, "Input data must contain label value.");
            return new Tuple2((Vector) row.getField(featuresCol), Double.valueOf(number.doubleValue()));
        }).returns(Types.TUPLE(new TypeInformation[]{VectorTypeInfo.INSTANCE, Types.DOUBLE}));
        DataStream aggregate = DataStreamUtils.aggregate(returns, new SummaryAggregator());
        return new Table[]{convertToTable(streamTableEnvironment, BroadcastUtils.withBroadcastStream(Collections.singletonList(DataStreamUtils.reduce(BroadcastUtils.withBroadcastStream(Collections.singletonList(returns), Collections.singletonMap("broadcastSummaryKey", aggregate), list -> {
            return DataStreamUtils.mapPartition((DataStream) list.get(0), new CalCovarianceOperator("broadcastSummaryKey"));
        }), (denseVector, denseVector2) -> {
            BLAS.axpy(1.0d, denseVector, denseVector2);
            return denseVector2;
        })), Collections.singletonMap("broadcastSummaryKey", aggregate), list2 -> {
            return DataStreamUtils.mapPartition((DataStream) list2.get(0), new CalFValueOperator("broadcastSummaryKey"));
        }), getFlatten())};
    }

    private Table convertToTable(StreamTableEnvironment streamTableEnvironment, DataStream<Tuple4<Integer, Double, Long, Double>> dataStream, boolean z) {
        return z ? streamTableEnvironment.fromDataStream(dataStream).as("featureIndex", new String[]{"pValue", "degreeOfFreedom", "fValue"}) : streamTableEnvironment.fromDataStream(DataStreamUtils.mapPartition(dataStream, new MapPartitionFunction<Tuple4<Integer, Double, Long, Double>, Tuple3<DenseVector, long[], DenseVector>>() { // from class: org.apache.flink.ml.stats.fvaluetest.FValueTest.1
            public void mapPartition(Iterable<Tuple4<Integer, Double, Long, Double>> iterable, Collector<Tuple3<DenseVector, long[], DenseVector>> collector) {
                List list = IteratorUtils.toList(iterable.iterator());
                int size = list.size();
                DenseVector denseVector = new DenseVector(size);
                long[] jArr = new long[size];
                DenseVector denseVector2 = new DenseVector(size);
                for (int i = 0; i < size; i++) {
                    Tuple4 tuple4 = (Tuple4) list.get(i);
                    denseVector.set(i, ((Double) tuple4.f1).doubleValue());
                    jArr[i] = ((Long) tuple4.f2).longValue();
                    denseVector2.set(i, ((Double) tuple4.f3).doubleValue());
                }
                collector.collect(Tuple3.of(denseVector, jArr, denseVector2));
            }
        })).as("pValues", new String[]{"degreesOfFreedom", "fValues"});
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static FValueTest load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -2075300007:
                if (implMethodName.equals("lambda$transform$5f7d27a5$1")) {
                    z = false;
                    break;
                }
                break;
            case 395302457:
                if (implMethodName.equals("lambda$transform$f675bd29$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/stats/fvaluetest/FValueTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;Ljava/lang/String;Lorg/apache/flink/types/Row;)Lorg/apache/flink/api/java/tuple/Tuple2;")) {
                    String str = (String) serializedLambda.getCapturedArg(0);
                    String str2 = (String) serializedLambda.getCapturedArg(1);
                    return row -> {
                        Number number = (Number) row.getField(str);
                        Preconditions.checkNotNull(number, "Input data must contain label value.");
                        return new Tuple2((Vector) row.getField(str2), Double.valueOf(number.doubleValue()));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/ReduceFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("reduce") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/stats/fvaluetest/FValueTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/ml/linalg/DenseVector;Lorg/apache/flink/ml/linalg/DenseVector;)Lorg/apache/flink/ml/linalg/DenseVector;")) {
                    return (denseVector, denseVector2) -> {
                        BLAS.axpy(1.0d, denseVector, denseVector2);
                        return denseVector2;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
