package io.moderne.ai.research;

import io.moderne.ai.AgentGenerativeModelClient;
import io.moderne.ai.EmbeddingModelClient;
import io.moderne.ai.RelatedModelClient;
import io.moderne.ai.table.CodeSearch;
import io.moderne.ai.table.EmbeddingPerformance;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Option;
import org.openrewrite.Preconditions;
import org.openrewrite.ScanningRecipe;
import org.openrewrite.SourceFile;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.lang.Nullable;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaSourceFile;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.marker.SearchResult;

/* loaded from: input_file:io/moderne/ai/research/FindCodeThatResembles.class */
public final class FindCodeThatResembles extends ScanningRecipe<Accumulator> {

    @Option(displayName = "Resembles", description = "The text, either a natural language description or a code sample, that you are looking for.", example = "HTTP request with Content-Type application/json")
    private final String resembles;

    @Option(displayName = "top k methods", description = "Since AI based matching has a higher latency than rules based matching, we do a first pass to find the top k methods using embeddings. To narrow the scope, you can specify the top k methods as method filters.", example = "1000")
    private final int k;
    private final transient CodeSearch codeSearchTable = new CodeSearch(this);
    private final transient EmbeddingPerformance performance = new EmbeddingPerformance(this);

    /* loaded from: input_file:io/moderne/ai/research/FindCodeThatResembles$Accumulator.class */
    public static final class Accumulator {
        private final int k;
        private final PriorityQueue<MethodSignatureWithDistance> methodSignaturesQueue = new PriorityQueue<>(Comparator.comparingDouble((v0) -> {
            return v0.getDistance();
        }));
        private final EmbeddingModelClient embeddingModelClient = EmbeddingModelClient.getInstance();

        @Nullable
        private List<MethodMatcher> topMethodPatterns;

        public void add(String str, String str2, String str3) {
            Iterator<MethodSignatureWithDistance> it = this.methodSignaturesQueue.iterator();
            while (it.hasNext()) {
                if (it.next().methodPattern.equals(str2)) {
                    return;
                }
            }
            this.methodSignaturesQueue.add(new MethodSignatureWithDistance(str, str2, (float) this.embeddingModelClient.getDistance(str3, str)));
        }

        public List<MethodMatcher> getMethodMatchersTopK() {
            if (this.topMethodPatterns != null) {
                return this.topMethodPatterns;
            }
            this.topMethodPatterns = new ArrayList(this.k);
            for (int i = 0; i < this.k && !this.methodSignaturesQueue.isEmpty(); i++) {
                String methodPattern = this.methodSignaturesQueue.poll().getMethodPattern();
                if (!methodPattern.contains("<constructor>")) {
                    methodPattern = methodPattern.replaceAll("<[^>]*>", "");
                }
                this.topMethodPatterns.add(new MethodMatcher(methodPattern, true));
            }
            return this.topMethodPatterns;
        }

        public int getK() {
            return this.k;
        }

        public PriorityQueue<MethodSignatureWithDistance> getMethodSignaturesQueue() {
            return this.methodSignaturesQueue;
        }

        public EmbeddingModelClient getEmbeddingModelClient() {
            return this.embeddingModelClient;
        }

