package edu.iu.dsc.tws.rsched.schedulers.standalone;

import com.google.protobuf.InvalidProtocolBufferException;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.config.Context;
import edu.iu.dsc.tws.api.driver.IScalerPerCluster;
import edu.iu.dsc.tws.api.exceptions.Twister2Exception;
import edu.iu.dsc.tws.api.resource.FSPersistentVolume;
import edu.iu.dsc.tws.api.resource.IPersistentVolume;
import edu.iu.dsc.tws.api.resource.IVolatileVolume;
import edu.iu.dsc.tws.api.resource.IWorker;
import edu.iu.dsc.tws.api.resource.IWorkerController;
import edu.iu.dsc.tws.api.scheduler.SchedulerContext;
import edu.iu.dsc.tws.common.config.ConfigLoader;
import edu.iu.dsc.tws.common.logging.LoggingHelper;
import edu.iu.dsc.tws.common.util.JSONUtils;
import edu.iu.dsc.tws.common.util.NetworkUtils;
import edu.iu.dsc.tws.common.util.ReflectionUtils;
import edu.iu.dsc.tws.master.JobMasterContext;
import edu.iu.dsc.tws.master.server.JobMaster;
import edu.iu.dsc.tws.master.worker.JMWorkerAgent;
import edu.iu.dsc.tws.proto.jobmaster.JobMasterAPI;
import edu.iu.dsc.tws.proto.system.JobExecutionState;
import edu.iu.dsc.tws.proto.system.job.JobAPI;
import edu.iu.dsc.tws.proto.utils.NodeInfoUtils;
import edu.iu.dsc.tws.proto.utils.WorkerInfoUtils;
import edu.iu.dsc.tws.rsched.schedulers.k8s.KubernetesContext;
import edu.iu.dsc.tws.rsched.schedulers.mesos.MesosContext;
import edu.iu.dsc.tws.rsched.schedulers.nomad.NomadContext;
import edu.iu.dsc.tws.rsched.schedulers.nomad.NomadTerminator;
import edu.iu.dsc.tws.rsched.utils.JobUtils;
import java.io.File;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
import mpi.Intracomm;
import mpi.MPI;
import mpi.MPIException;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;

/* loaded from: input_file:edu/iu/dsc/tws/rsched/schedulers/standalone/MPIWorker.class */
public final class MPIWorker {
    private static final Logger LOG = Logger.getLogger(MPIWorker.class.getName());
    private JMWorkerAgent masterClient;
    private Config config;
    private JobMasterAPI.WorkerInfo wInfo;

    public void finalizeMPI() {
        try {
            if (JobMasterContext.isJobMasterUsed(this.config)) {
                closeWorker();
            }
            MPI.Finalize();
        } catch (MPIException e) {
        }
    }

    private MPIWorker(String[] strArr) {
        try {
            MPI.InitThread(strArr, MPI.THREAD_MULTIPLE);
            int rank = MPI.COMM_WORLD.getRank();
            Thread.setDefaultUncaughtExceptionHandler((thread, th) -> {
                LOG.log(Level.SEVERE, "Uncaught exception in thread " + thread + ". Finalizing this worker...", th);
                if (!JobMasterContext.isJobMasterUsed(this.config)) {
                    throw new RuntimeException("Worker faild with exception", th);
                }
                JMWorkerAgent.getJMWorkerAgent().getSenderToDriver().sendToDriver(JobExecutionState.WorkerJobState.newBuilder().setFailure(true).setJobName(this.config.getStringValue("twister2.job.id")).setWorkerMessage(JSONUtils.toJSONString((Exception) th)).build());
                finalizeMPI();
            });
            this.config = loadConfigurations(new DefaultParser().parse(setupOptions(), strArr), rank);
            LOG.log(Level.FINE, "A worker process is starting...");
            JobAPI.Job readJobFile = JobUtils.readJobFile(null, JobUtils.getJobDescriptionFilePath(MPIContext.jobId(this.config), this.config));
            if (!JobMasterContext.isJobMasterUsed(this.config)) {
                this.wInfo = createWorkerInfo(this.config, MPI.COMM_WORLD.getRank(), readJobFile);
                startWorkerWithoutMaster(this.config, rank, MPI.COMM_WORLD, readJobFile);
            } else if (JobMasterContext.jobMasterRunsInClient(this.config)) {
                this.wInfo = createWorkerInfo(this.config, MPI.COMM_WORLD.getRank(), readJobFile);
                startWorker(this.config, rank, MPI.COMM_WORLD, readJobFile);
            } else {
                Intracomm split = MPI.COMM_WORLD.split(rank == 0 ? 0 : 1, rank);
                if (rank != 0) {
                    this.wInfo = createWorkerInfo(this.config, split.getRank(), readJobFile);
                } else {
                    this.wInfo = createWorkerInfo(this.config, -1, readJobFile);
                }
                broadCastMasterInformation(rank);
                if (rank != 0) {
                    startWorker(this.config, rank, split, readJobFile);
                } else {
                    startMaster(this.config, rank);
                }
            }
        } catch (InvalidProtocolBufferException e) {
            LOG.log(Level.SEVERE, "Protocol buffer exception ", e);
        } catch (ParseException e2) {
            new HelpFormatter().printHelp("SubmitterMain", (Options) null);
            throw new RuntimeException("Error parsing command line options: ", e2);
        } catch (MPIException e3) {
            LOG.log(Level.SEVERE, "Failed the MPI process", e3);
            throw new RuntimeException(e3);
        }
        finalizeMPI();
    }

