/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.AbstractRelNode;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
import org.apache.flink.table.planner.plan.rules.logical.ScalarFunctionSplitter;
import org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule;
import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
import scala.collection.Iterator;
import scala.collection.mutable.ArrayBuffer;

public class PythonCorrelateSplitRule
extends RelOptRule {
    public static final PythonCorrelateSplitRule INSTANCE = new PythonCorrelateSplitRule();

    private PythonCorrelateSplitRule() {
        super(PythonCorrelateSplitRule.operand(FlinkLogicalCorrelate.class, PythonCorrelateSplitRule.any()), "PythonCorrelateSplitRule");
    }

    private FlinkLogicalTableFunctionScan createNewScan(FlinkLogicalTableFunctionScan scan, ScalarFunctionSplitter splitter) {
        RexCall rightRexCall = (RexCall)scan.getCall();
        List<RexNode> rightCalcProjects = rightRexCall.getOperands().stream().map(x -> x.accept(splitter)).collect(Collectors.toList());
        RexCall newRightRexCall = rightRexCall.clone(rightRexCall.getType(), rightCalcProjects);
        return new FlinkLogicalTableFunctionScan(scan.getCluster(), scan.getTraitSet(), scan.getInputs(), newRightRexCall, scan.getElementType(), scan.getRowType(), scan.getColumnMappings());
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        FlinkLogicalTableFunctionScan tableFunctionScan;
        FlinkLogicalCorrelate correlate = (FlinkLogicalCorrelate)call.rel(0);
        RelNode right = ((HepRelVertex)correlate.getRight()).getCurrentRel();
        if (right instanceof FlinkLogicalTableFunctionScan) {
            tableFunctionScan = (FlinkLogicalTableFunctionScan)right;
        } else if (right instanceof FlinkLogicalCalc) {
            tableFunctionScan = StreamPhysicalCorrelateRule.getTableScan((FlinkLogicalCalc)right);
        } else {
            return false;
        }
        RexNode rexNode = tableFunctionScan.getCall();
        if (rexNode instanceof RexCall) {
            return PythonUtil.isPythonCall(rexNode, null) && PythonUtil.containsNonPythonCall(rexNode) || PythonUtil.isNonPythonCall(rexNode) && PythonUtil.containsPythonCall(rexNode, null) || PythonUtil.isPythonCall(rexNode, null) && RexUtil.containsFieldAccess(rexNode);
        }
        return false;
    }

    private List<String> createNewFieldNames(RelDataType rowType, final RexBuilder rexBuilder, int primitiveFieldCount, ArrayBuffer<RexNode> extractedRexNodes, List<RexNode> calcProjects) {
        for (int i = 0; i < primitiveFieldCount; ++i) {
            calcProjects.add(RexInputRef.of(i, rowType));
        }
        RexDefaultVisitor<RexNode> visitor = new RexDefaultVisitor<RexNode>(){

            @Override
            public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
                RexNode expr = fieldAccess.getReferenceExpr();
                if (expr instanceof RexCorrelVariable) {
                    RelDataTypeField field = fieldAccess.getField();
                    return new RexInputRef(field.getIndex(), field.getType());
                }
                return rexBuilder.makeFieldAccess(expr.accept(this), fieldAccess.getField().getIndex());
            }

            @Override
            public RexNode visitNode(RexNode rexNode) {
                return rexNode;
            }
        };
        for (RexNode rexNode : extractedRexNodes) {
            if (rexNode instanceof RexCall) {
                RexCall rexCall = (RexCall)rexNode;
                List<RexNode> newProjects = rexCall.getOperands().stream().map(x -> (RexNode)x.accept(visitor)).collect(Collectors.toList());
                RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects);
                calcProjects.add(newRexCall);
                continue;
            }
            calcProjects.add(rexNode);
        }
        LinkedList<String> nameList = new LinkedList<String>();
        for (int i = 0; i < primitiveFieldCount; ++i) {
            nameList.add(rowType.getFieldNames().get(i));
        }
        Iterator indicesIterator = extractedRexNodes.indices().iterator();
        while (indicesIterator.hasNext()) {
            nameList.add("f" + indicesIterator.next());
        }
        return SqlValidatorUtil.uniquify(nameList, rexBuilder.getTypeFactory().getTypeSystem().isSchemaCaseSensitive());
    }

    private FlinkLogicalCalc createNewLeftCalc(RelNode left, RexBuilder rexBuilder, ArrayBuffer<RexNode> extractedRexNodes, FlinkLogicalCorrelate correlate) {
        LinkedList<RexNode> leftCalcProjects = new LinkedList<RexNode>();
        RelDataType leftRowType = left.getRowType();
        List<String> leftCalcCalcFieldNames = this.createNewFieldNames(leftRowType, rexBuilder, leftRowType.getFieldCount(), extractedRexNodes, leftCalcProjects);
        return new FlinkLogicalCalc(correlate.getCluster(), correlate.getTraitSet(), left, RexProgram.create(leftRowType, leftCalcProjects, null, leftCalcCalcFieldNames, rexBuilder));
    }

    private FlinkLogicalCalc createTopCalc(int primitiveLeftFieldCount, RexBuilder rexBuilder, ArrayBuffer<RexNode> extractedRexNodes, RelDataType calcRowType, FlinkLogicalCorrelate newCorrelate) {
        RexProgram rexProgram = new RexProgramBuilder(newCorrelate.getRowType(), rexBuilder).getProgram();
        int offset = extractedRexNodes.size() + primitiveLeftFieldCount;
        List newTopCalcProjects = rexProgram.getExprList().stream().filter(x -> x instanceof RexInputRef).filter(x -> {
            int index = ((RexInputRef)x).getIndex();
            return index < primitiveLeftFieldCount || index >= offset;
        }).collect(Collectors.toList());
        return new FlinkLogicalCalc(newCorrelate.getCluster(), newCorrelate.getTraitSet(), newCorrelate, RexProgram.create(newCorrelate.getRowType(), newTopCalcProjects, null, calcRowType, rexBuilder));
    }

    private ScalarFunctionSplitter createScalarFunctionSplitter(RexProgram program, RexBuilder rexBuilder, int primitiveLeftFieldCount, ArrayBuffer<RexNode> extractedRexNodes, RexNode tableFunctionNode) {
        return new ScalarFunctionSplitter(program, rexBuilder, primitiveLeftFieldCount, extractedRexNodes, node -> {
            if (PythonUtil.isNonPythonCall(tableFunctionNode)) {
                return PythonUtil.isPythonCall(node, null);
            }
            if (PythonUtil.containsNonPythonCall(node)) {
                return PythonUtil.isNonPythonCall(node);
            }
            return node instanceof RexFieldAccess;
        });
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        FlinkLogicalCorrelate newCorrelate;
        AbstractRelNode rightNewInput;
        FlinkLogicalCorrelate correlate = (FlinkLogicalCorrelate)call.rel(0);
        RexBuilder rexBuilder = call.builder().getRexBuilder();
        RelNode left = ((HepRelVertex)correlate.getLeft()).getCurrentRel();
        RelNode right = ((HepRelVertex)correlate.getRight()).getCurrentRel();
        int primitiveLeftFieldCount = left.getRowType().getFieldCount();
        ArrayBuffer extractedRexNodes = new ArrayBuffer();
        if (right instanceof FlinkLogicalTableFunctionScan) {
            FlinkLogicalTableFunctionScan scan = (FlinkLogicalTableFunctionScan)right;
            rightNewInput = this.createNewScan(scan, this.createScalarFunctionSplitter(null, rexBuilder, primitiveLeftFieldCount, (ArrayBuffer<RexNode>)extractedRexNodes, scan.getCall()));
        } else {
            FlinkLogicalCalc calc = (FlinkLogicalCalc)right;
            FlinkLogicalTableFunctionScan scan = StreamPhysicalCorrelateRule.getTableScan(calc);
            FlinkLogicalCalc mergedCalc = StreamPhysicalCorrelateRule.getMergedCalc(calc);
            FlinkLogicalTableFunctionScan newScan = this.createNewScan(scan, this.createScalarFunctionSplitter(null, rexBuilder, primitiveLeftFieldCount, (ArrayBuffer<RexNode>)extractedRexNodes, scan.getCall()));
            rightNewInput = mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram());
        }
        if (extractedRexNodes.size() > 0) {
            FlinkLogicalCalc leftCalc = this.createNewLeftCalc(left, rexBuilder, (ArrayBuffer<RexNode>)extractedRexNodes, correlate);
            newCorrelate = new FlinkLogicalCorrelate(correlate.getCluster(), correlate.getTraitSet(), leftCalc, rightNewInput, correlate.getCorrelationId(), correlate.getRequiredColumns(), correlate.getJoinType());
        } else {
            newCorrelate = new FlinkLogicalCorrelate(correlate.getCluster(), correlate.getTraitSet(), left, rightNewInput, correlate.getCorrelationId(), correlate.getRequiredColumns(), correlate.getJoinType());
        }
        FlinkLogicalCalc newTopCalc = this.createTopCalc(primitiveLeftFieldCount, rexBuilder, (ArrayBuffer<RexNode>)extractedRexNodes, correlate.getRowType(), newCorrelate);
        call.transformTo(newTopCalc);
    }
}

