package de.breakpointsec.pushdown;

import com.google.common.base.Joiner;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import de.breakpointsec.pushdown.fsm.Transition;
import de.breakpointsec.pushdown.fsm.WeightedAutomaton;
import de.breakpointsec.pushdown.rules.NormalRule;
import de.breakpointsec.pushdown.rules.PopRule;
import de.breakpointsec.pushdown.rules.PushRule;
import de.breakpointsec.pushdown.rules.Rule;
import de.breakpointsec.pushdown.weights.Semiring;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:de/breakpointsec/pushdown/WPDS.class */
public abstract class WPDS<L, S, W extends Semiring> {
    protected final Set<PushRule<L, S, W>> pushRules = Sets.newHashSet();
    protected final Set<PopRule<L, S, W>> popRules = Sets.newHashSet();
    protected Set<S> states = Sets.newHashSet();
    protected final Set<NormalRule<L, S, W>> normalRules = Sets.newHashSet();
    private final Multimap<S, Transition<L, S>> transitionsInto = HashMultimap.create();

    public boolean addRule(Rule<L, S, W> rule) {
        return addRuleInternal(rule);
    }

    private boolean addRuleInternal(Rule<L, S, W> rule) {
        if (rule instanceof PushRule) {
            return this.pushRules.add((PushRule) rule);
        }
        if (rule instanceof PopRule) {
            return this.popRules.add((PopRule) rule);
        }
        if (rule instanceof NormalRule) {
            return this.normalRules.add((NormalRule) rule);
        }
        throw new RuntimeException("Try to add a rule of wrong type");
    }

    public Set<NormalRule<L, S, W>> getNormalRules() {
        return this.normalRules;
    }

    public Set<PopRule<L, S, W>> getPopRules() {
        return this.popRules;
    }

    public Set<PushRule<L, S, W>> getPushRules() {
        return this.pushRules;
    }

    public Set<Rule<L, S, W>> getAllRules() {
        HashSet newHashSet = Sets.newHashSet();
        newHashSet.addAll(this.normalRules);
        newHashSet.addAll(this.popRules);
        newHashSet.addAll(this.pushRules);
        return newHashSet;
    }

    @Deprecated
    public Set<Rule<L, S, W>> getRulesStarting(S s, L l) {
        HashSet hashSet = new HashSet();
        getRulesStartingWithinSet(s, l, this.popRules, hashSet);
        getRulesStartingWithinSet(s, l, this.normalRules, hashSet);
        getRulesStartingWithinSet(s, l, this.pushRules, hashSet);
        return hashSet;
    }

    @Deprecated
    private void getRulesStartingWithinSet(S s, L l, Set<? extends Rule<L, S, W>> set, Set<Rule<L, S, W>> set2) {
        for (Rule<L, S, W> rule : set) {
            if (rule.getS1().equals(s) && rule.getL1().equals(l)) {
                set2.add(rule);
            }
        }
    }

    public Set<NormalRule<L, S, W>> getNormalRulesEnding(S s, L l) {
        Set<NormalRule<L, S, W>> normalRules = getNormalRules();
        HashSet hashSet = new HashSet();
        for (NormalRule<L, S, W> normalRule : normalRules) {
            if (normalRule.getS2().equals(s) && normalRule.getL2().equals(l)) {
                hashSet.add(normalRule);
            }
        }
        return hashSet;
    }

    public Set<PushRule<L, S, W>> getPushRulesEnding(S s, L l) {
        Set<PushRule<L, S, W>> pushRules = getPushRules();
        HashSet hashSet = new HashSet();
        for (PushRule<L, S, W> pushRule : pushRules) {
            if (pushRule.getS2().equals(s) && pushRule.getL2().equals(l)) {
                hashSet.add(pushRule);
            }
        }
        return hashSet;
    }

    @Deprecated
    public Set<S> getStates() {
        HashSet newHashSet = Sets.newHashSet();
        for (Rule<L, S, W> rule : getAllRules()) {
            newHashSet.add(rule.getS1());
            newHashSet.add(rule.getS2());
        }
        return newHashSet;
    }

