package com.thesett.aima.logic.fol.wam.compiler;

import com.thesett.aima.logic.fol.AllTermsVisitor;
import com.thesett.aima.logic.fol.Clause;
import com.thesett.aima.logic.fol.DelegatingAllTermsVisitor;
import com.thesett.aima.logic.fol.Functor;
import com.thesett.aima.logic.fol.FunctorName;
import com.thesett.aima.logic.fol.LogicCompiler;
import com.thesett.aima.logic.fol.LogicCompilerObserver;
import com.thesett.aima.logic.fol.Sentence;
import com.thesett.aima.logic.fol.Term;
import com.thesett.aima.logic.fol.TermUtils;
import com.thesett.aima.logic.fol.Variable;
import com.thesett.aima.logic.fol.VariableAndFunctorInterner;
import com.thesett.aima.logic.fol.compiler.PositionalTermTraverserImpl;
import com.thesett.aima.logic.fol.compiler.TermWalker;
import com.thesett.aima.logic.fol.wam.TermWalkers;
import com.thesett.aima.logic.fol.wam.builtins.BuiltIn;
import com.thesett.aima.logic.fol.wam.builtins.Cut;
import com.thesett.aima.logic.fol.wam.compiler.DefaultBuiltIn;
import com.thesett.aima.logic.fol.wam.compiler.WAMInstruction;
import com.thesett.aima.logic.fol.wam.optimizer.Optimizer;
import com.thesett.aima.logic.fol.wam.optimizer.WAMOptimizer;
import com.thesett.aima.logic.fol.wam.printer.WAMCompiledPredicatePrintingVisitor;
import com.thesett.aima.logic.fol.wam.printer.WAMCompiledQueryPrintingVisitor;
import com.thesett.aima.search.util.Searches;
import com.thesett.aima.search.util.backtracking.DepthFirstBacktrackingSearch;
import com.thesett.aima.search.util.uninformed.BreadthFirstSearch;
import com.thesett.common.parsing.SourceCodeException;
import com.thesett.common.util.SizeableLinkedList;
import com.thesett.common.util.doublemaps.SymbolKey;
import com.thesett.common.util.doublemaps.SymbolTable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

/* loaded from: input_file:com/thesett/aima/logic/fol/wam/compiler/InstructionCompiler.class */
public class InstructionCompiler extends DefaultBuiltIn implements LogicCompiler<Clause, WAMCompiledPredicate, WAMCompiledQuery> {
    protected Queue<SymbolKey> predicatesInScope;
    private LogicCompilerObserver<WAMCompiledPredicate, WAMCompiledQuery> observer;
    protected int numPermanentVars;
    protected int cutLevelVarSlot;
    protected int scope;
    private SymbolTable<Integer, String, Object> scopeTable;
    private final Optimizer optimizer;

    /* loaded from: input_file:com/thesett/aima/logic/fol/wam/compiler/InstructionCompiler$QueryRegisterAllocatingVisitor.class */
    public class QueryRegisterAllocatingVisitor extends DelegatingAllTermsVisitor {
        private final SymbolTable<Integer, String, Object> symbolTable;
        private final Map<Byte, Integer> varNames;

        public QueryRegisterAllocatingVisitor(SymbolTable<Integer, String, Object> symbolTable, Map<Byte, Integer> map, AllTermsVisitor allTermsVisitor) {
            super(allTermsVisitor);
            this.symbolTable = symbolTable;
            this.varNames = map;
        }

