package pascal.taie.analysis.pta.toolkit.zipper;

import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import pascal.taie.analysis.graph.flowgraph.InstanceNode;
import pascal.taie.analysis.graph.flowgraph.Node;
import pascal.taie.analysis.graph.flowgraph.ObjectFlowGraph;
import pascal.taie.analysis.graph.flowgraph.VarNode;
import pascal.taie.analysis.pta.PointerAnalysisResult;
import pascal.taie.analysis.pta.toolkit.PointerAnalysisResultEx;
import pascal.taie.analysis.pta.toolkit.PointerAnalysisResultExImpl;
import pascal.taie.ir.exp.Var;
import pascal.taie.ir.stmt.New;
import pascal.taie.language.classes.JMethod;
import pascal.taie.language.type.Type;
import pascal.taie.util.MutableInt;
import pascal.taie.util.Timer;
import pascal.taie.util.collection.Maps;
import pascal.taie.util.collection.Sets;

/* loaded from: input_file:pascal/taie/analysis/pta/toolkit/zipper/Zipper.class */
public class Zipper {
    private static final Logger logger = LogManager.getLogger(Zipper.class);
    private static final float DEFAULT_PV = 0.05f;
    private final PointerAnalysisResultEx pta;
    private final boolean isExpress;
    private final float pv;
    private final ObjectAllocationGraph oag = (ObjectAllocationGraph) Timer.runAndCount(() -> {
        return new ObjectAllocationGraph(this.pta);
    }, "Building OAG", Level.INFO);
    private final PotentialContextElement pce = (PotentialContextElement) Timer.runAndCount(() -> {
        return new PotentialContextElement(this.pta, this.oag);
    }, "Building PCE", Level.INFO);
    private final ObjectFlowGraph ofg;
    private AtomicInteger totalPFGNodes;
    private AtomicInteger totalPFGEdges;
    private Map<Type, Collection<JMethod>> pcmMap;
    private int pcmThreshold;
    private Map<JMethod, MutableInt> methodPts;

    public static Set<JMethod> run(PointerAnalysisResult pointerAnalysisResult, String str) {
        boolean z;
        float parseFloat;
        if (str.equals("zipper")) {
            z = false;
            parseFloat = 1.0f;
        } else if (str.equals("zipper-e")) {
            z = true;
            parseFloat = 0.05f;
        } else {
            if (!str.startsWith("zipper-e=")) {
                throw new IllegalArgumentException("Illegal Zipper argument: " + str);
            }
            z = true;
            parseFloat = Float.parseFloat(str.split("=")[1]);
        }
        return new Zipper(pointerAnalysisResult, z, parseFloat).selectPrecisionCriticalMethods();
    }

    public Zipper(PointerAnalysisResult pointerAnalysisResult, boolean z, float f) {
        this.pta = new PointerAnalysisResultExImpl(pointerAnalysisResult, true);
        this.isExpress = z;
        this.pv = f;
        this.ofg = pointerAnalysisResult.getObjectFlowGraph();
        logger.info("{} nodes in OFG", Integer.valueOf(this.ofg.getNodes().size()));
        Logger logger2 = logger;
        Stream<Node> stream = this.ofg.getNodes().stream();
        ObjectFlowGraph objectFlowGraph = this.ofg;
        Objects.requireNonNull(objectFlowGraph);
        logger2.info("{} edges in OFG", Integer.valueOf(stream.mapToInt((v1) -> {
            return r3.getOutDegreeOf(v1);
        }).sum()));
    }