    protected boolean updatePostStar(Transition<L, S> transition, W w, Rule rule, WeightedAutomaton<L, S, W> weightedAutomaton, LinkedList<Transition<L, S>> linkedList) throws IllegalTransitionException {
        if (!weightedAutomaton.getTransitions().contains(transition)) {
            weightedAutomaton.addTransition(transition, null);
        }
        W weightFor = weightedAutomaton.getWeightFor(transition);
        Semiring combineWith = weightFor == null ? w : weightFor.combineWith(w);
        boolean z = !combineWith.equals(weightFor);
        if (z) {
            weightedAutomaton.setWeightForTransition(transition, combineWith);
            linkedList.add(transition);
        }
        return z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void poststar(WeightedAutomaton<L, S, W> weightedAutomaton) throws IllegalTransitionException {
        HashMap hashMap = new HashMap();
        for (PushRule pushRule : getPushRules()) {
            Object createState = weightedAutomaton.createState(pushRule.getS2(), pushRule.getL2());
            pushRule.getCallSite();
            hashMap.put(pushRule, createState);
        }
        LinkedList newLinkedList = Lists.newLinkedList(weightedAutomaton.getTransitions());
        while (!newLinkedList.isEmpty()) {
            Transition transition = (Transition) newLinkedList.pop();
            if (transition.getLabel().equals(epsilon())) {
                for (Transition transition2 : weightedAutomaton.getTransitionsOutOf(transition.getTarget())) {
                    updatePostStar(new Transition(transition.getStart(), transition2.getLabel(), transition2.getTarget()), weightedAutomaton.getWeightFor(transition2).extendWith(weightedAutomaton.getWeightFor(transition)), null, weightedAutomaton, newLinkedList);
                }
            } else {
                for (PopRule popRule : getPopRules()) {
                    if (popRule.getS1().equals(transition.getStart()) && popRule.getL1().equals(transition.getLabel())) {
                        updatePostStar(new Transition(popRule.getS2(), epsilon(), transition.getTarget()), weightedAutomaton.getWeightFor(transition).extendWith((Semiring) popRule.getWeight()), popRule, weightedAutomaton, newLinkedList);
                    }
                }
                for (NormalRule normalRule : getNormalRules()) {
                    if (normalRule.getS1().equals(transition.getStart()) && normalRule.getL1().equals(transition.getLabel())) {
                        updatePostStar(new Transition(normalRule.getS2(), normalRule.getL2(), transition.getTarget()), weightedAutomaton.getWeightFor(transition).extendWith((Semiring) normalRule.getWeight()), normalRule, weightedAutomaton, newLinkedList);
                    }
                }
                for (PushRule pushRule2 : getPushRules()) {
                    if (pushRule2.getS1().equals(transition.getStart()) && pushRule2.getL1().equals(transition.getLabel())) {
                        Object obj = hashMap.get(pushRule2);
                        if (obj == null) {
                            System.out.println("UNEXPECTED: No generated state found for rule " + pushRule2.toString());
                        }
                        updatePostStar(new Transition(pushRule2.getS2(), pushRule2.getL2(), obj), weightedAutomaton.getOne(), pushRule2, weightedAutomaton, newLinkedList);
                        Transition transition3 = new Transition(obj, pushRule2.getCallSite(), transition.getTarget());
                        Semiring extendWith = weightedAutomaton.getWeightFor(transition).extendWith((Semiring) pushRule2.getWeight());
                        if (updatePostStar(transition3, extendWith, pushRule2, weightedAutomaton, newLinkedList)) {
                            for (Transition transition4 : (Set) weightedAutomaton.getTransitionsInto(obj).stream().filter(transition5 -> {
                                return transition5.getLabel().equals(weightedAutomaton.epsilon()) && transition5.getTarget().equals(obj);
                            }).collect(Collectors.toSet())) {
                                updatePostStar(new Transition(transition4.getStart(), pushRule2.getCallSite(), transition.getTarget()), extendWith.extendWith(weightedAutomaton.getWeightFor(transition4)), pushRule2, weightedAutomaton, newLinkedList);
                            }
                        }
                    }
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public WeightedAutomaton<L, S, W> prestar(WeightedAutomaton<L, S, W> weightedAutomaton) throws IllegalTransitionException {
        LinkedList newLinkedList = Lists.newLinkedList(weightedAutomaton.getTransitions());
        Iterator it = Sets.newHashSet(weightedAutomaton.getTransitions()).iterator();
        while (it.hasNext()) {
            weightedAutomaton.combineWeightForTransition((Transition) it.next(), weightedAutomaton.getOne());
        }
        for (PopRule popRule : getPopRules()) {
            updatePrestar(newLinkedList, new Transition(popRule.getS1(), popRule.getL1(), popRule.getS2()), (Semiring) popRule.getWeight(), weightedAutomaton);
        }
        while (!newLinkedList.isEmpty()) {
            Transition<L, S> transition = (Transition) newLinkedList.pop();
            for (NormalRule normalRule : getNormalRulesEnding(transition.getStart(), transition.getLabel())) {
                updatePrestar(newLinkedList, new Transition(normalRule.getS1(), normalRule.getL1(), transition.getTarget()), ((Semiring) normalRule.getWeight()).extendWith(weightedAutomaton.getWeightFor(transition)), weightedAutomaton);
            }
            for (PushRule pushRule : getPushRules()) {
                if (pushRule.getS2().equals(transition.getStart()) && pushRule.getL2().equals(transition.getLabel())) {
                    Iterator it2 = Sets.newHashSet(weightedAutomaton.getTransitionsOutOf(transition.getTarget())).iterator();
                    while (it2.hasNext()) {
                        Transition<L, S> transition2 = (Transition) it2.next();
                        if (transition2.getLabel().equals(pushRule.getCallSite())) {
                            updatePrestar(newLinkedList, new Transition(pushRule.getS1(), pushRule.getL1(), transition2.getTarget()), ((Semiring) pushRule.getWeight()).extendWith(weightedAutomaton.getWeightFor(transition)).extendWith(weightedAutomaton.getWeightFor(transition2)), weightedAutomaton);
                        }
                    }
                }
            }
            for (PushRule pushRule2 : getPushRules()) {
                if (pushRule2.getL2().equals(transition.getLabel())) {
                    Iterator it3 = Sets.newHashSet(weightedAutomaton.getTransitionsOutOf(pushRule2.getS2())).iterator();
                    while (it3.hasNext()) {
                        Transition<L, S> transition3 = (Transition) it3.next();
                        if (transition3.getLabel().equals(pushRule2.getCallSite()) && transition3.getTarget().equals(transition.getStart())) {
                            updatePrestar(newLinkedList, new Transition(pushRule2.getS1(), pushRule2.getL1(), transition.getTarget()), ((Semiring) pushRule2.getWeight()).extendWith(weightedAutomaton.getWeightFor(transition3)).extendWith(weightedAutomaton.getWeightFor(transition)), weightedAutomaton);
                        }
                    }
                }
            }
        }
        return weightedAutomaton;
    }

    protected void updatePrestar(LinkedList<Transition<L, S>> linkedList, Transition<L, S> transition, W w, WeightedAutomaton<L, S, W> weightedAutomaton) throws IllegalTransitionException {
        if (!weightedAutomaton.getTransitions().contains(transition)) {
            weightedAutomaton.addTransition(transition, null);
        }
        W weightFor = weightedAutomaton.getWeightFor(transition);
        Semiring combineWith = weightFor == null ? w : weightFor.combineWith(w);
        if (!combineWith.equals(weightFor)) {
            weightedAutomaton.setWeightForTransition(transition, combineWith);
            linkedList.add(transition);
        }
    }

    public String toString() {
        return (((((((("WPDS (#Rules: " + getAllRules().size() + ")\n") + "\tNormalRules:\n\t\t") + Joiner.on("\n\t\t").join(this.normalRules)) + "\n") + "\tPopRules:\n\t\t") + Joiner.on("\n\t\t").join(this.popRules)) + "\n") + "\tPushRules:\n\t\t") + Joiner.on("\n\t\t").join(this.pushRules);
    }

    public abstract L epsilon();
}
