package ivory.cascade.retrieval;

import com.google.common.collect.Maps;
import edu.umd.cloud9.collection.DocnoMapping;
import ivory.bloomir.util.OptionManager;
import ivory.core.ConfigurationException;
import ivory.core.RetrievalEnvironment;
import ivory.core.eval.GradedQrels;
import ivory.core.eval.RankedListEvaluator;
import ivory.core.util.ResultWriter;
import ivory.core.util.XMLTools;
import ivory.smrf.model.builder.MRFBuilder;
import ivory.smrf.model.expander.MRFExpander;
import ivory.smrf.retrieval.Accumulator;
import ivory.smrf.retrieval.BatchQueryRunner;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;

/* loaded from: input_file:ivory/cascade/retrieval/CascadeBatchQueryRunner.class */
public class CascadeBatchQueryRunner extends BatchQueryRunner {
    private static final Logger LOG = Logger.getLogger(CascadeBatchQueryRunner.class);
    private HashMap<String, float[]> cascadeCosts;
    private HashMap cascadeCosts_lastStage;
    private String[] internalOutputFiles;
    private String[] internalInputFiles;
    public LinkedList ndcgValues;
    public LinkedList costKeys;
    private String dataCollection;
    private int K_val;
    private int kVal;

    public CascadeBatchQueryRunner(String[] strArr, FileSystem fileSystem) throws ConfigurationException {
        super(strArr, fileSystem);
        this.cascadeCosts = Maps.newHashMap();
        this.cascadeCosts_lastStage = new HashMap();
        this.ndcgValues = new LinkedList();
        this.costKeys = new LinkedList();
        this.dataCollection = null;
        parseParameters(strArr);
    }

