package de.mirkosertic.bytecoder.relooper;

import de.mirkosertic.bytecoder.core.BytecodeOpcodeAddress;
import de.mirkosertic.bytecoder.ssa.BreakExpression;
import de.mirkosertic.bytecoder.ssa.ContinueExpression;
import de.mirkosertic.bytecoder.ssa.ControlFlowGraph;
import de.mirkosertic.bytecoder.ssa.Expression;
import de.mirkosertic.bytecoder.ssa.ExpressionList;
import de.mirkosertic.bytecoder.ssa.ExpressionListContainer;
import de.mirkosertic.bytecoder.ssa.GotoExpression;
import de.mirkosertic.bytecoder.ssa.Label;
import de.mirkosertic.bytecoder.ssa.RegionNode;
import de.mirkosertic.bytecoder.ssa.ReturnExpression;
import java.io.PrintStream;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.Stack;

/* loaded from: input_file:de/mirkosertic/bytecoder/relooper/Relooper.class */
public class Relooper {

    /* loaded from: input_file:de/mirkosertic/bytecoder/relooper/Relooper$Block.class */
    public static abstract class Block {
        private final Set<RegionNode> entries;
        private final Label label;
        private int labelRequired;

        protected Block(Set<RegionNode> set, String str) {
            this.entries = set;
            StringBuilder sb = new StringBuilder();
            for (RegionNode regionNode : set) {
                if (sb.length() > 0) {
                    sb.append("_");
                }
                sb.append(regionNode.getStartAddress().getAddress());
            }
            this.labelRequired = 0;
            this.label = new Label(str + sb.toString());
        }

        public boolean isLabelRequired() {
            return this.labelRequired > 0;
        }

        public void requireLabel() {
            this.labelRequired++;
        }

        public Set<RegionNode> entries() {
            return this.entries;
        }

        public abstract Block next();

        public Label label() {
            return this.label;
        }
    }

    /* loaded from: input_file:de/mirkosertic/bytecoder/relooper/Relooper$LoopBlock.class */
    public static class LoopBlock extends Block {
        private final Block inner;
        private final Block next;

        public LoopBlock(Set<RegionNode> set, Block block, Block block2) {
            super(set, "L_");
            this.inner = block;
            this.next = block2;
        }

        public Block inner() {
            return this.inner;
        }

        @Override // de.mirkosertic.bytecoder.relooper.Relooper.Block
        public Block next() {
            return this.next;
        }
    }

    /* loaded from: input_file:de/mirkosertic/bytecoder/relooper/Relooper$MultipleBlock.class */
    public static class MultipleBlock extends Block {
        private final Set<Block> handlers;
        private final Block next;

        public MultipleBlock(Set<RegionNode> set, Set<Block> set2, Block block) {
            super(set, "M_");
            this.handlers = set2;
            this.next = block;
        }

        public Set<Block> handlers() {
            return this.handlers;
        }

        @Override // de.mirkosertic.bytecoder.relooper.Relooper.Block
        public Block next() {
            return this.next;
        }
    }

    /* loaded from: input_file:de/mirkosertic/bytecoder/relooper/Relooper$SimpleBlock.class */
    public static class SimpleBlock extends Block {
        private final RegionNode internalLabel;
        private final Block next;
        private final Map<BytecodeOpcodeAddress, Jump> knownJumps;

        /* loaded from: input_file:de/mirkosertic/bytecoder/relooper/Relooper$SimpleBlock$Jump.class */
        public static class Jump {
            private final JumpType type;
            private final Label label;

            public Jump(JumpType jumpType, Label label) {
                this.type = jumpType;
                this.label = label;
            }

            public JumpType getType() {
                return this.type;
            }

            public Label getLabel() {
                return this.label;
            }
        }

        /* loaded from: input_file:de/mirkosertic/bytecoder/relooper/Relooper$SimpleBlock$JumpType.class */
        public enum JumpType {
            BREAK,
            CONTINUE
        }

        public SimpleBlock(Set<RegionNode> set, RegionNode regionNode, Block block) {
            super(set, "S_");
            this.internalLabel = regionNode;
            this.next = block;
            this.knownJumps = new HashMap();
        }

        public void registerJump(BytecodeOpcodeAddress bytecodeOpcodeAddress, Jump jump) {
            this.knownJumps.put(bytecodeOpcodeAddress, jump);
        }

