package de.julielab.ml.embeddings.client;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import de.julielab.ml.embeddings.client.util.WordEmbeddingClientException;
import de.julielab.ml.embeddings.spi.EmbeddingVectors;
import de.julielab.ml.embeddings.spi.WordEmbedding;
import de.julielab.ml.embeddings.util.WordEmbeddingAccessException;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.UnsupportedEncodingException;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLEncoder;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cpu.nativecpu.NDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/julielab/ml/embeddings/client/WordEmbeddingClient.class */
public class WordEmbeddingClient implements WordEmbedding {
    private static final long serialVersionUID = 1976342456409693833L;
    private static final String UTF_8 = "UTF-8";
    public static final String GET_EMBEDDING = "get_embedding";
    public static final String GET_EMBEDDINGS = "get_embeddings";
    public static final String GET_EMBEDDINGS_MEAN = "get_embeddings_mean";
    public static final String GET_VOCAB_SIZE = "get_vocabulary_size";
    public static final String GET_WORD = "get_word";
    public static final String GET_HAS_WORD = "has_word";
    public static final String GET_EMBEDDING_DIMS = "get_embedding_dimensions";
    public static final String PARAM_WORD = "word";
    public static final String PARAM_INDEX = "index";
    private transient LoadingCache<String, double[]> vectorCache;
    private transient LoadingCache<WordListKey, EmbeddingVectors> vectorsCache;
    private transient LoadingCache<WordListKey, EmbeddingVectors> vectorsMeanCache;
    private String host;
    private int port;
    private static final Logger log = LoggerFactory.getLogger(WordEmbeddingClient.class);
    private static final double[] EMPTY = new double[0];

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/julielab/ml/embeddings/client/WordEmbeddingClient$WordListKey.class */
    public class WordListKey {
        private String words;
        private Collection<String> queryWords;

        public WordListKey(Collection<String> collection) {
            this.queryWords = collection;
            this.words = (String) collection.stream().collect(Collectors.joining());
        }

        public int hashCode() {
            return (31 * ((31 * 1) + getOuterType().hashCode())) + (this.words == null ? 0 : this.words.hashCode());
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            WordListKey wordListKey = (WordListKey) obj;
            if (getOuterType().equals(wordListKey.getOuterType())) {
                return this.words == null ? wordListKey.words == null : this.words.equals(wordListKey.words);
            }
            return false;
        }

        private WordEmbeddingClient getOuterType() {
            return WordEmbeddingClient.this;
        }

        public Collection<String> getQueryWords() {
            return this.queryWords;
        }
    }

    public WordEmbeddingClient(String str, int i) {
        this.host = str;
        this.port = i;
        setupCaches();
    }

    public void setupCaches() {
        this.vectorCache = CacheBuilder.newBuilder().maximumSize(10000L).expireAfterAccess(5L, TimeUnit.MINUTES).build(new CacheLoader<String, double[]>() { // from class: de.julielab.ml.embeddings.client.WordEmbeddingClient.1
            public double[] load(String str) throws Exception {
                return WordEmbeddingClient.this.loadWordVector(str);
            }
        });
        this.vectorsCache = CacheBuilder.newBuilder().maximumSize(10000L).expireAfterAccess(5L, TimeUnit.MINUTES).build(new CacheLoader<WordListKey, EmbeddingVectors>() { // from class: de.julielab.ml.embeddings.client.WordEmbeddingClient.2
            public EmbeddingVectors load(WordListKey wordListKey) throws Exception {
                return WordEmbeddingClient.this.loadWordVectors(wordListKey.getQueryWords() instanceof List ? (List) wordListKey.getQueryWords() : new ArrayList(wordListKey.getQueryWords()));
            }
        });
        this.vectorsMeanCache = CacheBuilder.newBuilder().maximumSize(10000L).expireAfterAccess(5L, TimeUnit.MINUTES).build(new CacheLoader<WordListKey, EmbeddingVectors>() { // from class: de.julielab.ml.embeddings.client.WordEmbeddingClient.3
            public EmbeddingVectors load(WordListKey wordListKey) throws Exception {
                return WordEmbeddingClient.this.loadWordVectorsMean(wordListKey.getQueryWords() instanceof List ? (List) wordListKey.getQueryWords() : new ArrayList(wordListKey.getQueryWords()));
            }
        });
    }

