package com.aliasi.classify;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.stats.MultivariateEstimator;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Counter;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.Math;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/aliasi/classify/BernoulliClassifier.class */
public class BernoulliClassifier<E> implements JointClassifier<E>, ObjectHandler<Classified<E>>, Serializable {
    static final long serialVersionUID = -7761909693358968780L;
    private final MultivariateEstimator mCategoryDistribution;
    private final FeatureExtractor<E> mFeatureExtractor;
    private final double mActivationThreshold;
    private final Set<String> mFeatureSet;
    private final Map<String, ObjectToCounterMap<String>> mFeatureDistributionMap;

    /* loaded from: input_file:com/aliasi/classify/BernoulliClassifier$Serializer.class */
    static class Serializer<F> extends AbstractExternalizable {
        static final long serialVersionUID = 4803666611627400222L;
        final BernoulliClassifier<F> mClassifier;

        public Serializer(BernoulliClassifier<F> bernoulliClassifier) {
            this.mClassifier = bernoulliClassifier;
        }

        public Serializer() {
            this(null);
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            objectOutput.writeObject(((BernoulliClassifier) this.mClassifier).mCategoryDistribution);
            objectOutput.writeObject(((BernoulliClassifier) this.mClassifier).mFeatureExtractor);
            objectOutput.writeDouble(((BernoulliClassifier) this.mClassifier).mActivationThreshold);
            objectOutput.writeInt(((BernoulliClassifier) this.mClassifier).mFeatureSet.size());
            Iterator<E> it = ((BernoulliClassifier) this.mClassifier).mFeatureSet.iterator();
            while (it.hasNext()) {
                objectOutput.writeUTF((String) it.next());
            }
            objectOutput.writeInt(((BernoulliClassifier) this.mClassifier).mFeatureDistributionMap.size());
            for (Map.Entry entry : ((BernoulliClassifier) this.mClassifier).mFeatureDistributionMap.entrySet()) {
                objectOutput.writeUTF((String) entry.getKey());
                ObjectToCounterMap objectToCounterMap = (ObjectToCounterMap) entry.getValue();
                objectOutput.writeInt(objectToCounterMap.size());
                for (Map.Entry<E, Counter> entry2 : objectToCounterMap.entrySet()) {
                    objectOutput.writeUTF((String) entry2.getKey());
                    objectOutput.writeInt(entry2.getValue().intValue());
                }
            }
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws ClassNotFoundException, IOException {
            MultivariateEstimator multivariateEstimator = (MultivariateEstimator) objectInput.readObject();
            FeatureExtractor featureExtractor = (FeatureExtractor) objectInput.readObject();
            double readDouble = objectInput.readDouble();
            int readInt = objectInput.readInt();
            HashSet hashSet = new HashSet(2 * readInt);
            for (int i = 0; i < readInt; i++) {
                hashSet.add(objectInput.readUTF());
            }
            int readInt2 = objectInput.readInt();
            HashMap hashMap = new HashMap(2 * readInt2);
            for (int i2 = 0; i2 < readInt2; i2++) {
                String readUTF = objectInput.readUTF();
                int readInt3 = objectInput.readInt();
                ObjectToCounterMap objectToCounterMap = new ObjectToCounterMap();
                hashMap.put(readUTF, objectToCounterMap);
                for (int i3 = 0; i3 < readInt3; i3++) {
                    objectToCounterMap.set(objectInput.readUTF(), objectInput.readInt());
                }
            }
            return new BernoulliClassifier(multivariateEstimator, featureExtractor, readDouble, hashSet, hashMap);
        }
    }

    public BernoulliClassifier(FeatureExtractor<E> featureExtractor) {
        this(featureExtractor, 0.0d);
    }

    public BernoulliClassifier(FeatureExtractor<E> featureExtractor, double d) {
        this(new MultivariateEstimator(), featureExtractor, d, new HashSet(), new HashMap());
    }

    BernoulliClassifier(MultivariateEstimator multivariateEstimator, FeatureExtractor<E> featureExtractor, double d, Set<String> set, Map<String, ObjectToCounterMap<String>> map) {
        this.mCategoryDistribution = multivariateEstimator;
        this.mFeatureExtractor = featureExtractor;
        this.mActivationThreshold = d;
        this.mFeatureSet = set;
        this.mFeatureDistributionMap = map;
    }