        public Jump jumpTo(BytecodeOpcodeAddress bytecodeOpcodeAddress) {
            return this.knownJumps.get(bytecodeOpcodeAddress);
        }

        public RegionNode internalLabel() {
            return this.internalLabel;
        }

        @Override // de.mirkosertic.bytecoder.relooper.Relooper.Block
        public Block next() {
            return this.next;
        }
    }

    public Block reloop(ControlFlowGraph controlFlowGraph) {
        HashSet hashSet = new HashSet();
        RegionNode startNode = controlFlowGraph.startNode();
        hashSet.add(controlFlowGraph.startNode());
        Block reloop = reloop(hashSet, startNode.dominatedNodes());
        replaceGotosIn(reloop);
        return reloop;
    }

    private void replaceGotosIn(Block block) {
        replaceGotosIn(new Stack<>(), block);
    }

    private void replaceGotosIn(Stack<Block> stack, Block block) {
        if (block == null) {
            return;
        }
        if (!(block instanceof SimpleBlock)) {
            if (block instanceof LoopBlock) {
                LoopBlock loopBlock = (LoopBlock) block;
                stack.push(loopBlock);
                replaceGotosIn(stack, loopBlock.inner());
                replaceGotosIn(stack, loopBlock.next());
                stack.pop();
                return;
            }
            if (!(block instanceof MultipleBlock)) {
                throw new IllegalStateException("Don't know how to handle " + block);
            }
            MultipleBlock multipleBlock = (MultipleBlock) block;
            stack.push(multipleBlock);
            Iterator<Block> it = multipleBlock.handlers().iterator();
            while (it.hasNext()) {
                replaceGotosIn(stack, it.next());
            }
            replaceGotosIn(stack, multipleBlock.next());
            stack.pop();
            return;
        }
        SimpleBlock simpleBlock = (SimpleBlock) block;
        stack.push(simpleBlock);
        RegionNode internalLabel = simpleBlock.internalLabel();
        replaceGotosIn(stack, simpleBlock, internalLabel, internalLabel.getExpressions());
        replaceGotosIn(stack, simpleBlock.next());
        stack.pop();
        Expression lastExpression = simpleBlock.internalLabel.getExpressions().lastExpression();
        if (lastExpression instanceof BreakExpression) {
            BreakExpression breakExpression = (BreakExpression) lastExpression;
            if (Objects.equals(breakExpression.blockToBreak().name(), simpleBlock.label().name())) {
                breakExpression.silent();
            }
        }
        for (Map.Entry<RegionNode.Edge, RegionNode> entry : internalLabel.getSuccessors().entrySet()) {
            RegionNode value = entry.getValue();
            if (entry.getKey().getType() == RegionNode.EdgeType.NORMAL) {
                int size = stack.size() - 1;
                while (true) {
                    if (size >= 0) {
                        Block block2 = stack.get(size);
                        if (block2.next() != null && block2.next().entries().contains(value)) {
                            block2.requireLabel();
                            simpleBlock.registerJump(value.getStartAddress(), new SimpleBlock.Jump(SimpleBlock.JumpType.BREAK, block2.label()));
                            break;
                        } else {
                            if (block2.entries().contains(value)) {
                                block2.requireLabel();
                                simpleBlock.registerJump(value.getStartAddress(), new SimpleBlock.Jump(SimpleBlock.JumpType.CONTINUE, block2.label()));
                                break;
                            }
                            size--;
                        }
                    }
                }
            } else {
                Iterator<Block> it2 = stack.iterator();
                while (true) {
                    if (it2.hasNext()) {
                        Block next = it2.next();
                        if (next.entries().contains(value)) {
                            next.requireLabel();
                            simpleBlock.registerJump(value.getStartAddress(), new SimpleBlock.Jump(SimpleBlock.JumpType.CONTINUE, next.label()));
                            break;
                        }
                    }
                }
            }
        }
    }

