package com.yahoo.language.wordpiece;

import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.process.Tokenizer;
import com.yahoo.language.simple.SimpleLinguistics;
import com.yahoo.language.tools.Embed;
import com.yahoo.language.wordpiece.WordPieceConfig;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
import java.nio.file.Path;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/* loaded from: input_file:com/yahoo/language/wordpiece/WordPieceEmbedder.class */
public class WordPieceEmbedder implements Embedder, Segmenter {
    private final Map<Language, Model> models;
    private final Tokenizer tokenizer;

    /* loaded from: input_file:com/yahoo/language/wordpiece/WordPieceEmbedder$Builder.class */
    public static final class Builder {
        private String subwordPrefix;
        private final Map<Language, Path> models;

        public Builder() {
            this.subwordPrefix = "##";
            this.models = new EnumMap(Language.class);
        }

        public Builder(String str) {
            this.subwordPrefix = "##";
            this.models = new EnumMap(Language.class);
            addDefaultModel(new File(str).toPath());
        }

        private Builder(WordPieceConfig wordPieceConfig) {
            this.subwordPrefix = "##";
            this.models = new EnumMap(Language.class);
            this.subwordPrefix = wordPieceConfig.subwordPrefix();
            for (WordPieceConfig.Model model : wordPieceConfig.model()) {
                addModel(Language.fromLanguageTag(model.language()), model.path());
            }
        }

        public Builder setSubwordPrefix(String str) {
            this.subwordPrefix = this.subwordPrefix;
            return this;
        }

        public String getSubwordPrefix() {
            return this.subwordPrefix;
        }

        public void addModel(Language language, Path path) {
            this.models.put(language, path);
        }

        public Builder addDefaultModel(Path path) {
            addModel(Language.UNKNOWN, path);
            return this;
        }

        public Map<Language, Path> getModels() {
            return this.models;
        }

        public WordPieceEmbedder build() {
            if (this.models.isEmpty()) {
                throw new IllegalStateException("At least one model must be supplied");
            }
            return new WordPieceEmbedder(this);
        }
    }

    @Inject
    public WordPieceEmbedder(WordPieceConfig wordPieceConfig) {
        this(new Builder(wordPieceConfig));
    }

    private WordPieceEmbedder(Builder builder) {
        this.tokenizer = new SimpleLinguistics().getTokenizer();
        this.models = (Map) builder.getModels().entrySet().stream().map(entry -> {
            return new Model(builder.getSubwordPrefix(), (Language) entry.getKey(), (Path) entry.getValue());
        }).collect(Collectors.toUnmodifiableMap(model -> {
            return model.language();
        }, model2 -> {
            return model2;
        }));
        if (this.models.isEmpty()) {
            throw new IllegalArgumentException("WordPieceEmbedder requires at least one model configured");
        }
    }

    public List<String> segment(String str, Language language) {
        return resolveModelFrom(language).segment(str, this.tokenizer);
    }

    public List<Integer> embed(String str, Embedder.Context context) {
        return resolveModelFrom(context.getLanguage()).embed(str, this.tokenizer);
    }

    public Tensor embed(String str, Embedder.Context context, TensorType tensorType) {
        return Embed.asTensor(str, this, context, tensorType);
    }

    private Model resolveModelFrom(Language language) {
        if (this.models.size() == 1 && this.models.containsKey(Language.UNKNOWN)) {
            return this.models.get(Language.UNKNOWN);
        }
        if (this.models.containsKey(language)) {
            return this.models.get(language);
        }
        throw new IllegalArgumentException("No WordPiece model for language " + language + " is configured");
    }
}