    private void broadCastMasterInformation(int i) throws MPIException, InvalidProtocolBufferException {
        byte[] byteArray = this.wInfo.toByteArray();
        int length = byteArray.length;
        IntBuffer newIntBuffer = MPI.newIntBuffer(1);
        if (i == 0) {
            newIntBuffer.put(length);
        }
        MPI.COMM_WORLD.bcast(newIntBuffer, 1, MPI.INT, 0);
        int i2 = newIntBuffer.get(0);
        ByteBuffer newByteBuffer = MPI.newByteBuffer(i2);
        if (i == 0) {
            newByteBuffer.put(byteArray);
        }
        MPI.COMM_WORLD.bcast(newByteBuffer, i2, MPI.BYTE, 0);
        byte[] bArr = new byte[i2];
        if (i == 0) {
            this.config = Config.newBuilder().putAll(this.config).put("twister2.job.master.port", Integer.valueOf(this.wInfo.getPort())).put("twister2.job.master.ip", this.wInfo.getNodeInfo().getNodeIP()).build();
            return;
        }
        newByteBuffer.get(bArr);
        JobMasterAPI.WorkerInfo build = JobMasterAPI.WorkerInfo.newBuilder().mergeFrom(bArr).build();
        this.config = Config.newBuilder().putAll(this.config).put("twister2.job.master.port", Integer.valueOf(build.getPort())).put("twister2.job.master.ip", build.getNodeInfo().getNodeIP()).build();
    }

    public static void main(String[] strArr) {
        new MPIWorker(strArr);
    }

    private IWorkerController createWorkerController(JobAPI.Job job) {
        this.masterClient = createMasterAgent(this.config, JobMasterContext.jobMasterIP(this.config), JobMasterContext.jobMasterPort(this.config), this.wInfo, job.getNumberOfWorkers());
        return this.masterClient.getJMWorkerController();
    }

    private JMWorkerAgent createMasterAgent(Config config, String str, int i, JobMasterAPI.WorkerInfo workerInfo, int i2) {
        JMWorkerAgent createJMWorkerAgent = JMWorkerAgent.createJMWorkerAgent(config, workerInfo, str, i, i2, JobMasterAPI.WorkerState.STARTED);
        LOG.log(Level.FINE, String.format("Connecting to job master %s:%d", str, Integer.valueOf(i)));
        createJMWorkerAgent.startThreaded();
        return createJMWorkerAgent;
    }

