package edu.iu.dsc.tws.examples.ml.svm.data;

import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.IFunction;
import edu.iu.dsc.tws.api.compute.IMessage;
import edu.iu.dsc.tws.api.compute.executor.ExecutionPlan;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.OperationMode;
import edu.iu.dsc.tws.api.compute.nodes.BaseSink;
import edu.iu.dsc.tws.api.compute.nodes.BaseSource;
import edu.iu.dsc.tws.api.dataset.DataObject;
import edu.iu.dsc.tws.dataset.partition.EntityPartition;
import edu.iu.dsc.tws.examples.batch.cdfw.CDFConstants;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.impl.TaskWorker;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/data/SourceTaskDataLoader.class */
public class SourceTaskDataLoader extends TaskWorker {
    private static final Logger LOG = Logger.getLogger(SourceTaskDataLoader.class.getName());
    private static int workers = 1;
    private static int parallelism = 4;
    private static String dataSource = "";

    /* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/data/SourceTaskDataLoader$DataSourceTask.class */
    private class DataSourceTask extends BaseSource {
        private static final long serialVersionUID = -1836625523925581215L;
        private Object object;
        private DataObject<?> dataPointsObject;

        private DataSourceTask() {
            this.object = null;
            this.dataPointsObject = null;
        }

        public void execute() {
            this.context.writeEnd("all-reduce", "s");
        }

        private Object getTaskIndexDataPoints(int i) {
            EntityPartition partition = this.dataPointsObject.getPartition(i);
            if (partition != null) {
                this.object = getDataObjects(i, (DataObject) partition.getConsumer().next());
            }
            return this.object;
        }

        public Object getDataObjects(int i, DataObject<?> dataObject) {
            Iterator it = (Iterator) dataObject.getPartition(i).getConsumer().next();
            ArrayList arrayList = new ArrayList();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            return arrayList.get(0);
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/data/SourceTaskDataLoader$SimpleDataAggregator.class */
    public class SimpleDataAggregator implements IFunction {
        private static final long serialVersionUID = 4948225063068433511L;
        private Object object;

        public SimpleDataAggregator() {
        }

        public Object onMessage(Object obj, Object obj2) {
            this.object = obj;
            SourceTaskDataLoader.LOG.info(String.format("Object Types : %s, %s", obj.getClass().getName(), obj2.getClass().getName()));
            return this.object;
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/data/SourceTaskDataLoader$SimpleDataAllReduceTask.class */
    private static class SimpleDataAllReduceTask extends BaseSink {
        private static final long serialVersionUID = 5705351508072337994L;
        private Object object;

        private SimpleDataAllReduceTask() {
        }

        public boolean execute(IMessage iMessage) {
            this.object = iMessage.getContent();
            SourceTaskDataLoader.LOG.info(String.format("Object Instance : %s", this.object.getClass().getName()));
            return true;
        }
    }

    public void execute() {
        getParams();
        ComputeGraphBuilder newBuilder = ComputeGraphBuilder.newBuilder(this.config);
        DataSourceTask dataSourceTask = new DataSourceTask();
        SimpleDataAllReduceTask simpleDataAllReduceTask = new SimpleDataAllReduceTask();
        newBuilder.addSource("kmeanssource", dataSourceTask, parallelism);
        newBuilder.addSink("kmeanssink", simpleDataAllReduceTask, parallelism).allreduce("kmeanssource").viaEdge("all-reduce").withReductionFunction(new SimpleDataAggregator()).withDataType(MessageTypes.OBJECT);
        newBuilder.setMode(OperationMode.BATCH);
        ComputeGraph build = newBuilder.build();
        ExecutionPlan plan = this.taskExecutor.plan(build);
        this.taskExecutor.execute(build, plan);
        this.taskExecutor.getOutput(build, plan, "kmeanssink");
    }

    public void getParams() {
        workers = this.config.getIntegerValue("workers", 1).intValue();
        parallelism = this.config.getIntegerValue(CDFConstants.ARGS_PARALLELISM_VALUE, 4).intValue();
        dataSource = this.config.getStringValue("training_data_dir", "");
    }
}