    private void replaceGotosIn(Stack<Block> stack, SimpleBlock simpleBlock, RegionNode regionNode, ExpressionList expressionList) {
        for (Expression expression : expressionList.toList()) {
            if (expression instanceof ExpressionListContainer) {
                Iterator<ExpressionList> it = ((ExpressionListContainer) expression).getExpressionLists().iterator();
                while (it.hasNext()) {
                    replaceGotosIn(stack, simpleBlock, regionNode, it.next());
                }
            }
            if (expression instanceof GotoExpression) {
                GotoExpression gotoExpression = (GotoExpression) expression;
                boolean z = false;
                for (Map.Entry<RegionNode.Edge, RegionNode> entry : regionNode.getSuccessors().entrySet()) {
                    if (Objects.equals(entry.getValue().getStartAddress(), gotoExpression.getJumpTarget())) {
                        z = true;
                        RegionNode value = entry.getValue();
                        if (entry.getKey().getType() == RegionNode.EdgeType.NORMAL) {
                            boolean z2 = false;
                            int size = stack.size() - 1;
                            while (true) {
                                if (size < 0) {
                                    break;
                                }
                                Block block = stack.get(size);
                                if (block.next() != null && block.next().entries().contains(value)) {
                                    block.requireLabel();
                                    BreakExpression breakExpression = new BreakExpression(block.label(), value.getStartAddress());
                                    expressionList.replace(gotoExpression, breakExpression);
                                    if ((block.next() instanceof SimpleBlock) && block.next().entries().size() == 1) {
                                        breakExpression.noSetRequired();
                                    }
                                    z2 = true;
                                } else {
                                    if (block.entries().contains(value)) {
                                        block.requireLabel();
                                        expressionList.replace(gotoExpression, new ContinueExpression(block.label(), value.getStartAddress()));
                                        z2 = true;
                                        break;
                                    }
                                    size--;
                                }
                            }
                            if (!z2) {
                                throw new IllegalStateException("Failed to jump to " + value.getStartAddress().getAddress() + " from " + simpleBlock.label().name() + " : no matching entry found!");
                            }
                        } else {
                            boolean z3 = false;
                            Iterator<Block> it2 = stack.iterator();
                            while (true) {
                                if (!it2.hasNext()) {
                                    break;
                                }
                                Block next = it2.next();
                                if (next.entries().contains(value)) {
                                    z3 = true;
                                    next.requireLabel();
                                    expressionList.replace(gotoExpression, new ContinueExpression(next.label(), value.getStartAddress()));
                                    break;
                                }
                            }
                            if (!z3) {
                                throw new IllegalStateException("No back edge target found for " + value.getStartAddress().getAddress());
                            }
                        }
                    }
                }
                if (!z) {
                    throw new IllegalStateException("No GOTO possible for " + gotoExpression.getJumpTarget().getAddress() + " in label " + simpleBlock.label().name());
                }
            }
        }
    }

    private Block reloop(Set<RegionNode> set, Set<RegionNode> set2) {
        if (set.isEmpty()) {
            return null;
        }
        Set<RegionNode> jumpTargetsOf = jumpTargetsOf(set2);
        if (set.size() == 1) {
            RegionNode next = set.iterator().next();
            if (!jumpTargetsOf.contains(next)) {
                return createSimpleBlock(set, set2, next);
            }
        }
        if (jumpTargetsOf.containsAll(set) && set.size() == 1) {
            RegionNode next2 = set.iterator().next();
            Set<RegionNode> dominatedNodes = next2.dominatedNodes();
            HashSet hashSet = new HashSet(set2);
            hashSet.removeAll(dominatedNodes);
            HashSet hashSet2 = new HashSet();
            for (RegionNode regionNode : next2.forwardReachableNodes()) {
                if (hashSet.contains(regionNode)) {
                    hashSet2.add(regionNode);
                }
            }
            return new LoopBlock(set, createSimpleBlock(set, dominatedNodes, next2), reloop(hashSet2, hashSet));
        }
        if (set.size() <= 1) {
            throw new IllegalStateException("What do do now?");
        }
        HashSet hashSet3 = new HashSet(set2);
        HashSet hashSet4 = new HashSet();
        HashMap hashMap = new HashMap();
        HashSet hashSet5 = new HashSet();
        for (RegionNode regionNode2 : set) {
            Set<RegionNode> dominatedNodes2 = regionNode2.dominatedNodes();
            hashMap.put(regionNode2, dominatedNodes2);
            hashSet3.removeAll(dominatedNodes2);
            HashSet hashSet6 = new HashSet();
            hashSet6.add(regionNode2);
            hashSet5.add(reloop(hashSet6, dominatedNodes2));
        }
        Iterator it = hashMap.entrySet().iterator();
        while (it.hasNext()) {
            for (RegionNode regionNode3 : allForwardJumpTargetsOf((Collection) ((Map.Entry) it.next()).getValue())) {
                if (hashSet3.contains(regionNode3)) {
                    hashSet4.add(regionNode3);
                }
            }
        }
        return new MultipleBlock(set, hashSet5, reloop(hashSet4, hashSet3));
    }

