package org.apache.flink.ml.feature.featurehasher;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Transformer;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.shaded.guava30.com.google.common.hash.Hashing;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.types.DataType;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/featurehasher/FeatureHasher.class */
public class FeatureHasher implements Transformer<FeatureHasher>, FeatureHasherParams<FeatureHasher> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private static final org.apache.flink.shaded.guava30.com.google.common.hash.HashFunction HASH = Hashing.murmur3_32(0);

    /* loaded from: input_file:org/apache/flink/ml/feature/featurehasher/FeatureHasher$HashFunction.class */
    private static class HashFunction implements MapFunction<Row, Row> {
        private final String[] categoricalCols;
        private final int numFeatures;
        private final String[] numericCols;

        public HashFunction(String[] strArr, String[] strArr2, int i) {
            this.categoricalCols = strArr2;
            this.numFeatures = i;
            this.numericCols = (String[]) ArrayUtils.removeElements(strArr, this.categoricalCols);
        }

        public Row map(Row row) {
            TreeMap treeMap = new TreeMap();
            for (String str : this.numericCols) {
                if (null != row.getField(str)) {
                    FeatureHasher.updateMap(str, ((Number) row.getFieldAs(str)).doubleValue(), treeMap, this.numFeatures);
                }
            }
            for (String str2 : this.categoricalCols) {
                if (null != row.getField(str2)) {
                    FeatureHasher.updateMap(str2 + "=" + row.getField(str2), 1.0d, treeMap, this.numFeatures);
                }
            }
            int size = treeMap.size();
            int[] iArr = new int[size];
            double[] dArr = new double[size];
            int i = 0;
            for (Map.Entry entry : treeMap.entrySet()) {
                iArr[i] = ((Integer) entry.getKey()).intValue();
                dArr[i] = ((Double) entry.getValue()).doubleValue();
                i++;
            }
            return Row.join(row, new Row[]{Row.of(new Object[]{new SparseVector(this.numFeatures, iArr, dArr)})});
        }
    }

    public FeatureHasher() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        ResolvedSchema resolvedSchema = tableArr[0].getResolvedSchema();
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(resolvedSchema);
        return new Table[]{tableEnvironment.fromDataStream(tableEnvironment.toDataStream(tableArr[0]).map(new HashFunction(getInputCols(), generateCategoricalCols(resolvedSchema, getInputCols(), getCategoricalCols()), getNumFeatures()), new RowTypeInfo((TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), new TypeInformation[]{VectorTypeInfo.INSTANCE}), (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), new String[]{getOutputCol()}))))};
    }

    private String[] generateCategoricalCols(ResolvedSchema resolvedSchema, String[] strArr, String[] strArr2) {
        if (null == strArr) {
            return strArr2;
        }
        List asList = Arrays.asList(strArr2);
        List asList2 = Arrays.asList(strArr);
        if (strArr2.length > 0 && !asList2.containsAll(asList)) {
            throw new IllegalArgumentException("CategoricalCols must be included in inputCols!");
        }
        List columnDataTypes = resolvedSchema.getColumnDataTypes();
        List columnNames = resolvedSchema.getColumnNames();
        ArrayList arrayList = new ArrayList();
        for (String str : strArr) {
            int i = 0;
            while (true) {
                if (i >= columnNames.size()) {
                    break;
                }
                if (str.equals(columnNames.get(i))) {
                    arrayList.add((DataType) columnDataTypes.get(i));
                    break;
                }
                i++;
            }
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < strArr.length; i2++) {
            if (asList.contains(strArr[i2]) || DataTypes.BOOLEAN().equals(arrayList.get(i2)) || DataTypes.STRING().equals(arrayList.get(i2))) {
                arrayList2.add(strArr[i2]);
            }
        }
        return (String[]) arrayList2.toArray(new String[0]);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void updateMap(String str, double d, TreeMap<Integer, Double> treeMap, int i) {
        int floorMod = Math.floorMod(Math.abs(HASH.hashUnencodedChars(str).asInt()), i);
        if (treeMap.containsKey(Integer.valueOf(floorMod))) {
            treeMap.put(Integer.valueOf(floorMod), Double.valueOf(treeMap.get(Integer.valueOf(floorMod)).doubleValue() + d));
        } else {
            treeMap.put(Integer.valueOf(floorMod), Double.valueOf(d));
        }
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static FeatureHasher load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}
