package de.julielab.jcore.ae.jpos.main;

import cc.mallet.fst.CRF;
import com.google.common.base.Charsets;
import com.google.common.io.Files;
import de.julielab.jcore.ae.jpos.tagger.POSTagger;
import de.julielab.jcore.ae.jpos.tagger.Sentence;
import de.julielab.jcore.ae.jpos.tagger.Unit;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import java.util.Random;

/* loaded from: input_file:de/julielab/jcore/ae/jpos/main/JPOSApplication.class */
public class JPOSApplication {
    public static void main(String[] strArr) throws Exception {
        long currentTimeMillis = System.currentTimeMillis();
        if (strArr.length < 1) {
            System.err.println("usage: <mode> <mode-specific-parameters>");
            showModes();
            System.exit(-1);
        }
        String str = strArr[0];
        if (str.equals("x")) {
            if (strArr.length < 4) {
                System.err.println("usage: x <trainData> <x-rounds> <featureConfigFile> [number of iterations]");
                System.err.println("pred-out format: token pred gold");
                System.exit(-1);
            }
            File file = new File(strArr[1]);
            int intValue = new Integer(strArr[2]).intValue();
            File file2 = new File(strArr[3]);
            int i = 0;
            if (strArr.length == 5) {
                i = new Integer(strArr[4]).intValue();
            }
            evalXVal(file, intValue, file2, i, true);
        } else if (str.equals("t")) {
            if (strArr.length < 4) {
                System.err.println("usage: t <trainData> <model-out-file> <featureConfigFile> [number of iterations]");
                System.exit(-1);
            }
            File file3 = new File(strArr[1]);
            File file4 = new File(strArr[2]);
            int i2 = 0;
            File file5 = new File(strArr[3]);
            if (strArr.length == 5) {
                i2 = new Integer(strArr[4]).intValue();
            }
            train(file3, file4, file5, i2);
        } else if (str.equals("p")) {
            if (strArr.length != 4) {
                System.err.println("usage: p <unlabeled data> <modelFile> <outFile>");
                System.exit(-1);
            }
            predict(new File(strArr[1]), new File(strArr[2]), new File(strArr[3]));
        } else if (str.equals("c")) {
            if (strArr.length != 3) {
                System.err.println("\ncompares the gold standard agains the prediction");
                System.err.println("\nusage: c <predData> <goldData>");
                System.exit(-1);
            }
            compare(new File(strArr[1]), new File(strArr[2]));
        } else if (str.equals("oc")) {
            if (strArr.length != 2) {
                System.err.println("\nusage: oc <model>");
                System.exit(-1);
            }
            printFeatureConfig(new File(strArr[1]));
        } else if (str.equals("ts")) {
            if (strArr.length != 2) {
                System.err.println("\nusage: ts <model>");
                System.exit(-1);
            }
            printTagset(new File(strArr[1]));
        } else {
            System.err.println("ERR: unknown mode");
            showModes();
            System.exit(-1);
        }
        System.out.println("Finished in " + (((System.currentTimeMillis() - currentTimeMillis) / 1000) / 60) + " minutes");
    }

    static void showModes() {
        System.err.println("\nAvailable modes:");
        System.err.println("x: cross validation ");
        System.err.println("c: compare goldstandard and prediction");
        System.err.println("t: train ");
        System.err.println("p: predict ");
        System.err.println("oc: output model configuration ");
        System.err.println("ts: output model tagset");
        System.exit(-1);
    }

    static void train(File file, File file2, File file3, int i) throws IOException {
        List readLines = Files.readLines(file, Charsets.UTF_8);
        ArrayList<Sentence> arrayList = new ArrayList<>();
        POSTagger pOSTagger = file3 != null ? new POSTagger(file3) : new POSTagger();
        pOSTagger.set_Number_Iterations(i);
        Iterator it = readLines.iterator();
        while (it.hasNext()) {
            arrayList.add(pOSTagger.PPDtoUnits((String) it.next()));
        }
        pOSTagger.train(arrayList);
        pOSTagger.writeModel(file2.toString());
    }

