package org.apache.flink.ml.feature.polynomialexpansion;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.util.ArithmeticUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Transformer;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.class */
public class PolynomialExpansion implements Transformer<PolynomialExpansion>, PolynomialExpansionParams<PolynomialExpansion> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* loaded from: input_file:org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion$PolynomialExpansionFunction.class */
    private static class PolynomialExpansionFunction implements MapFunction<Row, Row> {
        private final int degree;
        private final String inputCol;

        public PolynomialExpansionFunction(int i, String str) {
            this.degree = i;
            this.inputCol = str;
        }

        public Row map(Row row) throws Exception {
            DenseVector sparseVector;
            DenseVector denseVector = (Vector) row.getFieldAs(this.inputCol);
            if (denseVector == null) {
                throw new IllegalArgumentException("The vector must not be null.");
            }
            if (denseVector instanceof DenseVector) {
                int size = denseVector.size();
                double[] dArr = new double[getResultVectorSize(size, this.degree) - 1];
                expandDenseVector(denseVector.values, size - 1, this.degree, 1.0d, dArr, -1);
                sparseVector = new DenseVector(dArr);
            } else {
                if (!(denseVector instanceof SparseVector)) {
                    throw new UnsupportedOperationException("Only supports DenseVector or SparseVector.");
                }
                SparseVector sparseVector2 = (SparseVector) denseVector;
                int[] iArr = sparseVector2.indices;
                double[] dArr2 = sparseVector2.values;
                int size2 = sparseVector2.size();
                int length = sparseVector2.values.length;
                int resultVectorSize = getResultVectorSize(length, this.degree);
                Tuple2 of = Tuple2.of(0, new int[resultVectorSize - 1]);
                Tuple2 of2 = Tuple2.of(0, new double[resultVectorSize - 1]);
                expandSparseVector(iArr, dArr2, length - 1, size2 - 1, this.degree, 1.0d, of, of2, -1);
                sparseVector = new SparseVector(getResultVectorSize(size2, this.degree) - 1, (int[]) of.f1, (double[]) of2.f1);
            }
            return Row.join(row, new Row[]{Row.of(new Object[]{sparseVector})});
        }

        private static int getResultVectorSize(int i, int i2) {
            if (i == 0) {
                return 1;
            }
            if (i == 1 || i2 == 1) {
                return i + i2;
            }
            if (i2 > i) {
                return getResultVectorSize(i2, i);
            }
            long j = 1;
            int i3 = i + 1;
            if (i + i2 < 61) {
                for (int i4 = 1; i4 <= i2; i4++) {
                    j = (j * i3) / i4;
                    i3++;
                }
            } else {
                for (int i5 = 1; i5 <= i2; i5++) {
                    int gcd = ArithmeticUtils.gcd(i3, i5);
                    j = ArithmeticUtils.mulAndCheck(j / (i5 / gcd), i3 / gcd);
                    i3++;
                }
            }
            if (j > 2147483647L) {
                throw new RuntimeException("The expended polynomial size is too large.");
            }
            return (int) j;
        }

        private static int expandDenseVector(double[] dArr, int i, int i2, double d, double[] dArr2, int i3) {
            if (!Double.valueOf(d).equals(Double.valueOf(0.0d))) {
                if (i2 != 0 && i >= 0) {
                    double d2 = dArr[i];
                    int i4 = i - 1;
                    int i5 = 0;
                    int i6 = i3;
                    for (double d3 = d; i5 <= i2 && Math.abs(d3) > 0.0d; d3 *= d2) {
                        i6 = expandDenseVector(dArr, i4, i2 - i5, d3, dArr2, i6);
                        i5++;
                    }
                } else if (i3 >= 0) {
                    dArr2[i3] = d;
                }
            }
            return i3 + getResultVectorSize(i + 1, i2);
        }

        private static int expandSparseVector(int[] iArr, double[] dArr, int i, int i2, int i3, double d, Tuple2<Integer, int[]> tuple2, Tuple2<Integer, double[]> tuple22, int i4) {
            if (!Double.valueOf(d).equals(Double.valueOf(0.0d))) {
                if (i3 != 0 && i >= 0) {
                    double d2 = dArr[i];
                    int i5 = i - 1;
                    int i6 = iArr[i] - 1;
                    int i7 = i4;
                    int i8 = 0;
                    for (double d3 = d; i8 <= i3 && Math.abs(d3) > 0.0d; d3 *= d2) {
                        i7 = expandSparseVector(iArr, dArr, i5, i6, i3 - i8, d3, tuple2, tuple22, i7);
                        i8++;
                    }
                } else if (i4 >= 0) {
                    ((int[]) tuple2.f1)[((Integer) tuple2.f0).intValue()] = i4;
                    ((double[]) tuple22.f1)[((Integer) tuple22.f0).intValue()] = d;
                    tuple2.f0 = Integer.valueOf(((Integer) tuple2.f0).intValue() + 1);
                    tuple22.f0 = Integer.valueOf(((Integer) tuple22.f0).intValue() + 1);
                }
            }
            return i4 + getResultVectorSize(i2 + 1, i3);
        }
    }

    public PolynomialExpansion() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(tableArr[0].getResolvedSchema());
        return new Table[]{tableEnvironment.fromDataStream(tableEnvironment.toDataStream(tableArr[0]).map(new PolynomialExpansionFunction(getDegree(), getInputCol()), new RowTypeInfo((TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), new TypeInformation[]{VectorTypeInfo.INSTANCE}), (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), new String[]{getOutputCol()}))))};
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static PolynomialExpansion load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}
