package org.nd4j.linalg.benchmark.app;

import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.ServiceLoader;
import java.util.concurrent.TimeUnit;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.nd4j.linalg.benchmark.api.BenchMarkPerformer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.reflections.Reflections;

/* loaded from: input_file:org/nd4j/linalg/benchmark/app/BenchmarkRunnerApp.class */
public class BenchmarkRunnerApp {

    @Option(name = "--nTrials", usage = "Number of trials to run", aliases = {"-n"})
    private int nTrials = 1000;

    @Option(name = "--run", usage = "Trials to run", aliases = {"-r"})
    private String benchmarksToRun;

    public void doMain(String[] strArr) throws Exception {
        Reflections reflections = new Reflections(new Object[0]);
        try {
            new CmdLineParser(this).parseArgument(strArr);
            ServiceLoader load = ServiceLoader.load(Nd4jBackend.class);
            Iterator it = load.iterator();
            ArrayList arrayList = new ArrayList();
            HashSet hashSet = new HashSet();
            if (this.benchmarksToRun != null) {
                for (String str : this.benchmarksToRun.split(",")) {
                    hashSet.add(str);
                }
            }
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            for (Class cls : reflections.getSubTypesOf(BenchMarkPerformer.class)) {
                if (!Modifier.isAbstract(cls.getModifiers()) && (hashSet.isEmpty() || hashSet.contains(cls.getName()))) {
                    System.out.println("========================= Benchmark: " + cls.getName() + " ===========================");
                    Iterator it2 = load.iterator();
                    while (it2.hasNext()) {
                        Nd4jBackend nd4jBackend = (Nd4jBackend) it2.next();
                        new Nd4j().initWithBackend(nd4jBackend);
                        BenchMarkPerformer benchMarkPerformer = (BenchMarkPerformer) cls.getConstructor(Integer.TYPE).newInstance(Integer.valueOf(this.nTrials));
                        System.out.println("Running " + nd4jBackend.getClass().getName());
                        benchMarkPerformer.run(nd4jBackend);
                        System.out.println("Backend " + nd4jBackend.getClass().getName() + " took (in nanoseconds) " + benchMarkPerformer.averageTime() + " (in milliseconds) " + TimeUnit.MILLISECONDS.convert(benchMarkPerformer.averageTime(), TimeUnit.NANOSECONDS));
                    }
                    System.out.println("====================================================");
                }
            }
        } catch (CmdLineException e) {
            System.err.println(e.getMessage());
        }
    }

    public static void main(String[] strArr) throws Exception {
        new BenchmarkRunnerApp().doMain(strArr);
    }
}
