package com.linkedin.spark.datasources.tfrecord;

import com.linkedin.spark.shaded.org.tensorflow.example.Example;
import com.linkedin.spark.shaded.org.tensorflow.example.Feature;
import com.linkedin.spark.shaded.org.tensorflow.example.FeatureList;
import com.linkedin.spark.shaded.org.tensorflow.example.SequenceExample;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.ArrayType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.NullType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import scala.$less$colon$less$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.collection.IterableOnceOps;
import scala.collection.mutable.Iterable;
import scala.collection.mutable.Map;
import scala.collection.mutable.Map$;
import scala.jdk.CollectionConverters$;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;

/* compiled from: TensorFlowInferSchema.scala */
/* loaded from: input_file:com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema$.class */
public final class TensorFlowInferSchema$ {
    public static final TensorFlowInferSchema$ MODULE$ = new TensorFlowInferSchema$();

    public <T> StructType apply(RDD<T> rdd, TypeTags.TypeTag<T> typeTag) {
        Map map;
        Map map2 = (Map) Map$.MODULE$.empty();
        Types.TypeApi typeOf = package$.MODULE$.universe().typeOf(typeTag);
        TypeTags universe = package$.MODULE$.universe();
        TypeTags universe2 = package$.MODULE$.universe();
        if (typeOf.$eq$colon$eq(universe.typeOf(universe2.TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: com.linkedin.spark.datasources.tfrecord.TensorFlowInferSchema$$typecreator1$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("com.linkedin.spark.shaded.org.tensorflow.example.Example").asType().toTypeConstructor();
            }
        })))) {
            map = (Map) rdd.aggregate(map2, (map3, example) -> {
                return MODULE$.inferExampleRowType(map3, example);
            }, (map4, map5) -> {
                return MODULE$.mergeFieldTypes(map4, map5);
            }, ClassTag$.MODULE$.apply(Map.class));
        } else {
            TypeTags universe3 = package$.MODULE$.universe();
            TypeTags universe4 = package$.MODULE$.universe();
            if (!typeOf.$eq$colon$eq(universe3.typeOf(universe4.TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: com.linkedin.spark.datasources.tfrecord.TensorFlowInferSchema$$typecreator2$1
                public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                    mirror.universe();
                    return mirror.staticClass("com.linkedin.spark.shaded.org.tensorflow.example.SequenceExample").asType().toTypeConstructor();
                }
            })))) {
                throw new IllegalArgumentException("Unsupported recordType: recordType can be Example or SequenceExample");
            }
            map = (Map) rdd.aggregate(map2, (map6, sequenceExample) -> {
                return MODULE$.inferSequenceExampleRowType(map6, sequenceExample);
            }, (map7, map8) -> {
                return MODULE$.mergeFieldTypes(map7, map8);
            }, ClassTag$.MODULE$.apply(Map.class));
        }
        return StructType$.MODULE$.apply(((Iterable) map.map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            DataType dataType = (DataType) tuple2._2();
            return dataType == null ? new StructField(str, NullType$.MODULE$, StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()) : new StructField(str, dataType, StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4());
        })).toSeq());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Map<String, DataType> inferSequenceExampleRowType(Map<String, DataType> map, SequenceExample sequenceExample) {
        return inferFeatureListTypes(inferFeatureTypes(map, CollectionConverters$.MODULE$.MapHasAsScala(sequenceExample.getContext().getFeatureMap()).asScala()), CollectionConverters$.MODULE$.MapHasAsScala(sequenceExample.getFeatureLists().getFeatureListMap()).asScala());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Map<String, DataType> inferExampleRowType(Map<String, DataType> map, Example example) {
        return inferFeatureTypes(map, CollectionConverters$.MODULE$.MapHasAsScala(example.getFeatures().getFeatureMap()).asScala());
    }

    private Map<String, DataType> inferFeatureTypes(Map<String, DataType> map, Map<String, Feature> map2) {
        map2.foreach(tuple2 -> {
            BoxedUnit $plus$eq;
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            DataType inferField = MODULE$.inferField((Feature) tuple2._2());
            if (map.contains(str)) {
                map.update(str, MODULE$.findTightestCommonType((DataType) map.apply(str), inferField).orNull($less$colon$less$.MODULE$.refl()));
                $plus$eq = BoxedUnit.UNIT;
            } else {
                $plus$eq = map.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), inferField));
            }
            return $plus$eq;
        });
        return map;
    }

    public Map<String, DataType> inferFeatureListTypes(Map<String, DataType> map, Map<String, FeatureList> map2) {
        map2.foreach(tuple2 -> {
            BoxedUnit $plus$eq;
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            DataType dataType = (DataType) ((IterableOnceOps) CollectionConverters$.MODULE$.ListHasAsScala(((FeatureList) tuple2._2()).getFeatureList()).asScala().map(feature -> {
                return MODULE$.inferField(feature);
            })).reduceLeft((dataType2, dataType3) -> {
                return (DataType) MODULE$.findTightestCommonType(dataType2, dataType3).orNull($less$colon$less$.MODULE$.refl());
            });
            ArrayType apply = dataType instanceof ArrayType ? ArrayType$.MODULE$.apply(dataType) : ArrayType$.MODULE$.apply(ArrayType$.MODULE$.apply(dataType));
            if (map.contains(str)) {
                map.update(str, MODULE$.findTightestCommonType((DataType) map.apply(str), apply).orNull($less$colon$less$.MODULE$.refl()));
                $plus$eq = BoxedUnit.UNIT;
            } else {
                $plus$eq = map.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), apply));
            }
            return $plus$eq;
        });
        return map;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Map<String, DataType> mergeFieldTypes(Map<String, DataType> map, Map<String, DataType> map2) {
        return (Map) Map$.MODULE$.apply(((IterableOnceOps) map.keySet().$plus$plus(map2.keySet()).map(str -> {
            return new Tuple2(str, MODULE$.findTightestCommonType((DataType) map.getOrElse(str, () -> {
                return null;
            }), (DataType) map2.getOrElse(str, () -> {
                return null;
            })).get());
        })).toSeq());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public DataType inferField(Feature feature) {
        switch (feature.getKindCase().getNumber()) {
            case 1:
                return parseBytesList(feature);
            case 2:
                return parseFloatList(feature);
            case 3:
                return parseInt64List(feature);
            default:
                throw new RuntimeException("unsupported type ...");
        }
    }

    private DataType parseBytesList(Feature feature) {
        int valueCount = feature.getBytesList().getValueCount();
        if (valueCount == 0) {
            return null;
        }
        return valueCount > 1 ? ArrayType$.MODULE$.apply(StringType$.MODULE$) : StringType$.MODULE$;
    }

    private DataType parseInt64List(Feature feature) {
        int length = ((Long[]) CollectionConverters$.MODULE$.ListHasAsScala(feature.getInt64List().getValueList()).asScala().toArray(ClassTag$.MODULE$.apply(Long.class))).length;
        if (length == 0) {
            return null;
        }
        return length > 1 ? ArrayType$.MODULE$.apply(LongType$.MODULE$) : LongType$.MODULE$;
    }

    private DataType parseFloatList(Feature feature) {
        int length = ((Float[]) CollectionConverters$.MODULE$.ListHasAsScala(feature.getFloatList().getValueList()).asScala().toArray(ClassTag$.MODULE$.apply(Float.class))).length;
        if (length == 0) {
            return null;
        }
        return length > 1 ? ArrayType$.MODULE$.apply(FloatType$.MODULE$) : FloatType$.MODULE$;
    }

    private int getNumericPrecedence(DataType dataType) {
        int i;
        boolean z = false;
        ArrayType arrayType = null;
        if (LongType$.MODULE$.equals(dataType)) {
            i = 1;
        } else if (FloatType$.MODULE$.equals(dataType)) {
            i = 2;
        } else {
            if (!StringType$.MODULE$.equals(dataType)) {
                if (dataType instanceof ArrayType) {
                    z = true;
                    arrayType = (ArrayType) dataType;
                    if (LongType$.MODULE$.equals(arrayType.elementType())) {
                        i = 4;
                    }
                }
                if (z) {
                    if (FloatType$.MODULE$.equals(arrayType.elementType())) {
                        i = 5;
                    }
                }
                if (z) {
                    if (StringType$.MODULE$.equals(arrayType.elementType())) {
                        i = 6;
                    }
                }
                if (z) {
                    ArrayType elementType = arrayType.elementType();
                    if (elementType instanceof ArrayType) {
                        if (LongType$.MODULE$.equals(elementType.elementType())) {
                            i = 7;
                        }
                    }
                }
                if (z) {
                    ArrayType elementType2 = arrayType.elementType();
                    if (elementType2 instanceof ArrayType) {
                        if (FloatType$.MODULE$.equals(elementType2.elementType())) {
                            i = 8;
                        }
                    }
                }
                if (z) {
                    ArrayType elementType3 = arrayType.elementType();
                    if (elementType3 instanceof ArrayType) {
                        if (StringType$.MODULE$.equals(elementType3.elementType())) {
                            i = 9;
                        }
                    }
                }
                throw new RuntimeException("Unable to get the precedence for given datatype...");
            }
            i = 3;
        }
        return i;
    }

    private Option<DataType> findTightestCommonType(DataType dataType, DataType dataType2) {
        Some some;
        Tuple2 tuple2 = new Tuple2(dataType, dataType2);
        if (tuple2 != null) {
            DataType dataType3 = (DataType) tuple2._1();
            DataType dataType4 = (DataType) tuple2._2();
            if (dataType3 != null ? dataType3.equals(dataType4) : dataType4 == null) {
                some = new Some(dataType3);
                return some;
            }
        }
        if (tuple2 != null) {
            DataType dataType5 = (DataType) tuple2._1();
            DataType dataType6 = (DataType) tuple2._2();
            if (dataType5 == null) {
                some = new Some(dataType6);
                return some;
            }
        }
        if (tuple2 != null) {
            DataType dataType7 = (DataType) tuple2._1();
            if (((DataType) tuple2._2()) == null) {
                some = new Some(dataType7);
                return some;
            }
        }
        if (tuple2 != null) {
            DataType dataType8 = (DataType) tuple2._1();
            DataType dataType9 = (DataType) tuple2._2();
            some = new Some(getNumericPrecedence(dataType8) > getNumericPrecedence(dataType9) ? dataType8 : dataType9);
        } else {
            some = None$.MODULE$;
        }
        return some;
    }

    private TensorFlowInferSchema$() {
    }
}
