package org.apache.paimon.spark.extensions;

import java.math.BigDecimal;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.Arrays;
import org.apache.paimon.spark.catalyst.plans.logical.PaimonCallArgument;
import org.apache.paimon.spark.catalyst.plans.logical.PaimonCallStatement;
import org.apache.paimon.spark.catalyst.plans.logical.PaimonNamedArgument;
import org.apache.paimon.spark.catalyst.plans.logical.PaimonPositionalArgument;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.Literal$;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.apache.spark.sql.catalyst.parser.ParserInterface;
import org.apache.spark.sql.catalyst.parser.extensions.PaimonParseException;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import scala.Option;
import scala.collection.JavaConverters;

/* loaded from: input_file:org/apache/paimon/spark/extensions/CallStatementParserTest.class */
public class CallStatementParserTest {
    private SparkSession spark = null;
    private ParserInterface parser = null;

    @BeforeEach
    public void startSparkSession() {
        Option orElse = SparkSession.getActiveSession().orElse(SparkSession::getDefaultSession);
        if (!orElse.isEmpty()) {
            ((SparkSession) orElse.get()).stop();
        }
        SparkSession.clearActiveSession();
        this.spark = SparkSession.builder().master("local[2]").config("spark.sql.extensions", PaimonSparkSessionExtensions.class.getName()).getOrCreate();
        this.parser = this.spark.sessionState().sqlParser();
    }

    @AfterEach
    public void stopSparkSession() {
        if (this.spark != null) {
            this.spark.stop();
            this.spark = null;
            this.parser = null;
        }
    }

    @Test
    public void testCallWithNamedArguments() throws ParseException {
        PaimonCallStatement paimonCallStatement = (PaimonCallStatement) this.parser.parsePlan("CALL catalog.system.named_args_func(arg1 => 1, arg2 => 'test', arg3 => true)");
        Assertions.assertThat(JavaConverters.seqAsJavaList(paimonCallStatement.name())).isEqualTo(Arrays.asList("catalog", "system", "named_args_func"));
        Assertions.assertThat(paimonCallStatement.args().size()).isEqualTo(3);
        assertArgument(paimonCallStatement, 0, "arg1", 1, DataTypes.IntegerType);
        assertArgument(paimonCallStatement, 1, "arg2", "test", DataTypes.StringType);
        assertArgument(paimonCallStatement, 2, "arg3", true, DataTypes.BooleanType);
    }

    @Test
    public void testCallWithPositionalArguments() throws ParseException {
        PaimonCallStatement paimonCallStatement = (PaimonCallStatement) this.parser.parsePlan("CALL catalog.system.positional_args_func(1, '${spark.sql.extensions}', 2L, true, 3.0D, 4.0e1,500e-1BD, TIMESTAMP '2017-02-03T10:37:30.00Z')");
        Assertions.assertThat(JavaConverters.seqAsJavaList(paimonCallStatement.name())).isEqualTo(Arrays.asList("catalog", "system", "positional_args_func"));
        Assertions.assertThat(paimonCallStatement.args().size()).isEqualTo(8);
        assertArgument(paimonCallStatement, 0, 1, DataTypes.IntegerType);
        assertArgument(paimonCallStatement, 1, PaimonSparkSessionExtensions.class.getName(), DataTypes.StringType);
        assertArgument(paimonCallStatement, 2, 2L, DataTypes.LongType);
        assertArgument(paimonCallStatement, 3, true, DataTypes.BooleanType);
        assertArgument(paimonCallStatement, 4, Double.valueOf(3.0d), DataTypes.DoubleType);
        assertArgument(paimonCallStatement, 5, Double.valueOf(40.0d), DataTypes.DoubleType);
        assertArgument(paimonCallStatement, 6, new BigDecimal("500e-1"), DataTypes.createDecimalType(3, 1));
        assertArgument(paimonCallStatement, 7, Timestamp.from(Instant.parse("2017-02-03T10:37:30.00Z")), DataTypes.TimestampType);
    }

    @Test
    public void testCallWithMixedArguments() throws ParseException {
        PaimonCallStatement paimonCallStatement = (PaimonCallStatement) this.parser.parsePlan("CALL catalog.system.mixed_function(arg1 => 1, 'test')");
        Assertions.assertThat(JavaConverters.seqAsJavaList(paimonCallStatement.name())).isEqualTo(Arrays.asList("catalog", "system", "mixed_function"));
        Assertions.assertThat(paimonCallStatement.args().size()).isEqualTo(2);
        assertArgument(paimonCallStatement, 0, "arg1", 1, DataTypes.IntegerType);
        assertArgument(paimonCallStatement, 1, "test", DataTypes.StringType);
    }

    @Test
    public void testCallWithParseException() {
        Assertions.assertThatThrownBy(() -> {
            this.parser.parsePlan("CALL catalog.system func abc");
        }).isInstanceOf(PaimonParseException.class).hasMessageContaining("missing '(' at 'func'");
    }

    private void assertArgument(PaimonCallStatement paimonCallStatement, int i, Object obj, DataType dataType) {
        assertArgument(paimonCallStatement, i, null, obj, dataType);
    }

    private void assertArgument(PaimonCallStatement paimonCallStatement, int i, String str, Object obj, DataType dataType) {
        if (str == null) {
            assertCast((PaimonCallArgument) paimonCallStatement.args().apply(i), PaimonPositionalArgument.class);
        } else {
            Assertions.assertThat(((PaimonNamedArgument) assertCast(paimonCallStatement.args().apply(i), PaimonNamedArgument.class)).name()).isEqualTo(str);
        }
        Assertions.assertThat(((PaimonCallArgument) paimonCallStatement.args().apply(i)).expr()).isEqualTo(Literal$.MODULE$.create(obj, dataType));
    }

    private <T> T assertCast(Object obj, Class<T> cls) {
        Assertions.assertThat(obj).isInstanceOf(cls);
        return cls.cast(obj);
    }
}
