package fi.evolver.ai.spring.embedding;

import fi.evolver.ai.spring.ApiConnectionParameters;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.embedding.entity.EmbeddingVector;
import fi.evolver.basics.spring.lock.LockException;
import fi.evolver.basics.spring.lock.LockHandle;
import fi.evolver.basics.spring.lock.LockService;
import java.time.Duration;
import java.time.OffsetDateTime;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.codec.digest.DigestUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:fi/evolver/ai/spring/embedding/EmbeddingVectors.class */
public class EmbeddingVectors {
    private static final Logger LOG = LoggerFactory.getLogger(EmbeddingVectors.class);
    public static final ApiConnectionParameters PARAMETERS_DEFAULT = new ApiConnectionParameters(null, Duration.ofSeconds(5), Duration.ofSeconds(30), Collections.emptyMap());
    private static final String UPDATE_LOCK_NAME = EmbeddingVectors.class.getSimpleName() + "_UpdateLock";
    private final EmbeddingVectorApi embeddingVectorApi;
    private final EmbeddingVectorRepository embeddingVectorRepository;
    private final LockService lockService;
    private final Model model;
    private final Map<String, EmbeddingVector> memCache = new ConcurrentHashMap();
    private final Map<String, OffsetDateTime> persistedTimestampsByHash = new ConcurrentHashMap();
    private final ApiConnectionParameters parameters;

    public EmbeddingVectors(EmbeddingVectorApi embeddingVectorApi, EmbeddingVectorRepository embeddingVectorRepository, LockService lockService, Model model, ApiConnectionParameters apiConnectionParameters) {
        this.embeddingVectorApi = embeddingVectorApi;
        this.embeddingVectorRepository = embeddingVectorRepository;
        this.lockService = lockService;
        this.model = model;
        this.parameters = apiConnectionParameters;
    }

    public boolean persist(Duration duration) {
        try {
            LockHandle takeLock = this.lockService.takeLock(UPDATE_LOCK_NAME, 60000, 120000);
            try {
                ArrayList arrayList = new ArrayList();
                ArrayList arrayList2 = new ArrayList();
                for (EmbeddingVector embeddingVector : this.memCache.values()) {
                    OffsetDateTime offsetDateTime = this.persistedTimestampsByHash.get(embeddingVector.getHash());
                    if (offsetDateTime == null) {
                        arrayList.add(embeddingVector);
                    } else if (offsetDateTime.plus((TemporalAmount) duration).isBefore(embeddingVector.getLastAccessed())) {
                        arrayList2.add(embeddingVector);
                    }
                }
                this.embeddingVectorRepository.persistChanges(arrayList, arrayList2);
                arrayList.forEach(embeddingVector2 -> {
                    this.persistedTimestampsByHash.put(embeddingVector2.getHash(), embeddingVector2.getLastAccessed());
                });
                arrayList2.forEach(embeddingVector3 -> {
                    this.persistedTimestampsByHash.put(embeddingVector3.getHash(), embeddingVector3.getLastAccessed());
                });
                if (takeLock != null) {
                    takeLock.close();
                }
                return true;
            } finally {
            }
        } catch (LockException e) {
            LOG.info("Unable to get lock for embedding vector update");
            return false;
        }
    }

    public boolean clearStale(OffsetDateTime offsetDateTime) {
        try {
            LockHandle takeLock = this.lockService.takeLock(UPDATE_LOCK_NAME, 60000, 120000);
            try {
                this.embeddingVectorRepository.deleteStaleData(offsetDateTime);
                this.memCache.entrySet().removeIf(entry -> {
                    return ((EmbeddingVector) entry.getValue()).getLastAccessed().isBefore(offsetDateTime);
                });
                if (takeLock != null) {
                    takeLock.close();
                }
                return true;
            } finally {
            }
        } catch (LockException e) {
            LOG.info("Unable to get lock for embedding vector clean-up");
            return false;
        }
    }

    public double[] getEmbedding(String str) {
        return getEmbeddings(List.of(str)).get(0);
    }

    public Map<String, double[]> getEmbeddings(Collection<String> collection) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (String str : collection) {
            linkedHashMap.put(str, getCachedVectorOrNull(calculateHash(str)));
        }
        List<String> list = linkedHashMap.entrySet().stream().filter(entry -> {
            return entry.getValue() == null;
        }).map((v0) -> {
            return v0.getKey();
        }).toList();
        if (!list.isEmpty()) {
            linkedHashMap.putAll(createNewEntities(list));
        }
        return linkedHashMap;
    }

    private Map<String, double[]> createNewEntities(List<String> list) {
        List<double[]> createEmbeddingVectorsInBatches = this.embeddingVectorApi.createEmbeddingVectorsInBatches(this.model, list, this.parameters);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            String calculateHash = calculateHash(list.get(i));
            EmbeddingVector embeddingVector = new EmbeddingVector(this.model, calculateHash, createEmbeddingVectorsInBatches.get(i));
            this.memCache.put(calculateHash, embeddingVector);
            hashMap.put(list.get(i), embeddingVector.getVector());
        }
        return hashMap;
    }

    private static String calculateHash(String str) {
        return DigestUtils.sha256Hex(str);
    }

    private double[] getCachedVectorOrNull(String str) {
        return (double[]) getCachedVector(str).map((v0) -> {
            return v0.getVector();
        }).orElse(null);
    }

    private Optional<EmbeddingVector> getCachedVector(String str) {
        Optional<EmbeddingVector> ofNullable = Optional.ofNullable(this.memCache.computeIfAbsent(str, this::findPersistedVector));
        ofNullable.ifPresent(embeddingVector -> {
            embeddingVector.setLastAccessed(OffsetDateTime.now());
        });
        return ofNullable;
    }

    private EmbeddingVector findPersistedVector(String str) {
        Optional<EmbeddingVector> findByModelAndHash = this.embeddingVectorRepository.findByModelAndHash(this.model, str);
        findByModelAndHash.ifPresent(embeddingVector -> {
            this.persistedTimestampsByHash.put(embeddingVector.getHash(), embeddingVector.getLastAccessed());
        });
        return findByModelAndHash.orElse(null);
    }
}
