package edu.iu.dsc.tws.examples.task.batch;

import edu.iu.dsc.tws.api.comms.Op;
import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.TaskContext;
import edu.iu.dsc.tws.api.compute.nodes.ISink;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.examples.task.BenchTaskWorker;
import edu.iu.dsc.tws.examples.task.batch.verifiers.ReduceVerifier;
import edu.iu.dsc.tws.examples.utils.bench.BenchmarkConstants;
import edu.iu.dsc.tws.examples.utils.bench.BenchmarkUtils;
import edu.iu.dsc.tws.examples.utils.bench.Timing;
import edu.iu.dsc.tws.examples.verification.ResultsVerifier;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.typed.AllReduceCompute;
import java.util.List;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/task/batch/BTAllReduceExample.class */
public class BTAllReduceExample extends BenchTaskWorker {
    private static final Logger LOG = Logger.getLogger(BTAllReduceExample.class.getName());

    /* loaded from: input_file:edu/iu/dsc/tws/examples/task/batch/BTAllReduceExample$AllReduceSinkTask.class */
    protected static class AllReduceSinkTask extends AllReduceCompute<int[]> implements ISink {
        private static final long serialVersionUID = -254264903510284798L;
        private boolean timingCondition;
        private ResultsVerifier<int[], int[]> resultsVerifier;
        private boolean verified = true;

        protected AllReduceSinkTask() {
        }

        public void prepare(Config config, TaskContext taskContext) {
            super.prepare(config, taskContext);
            this.timingCondition = BenchTaskWorker.getTimingCondition("sink", this.context);
            this.resultsVerifier = new ReduceVerifier(BTAllReduceExample.inputDataArray, taskContext, "source", BTAllReduceExample.jobParameters);
        }

        public boolean allReduce(int[] iArr) {
            Timing.mark(BenchmarkConstants.TIMING_ALL_RECV, this.timingCondition);
            BTAllReduceExample.LOG.info(String.format("%d received allreduce %d", Integer.valueOf(this.context.getWorkerId()), Integer.valueOf(this.context.globalTaskId())));
            BenchmarkUtils.markTotalTime(BTAllReduceExample.resultsRecorder, this.timingCondition);
            BTAllReduceExample.resultsRecorder.writeToCSV();
            this.verified = BTAllReduceExample.verifyResults(this.resultsVerifier, iArr, null, this.verified);
            return true;
        }
    }

    @Override // edu.iu.dsc.tws.examples.task.BenchTaskWorker
    public ComputeGraphBuilder buildTaskGraph() {
        List<Integer> taskStages = jobParameters.getTaskStages();
        int intValue = taskStages.get(0).intValue();
        int intValue2 = taskStages.get(1).intValue();
        BenchTaskWorker.SourceTask sourceTask = new BenchTaskWorker.SourceTask("edge");
        AllReduceSinkTask allReduceSinkTask = new AllReduceSinkTask();
        this.computeGraphBuilder.addSource("source", sourceTask, intValue);
        this.computeConnection = this.computeGraphBuilder.addSink("sink", allReduceSinkTask, intValue2);
        this.computeConnection.allreduce("source").viaEdge("edge").withOperation(Op.SUM, MessageTypes.INTEGER_ARRAY);
        return this.computeGraphBuilder;
    }
}