    private Block createSimpleBlock(Set<RegionNode> set, Set<RegionNode> set2, RegionNode regionNode) {
        HashSet hashSet = new HashSet();
        Set<RegionNode> dominatedNodes = regionNode.dominatedNodes();
        for (Map.Entry<RegionNode.Edge, RegionNode> entry : regionNode.getSuccessors().entrySet()) {
            if (entry.getKey().getType() == RegionNode.EdgeType.NORMAL && dominatedNodes.contains(entry.getValue())) {
                hashSet.add(entry.getValue());
            }
        }
        HashSet hashSet2 = new HashSet(set2);
        hashSet2.remove(regionNode);
        return new SimpleBlock(set, regionNode, reloop(hashSet, hashSet2));
    }

    private Set<RegionNode> jumpTargetsOf(Collection<RegionNode> collection) {
        HashSet hashSet = new HashSet();
        Iterator<RegionNode> it = collection.iterator();
        while (it.hasNext()) {
            for (Map.Entry<RegionNode.Edge, RegionNode> entry : it.next().getSuccessors().entrySet()) {
                if (collection.contains(entry.getValue())) {
                    hashSet.add(entry.getValue());
                }
            }
        }
        return hashSet;
    }

    private Set<RegionNode> allForwardJumpTargetsOf(Collection<RegionNode> collection) {
        HashSet hashSet = new HashSet();
        Iterator<RegionNode> it = collection.iterator();
        while (it.hasNext()) {
            for (Map.Entry<RegionNode.Edge, RegionNode> entry : it.next().getSuccessors().entrySet()) {
                if (entry.getKey().getType() == RegionNode.EdgeType.NORMAL) {
                    hashSet.add(entry.getValue());
                }
            }
        }
        return hashSet;
    }

    public void debugPrint(PrintStream printStream, Block block) {
        debugPrint(printStream, block, 0);
    }

    private void debugPrint(PrintStream printStream, Block block, int i) {
        printInset(printStream, i);
        if (block == null) {
            printStream.println(" NULL");
            return;
        }
        if (block instanceof SimpleBlock) {
            printStream.println("SimpleBlock " + block.label().name());
            SimpleBlock simpleBlock = (SimpleBlock) block;
            debugPrint(printStream, i + 1, simpleBlock.internalLabel().getExpressions());
            debugPrint(printStream, simpleBlock.next(), i + 1);
            return;
        }
        if (block instanceof LoopBlock) {
            printStream.println("Loop " + block.label().name());
            LoopBlock loopBlock = (LoopBlock) block;
            debugPrint(printStream, loopBlock.inner(), i + 1);
            debugPrint(printStream, loopBlock.next(), i + 1);
            return;
        }
        if (!(block instanceof MultipleBlock)) {
            throw new IllegalStateException("No handler for " + block);
        }
        printStream.println("Multiple " + block.label().name());
        MultipleBlock multipleBlock = (MultipleBlock) block;
        Iterator<Block> it = multipleBlock.handlers().iterator();
        while (it.hasNext()) {
            debugPrint(printStream, it.next(), i + 1);
        }
        debugPrint(printStream, multipleBlock.next, i + 1);
    }

    private void printInset(PrintStream printStream, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            printStream.print(" ");
        }
    }

    private void debugPrint(PrintStream printStream, int i, ExpressionList expressionList) {
        for (Expression expression : expressionList.toList()) {
            if (expression instanceof BreakExpression) {
                BreakExpression breakExpression = (BreakExpression) expression;
                printInset(printStream, i);
                printStream.println("Break " + breakExpression.blockToBreak().name() + " and jump to " + breakExpression.jumpTarget().getAddress());
            } else if (expression instanceof ContinueExpression) {
                printInset(printStream, i);
                printStream.println("Continue at " + ((ContinueExpression) expression).labelToReturnTo().name());
            } else if (expression instanceof ReturnExpression) {
                printInset(printStream, i);
                printStream.println("Return");
            } else if (expression instanceof GotoExpression) {
                throw new IllegalStateException("Goto should have been removed!");
            }
        }
    }
}
