package com.yahoo.vespa.model.application.validation.change;

import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.model.api.ApplicationClusterEndpoint;
import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.model.api.ContainerEndpoint;
import com.yahoo.config.model.api.OnnxModelCost;
import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.config.model.provision.InMemoryProvisioner;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.application.validation.ValidationTester;
import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithMockPkg;
import java.net.URI;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.class */
public class RestartOnDeployForOnnxModelChangesValidatorTest {
    private static final long defaultCost = 635241309;
    private static final long defaultHash = 0;

    @Test
    void validate_no_changes() {
        Assertions.assertEquals(0, validateModel(createModel(), createModel()).size());
    }

    @Test
    void validate_changed_estimated_cost() {
        List<ConfigChangeAction> validateModel = validateModel(createModel(onnxModelCost(defaultCost, defaultHash)), createModel(onnxModelCost(603479243L, defaultHash)));
        Assertions.assertEquals(1, validateModel.size());
        Assertions.assertTrue(validateModel.get(0).validationId().isEmpty());
        Assertions.assertEquals("Onnx model 'https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx' has changed (estimated cost), need to restart services in container cluster 'cluster1'", validateModel.get(0).getMessage());
    }

    @Test
    void validate_changed_estimated_cost_non_hosted() {
        Assertions.assertEquals(0, validateModel(createModel(onnxModelCost(defaultCost, defaultHash), false), createModel(onnxModelCost(603479243L, defaultHash), false), false).size());
    }

    @Test
    void validate_changed_hash() {
        List<ConfigChangeAction> validateModel = validateModel(createModel(), createModel(onnxModelCost(defaultCost, 123L)));
        Assertions.assertEquals(1, validateModel.size());
        assertStartsWith("Onnx model 'https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx' has changed (model hash)", validateModel);
    }

    @Test
    void validate_changed_option() {
        List<ConfigChangeAction> validateModel = validateModel(createModel(), createModel(onnxModelCost(), true, "sequential"));
        Assertions.assertEquals(1, validateModel.size());
        assertStartsWith("Onnx model 'https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx' has changed (model option(s))", validateModel);
    }

    @Test
    void validate_changed_model_set() {
        List<ConfigChangeAction> validateModel = validateModel(createModel(), createModel(onnxModelCost(), true, "parallel", "e5-small-v2"));
        Assertions.assertEquals(1, validateModel.size());
        assertStartsWith("Onnx model set has changed from [https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx] to [https://data.vespa.oath.cloud/onnx_models/e5-small-v2/model.onnx", validateModel);
    }

    private static List<ConfigChangeAction> validateModel(VespaModel vespaModel, VespaModel vespaModel2) {
        return validateModel(vespaModel, vespaModel2, true);
    }

    private static List<ConfigChangeAction> validateModel(VespaModel vespaModel, VespaModel vespaModel2, boolean z) {
        return ValidationTester.validateChanges(new RestartOnDeployForOnnxModelChangesValidator(), vespaModel2, deployStateBuilder(z).previousModel(vespaModel).build());
    }

    private static OnnxModelCost onnxModelCost() {
        return onnxModelCost(defaultCost, defaultHash);
    }

    private static OnnxModelCost onnxModelCost(long j, long j2) {
        return (applicationPackage, applicationId, id) -> {
            return new OnnxModelCost.Calculator() { // from class: com.yahoo.vespa.model.application.validation.change.RestartOnDeployForOnnxModelChangesValidatorTest.1
                private final Map<String, OnnxModelCost.ModelInfo> models = new HashMap();
                private boolean restartOnDeploy = false;

                public long aggregatedModelCostInBytes() {
                    return j;
                }

                public void registerModel(ApplicationFile applicationFile, OnnxModelOptions onnxModelOptions) {
                }

                public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {
                    this.models.put(uri.toString(), new OnnxModelCost.ModelInfo(uri.toString(), j, j2, onnxModelOptions));
                }

                public Map<String, OnnxModelCost.ModelInfo> models() {
                    return this.models;
                }

                public void setRestartOnDeploy() {
                    this.restartOnDeploy = true;
                }

                public boolean restartOnDeploy() {
                    return this.restartOnDeploy;
                }

                public void store() {
                }
            };
        };
    }

