package edu.iu.dsc.tws.examples.batch.kmeans;

import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.IFunction;
import edu.iu.dsc.tws.api.compute.IMessage;
import edu.iu.dsc.tws.api.compute.TaskContext;
import edu.iu.dsc.tws.api.compute.executor.ExecutionPlan;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.modifiers.Collector;
import edu.iu.dsc.tws.api.compute.modifiers.Receptor;
import edu.iu.dsc.tws.api.compute.nodes.BaseSink;
import edu.iu.dsc.tws.api.compute.nodes.BaseSource;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.api.dataset.DataPartition;
import edu.iu.dsc.tws.api.resource.IPersistentVolume;
import edu.iu.dsc.tws.api.resource.IVolatileVolume;
import edu.iu.dsc.tws.api.resource.IWorker;
import edu.iu.dsc.tws.api.resource.IWorkerController;
import edu.iu.dsc.tws.dataset.partition.EntityPartition;
import edu.iu.dsc.tws.examples.batch.cdfw.CDFConstants;
import edu.iu.dsc.tws.task.ComputeEnvironment;
import edu.iu.dsc.tws.task.dataobjects.DataFileReplicatedReadSource;
import edu.iu.dsc.tws.task.dataobjects.DataObjectSource;
import edu.iu.dsc.tws.task.impl.ComputeConnection;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.impl.TaskExecutor;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/batch/kmeans/KMeansWorker.class */
public class KMeansWorker implements IWorker {
    private static final Logger LOG = Logger.getLogger(KMeansWorker.class.getName());

    /* loaded from: input_file:edu/iu/dsc/tws/examples/batch/kmeans/KMeansWorker$CentroidAggregator.class */
    public static class CentroidAggregator implements IFunction {
        private static final long serialVersionUID = -254264120110286748L;