    public Set<JMethod> selectPrecisionCriticalMethods() {
        this.totalPFGNodes = new AtomicInteger(0);
        this.totalPFGEdges = new AtomicInteger(0);
        this.pcmMap = Maps.newConcurrentMap(1024);
        if (this.isExpress) {
            PointerAnalysisResult base = this.pta.getBase();
            int i = 0;
            this.methodPts = Maps.newMap(base.getCallGraph().getNumberOfMethods());
            for (Var var : base.getVars()) {
                int size = base.getPointsToSet(var).size();
                if (size > 0) {
                    i += size;
                    this.methodPts.computeIfAbsent(var.getMethod(), jMethod -> {
                        return new MutableInt(0);
                    }).add(size);
                }
            }
            this.pcmThreshold = (int) (this.pv * i);
        }
        Set<Type> objectTypes = this.pta.getObjectTypes();
        Timer.runAndCount(() -> {
            objectTypes.parallelStream().forEach(this::analyze);
        }, "Building and analyzing PFG", Level.INFO);
        logger.info("#types: {}", Integer.valueOf(objectTypes.size()));
        logger.info("#avg. nodes in PFG: {}", Integer.valueOf(this.totalPFGNodes.get() / objectTypes.size()));
        logger.info("#avg. edges in PFG: {}", Integer.valueOf(this.totalPFGEdges.get() / objectTypes.size()));
        Set<JMethod> set = (Set) this.pcmMap.values().stream().flatMap((v0) -> {
            return v0.stream();
        }).collect(Collectors.toUnmodifiableSet());
        logger.info("#precision-critical methods: {}", Integer.valueOf(set.size()));
        return set;
    }

    private void analyze(Type type) {
        PrecisionFlowGraph build = new PFGBuilder(this.pta, this.ofg, this.oag, this.pce, type).build();
        this.totalPFGNodes.addAndGet(build.getNumberOfNodes());
        AtomicInteger atomicInteger = this.totalPFGEdges;
        Stream<Node> stream = build.getNodes().stream();
        Objects.requireNonNull(build);
        atomicInteger.addAndGet(stream.mapToInt((v1) -> {
            return r2.getOutDegreeOf(v1);
        }).sum());
        Set<JMethod> precisionCriticalMethods = getPrecisionCriticalMethods(build);
        if (precisionCriticalMethods.isEmpty()) {
            return;
        }
        this.pcmMap.put(type, precisionCriticalMethods);
    }

    private Set<JMethod> getPrecisionCriticalMethods(PrecisionFlowGraph precisionFlowGraph) {
        Stream filter = getFlowNodes(precisionFlowGraph).stream().map(Zipper::node2Method).filter((v0) -> {
            return Objects.nonNull(v0);
        });
        Set<JMethod> pceMethodsOf = this.pce.pceMethodsOf(precisionFlowGraph.getType());
        Objects.requireNonNull(pceMethodsOf);
        Set<JMethod> set = (Set) filter.filter((v1) -> {
            return r1.contains(v1);
        }).collect(Collectors.toUnmodifiableSet());
        if (this.isExpress) {
            int i = 0;
            Iterator<JMethod> it = set.iterator();
            while (it.hasNext()) {
                i += this.methodPts.get(it.next()).intValue();
            }
            if (i > this.pcmThreshold) {
                set = Set.of();
            }
        }
        return set;
    }

    private static Set<Node> getFlowNodes(PrecisionFlowGraph precisionFlowGraph) {
        Set<Node> newSet = Sets.newSet();
        for (VarNode varNode : precisionFlowGraph.getOutNodes()) {
            ArrayDeque arrayDeque = new ArrayDeque();
            arrayDeque.add(varNode);
            while (!arrayDeque.isEmpty()) {
                Node node = (Node) arrayDeque.poll();
                if (newSet.add(node)) {
                    Stream<Node> stream = precisionFlowGraph.getPredsOf(node).stream();
                    Objects.requireNonNull(newSet);
                    Stream<Node> filter = stream.filter(Predicate.not((v1) -> {
                        return r1.contains(v1);
                    }));
                    Objects.requireNonNull(arrayDeque);
                    filter.forEach((v1) -> {
                        r1.add(v1);
                    });
                }
            }
        }
        return newSet;
    }

    @Nullable
    private static JMethod node2Method(Node node) {
        if (node instanceof VarNode) {
            return ((VarNode) node).getVar().getMethod();
        }
        Object allocation = ((InstanceNode) node).getBase().getAllocation();
        if (allocation instanceof New) {
            return ((New) allocation).getContainer();
        }
        return null;
    }
}