    private URL getUrl(String str, String str2) throws WordEmbeddingClientException {
        try {
            String str3 = "http://" + this.host + ":" + this.port + "/" + str;
            if (str2 != null) {
                str3 = str3 + "?" + str2;
            }
            return new URL(str3);
        } catch (MalformedURLException e) {
            throw new WordEmbeddingClientException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double[] loadWordVector(String str) {
        try {
            URL url = getUrl(GET_EMBEDDING, "word=" + URLEncoder.encode(str, UTF_8));
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", url.toString());
            }
            InputStream openStream = url.openStream();
            Throwable th = null;
            try {
                ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                byte[] bArr = new byte[128];
                int i = 0;
                while (true) {
                    int read = openStream.read(bArr);
                    if (read == -1) {
                        break;
                    }
                    byteArrayOutputStream.write(bArr, 0, read);
                    i += read;
                }
                log.trace("Read {} bytes", Integer.valueOf(i));
                ByteBuffer wrap = ByteBuffer.wrap(byteArrayOutputStream.toByteArray());
                if (wrap.capacity() == 1) {
                    double[] dArr = EMPTY;
                    if (openStream != null) {
                        if (0 != 0) {
                            try {
                                openStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            openStream.close();
                        }
                    }
                    return dArr;
                }
                double[] readEmbeddingVector = readEmbeddingVector(wrap);
                if (openStream != null) {
                    if (0 != 0) {
                        try {
                            openStream.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        openStream.close();
                    }
                }
                return readEmbeddingVector;
            } finally {
            }
        } catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException(e);
        }
    }

    public double[] getWordVector(String str) throws WordEmbeddingAccessException {
        try {
            return (double[]) this.vectorCache.get(str);
        } catch (ExecutionException e) {
            throw new WordEmbeddingAccessException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Type inference failed for: r0v40, types: [double[], double[][]] */
    public EmbeddingVectors loadWordVectors(List<String> list) {
        try {
            URL url = getUrl(GET_EMBEDDINGS, (String) list.stream().map(str -> {
                try {
                    return "word=" + URLEncoder.encode(str, UTF_8);
                } catch (UnsupportedEncodingException e) {
                    log.error("{}", e);
                    return null;
                }
            }).collect(Collectors.joining("&")));
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", url.toString());
            }
            InputStream openStream = url.openStream();
            Throwable th = null;
            try {
                try {
                    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                    byte[] bArr = new byte[128];
                    int i = 0;
                    while (true) {
                        int read = openStream.read(bArr);
                        if (read == -1) {
                            break;
                        }
                        byteArrayOutputStream.write(bArr, 0, read);
                        i += read;
                    }
                    log.trace("Read {} bytes", Integer.valueOf(i));
                    ByteBuffer wrap = ByteBuffer.wrap(byteArrayOutputStream.toByteArray());
                    int i2 = wrap.getInt();
                    log.trace("Got embedding vector dimensions {}", Integer.valueOf(i2));
                    List<Integer> readFoundIndices = readFoundIndices(list, wrap);
                    log.trace("Got {} found word indices: {}", Integer.valueOf(readFoundIndices.size()), readFoundIndices);
                    int size = readFoundIndices.size();
                    ?? r0 = new double[size];
                    for (int i3 = 0; i3 < size; i3++) {
                        r0[i3] = readEmbeddingVector(wrap);
                    }
                    EmbeddingVectors embeddingVectors = new EmbeddingVectors(size > 0 ? new NDArray((double[][]) r0) : null, list, readFoundIndices, i2, EmbeddingVectors.StreamType.CONCATENATION);
                    if (openStream != null) {
                        if (0 != 0) {
                            try {
                                openStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            openStream.close();
                        }
                    }
                    return embeddingVectors;
                } finally {
                }
            } catch (Throwable th3) {
                if (openStream != null) {
                    if (th != null) {
                        try {
                            openStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        openStream.close();
                    }
                }
                throw th3;
            }
        } catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException(e);
        }
    }

    public EmbeddingVectors getWordVectors(List<String> list) {
        try {
            return (EmbeddingVectors) this.vectorsCache.get(new WordListKey(list));
        } catch (ExecutionException e) {
            throw new WordEmbeddingAccessException(e);
        }
    }

    public List<Integer> readFoundIndices(List<String> list, ByteBuffer byteBuffer) {
        return (List) IntStream.range(0, list.size()).filter(i -> {
            return byteBuffer.get() == 1;
        }).mapToObj(Integer::new).collect(Collectors.toList());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public EmbeddingVectors loadWordVectorsMean(List<String> list) {
        try {
            URL url = getUrl(GET_EMBEDDINGS_MEAN, (String) list.stream().map(str -> {
                try {
                    return "word=" + URLEncoder.encode(str, UTF_8);
                } catch (UnsupportedEncodingException e) {
                    log.error("{}", e);
                    return null;
                }
            }).collect(Collectors.joining("&")));
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", url.toString());
            }
            InputStream openStream = url.openStream();
            Throwable th = null;
            try {
                ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                byte[] bArr = new byte[128];
                int i = 0;
                while (true) {
                    int read = openStream.read(bArr);
                    if (read == -1) {
                        break;
                    }
                    byteArrayOutputStream.write(bArr, 0, read);
                    i += read;
                }
                log.trace("Read {} bytes", Integer.valueOf(i));
                ByteBuffer wrap = ByteBuffer.wrap(byteArrayOutputStream.toByteArray());
                int i2 = wrap.getInt();
                List<Integer> readFoundIndices = readFoundIndices(list, wrap);
                INDArray iNDArray = null;
                if (!readFoundIndices.isEmpty()) {
                    iNDArray = Nd4j.create(readEmbeddingVector(wrap));
                }
                EmbeddingVectors embeddingVectors = new EmbeddingVectors(iNDArray, list, readFoundIndices, i2, EmbeddingVectors.StreamType.AGGREGATION);
                if (openStream != null) {
                    if (0 != 0) {
                        try {
                            openStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        openStream.close();
                    }
                }
                return embeddingVectors;
            } catch (Throwable th3) {
                if (openStream != null) {
                    if (0 != 0) {
                        try {
                            openStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        openStream.close();
                    }
                }
                throw th3;
            }
        } catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException(e);
        }
    }

    public EmbeddingVectors getWordVectorsMean(List<String> list) {
        try {
            return (EmbeddingVectors) this.vectorsMeanCache.get(new WordListKey(list));
        } catch (ExecutionException e) {
            throw new WordEmbeddingAccessException(e);
        }
    }

    public int getVocabularySize() {
        try {
            URL url = getUrl(GET_VOCAB_SIZE, null);
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", url.toString());
            }
            InputStream openStream = url.openStream();
            Throwable th = null;
            try {
                try {
                    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                    byte[] bArr = new byte[4];
                    int i = 0;
                    while (true) {
                        int read = openStream.read(bArr);
                        if (read == -1) {
                            break;
                        }
                        byteArrayOutputStream.write(bArr, 0, read);
                        i += read;
                    }
                    log.trace("Read {} bytes", Integer.valueOf(i));
                    int parseInt = Integer.parseInt(new String(byteArrayOutputStream.toByteArray(), Charset.forName(UTF_8)));
                    if (openStream != null) {
                        if (0 != 0) {
                            try {
                                openStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            openStream.close();
                        }
                    }
                    return parseInt;
                } finally {
                }
            } catch (Throwable th3) {
                if (openStream != null) {
                    if (th != null) {
                        try {
                            openStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        openStream.close();
                    }
                }
                throw th3;
            }
        } catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException(e);
        }
    }

    public String getWord(int i) {
        try {
            URL url = getUrl(GET_WORD, "index=" + i);
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", url.toString());
            }
            InputStream openStream = url.openStream();
            Throwable th = null;
            try {
                try {
                    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                    byte[] bArr = new byte[128];
                    int i2 = 0;
                    while (true) {
                        int read = openStream.read(bArr);
                        if (read == -1) {
                            break;
                        }
                        byteArrayOutputStream.write(bArr, 0, read);
                        i2 += read;
                    }
                    log.trace("Read {} bytes", Integer.valueOf(i2));
                    String str = new String(byteArrayOutputStream.toByteArray(), Charset.forName(UTF_8));
                    if (openStream != null) {
                        if (0 != 0) {
                            try {
                                openStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            openStream.close();
                        }
                    }
                    return str;
                } finally {
                }
            } finally {
            }
        } catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException(e);
        }
    }

    public boolean hasWord(String str) {
        try {
            URL url = getUrl(GET_HAS_WORD, "word=" + URLEncoder.encode(str, UTF_8));
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", url.toString());
            }
            InputStream openStream = url.openStream();
            Throwable th = null;
            try {
                try {
                    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                    byte[] bArr = new byte[1];
                    int i = 0;
                    while (true) {
                        int read = openStream.read(bArr);
                        if (read == -1) {
                            break;
                        }
                        byteArrayOutputStream.write(bArr, 0, read);
                        i += read;
                    }
                    log.trace("Read {} bytes", Integer.valueOf(i));
                    boolean parseBoolean = Boolean.parseBoolean(new String(byteArrayOutputStream.toByteArray(), Charset.forName(UTF_8)));
                    if (openStream != null) {
                        if (0 != 0) {
                            try {
                                openStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            openStream.close();
                        }
                    }
                    return parseBoolean;
                } finally {
                }
            } finally {
            }
        } catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException(e);
        }
    }

    public double[] readEmbeddingVector(ByteBuffer byteBuffer) {
        int i = byteBuffer.getInt() / 8;
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = byteBuffer.getDouble();
        }
        return dArr;
    }

    public int getEmbeddingDimensions() {
        try {
            URL url = getUrl(GET_EMBEDDING_DIMS, null);
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", url.toString());
            }
            InputStream openStream = url.openStream();
            Throwable th = null;
            try {
                try {
                    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                    byte[] bArr = new byte[4];
                    int i = 0;
                    while (true) {
                        int read = openStream.read(bArr);
                        if (read == -1) {
                            break;
                        }
                        byteArrayOutputStream.write(bArr, 0, read);
                        i += read;
                    }
                    log.trace("Read {} bytes", Integer.valueOf(i));
                    int parseInt = Integer.parseInt(new String(byteArrayOutputStream.toByteArray(), Charset.forName(UTF_8)));
                    if (openStream != null) {
                        if (0 != 0) {
                            try {
                                openStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            openStream.close();
                        }
                    }
                    return parseInt;
                } finally {
                }
            } catch (Throwable th3) {
                if (openStream != null) {
                    if (th != null) {
                        try {
                            openStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        openStream.close();
                    }
                }
                throw th3;
            }
        } catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException(e);
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        setupCaches();
    }
}
