package com.yahoo.language.huggingface;

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.huggingface.tokenizers.jni.LibUtils;
import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.io.IOUtils;
import com.yahoo.language.Language;
import com.yahoo.language.huggingface.ModelInfo;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.tools.Embed;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.yolean.Exceptions;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

@Beta
/* loaded from: input_file:com/yahoo/language/huggingface/HuggingFaceTokenizer.class */
public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable {
    private final Path tmpDirectory;
    private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models;

    /* renamed from: com.yahoo.language.huggingface.HuggingFaceTokenizer$1, reason: invalid class name */
    /* loaded from: input_file:com/yahoo/language/huggingface/HuggingFaceTokenizer$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$yahoo$language$huggingface$config$HuggingFaceTokenizerConfig$Truncation$Enum;
        static final /* synthetic */ int[] $SwitchMap$com$yahoo$language$huggingface$config$HuggingFaceTokenizerConfig$Padding$Enum = new int[HuggingFaceTokenizerConfig.Padding.Enum.values().length];

        static {
            try {
                $SwitchMap$com$yahoo$language$huggingface$config$HuggingFaceTokenizerConfig$Padding$Enum[HuggingFaceTokenizerConfig.Padding.Enum.ON.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$yahoo$language$huggingface$config$HuggingFaceTokenizerConfig$Padding$Enum[HuggingFaceTokenizerConfig.Padding.Enum.OFF.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            $SwitchMap$com$yahoo$language$huggingface$config$HuggingFaceTokenizerConfig$Truncation$Enum = new int[HuggingFaceTokenizerConfig.Truncation.Enum.values().length];
            try {
                $SwitchMap$com$yahoo$language$huggingface$config$HuggingFaceTokenizerConfig$Truncation$Enum[HuggingFaceTokenizerConfig.Truncation.Enum.ON.ordinal()] = 1;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$yahoo$language$huggingface$config$HuggingFaceTokenizerConfig$Truncation$Enum[HuggingFaceTokenizerConfig.Truncation.Enum.OFF.ordinal()] = 2;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:com/yahoo/language/huggingface/HuggingFaceTokenizer$Builder.class */
    public static final class Builder {
        private final Map<Language, Path> models = new EnumMap(Language.class);
        private Boolean addSpecialTokens;
        private Integer maxLength;
        private Boolean truncation;
        private Boolean padding;

        public Builder() {
        }

        public Builder(HuggingFaceTokenizerConfig huggingFaceTokenizerConfig) {
            for (HuggingFaceTokenizerConfig.Model model : huggingFaceTokenizerConfig.model()) {
                addModel(Language.fromLanguageTag(model.language()), model.path());
            }
            addSpecialTokens(huggingFaceTokenizerConfig.addSpecialTokens());
            if (huggingFaceTokenizerConfig.maxLength() != -1) {
                setMaxLength(huggingFaceTokenizerConfig.maxLength());
            }
            switch (AnonymousClass1.$SwitchMap$com$yahoo$language$huggingface$config$HuggingFaceTokenizerConfig$Truncation$Enum[huggingFaceTokenizerConfig.truncation().ordinal()]) {
                case 1:
                    setTruncation(true);
                    break;
                case 2:
                    setTruncation(false);
                    break;
            }
            switch (AnonymousClass1.$SwitchMap$com$yahoo$language$huggingface$config$HuggingFaceTokenizerConfig$Padding$Enum[huggingFaceTokenizerConfig.padding().ordinal()]) {
                case 1:
                    setPadding(true);
                    return;
                case 2:
                    setPadding(false);
                    return;
                default:
                    return;
            }
        }

        public Builder addModel(Language language, Path path) {
            this.models.put(language, path);
            return this;
        }

        public Builder addDefaultModel(Path path) {
            return addModel(Language.UNKNOWN, path);
        }

        public Builder addSpecialTokens(boolean z) {
            this.addSpecialTokens = Boolean.valueOf(z);
            return this;
        }

        public Builder setMaxLength(int i) {
            this.maxLength = Integer.valueOf(i);
            return this;
        }

        public Builder setTruncation(boolean z) {
            this.truncation = Boolean.valueOf(z);
            return this;
        }

        public Builder setPadding(boolean z) {
            this.padding = Boolean.valueOf(z);
            return this;
        }

        public HuggingFaceTokenizer build() {
            return new HuggingFaceTokenizer(this);
        }
    }

    @Inject
    public HuggingFaceTokenizer(HuggingFaceTokenizerConfig huggingFaceTokenizerConfig) {
        this(new Builder(huggingFaceTokenizerConfig));
    }

    private HuggingFaceTokenizer(Builder builder) {
        this.tmpDirectory = (Path) Exceptions.uncheck(() -> {
            return Files.createTempDirectory("hf-tokenizer-", new FileAttribute[0]);
        });
        this.models = (Map) withContextClassloader(() -> {
            EnumMap enumMap = new EnumMap(Language.class);
            builder.models.forEach((language, path) -> {
                enumMap.put((EnumMap) language, (Language) Exceptions.uncheck(() -> {
                    Path createDirectory;
                    if (Files.isDirectory(path, new LinkOption[0])) {
                        createDirectory = path;
                    } else {
                        createDirectory = Files.createDirectory(this.tmpDirectory.resolve(language.languageCode()), new FileAttribute[0]);
                        Files.copy(path, createDirectory.resolve("tokenizer.json"), new CopyOption[0]);
                    }
                    HuggingFaceTokenizer.Builder optAddSpecialTokens = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder().optTokenizerPath(createDirectory).optAddSpecialTokens(builder.addSpecialTokens != null ? builder.addSpecialTokens.booleanValue() : true);
                    if (builder.maxLength != null) {
                        optAddSpecialTokens.optMaxLength(builder.maxLength.intValue());
                        optAddSpecialTokens.configure(Map.of("modelMaxLength", Integer.valueOf(builder.maxLength.intValue() > 0 ? builder.maxLength.intValue() : Integer.MAX_VALUE)));
                    }
                    if (builder.padding != null) {
                        if (builder.padding.booleanValue()) {
                            optAddSpecialTokens.optPadToMaxLength();
                        } else {
                            optAddSpecialTokens.optPadding(false);
                        }
                    }
                    if (builder.truncation != null) {
                        optAddSpecialTokens.optTruncation(builder.truncation.booleanValue());
                    }
                    return optAddSpecialTokens.build();
                }));
            });
            return enumMap;
        });
    }

    public List<Integer> embed(String str, Embedder.Context context) {
        return Arrays.stream(resolve(context.getLanguage()).encode(str).getIds()).mapToInt(Math::toIntExact).boxed().toList();
    }

    public Tensor embed(String str, Embedder.Context context, TensorType tensorType) {
        return Embed.asTensor(str, this, context, tensorType);
    }

    public List<String> segment(String str, Language language) {
        return List.of((Object[]) resolve(language).encode(str).getTokens());
    }

    public String decode(List<Integer> list, Embedder.Context context) {
        return resolve(context.getLanguage()).decode(toArray(list));
    }

    public Encoding encode(String str) {
        return encode(str, Language.UNKNOWN);
    }

    public Encoding encode(String str, Language language) {
        return Encoding.from(resolve(language).encode(str));
    }

    public String decode(long[] jArr) {
        return decode(jArr, Language.UNKNOWN);
    }

    public String decode(long[] jArr, Language language) {
        return resolve(language).decode(jArr);
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.models.forEach((language, huggingFaceTokenizer) -> {
            huggingFaceTokenizer.close();
        });
        IOUtils.recursiveDeleteDir(this.tmpDirectory.toFile());
    }

    public void deconstruct() {
        close();
    }

    public static ModelInfo getModelInfo(Path path) {
        return (ModelInfo) withContextClassloader(() -> {
            LibUtils.checkStatus();
            long createTokenizerFromString = TokenizersLibrary.LIB.createTokenizerFromString((String) Exceptions.uncheck(() -> {
                return Files.readString(path);
            }));
            try {
                ModelInfo modelInfo = new ModelInfo(ModelInfo.TruncationStrategy.fromString(TokenizersLibrary.LIB.getTruncationStrategy(createTokenizerFromString)), ModelInfo.PaddingStrategy.fromString(TokenizersLibrary.LIB.getPaddingStrategy(createTokenizerFromString)), TokenizersLibrary.LIB.getMaxLength(createTokenizerFromString), TokenizersLibrary.LIB.getStride(createTokenizerFromString), TokenizersLibrary.LIB.getPadToMultipleOf(createTokenizerFromString));
                TokenizersLibrary.LIB.deleteTokenizer(createTokenizerFromString);
                return modelInfo;
            } catch (Throwable th) {
                TokenizersLibrary.LIB.deleteTokenizer(createTokenizerFromString);
                throw th;
            }
        });
    }

    private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) {
        if (this.models.size() == 1 && this.models.containsKey(Language.UNKNOWN)) {
            return this.models.get(Language.UNKNOWN);
        }
        if (this.models.containsKey(language)) {
            return this.models.get(language);
        }
        throw new IllegalArgumentException("No model for language " + language);
    }

    private static <R> R withContextClassloader(Supplier<R> supplier) {
        ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
        Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
        try {
            R r = supplier.get();
            Thread.currentThread().setContextClassLoader(contextClassLoader);
            return r;
        } catch (Throwable th) {
            Thread.currentThread().setContextClassLoader(contextClassLoader);
            throw th;
        }
    }

    private static long[] toArray(Collection<? extends Number> collection) {
        return collection.stream().mapToLong((v0) -> {
            return v0.longValue();
        }).toArray();
    }

    static {
        System.setProperty("OPT_OUT_TRACKING", "true");
    }
}
