package fi.evolver.ai.spring.prompt.template;

import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.util.Json;
import freemarker.core.Environment;
import freemarker.template.TemplateDirectiveBody;
import freemarker.template.TemplateDirectiveModel;
import freemarker.template.TemplateException;
import freemarker.template.TemplateModel;
import java.io.IOException;
import java.io.Writer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:fi/evolver/ai/spring/prompt/template/HistoryTag.class */
public class HistoryTag implements TemplateDirectiveModel {
    private final ThreadLocal<List<Message>> history = new ThreadLocal<>();

    /* loaded from: input_file:fi/evolver/ai/spring/prompt/template/HistoryTag$NonFailingAutoCloseable.class */
    public interface NonFailingAutoCloseable extends AutoCloseable {
        @Override // java.lang.AutoCloseable
        void close();
    }

    public NonFailingAutoCloseable setHistory(List<Message> list) {
        this.history.set(list);
        ThreadLocal<List<Message>> threadLocal = this.history;
        Objects.requireNonNull(threadLocal);
        return threadLocal::remove;
    }

    public void execute(Environment environment, Map map, TemplateModel[] templateModelArr, TemplateDirectiveBody templateDirectiveBody) throws TemplateException, IOException {
        List<Message> findMessages = findMessages(map);
        Writer out = environment.getOut();
        String str = (String) Optional.ofNullable(map.get("format")).map((v0) -> {
            return v0.toString();
        }).map((v0) -> {
            return v0.strip();
        }).map((v0) -> {
            return v0.toLowerCase();
        }).orElseThrow(() -> {
            return new IllegalArgumentException("Found <@history /> tag without format");
        });
        boolean z = -1;
        switch (str.hashCode()) {
            case 101429380:
                if (str.equals("jsonl")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                generateJsonl(out, findMessages);
                return;
            default:
                throw new TemplateException("Unsupported format: %s".formatted(str), environment);
        }
    }

    private List<Message> findMessages(Map map) {
        int intValue = ((Integer) Optional.ofNullable(map.get("count")).map((v0) -> {
            return v0.toString();
        }).map(Integer::parseInt).orElse(10000)).intValue();
        int intValue2 = ((Integer) Optional.ofNullable(map.get("skipFirst")).map((v0) -> {
            return v0.toString();
        }).map(Integer::parseInt).orElse(0)).intValue();
        int intValue3 = ((Integer) Optional.ofNullable(map.get("skipLast")).map((v0) -> {
            return v0.toString();
        }).map(Integer::parseInt).orElse(0)).intValue();
        return findMessages(this.history.get(), (String) Optional.ofNullable(map.get("roles")).map((v0) -> {
            return v0.toString();
        }).orElseThrow(() -> {
            return new IllegalArgumentException("Found <@history /> tag without roles");
        }), intValue, intValue2, intValue3);
    }

    private static void generateJsonl(Writer writer, List<Message> list) throws IOException {
        Iterator<Message> it = list.iterator();
        while (it.hasNext()) {
            Json.OBJECT_MAPPER.writeValue(writer, it.next());
            writer.write(10);
        }
    }

    public static List<Message> findMessages(List<Message> list, String str, int i, int i2, int i3) {
        Set<String> inferRoles = inferRoles(str);
        LinkedList linkedList = new LinkedList();
        ListIterator<Message> listIterator = list.listIterator(list.size());
        while (listIterator.hasPrevious() && linkedList.size() < i + i2) {
            Message previous = listIterator.previous();
            if (inferRoles.contains(previous.getRole()) && i3 <= 0) {
                linkedList.addFirst(previous);
            } else if (i3 > 0) {
                i3--;
            }
        }
        while (!linkedList.isEmpty() && i2 > 0) {
            linkedList.removeFirst();
            i2--;
        }
        return linkedList;
    }

    static Set<String> inferRoles(String str) {
        return (Set) Arrays.stream(str.split(",")).map((v0) -> {
            return v0.strip();
        }).filter(str2 -> {
            return !str2.isEmpty();
        }).collect(Collectors.toSet());
    }
}
