package marytts.tools.newlanguage;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import marytts.cart.CART;
import marytts.cart.DecisionNode;
import marytts.cart.io.MaryCARTWriter;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.FeatureDefinition;
import marytts.fst.AlignerTrainer;
import marytts.fst.StringPair;
import marytts.modules.phonemiser.Allophone;
import marytts.modules.phonemiser.AllophoneSet;
import org.apache.log4j.BasicConfigurator;
import weka.classifiers.trees.j48.BinC45ModelSelection;
import weka.classifiers.trees.j48.C45PruneableClassifierTreeWithUnary;
import weka.classifiers.trees.j48.TreeConverter;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;

/* loaded from: input_file:marytts/tools/newlanguage/LTSTrainer.class */
public class LTSTrainer extends AlignerTrainer {
    protected AllophoneSet phSet;
    protected int context;
    protected boolean convertToLowercase;
    protected boolean considerStress;

    public LTSTrainer(AllophoneSet allophoneSet, boolean z, boolean z2, int i) {
        this.phSet = allophoneSet;
        this.convertToLowercase = z;
        this.considerStress = z2;
        this.context = i;
        BasicConfigurator.configure();
    }

    public CART trainTree(int i) throws IOException {
        HashMap hashMap = new HashMap();
        Iterator it = this.graphemeSet.iterator();
        while (it.hasNext()) {
            hashMap.put((String) it.next(), new ArrayList());
        }
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < this.inSplit.size(); i2++) {
            StringPair[] alignment = getAlignment(i2);
            for (int i3 = 0; i3 < alignment.length; i3++) {
                String str = "'" + alignment[i3].getString2() + "'";
                hashSet.add(str);
                String[] strArr = new String[(2 * this.context) + 2];
                for (int i4 = 0; i4 < (2 * this.context) + 1; i4++) {
                    int i5 = (i3 - this.context) + i4;
                    if (i5 < 0 || i5 >= alignment.length) {
                        strArr[i4] = "null";
                    } else {
                        strArr[i4] = alignment[i5].getString1();
                    }
                }
                strArr[(2 * this.context) + 1] = str;
                ((List) hashMap.get(alignment[i3].getString1())).add(strArr);
            }
        }
        FeatureDefinition graphemeFeatureDef = graphemeFeatureDef(hashSet);
        int featureIndex = graphemeFeatureDef.getFeatureIndex("att" + (this.context + 1));
        ArrayList arrayList = new ArrayList(graphemeFeatureDef.getNumberOfValues(featureIndex));
        for (String str2 : graphemeFeatureDef.getPossibleValues(featureIndex)) {
            System.out.println("      Training decision tree for: " + str2);
            this.logger.debug("      Training decision tree for: " + str2);
            ArrayList arrayList2 = new ArrayList();
            for (int i6 = 1; i6 <= (this.context * 2) + 1; i6++) {
                ArrayList arrayList3 = new ArrayList();
                String str3 = "att" + i6;
                for (String str4 : graphemeFeatureDef.getPossibleValues(graphemeFeatureDef.getFeatureIndex(str3))) {
                    arrayList3.add(str4);
                }
                arrayList2.add(new Attribute(str3, arrayList3));
            }
            List<String[]> list = (List) hashMap.get(str2);
            HashSet hashSet2 = new HashSet();
            for (String[] strArr2 : list) {
                hashSet2.add(strArr2[strArr2.length - 1]);
            }
            ArrayList arrayList4 = new ArrayList();
            Iterator it2 = hashSet2.iterator();
            while (it2.hasNext()) {
                arrayList4.add((String) it2.next());
            }
            arrayList2.add(new Attribute("predicted-string", arrayList4));
            Instances instances = new Instances(str2, arrayList2, 0);
            for (String[] strArr3 : list) {
                DenseInstance denseInstance = new DenseInstance(instances.numAttributes());
                denseInstance.setDataset(instances);
                for (int i7 = 0; i7 < strArr3.length; i7++) {
                    denseInstance.setValue(i7, strArr3[i7]);
                }
                instances.add(denseInstance);
            }
            instances.setClassIndex(instances.numAttributes() - 1);
            try {
                C45PruneableClassifierTreeWithUnary c45PruneableClassifierTreeWithUnary = new C45PruneableClassifierTreeWithUnary(new BinC45ModelSelection(i, instances, true), true, 0.25f, true, true, false);
                c45PruneableClassifierTreeWithUnary.buildClassifier(instances);
                arrayList.add(TreeConverter.c45toStringCART(c45PruneableClassifierTreeWithUnary, graphemeFeatureDef, instances));
            } catch (Exception e) {
                throw new RuntimeException("couldn't train decisiontree using weka: ", e);
            }
        }
        DecisionNode.ByteDecisionNode byteDecisionNode = new DecisionNode.ByteDecisionNode(featureIndex, arrayList.size(), graphemeFeatureDef);
        Iterator it3 = arrayList.iterator();
        while (it3.hasNext()) {
            byteDecisionNode.addDaughter(((CART) it3.next()).getRootNode());
        }
        Properties properties = new Properties();
        properties.setProperty("lowercase", String.valueOf(this.convertToLowercase));
        properties.setProperty("stress", String.valueOf(this.considerStress));
        properties.setProperty("context", String.valueOf(this.context));
        return new CART(byteDecisionNode, graphemeFeatureDef, properties);
    }

    public void save(CART cart, String str) throws IOException {
        new MaryCARTWriter().dumpMaryCART(cart, str);
    }

    private FeatureDefinition graphemeFeatureDef(Set<String> set) throws IOException {
        String property = System.getProperty("line.separator");
        StringBuilder sb = new StringBuilder("ByteValuedFeatureProcessors");
        sb.append(property);
        for (int i = 1; i <= (this.context * 2) + 1; i++) {
            sb.append("att").append(i);
            Iterator it = this.graphemeSet.iterator();
            while (it.hasNext()) {
                sb.append(" ").append((String) it.next());
            }
            sb.append(property);
        }
        sb.append("ShortValuedFeatureProcessors").append(property);
        sb.append("predicted-string");
        Iterator<String> it2 = set.iterator();
        while (it2.hasNext()) {
            sb.append(" ").append(it2.next());
        }
        sb.append(property);
        sb.append("ContinuousFeatureProcessors").append(property);
        return new FeatureDefinition(new BufferedReader(new StringReader(sb.toString())), false);
    }

    public void readLexicon(BufferedReader bufferedReader, String str) throws IOException {
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                addAlreadySplit(new String[]{"null"}, new String[]{""});
                return;
            }
            String[] split = readLine.trim().split(str);
            String str2 = split[0];
            if (this.convertToLowercase) {
                str2 = str2.toLowerCase(this.phSet.getLocale());
            }
            String replaceAll = str2.replaceAll("['-.]", "");
            String[] split2 = split[1].replaceAll(",", "").split("-");
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (String str3 : split2) {
                boolean z = false;
                if (str3.startsWith("'")) {
                    str3 = str3.substring(1);
                    z = true;
                }
                for (Allophone allophone : this.phSet.splitIntoAllophones(str3)) {
                    String name = allophone.name();
                    if (z && this.considerStress && allophone.isVowel()) {
                        name = name + "1";
                        z = false;
                    }
                    arrayList.add(name);
                }
            }
            for (int i = 0; i < replaceAll.length(); i++) {
                this.graphemeSet.add(replaceAll.substring(i, i + 1));
                arrayList2.add(replaceAll.substring(i, i + 1));
            }
            addAlreadySplit(arrayList2, arrayList);
        }
    }

    public void readLexicon(HashMap<String, String> hashMap) {
        for (String str : hashMap.keySet()) {
            String replaceAll = hashMap.get(str).replaceAll(",", "");
            if (this.convertToLowercase) {
                str = str.toLowerCase(this.phSet.getLocale());
            }
            String replaceAll2 = str.replaceAll("['-.]", "");
            String[] split = replaceAll.split("-");
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (String str2 : split) {
                boolean z = false;
                if (str2.startsWith("'")) {
                    str2 = str2.substring(1);
                    z = true;
                }
                for (Allophone allophone : this.phSet.splitIntoAllophones(str2)) {
                    String name = allophone.name();
                    if (z && this.considerStress && allophone.isVowel()) {
                        name = name + "1";
                        z = false;
                    }
                    arrayList.add(name);
                }
            }
            for (int i = 0; i < replaceAll2.length(); i++) {
                this.graphemeSet.add(replaceAll2.substring(i, i + 1));
                arrayList2.add(replaceAll2.substring(i, i + 1));
            }
            addAlreadySplit(arrayList2, arrayList);
        }
        addAlreadySplit(new String[]{"null"}, new String[]{""});
    }

    public static void main(String[] strArr) throws IOException, MaryConfigurationException {
        LTSTrainer lTSTrainer = new LTSTrainer(AllophoneSet.getAllophoneSet("/Users/benjaminroth/Desktop/mary/english/phone-list-engba.xml"), true, true, 2);
        lTSTrainer.readLexicon(new BufferedReader(new InputStreamReader(new FileInputStream("/Users/benjaminroth/Desktop/mary/english/sampa-lexicon.txt"), "ISO-8859-1")), "\\\\");
        for (int i = 0; i < 5; i++) {
            System.out.println("iteration " + i);
            lTSTrainer.alignIteration();
        }
        lTSTrainer.save(lTSTrainer.trainTree(100), "/Users/benjaminroth/Desktop/mary/english/trees/");
    }
}
