package au.csiro.pathling.fhirpath.function;

import au.csiro.pathling.errors.InvalidUserInputError;
import au.csiro.pathling.fhirpath.element.DecimalPath;
import au.csiro.pathling.fhirpath.element.ElementPath;
import au.csiro.pathling.fhirpath.element.IntegerPath;
import au.csiro.pathling.fhirpath.parser.ParserContext;
import au.csiro.pathling.test.SpringBootUnitTest;
import au.csiro.pathling.test.assertions.Assertions;
import au.csiro.pathling.test.builders.DatasetBuilder;
import au.csiro.pathling.test.builders.ElementPathBuilder;
import au.csiro.pathling.test.builders.ParserContextBuilder;
import ca.uhn.fhir.context.FhirContext;
import java.math.BigDecimal;
import java.util.Collections;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.hl7.fhir.r4.model.Enumerations;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;

@SpringBootUnitTest
/* loaded from: input_file:au/csiro/pathling/fhirpath/function/SumFunctionTest.class */
class SumFunctionTest {

    @Autowired
    SparkSession spark;

    @Autowired
    FhirContext fhirContext;
    ParserContext parserContext;

    SumFunctionTest() {
    }

    @Test
    void returnsCorrectIntegerResult() {
        ElementPath build = new ElementPathBuilder(this.spark).fhirType(Enumerations.FHIRDefinedType.INTEGER).dataset(new DatasetBuilder(this.spark).withIdColumn().withEidColumn().withColumn(DataTypes.IntegerType).withRow("observation-1", DatasetBuilder.makeEid(0), 3).withRow("observation-1", DatasetBuilder.makeEid(1), 5).withRow("observation-1", DatasetBuilder.makeEid(2), 7).withRow("observation-2", null, null).withRow("observation-3", DatasetBuilder.makeEid(0), -1).withRow("observation-3", DatasetBuilder.makeEid(1), null).build()).idAndEidAndValueColumns().expression("valueInteger").singular(false).build();
        this.parserContext = new ParserContextBuilder(this.spark, this.fhirContext).groupingColumns(Collections.singletonList(build.getIdColumn())).build();
        Assertions.assertThat(NamedFunction.getInstance("sum").invoke(new NamedFunctionInput(this.parserContext, build, Collections.emptyList()))).hasExpression("valueInteger.sum()").isSingular().isElementPath(IntegerPath.class).selectResult().hasRows(new DatasetBuilder(this.spark).withIdColumn().withColumn(DataTypes.IntegerType).withRow("observation-1", 15).withRow("observation-2", null).withRow("observation-3", -1).build());
    }

    @Test
    void returnsCorrectDecimalResult() {
        ElementPath build = new ElementPathBuilder(this.spark).fhirType(Enumerations.FHIRDefinedType.DECIMAL).dataset(new DatasetBuilder(this.spark).withIdColumn().withEidColumn().withColumn(DataTypes.createDecimalType()).withRow("observation-1", DatasetBuilder.makeEid(0), new BigDecimal("3.0")).withRow("observation-1", DatasetBuilder.makeEid(1), new BigDecimal("5.5")).withRow("observation-1", DatasetBuilder.makeEid(2), new BigDecimal("7")).withRow("observation-2", null, null).withRow("observation-3", DatasetBuilder.makeEid(0), new BigDecimal("-2.50")).build()).idAndEidAndValueColumns().expression("valueDecimal").singular(false).build();
        this.parserContext = new ParserContextBuilder(this.spark, this.fhirContext).groupingColumns(Collections.singletonList(build.getIdColumn())).build();
        Assertions.assertThat(NamedFunction.getInstance("sum").invoke(new NamedFunctionInput(this.parserContext, build, Collections.emptyList()))).hasExpression("valueDecimal.sum()").isSingular().isElementPath(DecimalPath.class).selectResult().hasRows(new DatasetBuilder(this.spark).withIdColumn().withColumn(DataTypes.createDecimalType()).withRow("observation-1", new BigDecimal("15.5")).withRow("observation-2", null).withRow("observation-3", new BigDecimal("-2.5")).build());
    }

    @Test
    void throwsErrorIfInputNotNumeric() {
        NamedFunctionInput namedFunctionInput = new NamedFunctionInput(this.parserContext, new ElementPathBuilder(this.spark).fhirType(Enumerations.FHIRDefinedType.STRING).expression("valueString").build(), Collections.emptyList());
        NamedFunction namedFunction = NamedFunction.getInstance("sum");
        org.junit.jupiter.api.Assertions.assertEquals("Input to sum function must be numeric: valueString", org.junit.jupiter.api.Assertions.assertThrows(InvalidUserInputError.class, () -> {
            namedFunction.invoke(namedFunctionInput);
        }).getMessage());
    }
}
