package com.aliasi.test.unit.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.KnnClassifier;
import com.aliasi.classify.ScoredClassification;
import com.aliasi.matrix.EuclideanDistance;
import com.aliasi.matrix.Vector;
import com.aliasi.tokenizer.IndoEuropeanTokenizerFactory;
import com.aliasi.tokenizer.TokenFeatureExtractor;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Distance;
import com.aliasi.util.Proximity;
import java.io.IOException;
import java.io.Serializable;
import junit.framework.Assert;
import org.junit.Test;

/* loaded from: input_file:com/aliasi/test/unit/classify/KnnClassifierTest.class */
public class KnnClassifierTest {
    static final TokenFeatureExtractor FEATURE_EXTRACTOR = new TokenFeatureExtractor(IndoEuropeanTokenizerFactory.INSTANCE);
    static final Distance<Vector> DISTANCE = EuclideanDistance.DISTANCE;

    /* loaded from: input_file:com/aliasi/test/unit/classify/KnnClassifierTest$TestProximity.class */
    static class TestProximity implements Proximity<Vector>, Serializable {
        TestProximity() {
        }

        @Override // com.aliasi.util.Proximity
        public double proximity(Vector vector, Vector vector2) {
            return 1.0d / (1.0d + KnnClassifierTest.DISTANCE.distance(vector, vector2));
        }
    }

    static void handle(KnnClassifier knnClassifier, String str, Classification classification) {
        knnClassifier.handle(new Classified(str, classification));
    }

    @Test
    public void testOne() throws ClassNotFoundException, IOException {
        String[] strArr = {"a a b", "a b b"};
        String[] strArr2 = {"A", "B"};
        KnnClassifier knnClassifier = new KnnClassifier(FEATURE_EXTRACTOR, 1);
        for (int i = 0; i < strArr.length; i++) {
            handle(knnClassifier, strArr[i], new Classification(strArr2[i]));
        }
        ScoredClassification classify = knnClassifier.classify((KnnClassifier) "a a a a b b");
        Assert.assertEquals("A", classify.bestCategory());
        Assert.assertEquals("A", classify.category(0));
        Assert.assertEquals("B", classify.category(1));
        Assert.assertEquals(Double.valueOf(1.0d), Double.valueOf(classify.score(0)));
        Assert.assertEquals(Double.valueOf(0.0d), Double.valueOf(classify.score(1)));
        ((KnnClassifier) AbstractExternalizable.serializeDeserialize(knnClassifier)).classify((KnnClassifier) "a a a a b b");
        Assert.assertEquals("A", classify.bestCategory());
        Assert.assertEquals("A", classify.category(0));
        Assert.assertEquals("B", classify.category(1));
        Assert.assertEquals(Double.valueOf(1.0d), Double.valueOf(classify.score(0)));
        Assert.assertEquals(Double.valueOf(0.0d), Double.valueOf(classify.score(1)));
    }

    @Test
    public void testTwo() throws ClassNotFoundException, IOException {
        String[] strArr = {"a a b", "a b b", "a a a", "a a a a a b", "a a b b"};
        String[] strArr2 = {"A", "B", "A", "A", "B"};
        KnnClassifier knnClassifier = new KnnClassifier(FEATURE_EXTRACTOR, 3);
        for (int i = 0; i < strArr.length; i++) {
            handle(knnClassifier, strArr[i], new Classification(strArr2[i]));
        }
        ScoredClassification classify = knnClassifier.classify((KnnClassifier) "a a b");
        Assert.assertEquals("A", classify.bestCategory());
        Assert.assertEquals("A", classify.category(0));
        Assert.assertEquals("B", classify.category(1));
        Assert.assertEquals(Double.valueOf(2.0d), Double.valueOf(classify.score(0)));
        Assert.assertEquals(Double.valueOf(1.0d), Double.valueOf(classify.score(1)));
        ScoredClassification classify2 = ((KnnClassifier) AbstractExternalizable.serializeDeserialize(knnClassifier)).classify((KnnClassifier) "a a b");
        Assert.assertEquals("A", classify2.bestCategory());
        Assert.assertEquals("A", classify2.category(0));
        Assert.assertEquals("B", classify2.category(1));
        Assert.assertEquals(Double.valueOf(2.0d), Double.valueOf(classify2.score(0)));
        Assert.assertEquals(Double.valueOf(1.0d), Double.valueOf(classify2.score(1)));
    }

    @Test
    public void testThree() {
        String[] strArr = {"a a b", "a b b", "a a a", "b b b"};
        String[] strArr2 = {"A", "B", "A", "B"};
        KnnClassifier knnClassifier = new KnnClassifier(FEATURE_EXTRACTOR, Integer.MAX_VALUE, new TestProximity(), true);
        for (int i = 0; i < strArr.length; i++) {
            handle(knnClassifier, strArr[i], new Classification(strArr2[i]));
        }
        double sqrt = 1.0d / (1.0d + Math.sqrt(sqrDiff(2.0d, 1.0d) + sqrDiff(1.0d, 2.0d)));
        double sqrt2 = 1.0d / (1.0d + Math.sqrt(sqrDiff(2.0d, 3.0d) + sqrDiff(1.0d, 0.0d)));
        double sqrt3 = 1.0d / (1.0d + Math.sqrt(sqrDiff(2.0d, 0.0d) + sqrDiff(1.0d, 3.0d)));
        double sqrt4 = 1.0d / (1.0d + Math.sqrt(sqrDiff(1.0d, 3.0d) + sqrDiff(1.0d, 0.0d)));
        double sqrt5 = 1.0d / (1.0d + Math.sqrt(sqrDiff(1.0d, 0.0d) + sqrDiff(2.0d, 3.0d)));
        double sqrt6 = 1.0d / (1.0d + Math.sqrt(sqrDiff(3.0d, 0.0d) + sqrDiff(0.0d, 3.0d)));
        ScoredClassification[] scoredClassificationArr = new ScoredClassification[strArr.length];
        for (int i2 = 0; i2 < strArr.length; i2++) {
            scoredClassificationArr[i2] = knnClassifier.classify((KnnClassifier) strArr[i2]);
        }
        for (int i3 = 0; i3 < strArr.length; i3++) {
            Assert.assertEquals(strArr2[i3], scoredClassificationArr[i3].bestCategory());
        }
    }

    static double sqrDiff(double d, double d2) {
        double d3 = d - d2;
        return d3 * d3;
    }
}
