/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.rules.ImmutableFilterIntoJoinRuleConfig;
import org.apache.calcite.rel.rules.ImmutableJoinConditionPushRuleConfig;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

public abstract class FilterJoinRule<C extends Config>
extends RelRule<C>
implements TransformationRule {
    @Deprecated
    public static final Predicate TRUE_PREDICATE = (join, joinType, exp) -> true;

    protected FilterJoinRule(C config) {
        super(config);
    }

    protected void perform(RelOptRuleCall call, @Nullable Filter filter, Join join) {
        List<RexNode> joinFilters = RelOptUtil.conjunctions(join.getCondition());
        ImmutableList origJoinFilters = ImmutableList.copyOf(joinFilters);
        if (filter == null && joinFilters.isEmpty()) {
            return;
        }
        ArrayList<RexNode> aboveFilters = filter != null ? FilterJoinRule.getConjunctions(filter) : new ArrayList<RexNode>();
        ImmutableList origAboveFilters = ImmutableList.copyOf(aboveFilters);
        JoinRelType joinType = join.getJoinType();
        if (((Config)this.config).isSmart() && !origAboveFilters.isEmpty() && join.getJoinType() != JoinRelType.INNER) {
            joinType = RelOptUtil.simplifyJoin(join, (ImmutableList<RexNode>)origAboveFilters, joinType);
        }
        ArrayList<RexNode> leftFilters = new ArrayList<RexNode>();
        ArrayList<RexNode> rightFilters = new ArrayList<RexNode>();
        boolean filterPushed = RelOptUtil.classifyFilters(join, aboveFilters, joinType.canPushIntoFromAbove(), joinType.canPushLeftFromAbove(), joinType.canPushRightFromAbove(), joinFilters, leftFilters, rightFilters);
        this.validateJoinFilters(aboveFilters, joinFilters, join, joinType);
        if (leftFilters.isEmpty() && rightFilters.isEmpty() && joinFilters.size() == origJoinFilters.size() && aboveFilters.size() == origAboveFilters.size() && Sets.newHashSet(joinFilters).equals(Sets.newHashSet((Iterable)origJoinFilters))) {
            filterPushed = false;
        }
        if (joinType != JoinRelType.FULL) {
            joinFilters = this.inferJoinEqualConditions(joinFilters, join);
        }
        if (RelOptUtil.classifyFilters(join, joinFilters, false, joinType.canPushLeftFromWithin(), joinType.canPushRightFromWithin(), joinFilters, leftFilters, rightFilters)) {
            filterPushed = true;
        }
        if (!filterPushed && joinType == join.getJoinType() || joinFilters.isEmpty() && leftFilters.isEmpty() && rightFilters.isEmpty()) {
            return;
        }
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        RelBuilder relBuilder = call.builder();
        RelNode leftRel = relBuilder.push(join.getLeft()).filter(leftFilters).build();
        RelNode rightRel = relBuilder.push(join.getRight()).filter(rightFilters).build();
        ImmutableList fieldTypes = ImmutableList.builder().addAll(RelOptUtil.getFieldTypeList(leftRel.getRowType())).addAll(RelOptUtil.getFieldTypeList(rightRel.getRowType())).build();
        RexNode joinFilter = RexUtil.composeConjunction(rexBuilder, RexUtil.fixUp(rexBuilder, joinFilters, (List<RelDataType>)fieldTypes));
        if (joinFilter.isAlwaysTrue() && leftFilters.isEmpty() && rightFilters.isEmpty() && joinType == join.getJoinType()) {
            return;
        }
        Join newJoinRel = join.copy(join.getTraitSet(), joinFilter, leftRel, rightRel, joinType, join.isSemiJoinDone());
        call.getPlanner().onCopy(join, newJoinRel);
        if (!leftFilters.isEmpty() && filter != null) {
            call.getPlanner().onCopy(filter, leftRel);
        }
        if (!rightFilters.isEmpty() && filter != null) {
            call.getPlanner().onCopy(filter, rightRel);
        }
        relBuilder.push(newJoinRel);
        relBuilder.convert(join.getRowType(), false);
        relBuilder.filter(RexUtil.fixUp(rexBuilder, aboveFilters, RelOptUtil.getFieldTypeList(relBuilder.peek().getRowType())));
        call.transformTo(relBuilder.build());
    }

    protected List<RexNode> inferJoinEqualConditions(List<RexNode> rexNodes, Join join) {
        ArrayList<RexNode> result = new ArrayList<RexNode>(rexNodes.size());
        List<Set<RexInputRef>> equalSets = FilterJoinRule.splitEqualSets(rexNodes, result);
        boolean needOptimize = false;
        for (Set<RexInputRef> set : equalSets) {
            if (set.size() <= 2) continue;
            needOptimize = true;
            break;
        }
        if (!needOptimize) {
            return rexNodes;
        }
        result.addAll(FilterJoinRule.constructConditionFromEqualSets(join, equalSets));
        return result;
    }

    private static List<Set<RexInputRef>> splitEqualSets(List<RexNode> rexNodes, List<RexNode> leftNodes) {
        ArrayList<Set<RexInputRef>> equalSets = new ArrayList<Set<RexInputRef>>();
        for (RexNode rexNode : rexNodes) {
            if (rexNode.isA(SqlKind.EQUALS)) {
                RexNode op1 = ((RexCall)rexNode).getOperands().get(0);
                RexNode op2 = ((RexCall)rexNode).getOperands().get(1);
                if (op1 instanceof RexInputRef && op2 instanceof RexInputRef) {
                    RexInputRef in1 = (RexInputRef)op1;
                    RexInputRef in2 = (RexInputRef)op2;
                    Set<RexInputRef> set = null;
                    for (Set set2 : equalSets) {
                        if (!set2.contains(in1) && !set2.contains(in2)) continue;
                        set = set2;
                        break;
                    }
                    if (set == null) {
                        set = new LinkedHashSet<RexInputRef>();
                        equalSets.add(set);
                    }
                    set.add(in1);
                    set.add(in2);
                    continue;
                }
                leftNodes.add(rexNode);
                continue;
            }
            leftNodes.add(rexNode);
        }
        return equalSets;
    }

    private static List<RexNode> constructConditionFromEqualSets(Join join, List<Set<RexInputRef>> equalSets) {
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        ArrayList<RexNode> result = new ArrayList<RexNode>();
        int leftFieldCount = join.getLeft().getRowType().getFieldCount();
        for (Set<RexInputRef> set : equalSets) {
            ArrayList<RexInputRef> leftSet = new ArrayList<RexInputRef>();
            ArrayList<RexInputRef> rightSet = new ArrayList<RexInputRef>();
            for (RexInputRef ref : set) {
                if (ref.getIndex() < leftFieldCount) {
                    leftSet.add(ref);
                    continue;
                }
                rightSet.add(ref);
            }
            if (leftSet.size() > 1) {
                for (int i = 1; i < leftSet.size(); ++i) {
                    result.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, (RexNode)leftSet.get(0), (RexNode)leftSet.get(i)));
                }
            }
            if (rightSet.size() > 1) {
                for (int i = 1; i < rightSet.size(); ++i) {
                    result.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, (RexNode)rightSet.get(0), (RexNode)rightSet.get(i)));
                }
            }
            if (leftSet.isEmpty() || rightSet.isEmpty()) continue;
            result.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, (RexNode)leftSet.get(0), (RexNode)rightSet.get(0)));
        }
        return result;
    }

    private static List<RexNode> getConjunctions(Filter filter) {
        List<RexNode> conjunctions = RelOptUtil.conjunctions(filter.getCondition());
        RexBuilder rexBuilder = filter.getCluster().getRexBuilder();
        for (int i = 0; i < conjunctions.size(); ++i) {
            RexNode node = conjunctions.get(i);
            if (!(node instanceof RexCall)) continue;
            conjunctions.set(i, RelOptUtil.collapseExpandedIsNotDistinctFromExpr((RexCall)node, rexBuilder));
        }
        return conjunctions;
    }

    protected void validateJoinFilters(List<RexNode> aboveFilters, List<RexNode> joinFilters, Join join, JoinRelType joinType) {
        Iterator<RexNode> filterIter = joinFilters.iterator();
        while (filterIter.hasNext()) {
            RexNode exp = filterIter.next();
            if (((Config)this.config).getPredicate().apply(join, joinType, exp) || !joinType.projectsRight()) continue;
            aboveFilters.add(exp);
            filterIter.remove();
        }
    }

    public static interface Config
    extends RelRule.Config {
        @Value.Default
        default public boolean isSmart() {
            return false;
        }

        public Config withSmart(boolean var1);

        @Value.Parameter
        public Predicate getPredicate();

        public Config withPredicate(Predicate var1);
    }

    @FunctionalInterface
    public static interface Predicate {
        public boolean apply(Join var1, JoinRelType var2, RexNode var3);
    }

    public static class FilterIntoJoinRule
    extends FilterJoinRule<FilterIntoJoinRuleConfig> {
        protected FilterIntoJoinRule(FilterIntoJoinRuleConfig config) {
            super(config);
        }

        @Deprecated
        public FilterIntoJoinRule(boolean smart, RelBuilderFactory relBuilderFactory, Predicate predicate) {
            this(ImmutableFilterIntoJoinRuleConfig.of(predicate).withRelBuilderFactory(relBuilderFactory).withOperandSupplier(b0 -> b0.operand(Filter.class).oneInput(b1 -> b1.operand(Join.class).anyInputs())).withDescription("FilterJoinRule:filter").withSmart(smart));
        }

        @Deprecated
        public FilterIntoJoinRule(boolean smart, RelFactories.FilterFactory filterFactory, RelFactories.ProjectFactory projectFactory, Predicate predicate) {
            this(ImmutableFilterIntoJoinRuleConfig.of(predicate).withRelBuilderFactory(RelBuilder.proto(filterFactory, projectFactory)).withOperandSupplier(b0 -> b0.operand(Filter.class).oneInput(b1 -> b1.operand(Join.class).anyInputs())).withDescription("FilterJoinRule:filter").withSmart(smart));
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Filter filter = (Filter)call.rel(0);
            Join join = (Join)call.rel(1);
            this.perform(call, filter, join);
        }

        @Value.Immutable(singleton=false)
        public static interface FilterIntoJoinRuleConfig
        extends Config {
            public static final FilterIntoJoinRuleConfig DEFAULT = ImmutableFilterIntoJoinRuleConfig.of((join, joinType, exp) -> true).withOperandSupplier(b0 -> b0.operand(Filter.class).oneInput(b1 -> b1.operand(Join.class).anyInputs())).withSmart(true);
            public static final FilterIntoJoinRuleConfig SMART_FALSE = ImmutableFilterIntoJoinRuleConfig.of((join, joinType, exp) -> true).withOperandSupplier(b0 -> b0.operand(Filter.class).oneInput(b1 -> b1.operand(Join.class).anyInputs())).withSmart(false);

            @Override
            default public FilterIntoJoinRule toRule() {
                return new FilterIntoJoinRule(this);
            }
        }
    }

    public static class JoinConditionPushRule
    extends FilterJoinRule<JoinConditionPushRuleConfig> {
        protected JoinConditionPushRule(JoinConditionPushRuleConfig config) {
            super(config);
        }

        @Deprecated
        public JoinConditionPushRule(RelBuilderFactory relBuilderFactory, Predicate predicate) {
            this(ImmutableJoinConditionPushRuleConfig.of(predicate).withRelBuilderFactory(relBuilderFactory).withOperandSupplier(b -> b.operand(Join.class).anyInputs()).withDescription("FilterJoinRule:no-filter").withSmart(true));
        }

        @Deprecated
        public JoinConditionPushRule(RelFactories.FilterFactory filterFactory, RelFactories.ProjectFactory projectFactory, Predicate predicate) {
            this(RelBuilder.proto(filterFactory, projectFactory), predicate);
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Join join = (Join)call.rel(0);
            this.perform(call, null, join);
        }

        @Value.Immutable(singleton=false)
        public static interface JoinConditionPushRuleConfig
        extends Config {
            public static final JoinConditionPushRuleConfig DEFAULT = ImmutableJoinConditionPushRuleConfig.of((join, joinType, exp) -> true).withOperandSupplier(b -> b.operand(Join.class).anyInputs()).withSmart(true);

            @Override
            default public JoinConditionPushRule toRule() {
                return new JoinConditionPushRule(this);
            }
        }
    }
}

