package com.facebook.presto.ml;

import com.facebook.presto.ml.type.ModelType;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.UnmodifiableIterator;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntIterator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:com/facebook/presto/ml/FeatureUnitNormalizer.class */
public class FeatureUnitNormalizer extends AbstractFeatureTransformation {
    private final Int2DoubleMap mins = new Int2DoubleOpenHashMap();
    private final Int2DoubleMap maxs = new Int2DoubleOpenHashMap();

    public FeatureUnitNormalizer() {
        this.mins.defaultReturnValue(Double.POSITIVE_INFINITY);
        this.maxs.defaultReturnValue(Double.NEGATIVE_INFINITY);
    }

    @Override // com.facebook.presto.ml.Model
    public ModelType getType() {
        return ModelType.MODEL;
    }

    @Override // com.facebook.presto.ml.Model
    public byte[] getSerializedData() {
        SliceOutput output = Slices.allocate(20 * this.mins.size()).getOutput();
        IntIterator it = this.mins.keySet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            output.appendInt(intValue);
            output.appendDouble(this.mins.get(intValue));
            output.appendDouble(this.maxs.get(intValue));
        }
        return output.slice().getBytes();
    }

    public static FeatureUnitNormalizer deserialize(byte[] bArr) {
        BasicSliceInput input = Slices.wrappedBuffer(bArr).getInput();
        FeatureUnitNormalizer featureUnitNormalizer = new FeatureUnitNormalizer();
        while (input.isReadable()) {
            int readInt = input.readInt();
            featureUnitNormalizer.mins.put(readInt, input.readDouble());
            featureUnitNormalizer.maxs.put(readInt, input.readDouble());
        }
        return featureUnitNormalizer;
    }

    @Override // com.facebook.presto.ml.Model
    public void train(Dataset dataset) {
        Iterator<FeatureVector> it = dataset.getDatapoints().iterator();
        while (it.hasNext()) {
            for (Map.Entry<Integer, Double> entry : it.next().getFeatures().entrySet()) {
                int intValue = entry.getKey().intValue();
                double doubleValue = entry.getValue().doubleValue();
                if (doubleValue < this.mins.get(intValue)) {
                    this.mins.put(intValue, doubleValue);
                }
                if (doubleValue > this.maxs.get(intValue)) {
                    this.maxs.put(intValue, doubleValue);
                }
            }
        }
        UnmodifiableIterator it2 = ImmutableSet.copyOf(this.mins.keySet()).iterator();
        while (it2.hasNext()) {
            int intValue2 = ((Integer) it2.next()).intValue();
            if (this.mins.get(intValue2) == this.maxs.get(intValue2)) {
                this.mins.remove(intValue2);
                this.maxs.remove(intValue2);
            }
        }
    }

    @Override // com.facebook.presto.ml.FeatureTransformation
    public FeatureVector transform(FeatureVector featureVector) {
        double d;
        HashMap hashMap = new HashMap();
        for (Map.Entry<Integer, Double> entry : featureVector.getFeatures().entrySet()) {
            int intValue = entry.getKey().intValue();
            double doubleValue = entry.getValue().doubleValue();
            if (this.mins.containsKey(entry.getKey())) {
                double d2 = this.mins.get(intValue);
                d = (doubleValue - d2) / (this.maxs.get(intValue) - d2);
            } else {
                d = 0.0d;
            }
            hashMap.put(entry.getKey(), Double.valueOf(Math.min(1.0d, Math.max(0.0d, d))));
        }
        return new FeatureVector(hashMap);
    }
}