        public Object onMessage(Object obj, Object obj2) throws ArrayIndexOutOfBoundsException {
            double[][] dArr = (double[][]) obj;
            double[][] dArr2 = (double[][]) obj2;
            double[][] dArr3 = new double[dArr.length][dArr[0].length];
            if (dArr.length != dArr2.length) {
                throw new RuntimeException("Center sizes not equal " + dArr.length + " != " + dArr2.length);
            }
            for (int i = 0; i < dArr.length; i++) {
                for (int i2 = 0; i2 < dArr[0].length; i2++) {
                    dArr3[i][i2] = dArr[i][i2] + dArr2[i][i2];
                }
            }
            return dArr3;
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/batch/kmeans/KMeansWorker$KMeansAllReduceTask.class */
    public static class KMeansAllReduceTask extends BaseSink implements Collector {
        private static final long serialVersionUID = -5190777711234234L;
        private double[][] centroids;
        private double[][] newCentroids;

        public boolean execute(IMessage iMessage) {
            this.centroids = (double[][]) iMessage.getContent();
            this.newCentroids = new double[this.centroids.length][this.centroids[0].length - 1];
            for (int i = 0; i < this.centroids.length; i++) {
                for (int i2 = 0; i2 < this.centroids[0].length - 1; i2++) {
                    this.newCentroids[i][i2] = this.centroids[i][i2] / this.centroids[i][this.centroids[0].length - 1];
                }
            }
            return true;
        }

        public DataPartition<double[][]> get() {
            return new EntityPartition(this.context.taskIndex(), this.newCentroids);
        }

        public Set<String> getCollectibleNames() {
            HashSet hashSet = new HashSet();
            hashSet.add("centroids");
            return hashSet;
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/batch/kmeans/KMeansWorker$KMeansSourceTask.class */
    public static class KMeansSourceTask extends BaseSource implements Receptor {
        private static final long serialVersionUID = -254264120110286748L;
        private double[][] centroid = null;
        private double[][] datapoints = null;
        private KMeansCalculator kMeansCalculator = null;
        private DataObject<?> dataPointsObject = null;
        private DataObject<?> centroidsObject = null;

        public void execute() {
            int parseInt = Integer.parseInt(this.config.getStringValue(CDFConstants.ARGS_NUMBER_OF_DIMENSIONS));
            this.datapoints = (double[][]) this.dataPointsObject.getPartition(this.context.taskIndex()).getConsumer().next();
            this.centroid = (double[][]) this.centroidsObject.getPartition(this.context.taskIndex()).getConsumer().next();
            this.kMeansCalculator = new KMeansCalculator(this.datapoints, this.centroid, parseInt);
            this.context.writeEnd("all-reduce", this.kMeansCalculator.calculate());
        }

        public void add(String str, DataObject<?> dataObject) {
            if ("points".equals(str)) {
                this.dataPointsObject = dataObject;
            }
            if ("centroids".equals(str)) {
                this.centroidsObject = dataObject;
            }
        }

        public Set<String> getReceivableNames() {
            HashSet hashSet = new HashSet();
            hashSet.add("points");
            hashSet.add("centroids");
            return hashSet;
        }
    }

    public void execute(Config config, int i, IWorkerController iWorkerController, IPersistentVolume iPersistentVolume, IVolatileVolume iVolatileVolume) {
        LOG.log(Level.FINE, "Task worker starting: " + i);
        ComputeEnvironment init = ComputeEnvironment.init(config, i, iWorkerController, iPersistentVolume, iVolatileVolume);
        TaskExecutor taskExecutor = init.getTaskExecutor();
        KMeansWorkerParameters build = KMeansWorkerParameters.build(config);
        new KMeansWorkerUtils(config);
        int parallelismValue = build.getParallelismValue();
        int dimension = build.getDimension();
        build.getNumFiles();
        int dsize = build.getDsize();
        int csize = build.getCsize();
        int iterations = build.getIterations();
        String datapointDirectory = build.getDatapointDirectory();
        String centroidDirectory = build.getCentroidDirectory();
        long currentTimeMillis = System.currentTimeMillis();
        ComputeGraph buildDataPointsTG = buildDataPointsTG(datapointDirectory, dsize, parallelismValue, dimension, config);
        ComputeGraph buildCentroidsTG = buildCentroidsTG(centroidDirectory, csize, parallelismValue, dimension, config);
        ComputeGraph buildKMeansTG = buildKMeansTG(parallelismValue, config);
        ExecutionPlan plan = taskExecutor.plan(buildDataPointsTG);
        taskExecutor.execute(buildDataPointsTG, plan);
        DataObject output = taskExecutor.getOutput(buildDataPointsTG, plan, "datapointsink");
        ExecutionPlan plan2 = taskExecutor.plan(buildCentroidsTG);
        taskExecutor.execute(buildCentroidsTG, plan2);
        DataObject output2 = taskExecutor.getOutput(buildCentroidsTG, plan2, "centroidsink");
        long currentTimeMillis2 = System.currentTimeMillis();
        ExecutionPlan plan3 = taskExecutor.plan(buildKMeansTG);
        for (int i2 = 0; i2 < iterations; i2++) {
            taskExecutor.addInput(buildKMeansTG, plan3, "kmeanssource", "points", output);
            taskExecutor.addInput(buildKMeansTG, plan3, "kmeanssource", "centroids", output2);
            taskExecutor.itrExecute(buildKMeansTG, plan3);
            output2 = taskExecutor.getOutput(buildKMeansTG, plan3, "kmeanssink");
        }
        taskExecutor.waitFor(buildKMeansTG, plan3);
        init.close();
        DataPartition partition = output2.getPartition(i);
        double[][] dArr = partition.getConsumer().hasNext() ? (double[][]) partition.getConsumer().next() : null;
        long currentTimeMillis3 = System.currentTimeMillis();
        LOG.info("Total K-Means Execution Time: " + (currentTimeMillis3 - currentTimeMillis) + "\tData Load time : " + (currentTimeMillis2 - currentTimeMillis) + "\tCompute Time : " + (currentTimeMillis3 - currentTimeMillis2));
        if (i == 0) {
            LOG.info("Final Centroids After\t" + iterations + "\titerations\t" + Arrays.toString(dArr[0]));
        }
    }

    public static ComputeGraph buildDataPointsTG(String str, int i, int i2, int i3, Config config) {
        DataObjectSource dataObjectSource = new DataObjectSource("direct", str, i);
        KMeansDataObjectCompute kMeansDataObjectCompute = new KMeansDataObjectCompute("direct", i, i2, i3);
        KMeansDataObjectDirectSink kMeansDataObjectDirectSink = new KMeansDataObjectDirectSink("points");
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(config);
        newBuilder.addSource("datapointsource", dataObjectSource, i2);
        ComputeConnection addCompute = newBuilder.addCompute("datapointcompute", kMeansDataObjectCompute, i2);
        ComputeConnection addSink = newBuilder.addSink("datapointsink", kMeansDataObjectDirectSink, i2);
        addCompute.direct("datapointsource").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        addSink.direct("datapointcompute").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        newBuilder.setTaskGraphName("datapointsTG");
        return newBuilder.build();
    }

    public static ComputeGraph buildCentroidsTG(String str, int i, int i2, int i3, Config config) {
        DataFileReplicatedReadSource dataFileReplicatedReadSource = new DataFileReplicatedReadSource("direct", str);
        KMeansDataObjectCompute kMeansDataObjectCompute = new KMeansDataObjectCompute("direct", i, i3);
        KMeansDataObjectDirectSink kMeansDataObjectDirectSink = new KMeansDataObjectDirectSink("centroids");
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(config);
        newBuilder.addSource("centroidsource", dataFileReplicatedReadSource, i2);
        ComputeConnection addCompute = newBuilder.addCompute("centroidcompute", kMeansDataObjectCompute, i2);
        ComputeConnection addSink = newBuilder.addSink("centroidsink", kMeansDataObjectDirectSink, i2);
        addCompute.direct("centroidsource").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        addSink.direct("centroidcompute").viaEdge("direct").withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        newBuilder.setTaskGraphName("centTG");
        return newBuilder.build();
    }

    public static ComputeGraph buildKMeansTG(int i, Config config) {
        KMeansSourceTask kMeansSourceTask = new KMeansSourceTask();
        KMeansAllReduceTask kMeansAllReduceTask = new KMeansAllReduceTask();
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(config);
        newBuilder.addSource("kmeanssource", kMeansSourceTask, i);
        newBuilder.addSink("kmeanssink", kMeansAllReduceTask, i).allreduce("kmeanssource").viaEdge("all-reduce").withReductionFunction(new CentroidAggregator()).withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        newBuilder.setTaskGraphName("kmeansTG");
        return newBuilder.build();
    }
}
