/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.rules.ImmutableSemiJoinProjectTransposeRule;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.immutables.value.Value;

@Value.Enclosing
public class SemiJoinProjectTransposeRule
extends RelRule<Config>
implements TransformationRule {
    protected SemiJoinProjectTransposeRule(Config config) {
        super(config);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Join semiJoin = (Join)call.rel(0);
        Project project = (Project)call.rel(1);
        RexNode newCondition = SemiJoinProjectTransposeRule.adjustCondition(project, semiJoin);
        LogicalJoin newSemiJoin = LogicalJoin.create(project.getInput(), semiJoin.getRight(), (List<RelHint>)ImmutableList.of(), newCondition, (Set<CorrelationId>)ImmutableSet.of(), JoinRelType.SEMI);
        RelBuilder relBuilder = call.builder();
        relBuilder.push(newSemiJoin);
        relBuilder.project(project.getProjects(), project.getRowType().getFieldNames());
        call.transformTo(relBuilder.build());
    }

    private static RexNode adjustCondition(Project project, Join semiJoin) {
        RexBuilder rexBuilder = project.getCluster().getRexBuilder();
        RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
        RelNode rightChild = semiJoin.getRight();
        RelDataType bottomInputRowType = SqlValidatorUtil.deriveJoinRowType(project.getInput().getRowType(), rightChild.getRowType(), JoinRelType.INNER, typeFactory, null, semiJoin.getSystemFieldList());
        RexProgramBuilder bottomProgramBuilder = new RexProgramBuilder(bottomInputRowType, rexBuilder);
        for (Pair<RexNode, String> pair : project.getNamedProjects()) {
            bottomProgramBuilder.addProject((RexNode)pair.left, (String)pair.right);
        }
        int nLeftFields = project.getInput().getRowType().getFieldCount();
        List<RelDataTypeField> rightFields = rightChild.getRowType().getFieldList();
        int nRightFields = rightFields.size();
        for (int i = 0; i < nRightFields; ++i) {
            RelDataTypeField field = rightFields.get(i);
            RexInputRef inputRef = rexBuilder.makeInputRef(field.getType(), i + nLeftFields);
            bottomProgramBuilder.addProject(inputRef, field.getName());
        }
        RexProgram bottomProgram = bottomProgramBuilder.getProgram();
        RelDataType topInputRowType = SqlValidatorUtil.deriveJoinRowType(project.getRowType(), rightChild.getRowType(), JoinRelType.INNER, typeFactory, null, semiJoin.getSystemFieldList());
        RexProgramBuilder topProgramBuilder = new RexProgramBuilder(topInputRowType, rexBuilder);
        topProgramBuilder.addIdentity();
        topProgramBuilder.addCondition(semiJoin.getCondition());
        RexProgram topProgram = topProgramBuilder.getProgram();
        RexProgram mergedProgram = RexProgramBuilder.mergePrograms(topProgram, bottomProgram, rexBuilder);
        return mergedProgram.expandLocalRef(Objects.requireNonNull(mergedProgram.getCondition(), () -> "mergedProgram.getCondition() for " + mergedProgram));
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableSemiJoinProjectTransposeRule.Config.of().withOperandFor(LogicalJoin.class, LogicalProject.class);

        @Override
        default public SemiJoinProjectTransposeRule toRule() {
            return new SemiJoinProjectTransposeRule(this);
        }

        default public Config withOperandFor(Class<? extends Join> joinClass, Class<? extends Project> projectClass) {
            return this.withOperandSupplier(b -> b.operand(joinClass).predicate(Join::isSemiJoin).inputs(b2 -> b2.operand(projectClass).anyInputs())).as(Config.class);
        }
    }
}