    public static void evalXVal(File file, int i, File file2, int i2, boolean z) throws IOException {
        List readLines = Files.readLines(file, Charsets.UTF_8);
        DecimalFormat decimalFormat = new DecimalFormat("0.000");
        Collections.shuffle(readLines, new Random(1L));
        int i3 = 0;
        int size = readLines.size() / i;
        int size2 = readLines.size();
        int i4 = size + (size2 % i);
        System.out.println(" * number of sentences: " + size2);
        System.out.println(" * size of each/last round: " + size + "/" + i4);
        System.out.println();
        double[] dArr = new double[i];
        for (int i5 = 0; i5 < i; i5++) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            if (i5 == i - 1) {
                for (int i6 = 0; i6 < readLines.size(); i6++) {
                    if (i6 < i3) {
                        arrayList.add((String) readLines.get(i6));
                    } else {
                        arrayList2.add((String) readLines.get(i6));
                    }
                }
            } else {
                for (int i7 = 0; i7 < readLines.size(); i7++) {
                    if (i7 < i3 || i7 >= i3 + size) {
                        arrayList.add((String) readLines.get(i7));
                    } else {
                        arrayList2.add((String) readLines.get(i7));
                    }
                }
                i3 += size;
            }
            System.out.println(" * training on: " + arrayList.size() + " -- testing on: " + arrayList2.size());
            double eval = eval(arrayList, arrayList2, file2, i2, i5);
            dArr[i5] = eval;
            System.out.println("\n** round " + (i5 + 1) + "\tAccuracy: " + decimalFormat.format(eval));
        }
        double average = getAverage(dArr);
        double standardDeviation = getStandardDeviation(dArr, average);
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Cross-validation results:\n");
        stringBuffer.append("Number of sentences in evaluation data set: " + size2 + "\n");
        stringBuffer.append("Number of sentences for training in each/last round: " + size + "/" + i4 + "\n\n");
        stringBuffer.append("Overall performance: avg (standard deviation)\n");
        stringBuffer.append("Accuracy: " + decimalFormat.format(average) + "(" + decimalFormat.format(standardDeviation) + ")\n");
        System.out.println("\n\nCross-validation finished");
        System.out.println(stringBuffer);
    }

    public static double getStandardDeviation(double[] dArr, double d) {
        double d2 = 0.0d;
        for (double d3 : dArr) {
            d2 += Math.pow(d3 - d, 2.0d);
        }
        return Math.sqrt(d2 / (dArr.length - 1.0d));
    }

    public static double getAverage(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d / dArr.length;
    }

    static void predict(File file, File file2, File file3) throws Exception {
        List readLines = Files.readLines(file, Charsets.UTF_8);
        ArrayList<Sentence> arrayList = new ArrayList<>();
        POSTagger readModel = POSTagger.readModel(file2);
        try {
            System.out.println("  * predicting...");
            long currentTimeMillis = System.currentTimeMillis();
            FileWriter fileWriter = new FileWriter(file3);
            Iterator it = readLines.iterator();
            while (it.hasNext()) {
                arrayList.add(readModel.textToUnits((String) it.next()));
            }
            Iterator<String> it2 = readModel.predictForCLI(arrayList).iterator();
            while (it2.hasNext()) {
                fileWriter.write(it2.next());
            }
            System.out.println("prediction took: " + (System.currentTimeMillis() - currentTimeMillis));
            fileWriter.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    static double eval(ArrayList<String> arrayList, ArrayList<String> arrayList2, File file, int i, int i2) {
        ArrayList<Sentence> arrayList3 = new ArrayList<>();
        ArrayList<Sentence> arrayList4 = new ArrayList<>();
        POSTagger pOSTagger = file != null ? new POSTagger(file) : new POSTagger();
        pOSTagger.set_Number_Iterations(i);
        Iterator<String> it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList3.add(pOSTagger.PPDtoUnits(it.next()));
        }
        Iterator<String> it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            arrayList4.add(pOSTagger.PPDtoUnits(it2.next()));
        }
        pOSTagger.train(arrayList3);
        ArrayList arrayList5 = new ArrayList();
        for (int i3 = 0; i3 < arrayList4.size(); i3++) {
            Iterator<Unit> it3 = arrayList4.get(i3).getUnits().iterator();
            while (it3.hasNext()) {
                Unit next = it3.next();
                arrayList5.add(next.getRep() + "|" + next.getLabel());
            }
        }
        ArrayList arrayList6 = new ArrayList();
        Iterator<String> it4 = pOSTagger.predictForCLI(arrayList4).iterator();
        while (it4.hasNext()) {
            for (String str : it4.next().trim().split(" ")) {
                arrayList6.add(str);
            }
        }
        double d = 0.0d;
        if (arrayList6.size() != arrayList5.size()) {
            throw new RuntimeException();
        }
        for (int i4 = 0; i4 < arrayList5.size(); i4++) {
            if (((String) arrayList6.get(i4)).replaceAll(".*\\|", "").equals(((String) arrayList5.get(i4)).replaceAll(".*\\|", ""))) {
                d += 1.0d;
            } else {
                System.out.println("Predicted:\t" + ((String) arrayList6.get(i4)) + "\tCorrect: " + ((String) arrayList5.get(i4)));
            }
        }
        return d / arrayList5.size();
    }

    static void compare(File file, File file2) throws IOException {
        List readLines = Files.readLines(file2, Charsets.UTF_8);
        List readLines2 = Files.readLines(file, Charsets.UTF_8);
        if (readLines.size() != readLines2.size()) {
            System.err.println("ERR: number of lines in gold standard is different from prediction... please check!");
            System.exit(-1);
        }
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < readLines.size(); i3++) {
            String[] split = ((String) readLines.get(i3)).split(" +");
            String[] split2 = ((String) readLines2.get(i3)).split(" +");
            if (split.length != split2.length) {
                System.err.println("ERR: number of tokens in gold standard is different from prediction for\n" + split + "\n" + split2);
                System.exit(-1);
            }
            for (int i4 = 0; i4 < split.length; i4++) {
                i2++;
                if (split[i4].replaceAll(".*\\|", "").equals(split2[i4].replaceAll(".*\\|", ""))) {
                    i++;
                }
            }
        }
        System.out.println("Correct: " + i);
        System.out.println("Seen: " + i2);
        System.out.println("Accuracy: " + (i / i2));
    }

    public static void printFeatureConfig(File file) throws FileNotFoundException, ClassNotFoundException, IOException {
        Properties featureConfig = POSTagger.readModel(file).getFeatureConfig();
        Enumeration<?> propertyNames = featureConfig.propertyNames();
        while (propertyNames.hasMoreElements()) {
            String str = (String) propertyNames.nextElement();
            System.out.printf("%s = %s\n", str, featureConfig.getProperty(str));
        }
    }

    public static void printTagset(File file) throws FileNotFoundException, ClassNotFoundException, IOException {
        for (Object obj : ((CRF) POSTagger.readModel(file).getModel()).getOutputAlphabet().toArray()) {
            System.out.println(obj);
        }
    }
}