    public HashMap readInternalInputFile(String str) {
        HashMap hashMap = new HashMap();
        if (str != null) {
            try {
                BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(this.fs.open(new Path(str))));
                LinkedList linkedList = new LinkedList();
                String str2 = "";
                float[] fArr = new float[2];
                while (true) {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null || readLine.trim().length() <= 0) {
                        break;
                    }
                    String[] split = readLine.split("\\s+");
                    if (!str2.equals(split[0])) {
                        if (!str2.equals("")) {
                            hashMap.put(str2, linkedList);
                        }
                        str2 = split[0];
                        linkedList = new LinkedList();
                    }
                    linkedList.add(new float[]{(float) Double.parseDouble(split[1]), (float) Double.parseDouble(split[2])});
                }
                hashMap.put(str2, linkedList);
            } catch (Exception e) {
                System.out.println("Problem reading " + str);
                System.exit(-1);
            }
            if (hashMap.size() <= 1) {
                System.out.println("Should have results for more queries.");
                System.exit(-1);
            }
        }
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry : hashMap.entrySet()) {
            String str3 = (String) entry.getKey();
            LinkedList linkedList2 = (LinkedList) entry.getValue();
            float[][] fArr2 = new float[linkedList2.size()][2];
            for (int i = 0; i < linkedList2.size(); i++) {
                float[] fArr3 = (float[]) linkedList2.get(i);
                fArr2[i][0] = fArr3[0];
                fArr2[i][1] = fArr3[1];
            }
            hashMap2.put(str3, fArr2);
        }
        return hashMap2;
    }

    @Override // ivory.smrf.retrieval.BatchQueryRunner
    public void runQueries() {
        this.cascadeCosts = new HashMap<>();
        this.cascadeCosts_lastStage = new HashMap();
        int i = 0;
        for (String str : this.models.keySet()) {
            HashMap readInternalInputFile = readInternalInputFile(this.internalInputFiles[i]);
            Node node = this.models.get(str);
            Node node2 = this.expanders.get(str);
            this.K_val = XMLTools.getAttributeValue(node, "K", 0);
            this.kVal = XMLTools.getAttributeValue(node, "topK", 0);
            if (this.kVal == 0) {
                System.out.println("Should not be 0!");
                System.exit(-1);
            }
            RetrievalEnvironment.topK = this.kVal;
            CascadeThreadedQueryRunner cascadeThreadedQueryRunner = null;
            try {
                MRFBuilder mRFBuilder = MRFBuilder.get(this.env, node.cloneNode(true));
                MRFExpander expander = node2 != null ? MRFExpander.getExpander(this.env, node2.cloneNode(true)) : null;
                if (this.stopwords != null && this.stopwords.size() != 0) {
                    expander.setStopwordList(this.stopwords);
                }
                int attributeValue = XMLTools.getAttributeValue(node, OptionManager.HITS, 1000);
                if (this.K_val != 0) {
                    attributeValue = this.K_val;
                }
                LOG.info("number of hits: " + attributeValue);
                cascadeThreadedQueryRunner = new CascadeThreadedQueryRunner(mRFBuilder, expander, 1, attributeValue, readInternalInputFile, this.K_val);
                this.queryRunners.put(str, cascadeThreadedQueryRunner);
            } catch (Exception e) {
                e.printStackTrace();
            }
            for (String str2 : this.queries.keySet()) {
                String str3 = this.queries.get(str2);
                String[] strArr = this.env.tokenize(str3);
                LOG.info(String.format("query id: %s, query: \"%s\"", str2, str3));
                cascadeThreadedQueryRunner.runQuery(str2, strArr);
            }
            Node node3 = this.models.get(str);
            String attributeValue2 = XMLTools.getAttributeValue(node3, OptionManager.OUTPUT_PATH, (String) null);
            boolean attributeValue3 = XMLTools.getAttributeValue(node3, "compress", false);
            String str4 = this.internalOutputFiles[i];
            try {
                ResultWriter resultWriter = new ResultWriter(attributeValue2, attributeValue3, this.fs);
                if (str4 != null) {
                    ResultWriter resultWriter2 = new ResultWriter(str4, attributeValue3, this.fs);
                    printResults(str, cascadeThreadedQueryRunner, resultWriter2, true);
                    resultWriter2.flush();
                }
                printResults(str, cascadeThreadedQueryRunner, resultWriter, false);
                resultWriter.flush();
                this.cascadeCosts.put(str, cascadeThreadedQueryRunner.getCascadeCostAllQueries());
                this.cascadeCosts_lastStage.put(str, cascadeThreadedQueryRunner.getCascadeCostAllQueries_lastStage());
                i++;
            } catch (IOException e2) {
                throw new RuntimeException("Error: Unable to write results!");
            }
        }
        float f = 0.0f;
        float f2 = 0.0f;
        for (int i2 = 0; i2 < this.costKeys.size(); i2++) {
            String[] split = ((String) this.costKeys.get(i2)).split("\\s+");
            float cascadeCost = getCascadeCost(split[0], split[1]);
            f += Float.parseFloat((String) this.ndcgValues.get(i2));
            f2 += cascadeCost;
        }
        if (this.costKeys.size() != this.ndcgValues.size()) {
            System.out.println("They should be equal " + this.costKeys.size() + " " + this.ndcgValues.size());
            System.exit(-1);
        }
        System.out.println("Evaluation results... NDCG Sum " + f + " TotalCost " + f2 + " # queries with results " + this.costKeys.size() + " dataCollection " + this.dataCollection + " kVal " + this.kVal);
    }

    public float getCascadeCost(String str, String str2) {
        return this.cascadeCosts.get(str)[Integer.parseInt(str2)];
    }

    public float getCascadeCost_lastStage(String str, String str2) {
        return ((float[]) this.cascadeCosts_lastStage.get(str))[Integer.parseInt(str2)];
    }

    private void printResults(String str, CascadeQueryRunner cascadeQueryRunner, ResultWriter resultWriter, boolean z) throws IOException {
        float f = 0.0f;
        String str2 = null;
        if (this.dataCollection.indexOf("wt10g") != -1) {
            if (this.fs.exists(new Path("/user/lidan/qrels/qrels.wt10g"))) {
                str2 = "/user/lidan/qrels/qrels.wt10g";
            } else if (this.fs.exists(new Path("/umd-lin/lidan/qrels/qrels.wt10g"))) {
                str2 = "/umd-lin/lidan/qrels/qrels.wt10g";
            } else if (this.fs.exists(new Path("/fs/clip-trec/trunk_new/docs/data/wt10g/qrels.wt10g"))) {
                str2 = "/fs/clip-trec/trunk_new/docs/data/wt10g/qrels.wt10g";
            } else if (this.fs.exists(new Path("data/wt10g/qrels.wt10g.all"))) {
                str2 = "data/wt10g/qrels.wt10g.all";
            }
        } else if (this.dataCollection.indexOf("gov2") != -1) {
            if (this.fs.exists(new Path("/user/lidan/qrels/qrels.gov2.all"))) {
                str2 = "/user/lidan/qrels/qrels.gov2.all";
            } else if (this.fs.exists(new Path("/umd-lin/lidan/qrels/qrels.gov2.all"))) {
                str2 = "/umd-lin/lidan/qrels/qrels.gov2.all";
            } else if (this.fs.exists(new Path("/fs/clip-trec/trunk_new/docs/data/gov2/qrels.gov2.all"))) {
                str2 = "/fs/clip-trec/trunk_new/docs/data/gov2/qrels.gov2.all";
            } else if (this.fs.exists(new Path("data/gov2/qrels.gov2.all"))) {
                str2 = "data/gov2/qrels.gov2.all";
            }
        } else if (this.dataCollection.indexOf("clue") != -1) {
            if (this.fs.exists(new Path("/user/lidan/qrels/qrels.web09catB.txt"))) {
                str2 = "/user/lidan/qrels/qrels.web09catB.txt";
            } else if (this.fs.exists(new Path("/umd-lin/lidan/qrels/qrels.web09catB.txt"))) {
                str2 = "/umd-lin/lidan/qrels/qrels.web09catB.txt";
            } else if (this.fs.exists(new Path("/fs/clip-trec/trunk_new/docs/data/clue/qrels.web09catB.txt"))) {
                str2 = "/fs/clip-trec/trunk_new/docs/data/clue/qrels.web09catB.txt";
            } else if (this.fs.exists(new Path("data/clue/qrels.web09catB.txt"))) {
                str2 = "data/clue/qrels.web09catB.txt";
            }
        }
        if (str2 == null) {
            System.out.println("Should have set qrelsPath!");
            System.exit(-1);
        }
        GradedQrels gradedQrels = new GradedQrels(str2);
        DocnoMapping docnoMapping = getDocnoMapping();
        if (this.K_val == 0) {
        }
        for (String str3 : this.queries.keySet()) {
            Accumulator[] results = cascadeQueryRunner.getResults(str3);
            if (results == null) {
                LOG.info("null results for: " + str3);
            } else {
                float computeNDCG = (float) RankedListEvaluator.computeNDCG(this.kVal, results, docnoMapping, gradedQrels.getReldocsForQid(str3, true));
                f += computeNDCG;
                if (!z && gradedQrels.getReldocsForQid(str3, true).size() > 0) {
                    this.ndcgValues.add(computeNDCG + "");
                    this.costKeys.add(str + " " + str3);
                }
                if (z) {
                    for (int i = 0; i < results.length; i++) {
                        resultWriter.println(str3 + " " + results[i].docno + " " + results[i].score);
                    }
                } else if (this.docnoMapping == null) {
                    for (int i2 = 0; i2 < results.length; i2++) {
                        resultWriter.println(str3 + " Q0 " + results[i2].docno + " " + (i2 + 1) + " " + results[i2].score + " " + str);
                    }
                } else {
                    for (int i3 = 0; i3 < results.length; i3++) {
                        resultWriter.println(str3 + " Q0 " + this.docnoMapping.getDocid(results[i3].docno) + " " + (i3 + 1) + " " + results[i3].score + " " + str);
                    }
                }
            }
        }
    }

    private void parseParameters(String[] strArr) throws ConfigurationException {
        for (String str : strArr) {
            try {
                Document parse = DocumentBuilderFactory.newInstance().newDocumentBuilder().parse((InputStream) this.fs.open(new Path(str)));
                parseModels(parse);
                parseIndexLocation(parse);
            } catch (IOException e) {
                throw new ConfigurationException(e.getMessage());
            } catch (ParserConfigurationException e2) {
                throw new ConfigurationException(e2.getMessage());
            } catch (SAXException e3) {
                throw new ConfigurationException(e3.getMessage());
            }
        }
        if (this.queries.isEmpty()) {
            throw new ConfigurationException("Must specify at least one query!");
        }
        if (this.models.isEmpty()) {
            throw new ConfigurationException("Must specify at least one model!");
        }
        if (this.indexPath == null) {
            throw new ConfigurationException("Must specify an index!");
        }
    }

    private void parseModels(Document document) throws ConfigurationException {
        NodeList elementsByTagName = document.getElementsByTagName("model");
        if (elementsByTagName.getLength() > 0) {
            this.internalInputFiles = new String[elementsByTagName.getLength()];
            this.internalOutputFiles = new String[elementsByTagName.getLength()];
        }
        for (int i = 0; i < elementsByTagName.getLength(); i++) {
            Node item = elementsByTagName.item(i);
            String attributeValue = XMLTools.getAttributeValue(item, "id", (String) null);
            String attributeValue2 = XMLTools.getAttributeValue(item, "internalInputFile", (String) null);
            if (attributeValue2 != null && attributeValue2.trim().length() == 0) {
                attributeValue2 = null;
            }
            this.internalInputFiles[i] = attributeValue2;
            String attributeValue3 = XMLTools.getAttributeValue(item, "internalOutputFile", (String) null);
            if (attributeValue3 != null && attributeValue3.trim().length() == 0) {
                attributeValue3 = null;
            }
            this.internalOutputFiles[i] = attributeValue3;
            if (attributeValue == null) {
                throw new ConfigurationException("Must specify a model id for every model!");
            }
            NodeList childNodes = item.getChildNodes();
            for (int i2 = 0; i2 < childNodes.getLength(); i2++) {
                Node item2 = childNodes.item(i2);
                if ("expander".equals(item2.getNodeName())) {
                    if (this.expanders.containsKey(attributeValue)) {
                        throw new ConfigurationException("Only one expander allowed per model!");
                    }
                    this.expanders.put(attributeValue, item2);
                }
            }
        }
    }

    private void parseIndexLocation(Document document) throws ConfigurationException {
        if (document.getElementsByTagName("index").getLength() <= 0 || this.indexPath == null) {
            return;
        }
        if (this.indexPath.toLowerCase().indexOf("wt10g") != -1) {
            this.dataCollection = "wt10g";
            RetrievalEnvironment.dataCollection = "wt10g";
        } else if (this.indexPath.toLowerCase().indexOf("gov2") != -1) {
            this.dataCollection = "gov2";
            RetrievalEnvironment.dataCollection = "gov2";
        } else if (this.indexPath.toLowerCase().indexOf("clue") != -1) {
            this.dataCollection = "clue";
            RetrievalEnvironment.dataCollection = "clue";
        } else {
            System.out.println("Invalid data collection " + this.indexPath);
            System.exit(-1);
        }
    }
}
