package com.yahoo.language.sentencepiece;

import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.sentencepiece.Model;
import com.yahoo.language.sentencepiece.SentencePieceAlgorithm;
import com.yahoo.language.sentencepiece.SentencePieceConfig;
import com.yahoo.language.tools.Embed;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Beta
/* loaded from: input_file:com/yahoo/language/sentencepiece/SentencePieceEmbedder.class */
public class SentencePieceEmbedder implements Segmenter, Embedder {
    private final Map<Language, Model> models;
    private final SentencePieceAlgorithm algorithm;

    /* loaded from: input_file:com/yahoo/language/sentencepiece/SentencePieceEmbedder$Builder.class */
    public static final class Builder {
        private final Map<Language, Path> models;
        private boolean collapseUnknowns;
        private Scoring scoring;

        public Builder() {
            this.models = new EnumMap(Language.class);
            this.collapseUnknowns = true;
            this.scoring = Scoring.fewestSegments;
        }

        public Builder(String str) {
            this.models = new EnumMap(Language.class);
            this.collapseUnknowns = true;
            this.scoring = Scoring.fewestSegments;
            addDefaultModel(new File(str).toPath());
        }

        private Builder(SentencePieceConfig sentencePieceConfig) {
            this.models = new EnumMap(Language.class);
            this.collapseUnknowns = true;
            this.scoring = Scoring.fewestSegments;
            this.collapseUnknowns = sentencePieceConfig.collapseUnknowns();
            this.scoring = sentencePieceConfig.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments : Scoring.highestScore;
            for (SentencePieceConfig.Model model : sentencePieceConfig.model()) {
                addModel(Language.fromLanguageTag(model.language()), model.path());
            }
        }

        public void addModel(Language language, Path path) {
            this.models.put(language, path);
        }

        public Builder addDefaultModel(Path path) {
            addModel(Language.UNKNOWN, path);
            return this;
        }

        public Map<Language, Path> getModels() {
            return this.models;
        }

        public Builder setCollapseUnknowns(boolean z) {
            this.collapseUnknowns = z;
            return this;
        }

        public boolean getCollapseUnknowns() {
            return this.collapseUnknowns;
        }

        public Builder setScoring(Scoring scoring) {
            this.scoring = scoring;
            return this;
        }

        public Scoring getScoring() {
            return this.scoring;
        }

        public SentencePieceEmbedder build() {
            if (this.models.isEmpty()) {
                throw new IllegalStateException("At least one model must be supplied");
            }
            return new SentencePieceEmbedder(this);
        }
    }

    @Inject
    public SentencePieceEmbedder(SentencePieceConfig sentencePieceConfig) {
        this(new Builder(sentencePieceConfig));
    }

    public SentencePieceEmbedder(Builder builder) {
        this.algorithm = new SentencePieceAlgorithm(builder.getCollapseUnknowns(), builder.getScoring());
        this.models = (Map) builder.getModels().entrySet().stream().map(entry -> {
            return new Model((Language) entry.getKey(), (Path) entry.getValue());
        }).collect(Collectors.toUnmodifiableMap(model -> {
            return model.language;
        }, model2 -> {
            return model2;
        }));
        if (this.models.isEmpty()) {
            throw new IllegalArgumentException("SentencePieceEmbedder requires at least one model configured");
        }
    }

    public List<String> segment(String str, Language language) {
        final String normalize = normalize(str);
        ResultBuilder<List<String>> resultBuilder = new ResultBuilder<List<String>>(new ArrayList()) { // from class: com.yahoo.language.sentencepiece.SentencePieceEmbedder.1
            @Override // com.yahoo.language.sentencepiece.ResultBuilder
            public void add(int i, int i2, SentencePieceAlgorithm.SegmentEnd[] segmentEndArr) {
                result().add(normalize.substring(i, i2));
            }
        };
        segment(normalize, language, resultBuilder);
        Collections.reverse(resultBuilder.result());
        return resultBuilder.result();
    }

    public List<Integer> embed(String str, Embedder.Context context) {
        ResultBuilder<List<Integer>> resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList()) { // from class: com.yahoo.language.sentencepiece.SentencePieceEmbedder.2
            @Override // com.yahoo.language.sentencepiece.ResultBuilder
            public void add(int i, int i2, SentencePieceAlgorithm.SegmentEnd[] segmentEndArr) {
                result().add(Integer.valueOf(segmentEndArr[i2].id));
            }
        };
        segment(normalize(str), context.getLanguage(), resultBuilder);
        Collections.reverse(resultBuilder.result());
        return resultBuilder.result();
    }

    public String decode(List<Integer> list, Embedder.Context context) {
        return decode(list, context, false);
    }

    public String decode(List<Integer> list, Embedder.Context context, boolean z) {
        Model resolveModelFrom = resolveModelFrom(context.getLanguage());
        StringBuilder sb = new StringBuilder();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            Model.Token token = resolveModelFrom.tokenId2Token.get(it.next());
            if (!(z && token.type() == TokenType.control)) {
                sb.append(token.text());
            }
        }
        return denormalize(sb.toString());
    }

    public Tensor embed(String str, Embedder.Context context, TensorType tensorType) {
        return Embed.asTensor(str, this, context, tensorType);
    }

    private <RESULTTYPE> void segment(String str, Language language, ResultBuilder<RESULTTYPE> resultBuilder) {
        this.algorithm.segment(str, resultBuilder, resolveModelFrom(language));
    }

    private Model resolveModelFrom(Language language) {
        if (this.models.size() == 1 && this.models.containsKey(Language.UNKNOWN)) {
            return this.models.get(Language.UNKNOWN);
        }
        if (this.models.containsKey(language)) {
            return this.models.get(language);
        }
        throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured");
    }

    public String normalize(String str) {
        StringBuilder sb = new StringBuilder(str.length() + 1);
        boolean z = true;
        for (int i = 0; i < str.length(); i++) {
            char charAt = str.charAt(i);
            if (str.charAt(i) == ' ') {
                z = true;
            } else {
                if (z) {
                    sb.append((char) 9601);
                    z = false;
                }
                sb.append(charAt);
            }
        }
        return sb.toString();
    }

    public String denormalize(String str) {
        String replace = str.replace((char) 9601, ' ');
        return replace.charAt(0) == ' ' ? replace.substring(1) : replace;
    }
}