    private Options setupOptions() {
        Options options = new Options();
        Option build = Option.builder("c").desc("The class name of the container to launch").longOpt("container_class").hasArgs().argName("container class").required().build();
        Option build2 = Option.builder("d").desc("The class name of the container to launch").longOpt("config_dir").hasArgs().argName("configuration directory").required().build();
        Option build3 = Option.builder("t").desc("The class name of the container to launch").longOpt("twister2_home").hasArgs().argName("twister2 home").required().build();
        Option build4 = Option.builder("n").desc("The clustr type").longOpt("cluster_type").hasArgs().argName("cluster type").required().build();
        Option build5 = Option.builder("j").desc("Job Id").longOpt("job_id").hasArgs().argName("job id").required().build();
        Option build6 = Option.builder("i").desc("Job master ip").longOpt("job_master_ip").hasArgs().argName("job master ip").required().build();
        Option build7 = Option.builder("p").desc("Job master ip").longOpt("job_master_port").hasArgs().argName("job master port").required().build();
        options.addOption(build3);
        options.addOption(build);
        options.addOption(build2);
        options.addOption(build4);
        options.addOption(build5);
        options.addOption(build6);
        options.addOption(build7);
        return options;
    }

    private Config loadConfigurations(CommandLine commandLine, int i) {
        String optionValue = commandLine.getOptionValue("twister2_home");
        String optionValue2 = commandLine.getOptionValue("container_class");
        String optionValue3 = commandLine.getOptionValue("config_dir");
        String optionValue4 = commandLine.getOptionValue("cluster_type");
        String optionValue5 = commandLine.getOptionValue("job_id");
        String optionValue6 = commandLine.getOptionValue("job_master_ip");
        int parseInt = Integer.parseInt(commandLine.getOptionValue("job_master_port"));
        LOG.log(Level.FINE, String.format("Initializing process with twister_home: %s container_class: %s config_dir: %s cluster_type: %s", optionValue, optionValue2, optionValue3, optionValue4));
        Config loadConfig = ConfigLoader.loadConfig(optionValue, optionValue3, optionValue4);
        JobAPI.Job readJobFile = JobUtils.readJobFile(null, JobUtils.getJobDescriptionFilePath(optionValue5, Config.newBuilder().putAll(loadConfig).put(MPIContext.TWISTER2_HOME.getKey(), optionValue).put(MesosContext.MESOS_CONTAINER_CLASS, optionValue2).put("twister2.container.id", Integer.valueOf(i)).put("twister2.cluster.type", optionValue4).build()));
        return Config.newBuilder().putAll(JobUtils.overrideConfigs(readJobFile, loadConfig)).put(MPIContext.TWISTER2_HOME.getKey(), optionValue).put(MesosContext.MESOS_CONTAINER_CLASS, optionValue2).put("twister2.container.id", Integer.valueOf(i)).put("twister2.job.id", optionValue5).put("twister2.job.object", readJobFile).put("twister2.cluster.type", optionValue4).put("twister2.job.master.ip", optionValue6).put("twister2.job.master.port", Integer.valueOf(parseInt)).build();
    }

    private void startMaster(Config config, int i) {
        JobAPI.Job job = (JobAPI.Job) config.get("twister2.job.object");
        try {
            int jobMasterPort = JobMasterContext.jobMasterPort(config);
            String hostAddress = InetAddress.getLocalHost().getHostAddress();
            LOG.log(Level.INFO, String.format("Starting the job manager: %s:%d", hostAddress, Integer.valueOf(jobMasterPort)));
            JobMaster jobMaster = new JobMaster(config, hostAddress, jobMasterPort, new NomadTerminator(), job, (JobMasterAPI.NodeInfo) null, (IScalerPerCluster) null, JobMasterAPI.JobMasterState.JM_STARTED);
            jobMaster.addShutdownHook(false);
            Thread startJobMasterThreaded = jobMaster.startJobMasterThreaded();
            if (startJobMasterThreaded != null) {
                try {
                    startJobMasterThreaded.join();
                } catch (InterruptedException e) {
                }
            }
            LOG.log(Level.INFO, "Master done... ");
        } catch (UnknownHostException e2) {
            LOG.log(Level.SEVERE, "Exception when getting local host address: ", (Throwable) e2);
            throw new RuntimeException(e2);
        } catch (Twister2Exception e3) {
            LOG.log(Level.SEVERE, "Exception when starting Job master: ", e3);
            throw new RuntimeException(e3);
        }
    }