    public double featureActivationThreshold() {
        return this.mActivationThreshold;
    }

    public FeatureExtractor<E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public String[] categories() {
        String[] strArr = new String[this.mCategoryDistribution.numDimensions()];
        for (int i = 0; i < this.mCategoryDistribution.numDimensions(); i++) {
            strArr[i] = this.mCategoryDistribution.label(i);
        }
        return strArr;
    }

    @Override // com.aliasi.corpus.ObjectHandler
    public void handle(Classified<E> classified) {
        handle(classified.getObject(), classified.getClassification());
    }

    void handle(E e, Classification classification) {
        String bestCategory = classification.bestCategory();
        this.mCategoryDistribution.train(bestCategory, 1L);
        ObjectToCounterMap<String> objectToCounterMap = this.mFeatureDistributionMap.get(bestCategory);
        if (objectToCounterMap == null) {
            objectToCounterMap = new ObjectToCounterMap<>();
            this.mFeatureDistributionMap.put(bestCategory, objectToCounterMap);
        }
        for (String str : activeFeatureSet(e)) {
            objectToCounterMap.increment(str);
            this.mFeatureSet.add(str);
        }
    }

    @Override // com.aliasi.classify.JointClassifier, com.aliasi.classify.ConditionalClassifier, com.aliasi.classify.ScoredClassifier, com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public JointClassification classify(E e) {
        Set<String> activeFeatureSet = activeFeatureSet(e);
        HashSet hashSet = new HashSet(this.mFeatureSet);
        hashSet.removeAll(activeFeatureSet);
        String[] strArr = (String[]) activeFeatureSet.toArray(Strings.EMPTY_STRING_ARRAY);
        String[] strArr2 = (String[]) hashSet.toArray(Strings.EMPTY_STRING_ARRAY);
        ObjectToDoubleMap objectToDoubleMap = new ObjectToDoubleMap();
        int numDimensions = this.mCategoryDistribution.numDimensions();
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= numDimensions) {
                break;
            }
            String label = this.mCategoryDistribution.label(j2);
            double log2 = Math.log2(this.mCategoryDistribution.probability(j2));
            double count = this.mCategoryDistribution.getCount(j2);
            ObjectToCounterMap<String> objectToCounterMap = this.mFeatureDistributionMap.get(label);
            for (String str : strArr) {
                double count2 = objectToCounterMap.getCount(str);
                if (count2 != 0.0d) {
                    log2 += Math.log2((count2 + 1.0d) / (count + 2.0d));
                }
            }
            for (String str2 : strArr2) {
                log2 += Math.log2(((count - objectToCounterMap.getCount(str2)) + 1.0d) / (count + 2.0d));
            }
            objectToDoubleMap.set(label, log2);
            j = j2 + 1;
        }
        String[] strArr3 = new String[numDimensions];
        double[] dArr = new double[numDimensions];
        List<ScoredObject<E>> scoredObjectsOrderedByValueList = objectToDoubleMap.scoredObjectsOrderedByValueList();
        for (int i = 0; i < numDimensions; i++) {
            ScoredObject<E> scoredObject = scoredObjectsOrderedByValueList.get(i);
            strArr3[i] = (String) scoredObject.getObject();
            dArr[i] = scoredObject.score();
        }
        return new JointClassification(strArr3, dArr);
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    private Set<String> activeFeatureSet(E e) {
        HashSet hashSet = new HashSet();
        for (Map.Entry<String, ? extends Number> entry : this.mFeatureExtractor.features(e).entrySet()) {
            String key = entry.getKey();
            if (entry.getValue().doubleValue() > this.mActivationThreshold) {
                hashSet.add(key);
            }
        }
        return hashSet;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.ConditionalClassifier, com.aliasi.classify.ScoredClassifier, com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ ConditionalClassification classify(Object obj) {
        return classify((BernoulliClassifier<E>) obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.ScoredClassifier, com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ ScoredClassification classify(Object obj) {
        return classify((BernoulliClassifier<E>) obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ RankedClassification classify(Object obj) {
        return classify((BernoulliClassifier<E>) obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ Classification classify(Object obj) {
        return classify((BernoulliClassifier<E>) obj);
    }
}
