package info.debatty.spark.knngraphs.builder;

import info.debatty.java.graphs.Graph;
import info.debatty.java.graphs.Neighbor;
import info.debatty.java.graphs.NeighborList;
import info.debatty.java.graphs.Node;
import info.debatty.java.graphs.SimilarityInterface;
import info.debatty.spark.knngraphs.ApproximateSearch;
import info.debatty.spark.knngraphs.BalancedKMedoidsPartitioner;
import info.debatty.spark.knngraphs.NodePartitioner;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import scala.Tuple2;

/* loaded from: input_file:info/debatty/spark/knngraphs/builder/Online.class */
public class Online<T> {
    private static final int PARTITIONING_ITERATIONS = 5;
    private static final int DEFAULT_SEARCH_SPEEDUP = 4;
    private static final int ITERATIONS_FOR_CHECKPOINT = 20;
    private static final int ITERATIONS_FOR_CENTROIDS = 100;
    private final ApproximateSearch<T> searcher;
    private final int k;
    private final JavaSparkContext sc;
    private final SimilarityInterface<T> similarity;
    private final long[] counts;
    private final LinkedList<JavaPairRDD<Node<T>, NeighborList>> previous_rdds;
    private int search_speedup = DEFAULT_SEARCH_SPEEDUP;
    private int iteration = 0;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:info/debatty/spark/knngraphs/builder/Online$AssignFunction.class */
    public static class AssignFunction<U> implements PairFlatMapFunction<Iterator<Tuple2<Node<U>, NeighborList>>, Node<U>, NeighborList> {
        private static final double IMBALANCE = 1.1d;
        private final List<Node<U>> medoids;
        private final long[] counts;
        private final SimilarityInterface<U> similarity;

        public AssignFunction(List<Node<U>> list, long[] jArr, SimilarityInterface<U> similarityInterface) {
            this.medoids = list;
            this.counts = jArr;
            this.similarity = similarityInterface;
        }

        public Iterable<Tuple2<Node<U>, NeighborList>> call(Iterator<Tuple2<Node<U>, NeighborList>> it) throws Exception {
            long sum = sum(this.counts) + 1;
            int size = this.medoids.size();
            int i = (int) ((IMBALANCE * sum) / size);
            ArrayList arrayList = new ArrayList();
            while (it.hasNext()) {
                Tuple2<Node<U>, NeighborList> next = it.next();
                arrayList.add(next);
                double[] dArr = new double[size];
                double[] dArr2 = new double[size];
                for (int i2 = 0; i2 < size; i2++) {
                    dArr[i2] = this.similarity.similarity(this.medoids.get(i2).value, ((Node) next._1).value);
                }
                for (int i3 = 0; i3 < size; i3++) {
                    dArr2[i3] = dArr[i3] * (1 - (this.counts[i3] / i));
                }
                int argmax = argmax(dArr2);
                long[] jArr = this.counts;
                jArr[argmax] = jArr[argmax] + 1;
                ((Node) next._1).setAttribute(BalancedKMedoidsPartitioner.PARTITION_KEY, Integer.valueOf(argmax));
            }
            return arrayList;
        }

        private static long sum(long[] jArr) {
            long j = 0;
            for (long j2 : jArr) {
                j += j2;
            }
            return j;
        }

        private static int argmax(double[] dArr) {
            double d = -1.7976931348623157E308d;
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < dArr.length; i++) {
                if (dArr[i] > d) {
                    d = dArr[i];
                    arrayList = new ArrayList();
                    arrayList.add(Integer.valueOf(i));
                } else if (dArr[i] == d) {
                    arrayList.add(Integer.valueOf(i));
                }
            }
            return arrayList.size() == 1 ? ((Integer) arrayList.get(0)).intValue() : ((Integer) arrayList.get(new Random().nextInt(arrayList.size()))).intValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:info/debatty/spark/knngraphs/builder/Online$PartitionCountFunction.class */
    public static class PartitionCountFunction<U> implements FlatMapFunction<Iterator<Tuple2<Node<U>, NeighborList>>, Long> {
        private PartitionCountFunction() {
        }

