package org.tribuo.common.nearest;

import com.oracle.labs.mlrg.olcut.util.Pair;
import com.oracle.labs.mlrg.olcut.util.StreamUtil;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.security.AccessController;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.Future;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.common.nearest.KNNTrainer;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/common/nearest/KNNModel.class */
public class KNNModel<T extends Output<T>> extends Model<T> {
    private static final long serialVersionUID = 1;
    private final Pair<SparseVector, T>[] vectors;
    private final int k;
    private final KNNTrainer.Distance distance;
    private final int numThreads;
    private final Backend parallelBackend;
    private final EnsembleCombiner<T> combiner;
    private static final Logger logger = Logger.getLogger(KNNModel.class.getName());
    private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory();

    /* loaded from: input_file:org/tribuo/common/nearest/KNNModel$Backend.class */
    public enum Backend {
        STREAMS,
        THREADPOOL,
        INNERTHREADPOOL
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/common/nearest/KNNModel$CustomForkJoinWorkerThreadFactory.class */
    public static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
        private CustomForkJoinWorkerThreadFactory() {
        }

        @Override // java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory
        public final ForkJoinWorkerThread newThread(ForkJoinPool forkJoinPool) {
            return (ForkJoinWorkerThread) AccessController.doPrivileged(() -> {
                return new ForkJoinWorkerThread(forkJoinPool) { // from class: org.tribuo.common.nearest.KNNModel.CustomForkJoinWorkerThreadFactory.1
                };
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/common/nearest/KNNModel$OutputDoublePair.class */
    public static final class OutputDoublePair<T extends Output<T>> implements Comparable<OutputDoublePair<T>> {
        T output;
        double value;

        public OutputDoublePair(T t, double d) {
            this.output = t;
            this.value = d;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            OutputDoublePair outputDoublePair = (OutputDoublePair) obj;
            return Double.compare(outputDoublePair.value, this.value) == 0 && this.output.equals(outputDoublePair.output);
        }

        public int hashCode() {
            return Objects.hash(this.output, Double.valueOf(this.value));
        }

        @Override // java.lang.Comparable
        public int compareTo(OutputDoublePair<T> outputDoublePair) {
            return Double.compare(this.value, outputDoublePair.value);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public KNNModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, boolean z, int i, KNNTrainer.Distance distance, int i2, EnsembleCombiner<T> ensembleCombiner, Pair<SparseVector, T>[] pairArr, Backend backend) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, z);
        this.k = i;
        this.distance = distance;
        this.numThreads = i2;
        this.combiner = ensembleCombiner;
        this.parallelBackend = backend;
        this.vectors = pairArr;
    }

    public Prediction<T> predict(Example<T> example) {
        Function function;
        List list;
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
        switch (this.distance) {
            case L1:
                function = pair -> {
                    return new OutputDoublePair((Output) pair.getB(), ((SparseVector) pair.getA()).l1Distance(createSparseVector));
                };
                break;
            case L2:
                function = pair2 -> {
                    return new OutputDoublePair((Output) pair2.getB(), ((SparseVector) pair2.getA()).l2Distance(createSparseVector));
                };
                break;
            case COSINE:
                function = pair3 -> {
                    return new OutputDoublePair((Output) pair3.getB(), ((SparseVector) pair3.getA()).cosineDistance(createSparseVector));
                };
                break;
            default:
                throw new IllegalStateException("Unknown distance function " + this.distance);
        }
        Stream of = Stream.of((Object[]) this.vectors);
        if (this.numThreads > 1) {
            try {
                Function function2 = function;
                list = (List) (System.getSecurityManager() == null ? new ForkJoinPool(this.numThreads) : new ForkJoinPool(this.numThreads, THREAD_FACTORY, null, false)).submit(() -> {
                    return (List) StreamUtil.boundParallelism((Stream) of.parallel()).map(function2).sorted().limit(this.k).map(outputDoublePair -> {
                        return new Prediction(outputDoublePair.output, createSparseVector.numActiveElements(), example);
                    }).collect(Collectors.toList());
                }).get();
            } catch (InterruptedException | ExecutionException e) {
                logger.log(Level.SEVERE, "Exception when predicting in KNNModel", e);
                throw new IllegalStateException("Failed to process example in parallel", e);
            }
        } else {
            list = (List) of.map(function).sorted().limit(this.k).map(outputDoublePair -> {
                return new Prediction(outputDoublePair.output, createSparseVector.numActiveElements(), example);
            }).collect(Collectors.toList());
        }
        return this.combiner.combine(this.outputIDInfo, list);
    }

    protected List<Prediction<T>> innerPredict(Iterable<Example<T>> iterable) {
        BiFunction biFunction;
        if (this.numThreads > 1) {
            return innerPredictMultithreaded(iterable);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        PriorityQueue priorityQueue = new PriorityQueue(this.k, (outputDoublePair, outputDoublePair2) -> {
            return Double.compare(outputDoublePair2.value, outputDoublePair.value);
        });
        switch (this.distance) {
            case L1:
                biFunction = (sparseVector, sparseVector2) -> {
                    return Double.valueOf(sparseVector2.l1Distance(sparseVector));
                };
                break;
            case L2:
                biFunction = (sparseVector3, sparseVector4) -> {
                    return Double.valueOf(sparseVector4.l2Distance(sparseVector3));
                };
                break;
            case COSINE:
                biFunction = (sparseVector5, sparseVector6) -> {
                    return Double.valueOf(sparseVector6.cosineDistance(sparseVector5));
                };
                break;
            default:
                throw new IllegalStateException("Unknown distance function " + this.distance);
        }
        for (Example<T> example : iterable) {
            priorityQueue.clear();
            arrayList2.clear();
            SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
            for (int i = 0; i < this.vectors.length; i++) {
                double doubleValue = ((Double) biFunction.apply(createSparseVector, this.vectors[i].getA())).doubleValue();
                if (priorityQueue.size() < this.k) {
                    priorityQueue.offer(new OutputDoublePair((Output) this.vectors[i].getB(), doubleValue));
                } else if (Double.compare(doubleValue, ((OutputDoublePair) priorityQueue.peek()).value) < 0) {
                    OutputDoublePair outputDoublePair3 = (OutputDoublePair) priorityQueue.poll();
                    outputDoublePair3.output = (T) this.vectors[i].getB();
                    outputDoublePair3.value = doubleValue;
                    priorityQueue.offer(outputDoublePair3);
                }
            }
            Iterator it = priorityQueue.iterator();
            while (it.hasNext()) {
                arrayList2.add(new Prediction(((OutputDoublePair) it.next()).output, createSparseVector.numActiveElements(), example));
            }
            arrayList.add(this.combiner.combine(this.outputIDInfo, arrayList2));
        }
        return arrayList;
    }

    private List<Prediction<T>> innerPredictMultithreaded(Iterable<Example<T>> iterable) {
        switch (this.parallelBackend) {
            case STREAMS:
                logger.log(Level.FINE, "Parallel backend - streams");
                return innerPredictStreams(iterable);
            case THREADPOOL:
                logger.log(Level.FINE, "Parallel backend - threadpool");
                return innerPredictThreadPool(iterable);
            case INNERTHREADPOOL:
                logger.log(Level.FINE, "Parallel backend - within example threadpool");
                return innerPredictWithinExampleThreadPool(iterable);
            default:
                throw new IllegalArgumentException("Unknown backend " + this.parallelBackend);
        }
    }

    private List<Prediction<T>> innerPredictStreams(Iterable<Example<T>> iterable) {
        Function function;
        ArrayList arrayList = new ArrayList();
        List list = null;
        ForkJoinPool forkJoinPool = System.getSecurityManager() == null ? new ForkJoinPool(this.numThreads) : new ForkJoinPool(this.numThreads, THREAD_FACTORY, null, false);
        for (Example<T> example : iterable) {
            SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
            switch (this.distance) {
                case L1:
                    function = pair -> {
                        return new OutputDoublePair((Output) pair.getB(), ((SparseVector) pair.getA()).l1Distance(createSparseVector));
                    };
                    break;
                case L2:
                    function = pair2 -> {
                        return new OutputDoublePair((Output) pair2.getB(), ((SparseVector) pair2.getA()).l2Distance(createSparseVector));
                    };
                    break;
                case COSINE:
                    function = pair3 -> {
                        return new OutputDoublePair((Output) pair3.getB(), ((SparseVector) pair3.getA()).cosineDistance(createSparseVector));
                    };
                    break;
                default:
                    throw new IllegalStateException("Unknown distance function " + this.distance);
            }
            Stream of = Stream.of((Object[]) this.vectors);
            try {
                Function function2 = function;
                list = (List) forkJoinPool.submit(() -> {
                    return (List) StreamUtil.boundParallelism((Stream) of.parallel()).map(function2).sorted().limit(this.k).map(outputDoublePair -> {
                        return new Prediction(outputDoublePair.output, createSparseVector.numActiveElements(), example);
                    }).collect(Collectors.toList());
                }).get();
            } catch (InterruptedException | ExecutionException e) {
                logger.log(Level.SEVERE, "Exception when predicting in KNNModel", e);
            }
            arrayList.add(this.combiner.combine(this.outputIDInfo, list));
        }
        return arrayList;
    }

    private List<Prediction<T>> innerPredictThreadPool(Iterable<Example<T>> iterable) {
        BiFunction biFunction;
        switch (this.distance) {
            case L1:
                biFunction = (sparseVector, sparseVector2) -> {
                    return Double.valueOf(sparseVector2.l1Distance(sparseVector));
                };
                break;
            case L2:
                biFunction = (sparseVector3, sparseVector4) -> {
                    return Double.valueOf(sparseVector4.l2Distance(sparseVector3));
                };
                break;
            case COSINE:
                biFunction = (sparseVector5, sparseVector6) -> {
                    return Double.valueOf(sparseVector6.cosineDistance(sparseVector5));
                };
                break;
            default:
                throw new IllegalStateException("Unknown distance function " + this.distance);
        }
        ArrayList arrayList = new ArrayList();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
        ArrayList arrayList2 = new ArrayList();
        ThreadLocal withInitial = ThreadLocal.withInitial(() -> {
            return new PriorityQueue(this.k, (outputDoublePair, outputDoublePair2) -> {
                return Double.compare(outputDoublePair2.value, outputDoublePair.value);
            });
        });
        for (Example<T> example : iterable) {
            BiFunction biFunction2 = biFunction;
            arrayList2.add(newFixedThreadPool.submit(() -> {
                return innerPredictOne(withInitial, this.vectors, this.combiner, biFunction2, this.featureIDMap, this.outputIDInfo, this.k, example);
            }));
        }
        try {
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                arrayList.add(((Future) it.next()).get());
            }
            newFixedThreadPool.shutdown();
            return arrayList;
        } catch (InterruptedException | ExecutionException e) {
            throw new IllegalStateException("Thread pool went bang", e);
        }
    }

    private List<Prediction<T>> innerPredictWithinExampleThreadPool(Iterable<Example<T>> iterable) {
        BiFunction<SparseVector, SparseVector, Double> biFunction;
        switch (this.distance) {
            case L1:
                biFunction = (sparseVector, sparseVector2) -> {
                    return Double.valueOf(sparseVector2.l1Distance(sparseVector));
                };
                break;
            case L2:
                biFunction = (sparseVector3, sparseVector4) -> {
                    return Double.valueOf(sparseVector4.l2Distance(sparseVector3));
                };
                break;
            case COSINE:
                biFunction = (sparseVector5, sparseVector6) -> {
                    return Double.valueOf(sparseVector6.cosineDistance(sparseVector5));
                };
                break;
            default:
                throw new IllegalStateException("Unknown distance function " + this.distance);
        }
        ArrayList arrayList = new ArrayList();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
        ThreadLocal<PriorityQueue<OutputDoublePair<T>>> withInitial = ThreadLocal.withInitial(() -> {
            return new PriorityQueue(this.k, (outputDoublePair, outputDoublePair2) -> {
                return Double.compare(outputDoublePair2.value, outputDoublePair.value);
            });
        });
        Iterator<Example<T>> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(innerPredictThreadPool(newFixedThreadPool, withInitial, biFunction, it.next()));
        }
        newFixedThreadPool.shutdown();
        return arrayList;
    }

    private Prediction<T> innerPredictThreadPool(ExecutorService executorService, ThreadLocal<PriorityQueue<OutputDoublePair<T>>> threadLocal, BiFunction<SparseVector, SparseVector, Double> biFunction, Example<T> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numThreads; i++) {
            int length = i * (this.vectors.length / this.numThreads);
            int length2 = (i + 1) * (this.vectors.length / this.numThreads);
            arrayList.add(executorService.submit(() -> {
                return innerPredictChunk(threadLocal, this.vectors, length, length2, biFunction, this.k, createSparseVector);
            }));
        }
        PriorityQueue priorityQueue = new PriorityQueue(this.k, (outputDoublePair, outputDoublePair2) -> {
            return Double.compare(outputDoublePair2.value, outputDoublePair.value);
        });
        try {
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                for (OutputDoublePair outputDoublePair3 : (List) ((Future) it.next()).get()) {
                    if (priorityQueue.size() < this.k) {
                        priorityQueue.offer(outputDoublePair3);
                    } else if (Double.compare(outputDoublePair3.value, ((OutputDoublePair) priorityQueue.peek()).value) < 0) {
                        priorityQueue.poll();
                        priorityQueue.offer(outputDoublePair3);
                    }
                }
            }
            ArrayList arrayList2 = new ArrayList();
            Iterator it2 = priorityQueue.iterator();
            while (it2.hasNext()) {
                arrayList2.add(new Prediction(((OutputDoublePair) it2.next()).output, createSparseVector.numActiveElements(), example));
            }
            return this.combiner.combine(this.outputIDInfo, arrayList2);
        } catch (InterruptedException | ExecutionException e) {
            throw new IllegalStateException("Thread pool went bang", e);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <T extends Output<T>> List<OutputDoublePair<T>> innerPredictChunk(ThreadLocal<PriorityQueue<OutputDoublePair<T>>> threadLocal, Pair<SparseVector, T>[] pairArr, int i, int i2, BiFunction<SparseVector, SparseVector, Double> biFunction, int i3, SparseVector sparseVector) {
        PriorityQueue<OutputDoublePair<T>> priorityQueue = threadLocal.get();
        priorityQueue.clear();
        int min = Math.min(i2, pairArr.length);
        for (int i4 = i; i4 < min; i4++) {
            double doubleValue = ((Double) biFunction.apply(sparseVector, pairArr[i4].getA())).doubleValue();
            if (priorityQueue.size() < i3) {
                priorityQueue.offer(new OutputDoublePair<>((Output) pairArr[i4].getB(), doubleValue));
            } else if (Double.compare(doubleValue, priorityQueue.peek().value) < 0) {
                OutputDoublePair<T> poll = priorityQueue.poll();
                poll.output = (T) pairArr[i4].getB();
                poll.value = doubleValue;
                priorityQueue.offer(poll);
            }
        }
        return new ArrayList(priorityQueue);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <T extends Output<T>> Prediction<T> innerPredictOne(ThreadLocal<PriorityQueue<OutputDoublePair<T>>> threadLocal, Pair<SparseVector, T>[] pairArr, EnsembleCombiner<T> ensembleCombiner, BiFunction<SparseVector, SparseVector, Double> biFunction, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, int i, Example<T> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, immutableFeatureMap, false);
        PriorityQueue<OutputDoublePair<T>> priorityQueue = threadLocal.get();
        priorityQueue.clear();
        for (int i2 = 0; i2 < pairArr.length; i2++) {
            double doubleValue = ((Double) biFunction.apply(createSparseVector, pairArr[i2].getA())).doubleValue();
            if (priorityQueue.size() < i) {
                priorityQueue.offer(new OutputDoublePair<>((Output) pairArr[i2].getB(), doubleValue));
            } else if (Double.compare(doubleValue, priorityQueue.peek().value) < 0) {
                OutputDoublePair<T> poll = priorityQueue.poll();
                poll.output = (T) pairArr[i2].getB();
                poll.value = doubleValue;
                priorityQueue.offer(poll);
            }
        }
        ArrayList arrayList = new ArrayList();
        Iterator<OutputDoublePair<T>> it = priorityQueue.iterator();
        while (it.hasNext()) {
            arrayList.add(new Prediction(it.next().output, createSparseVector.numActiveElements(), example));
        }
        return ensembleCombiner.combine(immutableOutputInfo, arrayList);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        return Optional.empty();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public KNNModel<T> m4copy(String str, ModelProvenance modelProvenance) {
        Pair[] pairArr = new Pair[this.vectors.length];
        for (int i = 0; i < this.vectors.length; i++) {
            pairArr[i] = new Pair(((SparseVector) this.vectors[i].getA()).copy(), ((Output) this.vectors[i].getB()).copy());
        }
        return new KNNModel<>(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.generatesProbabilities, this.k, this.distance, this.numThreads, this.combiner, pairArr, this.parallelBackend);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
    }
}