    private static VespaModel createModel() {
        return createModel(onnxModelCost(), true);
    }

    private static VespaModel createModel(OnnxModelCost onnxModelCost) {
        return createModel(onnxModelCost, true, "parallel");
    }

    private static VespaModel createModel(OnnxModelCost onnxModelCost, boolean z) {
        return createModel(onnxModelCost, z, "parallel");
    }

    private static VespaModel createModel(OnnxModelCost onnxModelCost, boolean z, String str) {
        return createModel(onnxModelCost, z, str, "e5-base-v2");
    }

    private static VespaModel createModel(OnnxModelCost onnxModelCost, boolean z, String str, String str2) {
        DeployState.Builder onnxModelCost2 = deployStateBuilder(z).onnxModelCost(onnxModelCost);
        return z ? hostedModel(onnxModelCost2, str, str2) : nonHostedModel(onnxModelCost2, str, str2);
    }

    private static VespaModel hostedModel(DeployState.Builder builder, String str, String str2) {
        return new VespaModelCreatorWithMockPkg(new MockApplicationPackage.Builder().withServices("<services version='1.0'>\n  <container id='cluster1' version='1.0'>\n    <component id=\"hf-embedder\" type=\"hugging-face-embedder\">\n                 <transformer-model model-id=\"%s\" url=\"https://my/url/%s.onnx\"/>\n                 <tokenizer-model model-id=\"e5-base-v2-vocab\" path=\"app/tokenizer.json\"/>\n                 <onnx-execution-mode>%s</onnx-execution-mode>\n    </component>\n    <nodes count='1'>\n      <resources vcpu='1' memory='2Gb' disk='25Gb'/>\n    </nodes>\n  </container>\n</services>\n".formatted(str2, str2, str)).withDeploymentSpec("<deployment version='1.0' empty-host-ttl='1d'>\n  <instance id='default'>\n    <prod>\n      <region>us-east-1</region>\n      <region empty-host-ttl='0m'>us-north-1</region>\n      <region>us-west-1</region>\n    </prod>\n  </instance>\n</deployment>\n").build()).create(builder);
    }

    private static VespaModel nonHostedModel(DeployState.Builder builder, String str, String str2) {
        return new VespaModelCreatorWithMockPkg(null, "                       <services version='1.0'>\n                         <container id='cluster1' version='1.0'>\n                           <http>\n                             <server id='server1' port='8080'/>\n                           </http>\n                         <component id=\"hf-embedder\" type=\"hugging-face-embedder\">\n                           <transformer-model model-id=\"%s\" url=\"https://my/url/%s.onnx\"/>\n                           <tokenizer-model model-id=\"e5-base-v2-vocab\" path=\"app/tokenizer.json\"/>\n                           <onnx-execution-mode>%s</onnx-execution-mode>\n                         </component>\n                         </container>\n                       </services>\n".formatted(str2, str2, str)).create(builder);
    }

    private static DeployState.Builder deployStateBuilder(boolean z) {
        DeployState.Builder properties = new DeployState.Builder().properties(new TestProperties().setHostedVespa(z));
        if (z) {
            properties.endpoints(Set.of(new ContainerEndpoint("cluster1", ApplicationClusterEndpoint.Scope.zone, List.of("tc.example.com")))).modelHostProvisioner(new InMemoryProvisioner(5, new NodeResources(1.0d, 2.0d, 25.0d, 0.3d), true));
        }
        return properties;
    }

    private static void assertStartsWith(String str, List<ConfigChangeAction> list) {
        Assertions.assertTrue(list.get(0).getMessage().startsWith(str));
    }
}