        public Iterable<Long> call(Iterator<Tuple2<Node<U>, NeighborList>> it) throws Exception {
            long j = 0;
            while (true) {
                long j2 = j;
                if (!it.hasNext()) {
                    ArrayList arrayList = new ArrayList(1);
                    arrayList.add(Long.valueOf(j2));
                    return arrayList;
                }
                it.next();
                j = j2 + 1;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:info/debatty/spark/knngraphs/builder/Online$UpdateFunction.class */
    public static class UpdateFunction<U> implements PairFlatMapFunction<Iterator<Tuple2<Node<U>, NeighborList>>, Node<U>, NeighborList> {
        private static final int EXPANSION_LEVELS = 3;
        private final NeighborList neighborlist;
        private final SimilarityInterface<U> similarity;
        private final Node<U> node;

        public UpdateFunction(Node<U> node, NeighborList neighborList, SimilarityInterface<U> similarityInterface) {
            this.node = node;
            this.neighborlist = neighborList;
            this.similarity = similarityInterface;
        }

        public Iterable<Tuple2<Node<U>, NeighborList>> call(Iterator<Tuple2<Node<U>, NeighborList>> it) throws Exception {
            Graph graph = new Graph();
            while (it.hasNext()) {
                Tuple2<Node<U>, NeighborList> next = it.next();
                graph.put((Node) next._1, (NeighborList) next._2);
            }
            LinkedList linkedList = new LinkedList();
            LinkedList linkedList2 = new LinkedList();
            HashMap hashMap = new HashMap();
            Iterator it2 = this.neighborlist.iterator();
            while (it2.hasNext()) {
                linkedList.add(((Neighbor) it2.next()).node);
            }
            for (int i = 0; i < EXPANSION_LEVELS; i++) {
                while (!linkedList.isEmpty()) {
                    Node node = (Node) linkedList.pop();
                    NeighborList neighborList = graph.get(node);
                    if (neighborList != null) {
                        Iterator it3 = neighborList.iterator();
                        while (it3.hasNext()) {
                            Neighbor neighbor = (Neighbor) it3.next();
                            if (!hashMap.containsKey(neighbor.node)) {
                                linkedList2.add(neighbor.node);
                            }
                        }
                        neighborList.add(new Neighbor(this.node, this.similarity.similarity(this.node.value, node.value)));
                        hashMap.put(node, Boolean.TRUE);
                    }
                }
                linkedList = linkedList2;
                linkedList2 = new LinkedList();
            }
            ArrayList arrayList = new ArrayList(graph.size());
            for (Node node2 : graph.getNodes()) {
                arrayList.add(new Tuple2(node2, graph.get(node2)));
            }
            return arrayList;
        }
    }

    public Online(int i, SimilarityInterface<T> similarityInterface, JavaSparkContext javaSparkContext, JavaPairRDD<Node<T>, NeighborList> javaPairRDD, int i2) {
        this.similarity = similarityInterface;
        this.k = i;
        this.sc = javaSparkContext;
        this.searcher = new ApproximateSearch<>(javaPairRDD, PARTITIONING_ITERATIONS, i2, similarityInterface);
        javaSparkContext.setCheckpointDir("/tmp/checkpoints");
        this.counts = getCounts();
        this.previous_rdds = new LinkedList<>();
    }

    public final void setSearchSpeedup(int i) {
        this.search_speedup = i;
    }

    public final void addNode(Node<T> node) {
        NeighborList search = this.searcher.search(node, this.k, this.search_speedup);
        LinkedList linkedList = new LinkedList();
        linkedList.add(new Tuple2(node, search));
        JavaPairRDD<Node<T>, NeighborList> partition = partition(this.sc.parallelizePairs(linkedList), this.searcher.getMedoids(), this.counts, this.searcher.getPartitioner().getInternalPartitioner());
        Node node2 = (Node) ((Tuple2) partition.collect().get(0))._1;
        long[] jArr = this.counts;
        int intValue = ((Integer) node2.getAttribute(BalancedKMedoidsPartitioner.PARTITION_KEY)).intValue();
        jArr[intValue] = jArr[intValue] + 1;
        JavaPairRDD<Node<T>, NeighborList> union = update(this.searcher.getGraph(), node, search).union(partition);
        union.cache();
        this.searcher.setGraph(union);
        if (this.iteration % ITERATIONS_FOR_CHECKPOINT == 0) {
            union.checkpoint();
        }
        this.previous_rdds.add(union);
        if (this.iteration > 2) {
            this.previous_rdds.pop().unpersist();
        }
        if (this.iteration % ITERATIONS_FOR_CENTROIDS == 0) {
            this.searcher.getPartitioner().computeNewMedoids(union);
        }
        this.iteration++;
    }

    public final JavaPairRDD<Node<T>, NeighborList> getGraph() {
        return this.searcher.getGraph();
    }

    private JavaPairRDD<Node<T>, NeighborList> update(JavaPairRDD<Node<T>, NeighborList> javaPairRDD, Node<T> node, NeighborList neighborList) {
        return javaPairRDD.mapPartitionsToPair(new UpdateFunction(node, neighborList, this.similarity), true);
    }

    private JavaPairRDD<Node<T>, NeighborList> partition(JavaPairRDD<Node<T>, NeighborList> javaPairRDD, List<Node<T>> list, long[] jArr, NodePartitioner nodePartitioner) {
        return javaPairRDD.mapPartitionsToPair(new AssignFunction(list, jArr, this.similarity), true).partitionBy(nodePartitioner);
    }

    private long[] getCounts() {
        List collect = this.searcher.getGraph().mapPartitions(new PartitionCountFunction(), true).collect();
        long[] jArr = new long[collect.size()];
        for (int i = 0; i < jArr.length; i++) {
            jArr[i] = ((Long) collect.get(i)).longValue();
        }
        return jArr;
    }
}
