package com.linkedin.transport.test.hive;

import com.google.common.base.Preconditions;
import com.linkedin.transport.api.StdFactory;
import com.linkedin.transport.api.udf.StdUDF;
import com.linkedin.transport.api.udf.TopLevelStdUDF;
import com.linkedin.transport.hive.HiveFactory;
import com.linkedin.transport.hive.typesystem.HiveBoundVariables;
import com.linkedin.transport.test.hive.udf.MapFromEntriesWrapper;
import com.linkedin.transport.test.spi.SqlFunctionCallGenerator;
import com.linkedin.transport.test.spi.SqlStdTester;
import com.linkedin.transport.test.spi.ToPlatformTestOutputConverter;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.FunctionInfo;
import org.apache.hadoop.hive.ql.exec.Registry;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hive.service.cli.CLIService;
import org.apache.hive.service.cli.ColumnDescriptor;
import org.apache.hive.service.cli.HiveSQLException;
import org.apache.hive.service.cli.OperationHandle;
import org.apache.hive.service.cli.RowSet;
import org.apache.hive.service.cli.SessionHandle;
import org.apache.hive.service.server.HiveServer2;
import org.testng.Assert;

/* loaded from: input_file:com/linkedin/transport/test/hive/HiveTester.class */
public class HiveTester implements SqlStdTester {
    private CLIService _client;
    private SessionHandle _sessionHandle;
    private Registry _functionRegistry;
    private Method _functionRegistryAddFunctionMethod;
    private StdFactory _stdFactory = new HiveFactory(new HiveBoundVariables());
    private SqlFunctionCallGenerator _sqlFunctionCallGenerator = new HiveSqlFunctionCallGenerator();
    private ToPlatformTestOutputConverter _platformOutputDataConverter = new ToHiveTestOutputConverter();

    public HiveTester() {
        createHiveServer();
    }

    private void createHiveServer() {
        HiveServer2 hiveServer2 = new HiveServer2();
        hiveServer2.init(new HiveConf());
        for (CLIService cLIService : hiveServer2.getServices()) {
            if (cLIService instanceof CLIService) {
                this._client = cLIService;
            }
        }
        Preconditions.checkNotNull(this._client, "CLI service not found in local Hive server");
        try {
            this._sessionHandle = this._client.openSession((String) null, (String) null, (Map) null);
            this._functionRegistry = SessionState.getRegistryForWrite();
            this._functionRegistry.registerGenericUDF("map_from_entries", MapFromEntriesWrapper.class, new FunctionInfo.FunctionResource[0]);
            this._functionRegistryAddFunctionMethod = this._functionRegistry.getClass().getDeclaredMethod("addFunction", String.class, FunctionInfo.class);
            this._functionRegistryAddFunctionMethod.setAccessible(true);
        } catch (HiveSQLException | NoSuchMethodException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public void setup(Map<Class<? extends TopLevelStdUDF>, List<Class<? extends StdUDF>>> map) {
        map.forEach((cls, list) -> {
            HiveTestStdUDFWrapper hiveTestStdUDFWrapper = new HiveTestStdUDFWrapper(cls, list);
            try {
                String functionName = ((TopLevelStdUDF) ((Class) list.get(0)).getConstructor(new Class[0]).newInstance(new Object[0])).getFunctionName();
                this._functionRegistryAddFunctionMethod.invoke(this._functionRegistry, functionName, new FunctionInfo(false, functionName, hiveTestStdUDFWrapper, new FunctionInfo.FunctionResource[0]));
            } catch (IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
                throw new RuntimeException("Error registering UDF " + cls.getName() + " with Hive Server", e);
            }
        });
    }

    public StdFactory getStdFactory() {
        return this._stdFactory;
    }

    public SqlFunctionCallGenerator getSqlFunctionCallGenerator() {
        return this._sqlFunctionCallGenerator;
    }

    public ToPlatformTestOutputConverter getToPlatformTestOutputConverter() {
        return this._platformOutputDataConverter;
    }

    public void assertFunctionCall(String str, Object obj, Object obj2) {
        String str2 = "SELECT " + str;
        try {
            OperationHandle executeStatement = this._client.executeStatement(this._sessionHandle, str2, (Map) null);
            if (!executeStatement.hasResultSet()) {
                throw new RuntimeException("Query did not return any rows. Query: \"" + str2 + "\"");
            }
            RowSet fetchResults = this._client.fetchResults(executeStatement);
            if (fetchResults.numRows() > 1 || fetchResults.numColumns() > 1) {
                throw new RuntimeException("Expected 1 row and 1 column in query output. Received " + fetchResults.numRows() + " rows and " + fetchResults.numColumns() + " columns.\nQuery: \"" + str2 + "\"");
            }
            Assert.assertEquals(((Object[]) fetchResults.iterator().next())[0], obj, "UDF output does not match");
            Assert.assertEquals(TypeInfoUtils.getTypeInfoFromTypeString(((ColumnDescriptor) this._client.getResultSetMetadata(executeStatement).getColumnDescriptors().get(0)).getTypeName().toLowerCase()), TypeInfoUtils.getTypeInfoFromObjectInspector((ObjectInspector) obj2), "UDF output type does not match");
        } catch (HiveSQLException e) {
            throw fetchUnderlyingException(e);
        }
    }

    private RuntimeException fetchUnderlyingException(Throwable th) {
        while (true) {
            if ((th instanceof HiveSQLException) || (th instanceof HiveException)) {
                th = th.getCause();
            } else {
                if (!(th instanceof IOException) || !(th.getCause() instanceof HiveException)) {
                    break;
                }
                th = th.getCause();
            }
        }
        return th instanceof RuntimeException ? (RuntimeException) th : new RuntimeException(th);
    }
}
