package org.apache.flink.ml.feature.variancethresholdselector;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.class */
public class VarianceThresholdSelector implements Estimator<VarianceThresholdSelector, VarianceThresholdSelectorModel>, VarianceThresholdSelectorParams<VarianceThresholdSelector> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector$VarianceThresholdSelectorAggregator.class */
    public static class VarianceThresholdSelectorAggregator implements AggregateFunction<Vector, Tuple3<Long, DenseVector, DenseVector>, VarianceThresholdSelectorModelData> {
        private final double varianceThreshold;

        public VarianceThresholdSelectorAggregator(double d) {
            this.varianceThreshold = d;
        }

        /* renamed from: createAccumulator, reason: merged with bridge method [inline-methods] */
        public Tuple3<Long, DenseVector, DenseVector> m119createAccumulator() {
            return Tuple3.of(0L, new DenseVector(new double[0]), new DenseVector(new double[0]));
        }

        public Tuple3<Long, DenseVector, DenseVector> add(Vector vector, Tuple3<Long, DenseVector, DenseVector> tuple3) {
            if (((Long) tuple3.f0).longValue() == 0) {
                tuple3.f1 = new DenseVector(vector.size());
                tuple3.f2 = new DenseVector(vector.size());
            }
            tuple3.f0 = Long.valueOf(((Long) tuple3.f0).longValue() + 1);
            BLAS.axpy(1.0d, vector, (DenseVector) tuple3.f1);
            for (int i = 0; i < vector.size(); i++) {
                double[] dArr = ((DenseVector) tuple3.f2).values;
                int i2 = i;
                dArr[i2] = dArr[i2] + (vector.get(i) * vector.get(i));
            }
            return tuple3;
        }

        public VarianceThresholdSelectorModelData getResult(Tuple3<Long, DenseVector, DenseVector> tuple3) {
            long longValue = ((Long) tuple3.f0).longValue();
            DenseVector denseVector = (DenseVector) tuple3.f1;
            DenseVector denseVector2 = (DenseVector) tuple3.f2;
            Preconditions.checkState(longValue > 0, "The training set is empty.");
            return new VarianceThresholdSelectorModelData(denseVector.size(), IntStream.range(0, denseVector.size()).filter(i -> {
                return (denseVector2.values[i] / ((double) longValue)) - ((denseVector.values[i] / ((double) longValue)) * (denseVector.values[i] / ((double) longValue))) > this.varianceThreshold;
            }).toArray());
        }

        public Tuple3<Long, DenseVector, DenseVector> merge(Tuple3<Long, DenseVector, DenseVector> tuple3, Tuple3<Long, DenseVector, DenseVector> tuple32) {
            if (((Long) tuple3.f0).longValue() == 0) {
                return tuple32;
            }
            if (((Long) tuple32.f0).longValue() == 0) {
                return tuple3;
            }
            tuple32.f0 = Long.valueOf(((Long) tuple32.f0).longValue() + ((Long) tuple3.f0).longValue());
            BLAS.axpy(1.0d, (Vector) tuple3.f1, (DenseVector) tuple32.f1);
            BLAS.axpy(1.0d, (Vector) tuple3.f2, (DenseVector) tuple32.f2);
            return tuple32;
        }
    }

    public VarianceThresholdSelector() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* renamed from: fit, reason: merged with bridge method [inline-methods] */
    public VarianceThresholdSelectorModel m118fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        String inputCol = getInputCol();
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        VarianceThresholdSelectorModel m120setModelData = new VarianceThresholdSelectorModel().m120setModelData(tableEnvironment.fromDataStream(DataStreamUtils.aggregate(tableEnvironment.toDataStream(tableArr[0]).map(row -> {
            return (Vector) row.getField(inputCol);
        }, VectorTypeInfo.INSTANCE), new VarianceThresholdSelectorAggregator(getVarianceThreshold()), Types.TUPLE(new TypeInformation[]{Types.LONG, DenseVectorTypeInfo.INSTANCE, DenseVectorTypeInfo.INSTANCE}), TypeInformation.of(VarianceThresholdSelectorModelData.class))));
        ParamUtils.updateExistingParams(m120setModelData, getParamMap());
        return m120setModelData;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static VarianceThresholdSelector load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1001330886:
                if (implMethodName.equals("lambda$fit$6de0a66e$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;Lorg/apache/flink/types/Row;)Lorg/apache/flink/ml/linalg/Vector;")) {
                    String str = (String) serializedLambda.getCapturedArg(0);
                    return row -> {
                        return (Vector) row.getField(str);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