        public void visit(Variable variable) {
            if (this.symbolTable.get(variable.getSymbolKey(), SymbolTableKeys.SYMKEY_ALLOCATION) == null) {
                if (variable.isAnonymous()) {
                    InstructionCompiler instructionCompiler = InstructionCompiler.this;
                    int i = instructionCompiler.lastAllocatedTempReg;
                    instructionCompiler.lastAllocatedTempReg = i + 1;
                    int i2 = (i & 255) | 256;
                    this.symbolTable.put(variable.getSymbolKey(), SymbolTableKeys.SYMKEY_ALLOCATION, Integer.valueOf(i2));
                    this.varNames.put(Byte.valueOf((byte) i2), Integer.valueOf(variable.getName()));
                } else {
                    InstructionCompiler instructionCompiler2 = InstructionCompiler.this;
                    int i3 = instructionCompiler2.numPermanentVars;
                    instructionCompiler2.numPermanentVars = i3 + 1;
                    int i4 = (i3 & 255) | 512;
                    this.symbolTable.put(variable.getSymbolKey(), SymbolTableKeys.SYMKEY_ALLOCATION, Integer.valueOf(i4));
                    this.varNames.put(Byte.valueOf((byte) i4), Integer.valueOf(variable.getName()));
                }
            }
            super.visit(variable);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public InstructionCompiler(SymbolTable<Integer, String, Object> symbolTable, VariableAndFunctorInterner variableAndFunctorInterner) {
        super(symbolTable, variableAndFunctorInterner);
        this.predicatesInScope = new LinkedList();
        this.cutLevelVarSlot = -1;
        this.optimizer = new WAMOptimizer(symbolTable, variableAndFunctorInterner);
    }

    public void setCompilerObserver(LogicCompilerObserver<WAMCompiledPredicate, WAMCompiledQuery> logicCompilerObserver) {
        this.observer = logicCompilerObserver;
    }

    public void endScope() throws SourceCodeException {
        SymbolKey poll = this.predicatesInScope.poll();
        while (true) {
            SymbolKey symbolKey = poll;
            if (symbolKey == null) {
                this.symbolTable.clearUpToLowMark(SymbolTableKeys.SYMKEY_PREDICATES);
                this.scopeTable = null;
                this.scope++;
                return;
            }
            List list = (List) this.scopeTable.get(symbolKey, SymbolTableKeys.SYMKEY_PREDICATES);
            int size = list.size();
            int i = 0;
            boolean z = size > 1;
            WAMCompiledPredicate wAMCompiledPredicate = null;
            Iterator it = list.iterator();
            while (it.hasNext()) {
                Clause clause = (Clause) it.next();
                if (wAMCompiledPredicate == null) {
                    wAMCompiledPredicate = new WAMCompiledPredicate(clause.getHead().getName());
                }
                compileClause(clause, wAMCompiledPredicate, i == 0, i >= size - 1, z, i);
                i++;
                it.remove();
            }
            WAMCompiledPredicate wAMCompiledPredicate2 = (WAMCompiledPredicate) this.optimizer.apply(wAMCompiledPredicate);
            displayCompiledPredicate(wAMCompiledPredicate2);
            this.observer.onCompilation(wAMCompiledPredicate2);
            this.symbolTable.setLowMark(symbolKey, SymbolTableKeys.SYMKEY_PREDICATES);
            poll = this.predicatesInScope.poll();
        }
    }

    public void compile(Sentence<Clause> sentence) throws SourceCodeException {
        Clause clause = (Clause) sentence.getT();
        if (clause.isQuery()) {
            compileQuery(clause);
            return;
        }
        if (this.scopeTable == null) {
            this.scopeTable = this.symbolTable.enterScope(Integer.valueOf(this.scope));
        }
        SymbolKey symbolKey = this.scopeTable.getSymbolKey(Integer.valueOf(clause.getHead().getName()));
        List list = (List) this.scopeTable.get(symbolKey, SymbolTableKeys.SYMKEY_PREDICATES);
        if (list == null) {
            list = new LinkedList();
            this.scopeTable.put(symbolKey, SymbolTableKeys.SYMKEY_PREDICATES, list);
            this.predicatesInScope.offer(symbolKey);
        }
        list.add(clause);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v79, types: [com.thesett.aima.logic.fol.wam.builtins.BuiltIn] */
    private void compileClause(Clause clause, WAMCompiledPredicate wAMCompiledPredicate, boolean z, boolean z2, boolean z3, int i) throws SourceCodeException {
        WAMCompiledClause wAMCompiledClause = new WAMCompiledClause(wAMCompiledPredicate);
        boolean z4 = clause.getBody() == null;
        boolean z5 = clause.getBody() != null && clause.getBody().length == 1;
        this.seenRegisters = new TreeSet();
        this.lastAllocatedTempReg = findMaxArgumentsInClause(clause);
        this.numPermanentVars = 0;
        this.cutLevelVarSlot = -1;
        SizeableLinkedList sizeableLinkedList = new SizeableLinkedList();
        SizeableLinkedList sizeableLinkedList2 = new SizeableLinkedList();
        Set findFreeNonAnonymousVariables = TermUtils.findFreeNonAnonymousVariables(clause);
        TreeSet treeSet = new TreeSet();
        Iterator it = findFreeNonAnonymousVariables.iterator();
        while (it.hasNext()) {
            treeSet.add(Integer.valueOf(((Variable) it.next()).getName()));
        }
        allocatePermanentProgramRegisters(clause);
        gatherPositionAndOccurrenceInfo(clause);
        FunctorName functorFunctorName = this.interner.getFunctorFunctorName(clause.getHead());
        WAMLabel wAMLabel = new WAMLabel(functorFunctorName, i);
        WAMLabel wAMLabel2 = new WAMLabel(functorFunctorName, i + 1);
        if (z && !z2 && z3) {
            sizeableLinkedList.add(new WAMInstruction(wAMLabel, WAMInstruction.WAMInstructionSet.TryMeElse, wAMLabel2));
        } else if (!z && !z2 && z3) {
            sizeableLinkedList.add(new WAMInstruction(wAMLabel, WAMInstruction.WAMInstructionSet.RetryMeElse, wAMLabel2));
        } else if (z2 && z3) {
            sizeableLinkedList.add(new WAMInstruction(wAMLabel, WAMInstruction.WAMInstructionSet.TrustMe));
        }
        if (!z4 && !z5) {
            sizeableLinkedList.add(new WAMInstruction(WAMInstruction.WAMInstructionSet.Allocate));
        }
        if (this.cutLevelVarSlot >= 0) {
            sizeableLinkedList.add(new WAMInstruction(WAMInstruction.WAMInstructionSet.GetLevel, (byte) this.cutLevelVarSlot));
        }
        wAMCompiledClause.addInstructions(sizeableLinkedList);
        Functor head = clause.getHead();
        wAMCompiledClause.addInstructions(head, compileHead(head));
        if (!z4) {
            Functor[] body = clause.getBody();
            int i2 = 0;
            while (i2 < body.length) {
                Functor functor = body[i2];
                boolean z6 = i2 == body.length - 1;
                boolean z7 = i2 == 0;
                Integer num = (Integer) this.symbolTable.get(functor.getSymbolKey(), SymbolTableKeys.SYMKEY_PERM_VARS_REMAINING);
                InstructionCompiler instructionCompiler = functor instanceof BuiltIn ? (BuiltIn) functor : this;
                wAMCompiledClause.addInstructions(functor, instructionCompiler.compileBodyArguments(functor, i2 == 0, functorFunctorName, i2));
                wAMCompiledClause.addInstructions(functor, instructionCompiler.compileBodyCall(functor, z7, z6, z5, num.intValue()));
                i2++;
            }
        }
        if (z4) {
            sizeableLinkedList2.add(new WAMInstruction(WAMInstruction.WAMInstructionSet.Proceed));
        }
        wAMCompiledClause.addInstructions(sizeableLinkedList2);
    }

    private void compileQuery(Clause clause) throws SourceCodeException {
        TreeMap treeMap = new TreeMap();
        this.seenRegisters = new TreeSet();
        this.lastAllocatedTempReg = findMaxArgumentsInClause(clause);
        this.numPermanentVars = 0;
        this.cutLevelVarSlot = -1;
        Collection<WAMInstruction> sizeableLinkedList = new SizeableLinkedList<>();
        Collection<WAMInstruction> sizeableLinkedList2 = new SizeableLinkedList<>();
        Set findFreeNonAnonymousVariables = TermUtils.findFreeNonAnonymousVariables(clause);
        TreeSet treeSet = new TreeSet();
        Iterator it = findFreeNonAnonymousVariables.iterator();
        while (it.hasNext()) {
            treeSet.add(Integer.valueOf(((Variable) it.next()).getName()));
        }
        allocatePermanentQueryRegisters(clause, treeMap);
        gatherPositionAndOccurrenceInfo(clause);
        WAMCompiledQuery wAMCompiledQuery = new WAMCompiledQuery(treeMap, treeSet);
        sizeableLinkedList.add(new WAMInstruction(WAMInstruction.WAMInstructionSet.AllocateN, (byte) 1, (byte) (this.numPermanentVars & 255)));
        if (this.cutLevelVarSlot >= 0) {
            sizeableLinkedList.add(new WAMInstruction(WAMInstruction.WAMInstructionSet.GetLevel, (byte) 2, (byte) this.cutLevelVarSlot));
        }
        wAMCompiledQuery.addInstructions(sizeableLinkedList);
        BuiltIn[] body = clause.getBody();
        FunctorName functorName = new FunctorName("tq", 0);
        int i = 0;
        while (i < body.length) {
            BuiltIn builtIn = body[i];
            boolean z = i == 0;
            BuiltIn builtIn2 = builtIn instanceof BuiltIn ? builtIn : this;
            wAMCompiledQuery.addInstructions(builtIn, builtIn2.compileBodyArguments(builtIn, false, functorName, i));
            wAMCompiledQuery.addInstructions(builtIn, builtIn2.compileBodyCall(builtIn, z, false, false, this.numPermanentVars));
            i++;
        }
        sizeableLinkedList2.add(new WAMInstruction(WAMInstruction.WAMInstructionSet.Suspend));
        sizeableLinkedList2.add(new WAMInstruction(WAMInstruction.WAMInstructionSet.Deallocate));
        wAMCompiledQuery.addInstructions(sizeableLinkedList2);
        WAMCompiledQuery wAMCompiledQuery2 = (WAMCompiledQuery) this.optimizer.apply(wAMCompiledQuery);
        displayCompiledQuery(wAMCompiledQuery2);
        this.observer.onQueryCompilation(wAMCompiledQuery2);
    }

    private int findMaxArgumentsInClause(Clause clause) {
        Functor head = clause.getHead();
        int arity = head != null ? head.getArity() : 0;
        Functor[] body = clause.getBody();
        if (body != null) {
            for (Functor functor : body) {
                int arity2 = functor.getArity();
                arity = arity2 > arity ? arity2 : arity;
            }
        }
        return arity;
    }

    private SizeableLinkedList<WAMInstruction> compileHead(Functor functor) {
        WAMInstruction wAMInstruction;
        WAMInstruction wAMInstruction2;
        SizeableLinkedList<WAMInstruction> sizeableLinkedList = new SizeableLinkedList<>();
        allocateArgumentRegisters(functor);
        allocateTemporaryRegisters(functor);
        BreadthFirstSearch breadthFirstSearch = new BreadthFirstSearch();
        breadthFirstSearch.reset();
        breadthFirstSearch.addStartState(functor);
        Iterator allSolutions = Searches.allSolutions(breadthFirstSearch);
        allSolutions.next();
        int arity = functor.getArity();
        int i = 0;
        while (allSolutions.hasNext()) {
            Functor functor2 = (Term) allSolutions.next();
            if (functor2.isFunctor()) {
                Functor functor3 = functor2;
                int intValue = ((Integer) this.symbolTable.get(functor3.getSymbolKey(), SymbolTableKeys.SYMKEY_ALLOCATION)).intValue();
                sizeableLinkedList.add(new WAMInstruction(WAMInstruction.WAMInstructionSet.GetStruc, (byte) ((intValue & 65280) >> 8), (byte) (intValue & 255), this.interner.getFunctorFunctorName(functor3), functor3));
                int arity2 = functor3.getArity();
                for (int i2 = 0; i2 < arity2; i2++) {
                    Term argument = functor3.getArgument(i2);
                    int intValue2 = ((Integer) this.symbolTable.get(argument.getSymbolKey(), SymbolTableKeys.SYMKEY_ALLOCATION)).intValue();
                    byte b = (byte) ((intValue2 & 65280) >> 8);
                    byte b2 = (byte) (intValue2 & 255);
                    if (!this.seenRegisters.contains(Integer.valueOf(intValue2))) {
                        this.seenRegisters.add(Integer.valueOf(intValue2));
                        wAMInstruction2 = new WAMInstruction(WAMInstruction.WAMInstructionSet.UnifyVar, b, b2, argument);
                        this.symbolTable.put(argument.getSymbolKey(), SymbolTableKeys.SYMKEY_VARIABLE_INTRO, DefaultBuiltIn.VarIntroduction.Unify);
                    } else if (isLocalVariable((DefaultBuiltIn.VarIntroduction) this.symbolTable.get(argument.getSymbolKey(), SymbolTableKeys.SYMKEY_VARIABLE_INTRO), b)) {
                        wAMInstruction2 = new WAMInstruction(WAMInstruction.WAMInstructionSet.UnifyLocalVal, b, b2, argument);
                        this.symbolTable.put(argument.getSymbolKey(), SymbolTableKeys.SYMKEY_VARIABLE_INTRO, (Object) null);
                    } else {
                        wAMInstruction2 = new WAMInstruction(WAMInstruction.WAMInstructionSet.UnifyVal, b, b2, argument);
                    }
                    sizeableLinkedList.add(wAMInstruction2);
                }
            } else if (i < arity) {
                Variable variable = (Variable) functor2;
                int intValue3 = ((Integer) this.symbolTable.get(variable.getSymbolKey(), SymbolTableKeys.SYMKEY_ALLOCATION)).intValue();
                byte b3 = (byte) ((intValue3 & 65280) >> 8);
                byte b4 = (byte) (intValue3 & 255);
                if (this.seenRegisters.contains(Integer.valueOf(intValue3))) {
                    wAMInstruction = new WAMInstruction(WAMInstruction.WAMInstructionSet.GetVal, b3, b4, (byte) (i & 255));
                } else {
                    this.seenRegisters.add(Integer.valueOf(intValue3));
                    wAMInstruction = new WAMInstruction(WAMInstruction.WAMInstructionSet.GetVar, b3, b4, (byte) (i & 255));
                    this.symbolTable.put(variable.getSymbolKey(), SymbolTableKeys.SYMKEY_VARIABLE_INTRO, DefaultBuiltIn.VarIntroduction.Get);
                }
                sizeableLinkedList.add(wAMInstruction);
            }
            i++;
        }
        return sizeableLinkedList;
    }

    private void allocatePermanentProgramRegisters(Clause clause) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashSet<Variable> hashSet = new HashSet();
        if (clause.getBody() != null) {
            for (int length = clause.getBody().length - 1; length >= 1; length--) {
                for (Variable variable : TermUtils.findFreeVariables(clause.getBody()[length])) {
                    Integer num = (Integer) hashMap.get(variable);
                    hashMap.put(variable, Integer.valueOf(num == null ? 1 : num.intValue() + 1));
                    if (!hashMap2.containsKey(variable)) {
                        hashMap2.put(variable, Integer.valueOf(length));
                    }
                    if (variable instanceof Cut.CutLevelVariable) {
                        hashSet.add(variable);
                    }
                }
            }
        }
        if (clause.getHead() != null) {
            hashSet.addAll(TermUtils.findFreeVariables(clause.getHead()));
        }
        if (clause.getBody() != null && clause.getBody().length > 0) {
            hashSet.addAll(TermUtils.findFreeVariables(clause.getBody()[0]));
        }
        for (Variable variable2 : hashSet) {
            Integer num2 = (Integer) hashMap.get(variable2);
            hashMap.put(variable2, Integer.valueOf(num2 == null ? 1 : num2.intValue() + 1));
            if (!hashMap2.containsKey(variable2)) {
                hashMap2.put(variable2, 0);
            }
        }
        ArrayList<Map.Entry> arrayList = new ArrayList(hashMap2.entrySet());
        Collections.sort(arrayList, new Comparator<Map.Entry<Variable, Integer>>() { // from class: com.thesett.aima.logic.fol.wam.compiler.InstructionCompiler.1
            @Override // java.util.Comparator
            public int compare(Map.Entry<Variable, Integer> entry, Map.Entry<Variable, Integer> entry2) {
                return entry2.getValue().compareTo(entry.getValue());
            }
        });
        int[] iArr = new int[clause.getBody() != null ? clause.getBody().length : 0];
        for (Map.Entry entry : arrayList) {
            Variable variable3 = (Variable) entry.getKey();
            Integer num3 = (Integer) hashMap.get(variable3);
            int intValue = ((Integer) entry.getValue()).intValue();
            if (num3 != null && num3.intValue() > 1) {
                int i = this.numPermanentVars;
                this.numPermanentVars = i + 1;
                this.symbolTable.put(variable3.getSymbolKey(), SymbolTableKeys.SYMKEY_ALLOCATION, Integer.valueOf((i & 255) | 512));
                if (variable3 instanceof Cut.CutLevelVariable) {
                    this.cutLevelVarSlot = this.numPermanentVars - 1;
                }
                iArr[intValue] = iArr[intValue] + 1;
            }
        }
        int i2 = 0;
        for (int length2 = iArr.length - 1; length2 >= 0; length2--) {
            this.symbolTable.put(clause.getBody()[length2].getSymbolKey(), SymbolTableKeys.SYMKEY_PERM_VARS_REMAINING, Integer.valueOf(i2));
            i2 += iArr[length2];
        }
    }

    private void allocatePermanentQueryRegisters(Term term, Map<Byte, Integer> map) {
        QueryRegisterAllocatingVisitor queryRegisterAllocatingVisitor = new QueryRegisterAllocatingVisitor(this.symbolTable, map, null);
        PositionalTermTraverserImpl positionalTermTraverserImpl = new PositionalTermTraverserImpl();
        positionalTermTraverserImpl.setContextChangeVisitor(queryRegisterAllocatingVisitor);
        new TermWalker(new DepthFirstBacktrackingSearch(), positionalTermTraverserImpl, queryRegisterAllocatingVisitor).walk(term);
    }

    private void gatherPositionAndOccurrenceInfo(Term term) {
        PositionalTermTraverserImpl positionalTermTraverserImpl = new PositionalTermTraverserImpl();
        PositionAndOccurrenceVisitor positionAndOccurrenceVisitor = new PositionAndOccurrenceVisitor(this.interner, this.symbolTable, positionalTermTraverserImpl);
        positionalTermTraverserImpl.setContextChangeVisitor(positionAndOccurrenceVisitor);
        new TermWalker(new DepthFirstBacktrackingSearch(), positionalTermTraverserImpl, positionAndOccurrenceVisitor).walk(term);
    }

    private void displayCompiledPredicate(Term term) {
        TermWalkers.positionalWalker(new WAMCompiledPredicatePrintingVisitor(this.interner, this.symbolTable, new StringBuffer())).walk(term);
    }

    private void displayCompiledQuery(Term term) {
        TermWalkers.positionalWalker(new WAMCompiledQueryPrintingVisitor(this.interner, this.symbolTable, new StringBuffer())).walk(term);
    }
}