    private void startWorker(Config config, int i, Intracomm intracomm, JobAPI.Job job) {
        try {
            initLogger(config, intracomm.getRank(), Context.twister2Home(config));
            MPIJobWorkerController mPIJobWorkerController = new MPIJobWorkerController(createWorkerController(job));
            IPersistentVolume initPersistenceVolume = initPersistenceVolume(config, job.getJobName(), i);
            mPIJobWorkerController.add("comm", intracomm);
            String workerClass = MPIContext.workerClass(config);
            try {
                Object newInstance = ReflectionUtils.newInstance(workerClass);
                if (!(newInstance instanceof IWorker)) {
                    throw new RuntimeException("Cannot instantiate class: " + newInstance.getClass());
                }
                ((IWorker) newInstance).execute(config, intracomm.getRank(), mPIJobWorkerController, initPersistenceVolume, (IVolatileVolume) null);
                LOG.log(Level.FINE, "loaded worker class: " + workerClass);
                LOG.log(Level.FINE, String.format("Worker %d: the cluster is ready...", Integer.valueOf(i)));
            } catch (ClassNotFoundException | IllegalAccessException | InstantiationException e) {
                LOG.log(Level.SEVERE, String.format("failed to load the worker class %s", workerClass), e);
                throw new RuntimeException(e);
            }
        } catch (MPIException e2) {
            LOG.log(Level.SEVERE, "Failed to synchronize the workers at the start");
            throw new RuntimeException((Throwable) e2);
        }
    }

    private void startWorkerWithoutMaster(Config config, int i, Intracomm intracomm, JobAPI.Job job) {
        try {
            initLogger(config, intracomm.getRank(), Context.twister2Home(config));
            MPIWorkerController mPIWorkerController = new MPIWorkerController(intracomm.getRank(), createResourcePlan(config, intracomm, job));
            IPersistentVolume initPersistenceVolume = initPersistenceVolume(config, job.getJobName(), i);
            mPIWorkerController.add("comm", intracomm);
            String workerClass = MPIContext.workerClass(config);
            try {
                Object newInstance = ReflectionUtils.newInstance(workerClass);
                if (!(newInstance instanceof IWorker)) {
                    throw new RuntimeException("Cannot instantiate class: " + newInstance.getClass());
                }
                ((IWorker) newInstance).execute(config, intracomm.getRank(), mPIWorkerController, initPersistenceVolume, (IVolatileVolume) null);
                LOG.log(Level.FINE, "loaded worker class: " + workerClass);
                LOG.log(Level.FINE, String.format("Worker %d: the cluster is ready...", Integer.valueOf(i)));
            } catch (ClassNotFoundException | IllegalAccessException | InstantiationException e) {
                LOG.log(Level.SEVERE, String.format("failed to load the worker class %s", workerClass), e);
                throw new RuntimeException(e);
            }
        } catch (MPIException e2) {
            LOG.log(Level.SEVERE, "Failed to synchronize the workers at the start");
            throw new RuntimeException((Throwable) e2);
        }
    }

    private void closeWorker() {
        LOG.log(Level.INFO, String.format("Worker finished executing - %d", Integer.valueOf(this.wInfo.getWorkerID())));
        if (this.masterClient != null) {
            this.masterClient.sendWorkerCompletedMessage();
            this.masterClient.close();
        }
    }

