package com.linkedin.coral.vis;

import com.linkedin.coral.hive.hive2rel.rel.HiveUncollect;
import guru.nidi.graphviz.attribute.Label;
import guru.nidi.graphviz.model.Factory;
import guru.nidi.graphviz.model.Link;
import guru.nidi.graphviz.model.LinkTarget;
import guru.nidi.graphviz.model.Node;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.UUID;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttleImpl;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataTypeField;

/* loaded from: input_file:com/linkedin/coral/vis/RelNodeVisualizationShuttle.class */
public class RelNodeVisualizationShuttle extends RelShuttleImpl {
    private final Map<RelNode, Node> nodeMap = new HashMap();

    public RelNode visit(LogicalAggregate logicalAggregate) {
        super.visit(logicalAggregate);
        return logicalAggregate;
    }

    public RelNode visit(TableScan tableScan) {
        super.visit(tableScan);
        this.nodeMap.put(tableScan, node(String.join(".", tableScan.getTable().getQualifiedName())));
        return tableScan;
    }

    public RelNode visit(TableFunctionScan tableFunctionScan) {
        super.visit(tableFunctionScan);
        this.nodeMap.put(tableFunctionScan, node(tableFunctionScan.getCall().toString()));
        return tableFunctionScan;
    }

    public RelNode visit(LogicalValues logicalValues) {
        super.visit(logicalValues);
        String str = logicalValues.getRowType().toString() + " values";
        this.nodeMap.put(logicalValues, node(logicalValues.getRowType().toString() + " values"));
        return logicalValues;
    }

    public RelNode visit(LogicalFilter logicalFilter) {
        super.visit(logicalFilter);
        this.nodeMap.put(logicalFilter, node("Filter: " + logicalFilter.getCondition().toString()).link(new LinkTarget[]{(Node) this.nodeMap.get(logicalFilter.getInput())}));
        return logicalFilter;
    }

    public RelNode visit(LogicalProject logicalProject) {
        super.visit(logicalProject);
        this.nodeMap.put(logicalProject, node("Project: " + logicalProject.getProjects().toString()).link(new LinkTarget[]{(Node) this.nodeMap.get(logicalProject.getInput())}));
        return logicalProject;
    }

    public RelNode visit(LogicalJoin logicalJoin) {
        super.visit(logicalJoin);
        Node node = node("Join: " + logicalJoin.getCondition().toString());
        RelNode left = logicalJoin.getLeft();
        RelNode right = logicalJoin.getRight();
        this.nodeMap.put(logicalJoin, node.link(new LinkTarget[]{edge(left, getLabel(left, 0, false)), edge(right, getLabel(right, left.getRowType().getFieldCount(), false))}));
        return logicalJoin;
    }

    public RelNode visit(LogicalCorrelate logicalCorrelate) {
        super.visit(logicalCorrelate);
        Node node = node("Correlate: " + logicalCorrelate.getCorrelVariable());
        RelNode left = logicalCorrelate.getLeft();
        RelNode right = logicalCorrelate.getRight();
        this.nodeMap.put(logicalCorrelate, node.link(new LinkTarget[]{edge(left, getLabel(left, 0, true)), edge(right, getLabel(right, left.getRowType().getFieldCount(), false))}));
        return logicalCorrelate;
    }

    private String getLabel(RelNode relNode, int i, boolean z) {
        String str = "";
        String str2 = z ? "$cor" : "$";
        for (int i2 = 0; i2 < relNode.getRowType().getFieldCount(); i2++) {
            str = str + str2 + (i2 + i) + " = " + ((RelDataTypeField) relNode.getRowType().getFieldList().get(i2)).getName() + "\n";
        }
        return str;
    }

    public RelNode visit(RelNode relNode) {
        super.visit(relNode);
        if (relNode instanceof HiveUncollect) {
            this.nodeMap.put(relNode, node("Uncollect").link(new LinkTarget[]{(Node) this.nodeMap.get(relNode.getInput(0))}));
        }
        return relNode;
    }

    public RelNode visit(LogicalUnion logicalUnion) {
        super.visit(logicalUnion);
        Node node = node("Union");
        ArrayList arrayList = new ArrayList();
        Iterator it = logicalUnion.getInputs().iterator();
        while (it.hasNext()) {
            arrayList.add(edge((RelNode) it.next(), ""));
        }
        this.nodeMap.put(logicalUnion, node.link((LinkTarget[]) arrayList.toArray(new LinkTarget[0])));
        return logicalUnion;
    }

    public RelNode visit(LogicalIntersect logicalIntersect) {
        super.visit(logicalIntersect);
        return logicalIntersect;
    }

    public RelNode visit(LogicalMinus logicalMinus) {
        super.visit(logicalMinus);
        return logicalMinus;
    }

    public RelNode visit(LogicalSort logicalSort) {
        super.visit(logicalSort);
        return logicalSort;
    }

    public Node getNode(RelNode relNode) {
        return this.nodeMap.get(relNode);
    }

    private Link edge(RelNode relNode, String str) {
        return Factory.to(this.nodeMap.get(relNode)).with(Label.of(str));
    }

    private static Node node(String str) {
        return (Node) Factory.node(UUID.randomUUID().toString()).with(Label.of(str));
    }
}
