package pascal.taie.analysis.pta.toolkit.scaler;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import pascal.taie.analysis.pta.PointerAnalysisResult;
import pascal.taie.analysis.pta.core.heap.Obj;
import pascal.taie.analysis.pta.toolkit.PointerAnalysisResultEx;
import pascal.taie.analysis.pta.toolkit.PointerAnalysisResultExImpl;
import pascal.taie.analysis.pta.toolkit.util.OAGs;
import pascal.taie.ir.exp.Var;
import pascal.taie.language.classes.JMethod;
import pascal.taie.language.type.NullType;
import pascal.taie.language.type.ReferenceType;
import pascal.taie.language.type.Type;
import pascal.taie.util.collection.Maps;
import pascal.taie.util.graph.Graph;

/* loaded from: input_file:pascal/taie/analysis/pta/toolkit/scaler/Scaler.class */
public class Scaler {
    private static final Logger logger = LogManager.getLogger(Scaler.class);
    private static final long DEFAULT_TST = 30000000;
    private final PointerAnalysisResultEx pta;
    private final long tst;
    private final ContextComputer bottomLine;
    private final List<ContextComputer> ctxComputers;
    private final Map<JMethod, Integer> ptsSize = Maps.newMap();

    public static Map<JMethod, String> run(PointerAnalysisResult pointerAnalysisResult, String str) {
        long parseInt;
        if (str.equals("scaler")) {
            parseInt = 30000000;
        } else {
            if (!str.startsWith("scaler=")) {
                throw new IllegalArgumentException("Illegal Scaler argument: " + str);
            }
            parseInt = Integer.parseInt(str.split("=")[1]);
        }
        return new Scaler(pointerAnalysisResult, parseInt).selectContext();
    }

    public Scaler(PointerAnalysisResult pointerAnalysisResult, long j) {
        this.pta = new PointerAnalysisResultExImpl(pointerAnalysisResult, true);
        this.tst = j;
        this.bottomLine = new _InsensitiveContextComputer(this.pta);
        Graph<Obj> build = OAGs.build(this.pta);
        this.ctxComputers = List.of(new _2ObjContextComputer(this.pta, build), new _2TypeContextComputer(this.pta, build), new _1TypeContextComputer(this.pta));
    }

    public Map<JMethod, String> selectContext() {
        logger.info("Scaler TST: {}", Long.valueOf(this.tst));
        Set<JMethod> set = (Set) this.pta.getBase().getCallGraph().reachableMethods().filter(jMethod -> {
            return !jMethod.isStatic();
        }).collect(Collectors.toUnmodifiableSet());
        long binarySearch = binarySearch(set, this.tst);
        Map<JMethod, String> map = (Map) set.stream().collect(Collectors.toMap(jMethod2 -> {
            return jMethod2;
        }, jMethod3 -> {
            return selectVariantFor(jMethod3, binarySearch);
        }));
        logCSMap(map);
        return map;
    }

    private long binarySearch(Set<JMethod> set, long j) {
        long asLong = set.stream().mapToLong(jMethod -> {
            return getWeight(jMethod, this.ctxComputers.get(0));
        }).max().getAsLong();
        long j2 = 0;
        long j3 = 0;
        while (true) {
            if (j2 > asLong) {
                break;
            }
            long j4 = (j2 + asLong) / 2;
            long totalAccumulativePTS = getTotalAccumulativePTS(set, j4);
            if (totalAccumulativePTS >= j) {
                if (totalAccumulativePTS <= j) {
                    j3 = j4;
                    break;
                }
                asLong = j4 - 1;
            } else {
                j3 = j4;
                j2 = j4 + 1;
            }
        }
        return j3;
    }

    private long getTotalAccumulativePTS(Set<JMethod> set, long j) {
        long j2 = 0;
        for (JMethod jMethod : set) {
            if (!isSpecialMethod(jMethod)) {
                j2 += getWeight(jMethod, selectContextComputer(jMethod, j));
            }
        }
        return j2;
    }

    private ContextComputer selectContextComputer(JMethod jMethod, long j) {
        ContextComputer contextComputer;
        if (!isSpecialMethod(jMethod)) {
            contextComputer = this.bottomLine;
            Iterator<ContextComputer> it = this.ctxComputers.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                ContextComputer next = it.next();
                if (getWeight(jMethod, next) <= j) {
                    contextComputer = next;
                    break;
                }
            }
        } else {
            contextComputer = this.ctxComputers.get(0);
        }
        return contextComputer;
    }

    private static boolean isSpecialMethod(JMethod jMethod) {
        return jMethod.getDeclaringClass().getName().startsWith("java.util.");
    }

    private long getWeight(JMethod jMethod, ContextComputer contextComputer) {
        return contextComputer.contextNumberOf(jMethod) * getCIPTSSizeOf(jMethod);
    }

    private int getCIPTSSizeOf(JMethod jMethod) {
        if (!this.ptsSize.containsKey(jMethod)) {
            this.ptsSize.put(jMethod, Integer.valueOf(jMethod.getIR().getVars().stream().filter(Scaler::isConcerned).mapToInt(var -> {
                return this.pta.getBase().getPointsToSet(var).size();
            }).sum()));
        }
        return this.ptsSize.get(jMethod).intValue();
    }

    private static boolean isConcerned(Var var) {
        Type type = var.getType();
        return (type instanceof ReferenceType) && !(type instanceof NullType);
    }

    private String selectVariantFor(JMethod jMethod, long j) {
        ContextComputer selectContextComputer = selectContextComputer(jMethod, j);
        logger.debug("{}, {}, {}", jMethod, selectContextComputer.getVariantName(), Integer.valueOf(selectContextComputer.contextNumberOf(jMethod)));
        return selectContextComputer.getVariantName();
    }

    private static void logCSMap(Map<JMethod, String> map) {
        if (logger.isDebugEnabled()) {
            Stream<Map.Entry<JMethod, String>> sorted = map.entrySet().stream().sorted((entry, entry2) -> {
                int compareTo = ((String) entry.getValue()).compareTo((String) entry2.getValue());
                return compareTo != 0 ? compareTo : ((JMethod) entry.getKey()).toString().compareTo(((JMethod) entry2.getKey()).toString());
            });
            Logger logger2 = logger;
            Objects.requireNonNull(logger2);
            sorted.forEach((v1) -> {
                r1.debug(v1);
            });
        }
    }
}