        public List<MethodMatcher> getTopMethodPatterns() {
            return this.topMethodPatterns;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Accumulator)) {
                return false;
            }
            Accumulator accumulator = (Accumulator) obj;
            if (getK() != accumulator.getK()) {
                return false;
            }
            PriorityQueue<MethodSignatureWithDistance> methodSignaturesQueue = getMethodSignaturesQueue();
            PriorityQueue<MethodSignatureWithDistance> methodSignaturesQueue2 = accumulator.getMethodSignaturesQueue();
            if (methodSignaturesQueue == null) {
                if (methodSignaturesQueue2 != null) {
                    return false;
                }
            } else if (!methodSignaturesQueue.equals(methodSignaturesQueue2)) {
                return false;
            }
            EmbeddingModelClient embeddingModelClient = getEmbeddingModelClient();
            EmbeddingModelClient embeddingModelClient2 = accumulator.getEmbeddingModelClient();
            if (embeddingModelClient == null) {
                if (embeddingModelClient2 != null) {
                    return false;
                }
            } else if (!embeddingModelClient.equals(embeddingModelClient2)) {
                return false;
            }
            List<MethodMatcher> topMethodPatterns = getTopMethodPatterns();
            List<MethodMatcher> topMethodPatterns2 = accumulator.getTopMethodPatterns();
            return topMethodPatterns == null ? topMethodPatterns2 == null : topMethodPatterns.equals(topMethodPatterns2);
        }

        public int hashCode() {
            int k = (1 * 59) + getK();
            PriorityQueue<MethodSignatureWithDistance> methodSignaturesQueue = getMethodSignaturesQueue();
            int hashCode = (k * 59) + (methodSignaturesQueue == null ? 43 : methodSignaturesQueue.hashCode());
            EmbeddingModelClient embeddingModelClient = getEmbeddingModelClient();
            int hashCode2 = (hashCode * 59) + (embeddingModelClient == null ? 43 : embeddingModelClient.hashCode());
            List<MethodMatcher> topMethodPatterns = getTopMethodPatterns();
            return (hashCode2 * 59) + (topMethodPatterns == null ? 43 : topMethodPatterns.hashCode());
        }

        public String toString() {
            return "FindCodeThatResembles.Accumulator(k=" + getK() + ", methodSignaturesQueue=" + getMethodSignaturesQueue() + ", embeddingModelClient=" + getEmbeddingModelClient() + ", topMethodPatterns=" + getTopMethodPatterns() + ")";
        }

        public Accumulator(int i) {
            this.k = i;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/moderne/ai/research/FindCodeThatResembles$MethodSignatureWithDistance.class */
    public static final class MethodSignatureWithDistance {
        private final String methodSignature;
        private final String methodPattern;
        private final double distance;

        public MethodSignatureWithDistance(String str, String str2, double d) {
            this.methodSignature = str;
            this.methodPattern = str2;
            this.distance = d;
        }

        public String getMethodSignature() {
            return this.methodSignature;
        }

        public String getMethodPattern() {
            return this.methodPattern;
        }

        public double getDistance() {
            return this.distance;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof MethodSignatureWithDistance)) {
                return false;
            }
            MethodSignatureWithDistance methodSignatureWithDistance = (MethodSignatureWithDistance) obj;
            if (Double.compare(getDistance(), methodSignatureWithDistance.getDistance()) != 0) {
                return false;
            }
            String methodSignature = getMethodSignature();
            String methodSignature2 = methodSignatureWithDistance.getMethodSignature();
            if (methodSignature == null) {
                if (methodSignature2 != null) {
                    return false;
                }
            } else if (!methodSignature.equals(methodSignature2)) {
                return false;
            }
            String methodPattern = getMethodPattern();
            String methodPattern2 = methodSignatureWithDistance.getMethodPattern();
            return methodPattern == null ? methodPattern2 == null : methodPattern.equals(methodPattern2);
        }

        public int hashCode() {
            long doubleToLongBits = Double.doubleToLongBits(getDistance());
            int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            String methodSignature = getMethodSignature();
            int hashCode = (i * 59) + (methodSignature == null ? 43 : methodSignature.hashCode());
            String methodPattern = getMethodPattern();
            return (hashCode * 59) + (methodPattern == null ? 43 : methodPattern.hashCode());
        }

        public String toString() {
            return "FindCodeThatResembles.MethodSignatureWithDistance(methodSignature=" + getMethodSignature() + ", methodPattern=" + getMethodPattern() + ", distance=" + getDistance() + ")";
        }
    }

    public String getDisplayName() {
        return "Find method invocations that resemble a pattern";
    }

    public String getDescription() {
        return "This recipe uses two phase AI approach to find a method invocation that resembles a search string.";
    }

    /* renamed from: getInitialValue, reason: merged with bridge method [inline-methods] */
    public Accumulator m8getInitialValue(ExecutionContext executionContext) {
        return new Accumulator(this.k);
    }

    public TreeVisitor<?, ExecutionContext> getScanner(final Accumulator accumulator) {
        return new JavaIsoVisitor<ExecutionContext>() { // from class: io.moderne.ai.research.FindCodeThatResembles.1
            private String extractTypeName(String str) {
                return str.replace("<.*>", "").substring(str.lastIndexOf(46) + 1);
            }

            /* renamed from: visitCompilationUnit, reason: merged with bridge method [inline-methods] */
            public J.CompilationUnit m9visitCompilationUnit(J.CompilationUnit compilationUnit, ExecutionContext executionContext) {
                Set usedMethods = compilationUnit.getTypesInUse().getUsedMethods();
                Accumulator accumulator2 = accumulator;
                usedMethods.forEach(method -> {
                    String str = extractTypeName((String) Optional.ofNullable(method.getReturnType()).map((v0) -> {
                        return v0.toString();
                    }).orElse("")) + " " + method.getName();
                    String[] strArr = new String[method.getParameterTypes().size()];
                    for (int i = 0; i < method.getParameterTypes().size(); i++) {
                        strArr[i] = extractTypeName(((JavaType) method.getParameterTypes().get(i)).toString()) + " " + ((String) method.getParameterNames().get(i));
                    }
                    accumulator2.add(str + "(" + String.join(", ", strArr) + ")", ((String) Optional.ofNullable(method.getDeclaringType()).map((v0) -> {
                        return v0.toString();
                    }).orElse("")) + " " + method.getName() + "(..)", FindCodeThatResembles.this.resembles);
                });
                return super.visitCompilationUnit(compilationUnit, executionContext);
            }
        };
    }

    public TreeVisitor<?, ExecutionContext> getVisitor(Accumulator accumulator) {
        final List<MethodMatcher> methodMatchersTopK = accumulator.getMethodMatchersTopK();
        ArrayList arrayList = new ArrayList(methodMatchersTopK.size());
        Iterator<MethodMatcher> it = methodMatchersTopK.iterator();
        while (it.hasNext()) {
            arrayList.add(new UsesMethod(it.next()));
        }
        return Preconditions.check(Preconditions.or((TreeVisitor[]) arrayList.toArray(new TreeVisitor[0])), new JavaIsoVisitor<ExecutionContext>() { // from class: io.moderne.ai.research.FindCodeThatResembles.2
            public boolean isAcceptable(SourceFile sourceFile, ExecutionContext executionContext) {
                return sourceFile instanceof J.CompilationUnit;
            }

            /* renamed from: visitCompilationUnit, reason: merged with bridge method [inline-methods] */
            public J.CompilationUnit m11visitCompilationUnit(J.CompilationUnit compilationUnit, ExecutionContext executionContext) {
                getCursor().putMessage("count", new AtomicInteger());
                getCursor().putMessage("max", new AtomicLong());
                getCursor().putMessage("histogram", new EmbeddingPerformance.Histogram());
                try {
                    J.CompilationUnit visitCompilationUnit = super.visitCompilationUnit(compilationUnit, executionContext);
                    if (((AtomicInteger) getCursor().getMessage("count", new AtomicInteger())).get() > 0) {
                        FindCodeThatResembles.this.performance.insertRow(executionContext, new EmbeddingPerformance.Row(compilationUnit.getSourcePath().toString(), ((AtomicInteger) Objects.requireNonNull((AtomicInteger) getCursor().getMessage("count"))).get(), ((EmbeddingPerformance.Histogram) Objects.requireNonNull((EmbeddingPerformance.Histogram) getCursor().getMessage("histogram"))).getBuckets(), Duration.ofNanos(((AtomicLong) Objects.requireNonNull((AtomicLong) getCursor().getMessage("max"))).get())));
                    }
                    return visitCompilationUnit;
                } catch (Throwable th) {
                    if (((AtomicInteger) getCursor().getMessage("count", new AtomicInteger())).get() > 0) {
                        FindCodeThatResembles.this.performance.insertRow(executionContext, new EmbeddingPerformance.Row(compilationUnit.getSourcePath().toString(), ((AtomicInteger) Objects.requireNonNull((AtomicInteger) getCursor().getMessage("count"))).get(), ((EmbeddingPerformance.Histogram) Objects.requireNonNull((EmbeddingPerformance.Histogram) getCursor().getMessage("histogram"))).getBuckets(), Duration.ofNanos(((AtomicLong) Objects.requireNonNull((AtomicLong) getCursor().getMessage("max"))).get())));
                    }
                    throw th;
                }
            }

            /* renamed from: visitMethodInvocation, reason: merged with bridge method [inline-methods] */
            public J.MethodInvocation m10visitMethodInvocation(J.MethodInvocation methodInvocation, ExecutionContext executionContext) {
                boolean z;
                boolean z2 = false;
                Iterator it2 = methodMatchersTopK.iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        break;
                    }
                    if (((MethodMatcher) it2.next()).matches(methodInvocation)) {
                        z2 = true;
                        break;
                    }
                }
                if (!z2) {
                    return super.visitMethodInvocation(methodInvocation, executionContext);
                }
                RelatedModelClient.Relatedness relatedness = RelatedModelClient.getInstance().getRelatedness(FindCodeThatResembles.this.resembles, methodInvocation.printTrimmed(getCursor()));
                for (Duration duration : relatedness.getEmbeddingTimings()) {
                    ((AtomicInteger) Objects.requireNonNull((AtomicInteger) getCursor().getNearestMessage("count"))).incrementAndGet();
                    ((EmbeddingPerformance.Histogram) Objects.requireNonNull((EmbeddingPerformance.Histogram) getCursor().getNearestMessage("histogram"))).add(duration);
                    AtomicLong atomicLong = (AtomicLong) getCursor().getNearestMessage("max");
                    if (((AtomicLong) Objects.requireNonNull(atomicLong)).get() < duration.toNanos()) {
                        atomicLong.set(duration.toNanos());
                    }
                }
                int isRelated = relatedness.isRelated();
                boolean z3 = false;
                if (isRelated == 0) {
                    z = AgentGenerativeModelClient.getInstance().isRelated(FindCodeThatResembles.this.resembles, methodInvocation.printTrimmed(getCursor()), 0.5932d);
                    z3 = true;
                } else {
                    z = isRelated == 1;
                }
                String path = ((JavaSourceFile) getCursor().firstEnclosing(JavaSourceFile.class)).getSourcePath().toString();
                if (z || z3) {
                    FindCodeThatResembles.this.codeSearchTable.insertRow(executionContext, new CodeSearch.Row(path, methodInvocation.printTrimmed(getCursor()), FindCodeThatResembles.this.resembles, isRelated, z3 ? z ? 1 : -1 : 0));
                }
                return z ? SearchResult.found(methodInvocation) : super.visitMethodInvocation(methodInvocation, executionContext);
            }
        });
    }

    public FindCodeThatResembles(String str, int i) {
        this.resembles = str;
        this.k = i;
    }

    public String getResembles() {
        return this.resembles;
    }

    public int getK() {
        return this.k;
    }

    public CodeSearch getCodeSearchTable() {
        return this.codeSearchTable;
    }

    public EmbeddingPerformance getPerformance() {
        return this.performance;
    }

    public String toString() {
        return "FindCodeThatResembles(resembles=" + getResembles() + ", k=" + getK() + ", codeSearchTable=" + getCodeSearchTable() + ", performance=" + getPerformance() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof FindCodeThatResembles)) {
            return false;
        }
        FindCodeThatResembles findCodeThatResembles = (FindCodeThatResembles) obj;
        if (!findCodeThatResembles.canEqual(this) || getK() != findCodeThatResembles.getK()) {
            return false;
        }
        String resembles = getResembles();
        String resembles2 = findCodeThatResembles.getResembles();
        return resembles == null ? resembles2 == null : resembles.equals(resembles2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof FindCodeThatResembles;
    }

    public int hashCode() {
        int k = (1 * 59) + getK();
        String resembles = getResembles();
        return (k * 59) + (resembles == null ? 43 : resembles.hashCode());
    }
}
