package de.kherud.llama;

import de.kherud.llama.InferenceParameters;
import de.kherud.llama.ModelParameters;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.function.BiConsumer;
import org.jetbrains.annotations.Nullable;

/* loaded from: input_file:de/kherud/llama/LlamaModel.class */
public class LlamaModel implements AutoCloseable {
    private static final ModelParameters defaultModelParams;
    private static final InferenceParameters defaultInferenceParams;
    private long ctx;

    /* loaded from: input_file:de/kherud/llama/LlamaModel$LlamaIterator.class */
    private final class LlamaIterator implements Iterator<String> {
        private boolean hasNext;
        private long generatedCount;
        private long tokenIndex;

        private LlamaIterator(String str, InferenceParameters inferenceParameters) {
            this.hasNext = true;
            this.generatedCount = 0L;
            this.tokenIndex = 0L;
            LlamaModel.this.newAnswerIterator(str, inferenceParameters);
        }

        private LlamaIterator(String str, String str2, InferenceParameters inferenceParameters) {
            this.hasNext = true;
            this.generatedCount = 0L;
            this.tokenIndex = 0L;
            LlamaModel.this.newInfillIterator(str, str2, inferenceParameters);
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.hasNext;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public String next() {
            if (this.hasNext) {
                return new String(LlamaModel.this.getNext(this), StandardCharsets.UTF_8);
            }
            throw new NoSuchElementException();
        }
    }

    public LlamaModel(String str) {
        this(str, defaultModelParams);
    }

    public LlamaModel(String str, ModelParameters modelParameters) {
        loadModel(str, modelParameters);
    }

    public String complete(String str) {
        return complete(str, defaultInferenceParams);
    }

    public String complete(String str, InferenceParameters inferenceParameters) {
        return new String(getAnswer(str, inferenceParameters), StandardCharsets.UTF_8);
    }

    public String complete(String str, String str2) {
        return complete(str, str2, defaultInferenceParams);
    }

    public String complete(String str, String str2, InferenceParameters inferenceParameters) {
        return new String(getInfill(str, str2, inferenceParameters), StandardCharsets.UTF_8);
    }

    public Iterable<String> generate(String str) {
        return generate(str, defaultInferenceParams);
    }

    public Iterable<String> generate(String str, InferenceParameters inferenceParameters) {
        return () -> {
            return new LlamaIterator(str, inferenceParameters);
        };
    }

    public Iterable<String> generate(String str, String str2) {
        return generate(str, str2, defaultInferenceParams);
    }

    public Iterable<String> generate(String str, String str2, InferenceParameters inferenceParameters) {
        return () -> {
            return new LlamaIterator(str, str2, inferenceParameters);
        };
    }

    public native float[] embed(String str);

    public native int[] encode(String str);

    public String decode(int[] iArr) {
        return new String(decodeBytes(iArr), StandardCharsets.UTF_8);
    }

    public static native void setLogger(@Nullable BiConsumer<LogLevel, String> biConsumer);

    @Override // java.lang.AutoCloseable
    public void close() {
        delete();
    }

    private native void loadModel(String str, ModelParameters modelParameters) throws LlamaException;

    /* JADX INFO: Access modifiers changed from: private */
    public native void newAnswerIterator(String str, InferenceParameters inferenceParameters);

    /* JADX INFO: Access modifiers changed from: private */
    public native void newInfillIterator(String str, String str2, InferenceParameters inferenceParameters);

    private native byte[] getAnswer(String str, InferenceParameters inferenceParameters);

    private native byte[] getInfill(String str, String str2, InferenceParameters inferenceParameters);

    /* JADX INFO: Access modifiers changed from: private */
    public native byte[] getNext(LlamaIterator llamaIterator);

    private native byte[] decodeBytes(int[] iArr);

    private native void delete();

    static {
        LlamaLoader.initialize();
        defaultModelParams = new ModelParameters.Builder().build();
        defaultInferenceParams = new InferenceParameters.Builder().build();
    }
}
