package eu.stratosphere.pact.runtime.iterative.task;

import com.google.common.base.Preconditions;
import eu.stratosphere.api.common.aggregators.Aggregator;
import eu.stratosphere.api.common.aggregators.AggregatorWithName;
import eu.stratosphere.api.common.aggregators.ConvergenceCriterion;
import eu.stratosphere.nephele.event.task.AbstractTaskEvent;
import eu.stratosphere.nephele.execution.librarycache.LibraryCacheManager;
import eu.stratosphere.nephele.template.AbstractOutputTask;
import eu.stratosphere.nephele.types.IntegerRecord;
import eu.stratosphere.pact.runtime.iterative.event.AllWorkersDoneEvent;
import eu.stratosphere.pact.runtime.iterative.event.TerminationEvent;
import eu.stratosphere.pact.runtime.iterative.event.WorkerDoneEvent;
import eu.stratosphere.pact.runtime.task.RegularPactTask;
import eu.stratosphere.pact.runtime.task.util.TaskConfig;
import eu.stratosphere.runtime.io.api.MutableRecordReader;
import eu.stratosphere.types.Value;
import eu.stratosphere.util.InstantiationUtil;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/* loaded from: input_file:eu/stratosphere/pact/runtime/iterative/task/IterationSynchronizationSinkTask.class */
public class IterationSynchronizationSinkTask extends AbstractOutputTask implements Terminable {
    private static final Log log = LogFactory.getLog(IterationSynchronizationSinkTask.class);
    private MutableRecordReader<IntegerRecord> headEventReader;
    private ClassLoader userCodeClassLoader;
    private SyncEventHandler eventHandler;
    private ConvergenceCriterion<Value> convergenceCriterion;
    private Map<String, Aggregator<?>> aggregators;
    private String convergenceAggregatorName;
    private int maxNumberOfIterations;
    private int currentIteration = 1;
    private final AtomicBoolean terminated = new AtomicBoolean(false);

    @Override // eu.stratosphere.nephele.template.AbstractInvokable
    public void registerInputOutput() {
        this.headEventReader = new MutableRecordReader<>(this);
    }

    @Override // eu.stratosphere.nephele.template.AbstractInvokable
    public void invoke() throws Exception {
        this.userCodeClassLoader = LibraryCacheManager.getClassLoader(getEnvironment().getJobID());
        TaskConfig taskConfig = new TaskConfig(getTaskConfiguration());
        this.aggregators = new HashMap();
        for (AggregatorWithName<?> aggregatorWithName : taskConfig.getIterationAggregators()) {
            this.aggregators.put(aggregatorWithName.getName(), (Aggregator) InstantiationUtil.instantiate(aggregatorWithName.getAggregator(), Aggregator.class));
        }
        if (taskConfig.usesConvergenceCriterion()) {
            this.convergenceCriterion = (ConvergenceCriterion) InstantiationUtil.instantiate(taskConfig.getConvergenceCriterion(), ConvergenceCriterion.class);
            this.convergenceAggregatorName = taskConfig.getConvergenceCriterionAggregatorName();
            Preconditions.checkNotNull(this.convergenceAggregatorName);
        }
        this.maxNumberOfIterations = taskConfig.getNumberOfIterations();
        this.eventHandler = new SyncEventHandler(taskConfig.getNumberOfEventsUntilInterruptInIterativeGate(0), this.aggregators, this.userCodeClassLoader);
        this.headEventReader.subscribeToEvent(this.eventHandler, WorkerDoneEvent.class);
        IntegerRecord integerRecord = new IntegerRecord();
        while (!terminationRequested()) {
            if (log.isInfoEnabled()) {
                log.info(formatLogString("starting iteration [" + this.currentIteration + "]"));
            }
            readHeadEventChannel(integerRecord);
            if (log.isInfoEnabled()) {
                log.info(formatLogString("finishing iteration [" + this.currentIteration + "]"));
            }
            if (checkForConvergence()) {
                if (log.isInfoEnabled()) {
                    log.info(formatLogString("signaling that all workers are to terminate in iteration [" + this.currentIteration + "]"));
                }
                requestTermination();
                sendToAllWorkers(new TerminationEvent());
            } else {
                if (log.isInfoEnabled()) {
                    log.info(formatLogString("signaling that all workers are done in iteration [" + this.currentIteration + "]"));
                }
                sendToAllWorkers(new AllWorkersDoneEvent(this.aggregators));
                Iterator<Aggregator<?>> it = this.aggregators.values().iterator();
                while (it.hasNext()) {
                    it.next().reset();
                }
                this.currentIteration++;
            }
        }
    }

    private boolean checkForConvergence() {
        if (this.maxNumberOfIterations == this.currentIteration) {
            if (!log.isInfoEnabled()) {
                return true;
            }
            log.info(formatLogString("maximum number of iterations [" + this.currentIteration + "] reached, terminating..."));
            return true;
        }
        if (this.convergenceAggregatorName == null) {
            return false;
        }
        Aggregator<?> aggregator = this.aggregators.get(this.convergenceAggregatorName);
        if (aggregator == null) {
            throw new RuntimeException("Error: Aggregator for convergence criterion was null.");
        }
        if (!this.convergenceCriterion.isConverged(this.currentIteration, aggregator.getAggregate())) {
            return false;
        }
        if (!log.isInfoEnabled()) {
            return true;
        }
        log.info(formatLogString("convergence reached after [" + this.currentIteration + "] iterations, terminating..."));
        return true;
    }

    private void readHeadEventChannel(IntegerRecord integerRecord) throws IOException {
        this.eventHandler.resetEndOfSuperstep();
        try {
            if (this.headEventReader.next(integerRecord)) {
                throw new RuntimeException("Synchronization task must not see any records!");
            }
        } catch (InterruptedException e) {
            if (!this.eventHandler.isEndOfSuperstep()) {
                throw new RuntimeException("Event handler interrupted without reaching end-of-superstep.");
            }
        }
    }

    private void sendToAllWorkers(AbstractTaskEvent abstractTaskEvent) throws IOException, InterruptedException {
        this.headEventReader.publishEvent(abstractTaskEvent);
    }

    private String formatLogString(String str) {
        return RegularPactTask.constructLogString(str, getEnvironment().getTaskName(), this);
    }

    @Override // eu.stratosphere.pact.runtime.iterative.task.Terminable
    public boolean terminationRequested() {
        return this.terminated.get();
    }

    @Override // eu.stratosphere.pact.runtime.iterative.task.Terminable
    public void requestTermination() {
        this.terminated.set(true);
    }
}