    public Map<Integer, JobMasterAPI.WorkerInfo> createResourcePlan(Config config, Intracomm intracomm, JobAPI.Job job) {
        try {
            byte[] byteArray = createWorkerInfo(config, intracomm.getRank(), job).toByteArray();
            int length = byteArray.length;
            IntBuffer newIntBuffer = MPI.newIntBuffer(1);
            int size = intracomm.getSize();
            IntBuffer newIntBuffer2 = MPI.newIntBuffer(size);
            newIntBuffer.put(length);
            intracomm.allGather(newIntBuffer, 1, MPI.INT, newIntBuffer2, 1, MPI.INT);
            int[] iArr = new int[size];
            int[] iArr2 = new int[size];
            int i = 0;
            for (int i2 = 0; i2 < size; i2++) {
                iArr[i2] = newIntBuffer2.get(i2);
                iArr2[i2] = i;
                i += iArr[i2];
            }
            ByteBuffer newByteBuffer = MPI.newByteBuffer(length);
            ByteBuffer newByteBuffer2 = MPI.newByteBuffer(i);
            newByteBuffer.put(byteArray);
            intracomm.allGatherv(newByteBuffer, length, MPI.BYTE, newByteBuffer2, iArr, iArr2, MPI.BYTE);
            HashMap hashMap = new HashMap();
            for (int i3 = 0; i3 < iArr.length; i3++) {
                byte[] bArr = new byte[iArr[i3]];
                newByteBuffer2.get(bArr);
                hashMap.put(Integer.valueOf(i3), JobMasterAPI.WorkerInfo.newBuilder().mergeFrom(bArr).build());
                LOG.log(Level.FINE, String.format("Process %d name: %s", Integer.valueOf(i3), hashMap.get(Integer.valueOf(i3))));
            }
            return hashMap;
        } catch (InvalidProtocolBufferException e) {
            throw new RuntimeException("Failed to create worker info", e);
        } catch (MPIException e2) {
            throw new RuntimeException("Failed to communicate", e2);
        }
    }

    private JobMasterAPI.WorkerInfo createWorkerInfo(Config config, int i, JobAPI.Job job) throws MPIException {
        try {
            String hostAddress = InetAddress.getLocalHost().getHostAddress();
            JobMasterAPI.NodeInfo createNodeInfo = NodeInfoUtils.createNodeInfo(hostAddress, KubernetesContext.KUBERNETES_NAMESPACE_DEFAULT, KubernetesContext.KUBERNETES_NAMESPACE_DEFAULT);
            JobAPI.ComputeResource computeResource = JobUtils.getComputeResource(job, i);
            List additionalPorts = SchedulerContext.additionalPorts(config);
            HashMap hashMap = new HashMap();
            if (additionalPorts == null) {
                additionalPorts = new ArrayList();
            }
            additionalPorts.add("__worker__");
            Map findFreePorts = NetworkUtils.findFreePorts(additionalPorts);
            MPI.COMM_WORLD.barrier();
            AtomicBoolean atomicBoolean = new AtomicBoolean(true);
            findFreePorts.forEach((str, serverSocket) -> {
                hashMap.put(str, Integer.valueOf(serverSocket.getLocalPort()));
                try {
                    serverSocket.close();
                } catch (IOException e) {
                    LOG.log(Level.SEVERE, e, () -> {
                        return "Couldn't close opened server socket : " + str;
                    });
                    atomicBoolean.set(false);
                }
            });
            if (!atomicBoolean.get()) {
                throw new IllegalStateException("Could not release one or more free TCP/IP ports");
            }
            Integer num = (Integer) hashMap.get("__worker__");
            hashMap.remove("__worker__");
            LOG.fine("Worker info host:" + hostAddress + ":" + num);
            return WorkerInfoUtils.createWorkerInfo(i, hostAddress, num.intValue(), createNodeInfo, computeResource, hashMap);
        } catch (UnknownHostException e) {
            throw new RuntimeException("Failed to get ip address", e);
        }
    }

    private void initLogger(Config config, int i, String str) {
        String path = NomadContext.getLoggingSandbox(config) ? Paths.get(NomadContext.workingDirectory(config), NomadContext.jobId(config)).toString() : str;
        if (path == null) {
            return;
        }
        String str2 = path + "/logs/worker-" + i;
        File file = new File(str2);
        if (!file.exists() && !file.mkdirs()) {
            throw new RuntimeException("Failed to create log directory: " + str2);
        }
        LoggingHelper.setupLogging(config, str2, "worker-" + i);
        LOG.fine(String.format("Logging is setup with file %s", str2));
    }

    private IPersistentVolume initPersistenceVolume(Config config, String str, int i) {
        File file = new File(MPIContext.fileSystemMount(config));
        while (!file.exists() && !file.mkdirs()) {
            try {
                Thread.sleep(100L);
            } catch (InterruptedException e) {
                throw new RuntimeException("Thread interrupted", e);
            }
        }
        return new FSPersistentVolume(file.getAbsolutePath(), i);
    }
}
