/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.LogicVisitor;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlQuantifyOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSubqueryRuntimeException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelShuttleImpl;
import org.apache.hadoop.hive.ql.optimizer.calcite.correlation.HiveCorrelationInfo;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortLimit;

public class HiveSubQueryRemoveRule
extends RelOptRule {
    private final HiveConf conf;

    public static RelOptRule forProject(HiveConf conf) {
        return new HiveSubQueryRemoveRule(RelOptRule.operandJ(HiveProject.class, null, RexUtil.SubQueryFinder::containsSubQuery, (RelOptRuleOperandChildren)HiveSubQueryRemoveRule.any()), "SubQueryRemoveRule:Project", conf);
    }

    public static RelOptRule forFilter(HiveConf conf) {
        return new HiveSubQueryRemoveRule(RelOptRule.operandJ(HiveFilter.class, null, RexUtil.SubQueryFinder::containsSubQuery, (RelOptRuleOperandChildren)HiveSubQueryRemoveRule.any()), "SubQueryRemoveRule:Filter", conf);
    }

    private HiveSubQueryRemoveRule(RelOptRuleOperand operand, String description, HiveConf conf) {
        super(operand, HiveRelFactories.HIVE_BUILDER, description);
        this.conf = conf;
    }

    public void onMatch(RelOptRuleCall call) {
        RelNode relNode = call.rel(0);
        RelBuilder builder = call.builder();
        if (relNode instanceof HiveFilter) {
            HiveFilter filter = (HiveFilter)call.rel(0);
            Preconditions.checkState((!filter.getCorrelationInfos().isEmpty() ? 1 : 0) != 0);
            HiveCorrelationInfo correlationInfo = filter.getCorrelationInfos().get(0);
            RelOptUtil.Logic logic = LogicVisitor.find((RelOptUtil.Logic)RelOptUtil.Logic.TRUE, (List)ImmutableList.of((Object)filter.getCondition()), (RexNode)correlationInfo.rexSubQuery);
            builder.push(filter.getInput());
            int fieldCount = builder.peek().getRowType().getFieldCount();
            boolean isCorrScalarQuery = correlationInfo.isCorrScalarQuery();
            RexNode target = this.apply(call.getMetadataQuery(), correlationInfo.rexSubQuery, correlationInfo.correlationIds, logic, builder, 1, fieldCount, isCorrScalarQuery);
            ReplaceSubQueryShuttle shuttle = new ReplaceSubQueryShuttle(correlationInfo.rexSubQuery, target);
            builder.filter(new RexNode[]{shuttle.apply(filter.getCondition())});
            builder.project(HiveSubQueryRemoveRule.fields(builder, filter.getRowType().getFieldCount()));
            RelNode newRel = builder.build();
            call.transformTo(newRel);
        } else if (relNode instanceof HiveProject) {
            HiveProject project = (HiveProject)call.rel(0);
            Preconditions.checkState((!project.getCorrelationInfos().isEmpty() ? 1 : 0) != 0);
            HiveCorrelationInfo correlationInfo = project.getCorrelationInfos().get(0);
            RelOptUtil.Logic logic = LogicVisitor.find((RelOptUtil.Logic)RelOptUtil.Logic.TRUE_FALSE_UNKNOWN, (List)project.getProjects(), (RexNode)correlationInfo.rexSubQuery);
            builder.push(project.getInput());
            int fieldCount = builder.peek().getRowType().getFieldCount();
            boolean isCorrScalarQuery = correlationInfo.isCorrScalarQuery();
            RexNode target = this.apply(call.getMetadataQuery(), correlationInfo.rexSubQuery, correlationInfo.correlationIds, logic, builder, 1, fieldCount, isCorrScalarQuery);
            ReplaceSubQueryShuttle shuttle = new ReplaceSubQueryShuttle(correlationInfo.rexSubQuery, target);
            builder.project((Iterable)shuttle.apply(project.getProjects()), (Iterable)project.getRowType().getFieldNames());
            call.transformTo(builder.build());
        }
    }

    private boolean isAggZeroOnEmpty(RexSubQuery e) {
        assert (e.getKind() == SqlKind.SCALAR_QUERY);
        assert (e.rel.getInputs().size() == 1);
        Aggregate relAgg = (Aggregate)e.rel.getInput(0);
        assert (relAgg.getAggCallList().size() == 1);
        return ((AggregateCall)relAgg.getAggCallList().get(0)).getAggregation().getKind() == SqlKind.COUNT;
    }

    private SqlTypeName getAggTypeForScalarSub(RexSubQuery e) {
        assert (e.getKind() == SqlKind.SCALAR_QUERY);
        assert (e.rel.getInputs().size() == 1);
        Aggregate relAgg = (Aggregate)e.rel.getInput(0);
        assert (relAgg.getAggCallList().size() == 1);
        return ((AggregateCall)relAgg.getAggCallList().get(0)).getType().getSqlTypeName();
    }

    private RexNode rewriteScalar(RelMetadataQuery mq, RexSubQuery e, Set<CorrelationId> variablesSet, RelBuilder builder, int offset, int inputCount, boolean isCorrScalarAgg) {
        boolean shouldIntroSQCountCheck;
        Double maxRowCount = mq.getMaxRowCount(e.rel);
        boolean bl = shouldIntroSQCountCheck = maxRowCount == null || maxRowCount > 1.0;
        if (shouldIntroSQCountCheck) {
            builder.push(e.rel);
            builder.aggregate(builder.groupKey(), new RelBuilder.AggCall[]{builder.count(false, "cnt", new RexNode[0])});
            SqlFunction countCheck = new SqlFunction("sq_count_check", SqlKind.OTHER_FUNCTION, ReturnTypes.BOOLEAN, InferTypes.RETURN_TYPE, (SqlOperandTypeChecker)OperandTypes.NUMERIC, SqlFunctionCategory.USER_DEFINED_FUNCTION);
            builder.filter(new RexNode[]{builder.call((SqlOperator)countCheck, new RexNode[]{builder.field("cnt")})});
            if (!variablesSet.isEmpty()) {
                builder.join(JoinRelType.LEFT, (RexNode)builder.literal((Object)true), variablesSet);
            } else {
                builder.join(JoinRelType.INNER, (RexNode)builder.literal((Object)true), variablesSet);
            }
            ++offset;
        }
        if (isCorrScalarAgg) {
            builder.push(e.rel);
            ArrayList<RexNode> parentQueryFields = new ArrayList<RexNode>();
            parentQueryFields.addAll((Collection<RexNode>)builder.fields());
            String indicator = "trueLiteral";
            parentQueryFields.add(builder.alias((RexNode)builder.literal((Object)true), indicator));
            builder.project(parentQueryFields);
            builder.join(JoinRelType.LEFT, (RexNode)builder.literal((Object)true), variablesSet);
            ImmutableList.Builder operands = ImmutableList.builder();
            Object literal = this.isAggZeroOnEmpty(e) ? e.rel.getCluster().getRexBuilder().makeBigintLiteral(new BigDecimal(0)) : e.rel.getCluster().getRexBuilder().makeNullLiteral(this.getAggTypeForScalarSub(e));
            operands.add((Object[])new RexNode[]{builder.isNull((RexNode)builder.field(indicator)), literal});
            operands.add((Object)this.field(builder, 1, builder.fields().size() - 2));
            return builder.call((SqlOperator)SqlStdOperatorTable.CASE, (Iterable)operands.build());
        }
        builder.push(e.rel);
        builder.join(JoinRelType.LEFT, (RexNode)builder.literal((Object)true), variablesSet);
        return this.field(builder, inputCount, offset);
    }

    private RexNode rewriteSomeAll(RexSubQuery e, Set<CorrelationId> variablesSet, RelBuilder builder) {
        SqlQuantifyOperator op = (SqlQuantifyOperator)e.op;
        assert (op == SqlStdOperatorTable.SOME_GE || op == SqlStdOperatorTable.SOME_LE || op == SqlStdOperatorTable.SOME_LT || op == SqlStdOperatorTable.SOME_GT);
        if (variablesSet.isEmpty()) {
            builder.push(e.rel).aggregate(builder.groupKey(), new RelBuilder.AggCall[]{op.comparisonKind == SqlKind.GREATER_THAN || op.comparisonKind == SqlKind.GREATER_THAN_OR_EQUAL ? builder.min("m", (RexNode)builder.field(0)) : builder.max("m", (RexNode)builder.field(0)), builder.count(false, "c", new RexNode[0]), builder.count(false, "d", new RexNode[]{builder.field(0)})}).as("q").join(JoinRelType.INNER, new String[0]);
            return builder.call((SqlOperator)SqlStdOperatorTable.CASE, new RexNode[]{builder.call((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{builder.field("q", "c"), builder.literal((Object)0)}), builder.literal((Object)false), builder.call((SqlOperator)SqlStdOperatorTable.IS_TRUE, new RexNode[]{builder.call(RelOptUtil.op((SqlKind)op.comparisonKind, null), new RexNode[]{(RexNode)e.operands.get(0), builder.field("q", "m")})}), builder.literal((Object)true), builder.call((SqlOperator)SqlStdOperatorTable.GREATER_THAN, new RexNode[]{builder.field("q", "c"), builder.field("q", "d")}), e.rel.getCluster().getRexBuilder().makeNullLiteral(SqlTypeName.BOOLEAN), builder.call(RelOptUtil.op((SqlKind)op.comparisonKind, null), new RexNode[]{(RexNode)e.operands.get(0), builder.field("q", "m")})});
        }
        HiveSubQueryRemoveRule.subqueryRestriction(e.rel);
        builder.push(e.rel);
        builder.aggregate(builder.groupKey(), new RelBuilder.AggCall[]{op.comparisonKind == SqlKind.GREATER_THAN || op.comparisonKind == SqlKind.GREATER_THAN_OR_EQUAL ? builder.min("m", (RexNode)builder.field(0)) : builder.max("m", (RexNode)builder.field(0)), builder.count(false, "c", new RexNode[0]), builder.count(false, "d", new RexNode[]{builder.field(0)})});
        ArrayList<RexNode> parentQueryFields = new ArrayList<RexNode>();
        parentQueryFields.addAll((Collection<RexNode>)builder.fields());
        String indicator = "trueLiteral";
        parentQueryFields.add(builder.alias((RexNode)builder.literal((Object)true), indicator));
        builder.project(parentQueryFields).as("q");
        builder.join(JoinRelType.LEFT, (RexNode)builder.literal((Object)true), variablesSet);
        return builder.call((SqlOperator)SqlStdOperatorTable.CASE, new RexNode[]{builder.call((SqlOperator)SqlStdOperatorTable.IS_NULL, new RexNode[]{builder.field(indicator)}), builder.literal((Object)false), builder.call((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{builder.field("q", "c"), builder.literal((Object)0)}), builder.literal((Object)false), builder.call((SqlOperator)SqlStdOperatorTable.IS_TRUE, new RexNode[]{builder.call(RelOptUtil.op((SqlKind)op.comparisonKind, null), new RexNode[]{(RexNode)e.operands.get(0), builder.field("q", "m")})}), builder.literal((Object)true), builder.call((SqlOperator)SqlStdOperatorTable.GREATER_THAN, new RexNode[]{builder.field("q", "c"), builder.field("q", "d")}), e.rel.getCluster().getRexBuilder().makeNullLiteral(SqlTypeName.BOOLEAN), builder.call(RelOptUtil.op((SqlKind)op.comparisonKind, null), new RexNode[]{(RexNode)e.operands.get(0), builder.field("q", "m")})});
    }

    private RexNode rewriteInExists(RexSubQuery e, Set<CorrelationId> variablesSet, RelOptUtil.Logic logic, RelBuilder builder, int offset, boolean isCorrScalarAgg) {
        ArrayList<RexNode> fields = new ArrayList<RexNode>();
        if (e.getKind() == SqlKind.IN) {
            builder.push(e.rel);
            fields.addAll((Collection<RexNode>)builder.fields());
            if (isCorrScalarAgg) {
                builder.aggregate(builder.groupKey(), new RelBuilder.AggCall[]{builder.count(false, "cnt_in", new RexNode[0])});
                if (!variablesSet.isEmpty()) {
                    builder.join(JoinRelType.LEFT, (RexNode)builder.literal((Object)true), variablesSet);
                } else {
                    builder.join(JoinRelType.INNER, (RexNode)builder.literal((Object)true), variablesSet);
                }
                SqlFunction inCountCheck = new SqlFunction("sq_count_check", SqlKind.OTHER_FUNCTION, ReturnTypes.BOOLEAN, InferTypes.RETURN_TYPE, (SqlOperandTypeChecker)OperandTypes.NUMERIC, SqlFunctionCategory.USER_DEFINED_FUNCTION);
                builder.filter(new RexNode[]{builder.call((SqlOperator)inCountCheck, new RexNode[]{builder.field("cnt_in"), builder.literal((Object)true)})});
                ++offset;
                builder.push(e.rel);
            }
        } else if (e.getKind() == SqlKind.EXISTS && !variablesSet.isEmpty()) {
            builder.push(e.rel.accept((RelShuttle)new HiveSortLimitRemover()));
        } else {
            builder.push(e.rel);
        }
        boolean isCandidateForAntiJoin = false;
        switch (logic) {
            case TRUE_FALSE_UNKNOWN: 
            case UNKNOWN_AS_TRUE: {
                if (e.getKind() == SqlKind.EXISTS) {
                    logic = RelOptUtil.Logic.TRUE_FALSE;
                    if (!this.conf.getBoolVar(HiveConf.ConfVars.HIVE_CONVERT_ANTI_JOIN)) break;
                    isCandidateForAntiJoin = true;
                    break;
                }
                builder.aggregate(builder.groupKey(), new RelBuilder.AggCall[]{builder.count(false, "c", new RexNode[0]), builder.aggregateCall(SqlStdOperatorTable.COUNT, false, null, "ck", (Iterable)builder.fields())});
                builder.as("ct");
                if (!variablesSet.isEmpty()) {
                    builder.join(JoinRelType.LEFT, (RexNode)builder.literal((Object)true), variablesSet);
                } else {
                    builder.join(JoinRelType.INNER, (RexNode)builder.literal((Object)true), variablesSet);
                }
                offset += 2;
                builder.push(e.rel);
            }
        }
        String trueLiteral = "literalTrue";
        switch (logic) {
            case TRUE: 
            case FALSE: {
                if (fields.isEmpty()) {
                    if (logic == RelOptUtil.Logic.TRUE) {
                        builder.project(new RexNode[]{builder.alias((RexNode)builder.literal((Object)true), trueLiteral)});
                    } else {
                        builder.project(new RexNode[]{builder.alias((RexNode)builder.literal((Object)false), "literalFalse")});
                    }
                    if (!variablesSet.isEmpty() && (e.getKind() == SqlKind.EXISTS || e.getKind() == SqlKind.IN)) break;
                    builder.aggregate(builder.groupKey(new int[]{0}), new RelBuilder.AggCall[0]);
                    break;
                }
                if (!variablesSet.isEmpty() && (e.getKind() == SqlKind.EXISTS || e.getKind() == SqlKind.IN)) break;
                builder.aggregate(builder.groupKey(fields), new RelBuilder.AggCall[0]);
                break;
            }
            default: {
                fields.add(builder.alias((RexNode)builder.literal((Object)true), trueLiteral));
                builder.project(fields);
                if (isCandidateForAntiJoin && !variablesSet.isEmpty()) break;
                builder.distinct();
            }
        }
        builder.as("dt");
        ArrayList<RexNode> conditions = new ArrayList<RexNode>();
        for (Object pair : Pair.zip((List)e.getOperands(), (List)builder.fields())) {
            conditions.add(builder.equals((RexNode)((Pair)pair).left, RexUtil.shift((RexNode)((RexNode)((Pair)pair).right), (int)offset)));
        }
        switch (logic) {
            case TRUE: {
                builder.join(JoinRelType.SEMI, builder.and(conditions), variablesSet);
                return builder.literal((Object)true);
            }
            case FALSE: {
                builder.join(JoinRelType.ANTI, builder.and(conditions), variablesSet);
                return builder.literal((Object)false);
            }
        }
        builder.join(JoinRelType.LEFT, builder.and(conditions), variablesSet);
        ArrayList<RexNode> keyIsNulls = new ArrayList<RexNode>();
        for (RexNode operand : e.getOperands()) {
            if (!operand.getType().isNullable()) continue;
            keyIsNulls.add(builder.isNull(operand));
        }
        ImmutableList.Builder operands = ImmutableList.builder();
        switch (logic) {
            case TRUE_FALSE_UNKNOWN: 
            case UNKNOWN_AS_TRUE: {
                operands.add((Object[])new RexNode[]{builder.equals(builder.field("ct", "c"), (RexNode)builder.literal((Object)0)), builder.literal((Object)false)});
                operands.add((Object[])new RexNode[]{builder.isNull(builder.field("ct", "c")), builder.literal((Object)false)});
            }
        }
        operands.add((Object[])new RexNode[]{builder.isNotNull(builder.field("dt", trueLiteral)), builder.literal((Object)true)});
        if (!keyIsNulls.isEmpty()) {
            operands.add((Object[])new RexNode[]{builder.or(keyIsNulls), e.rel.getCluster().getRexBuilder().makeNullLiteral(SqlTypeName.BOOLEAN)});
        }
        RexLiteral b = builder.literal((Object)true);
        switch (logic) {
            case TRUE_FALSE_UNKNOWN: {
                b = e.rel.getCluster().getRexBuilder().makeNullLiteral(SqlTypeName.BOOLEAN);
            }
            case UNKNOWN_AS_TRUE: {
                operands.add((Object[])new RexNode[]{builder.call((SqlOperator)SqlStdOperatorTable.LESS_THAN, new RexNode[]{builder.field("ct", "ck"), builder.field("ct", "c")}), b});
            }
        }
        operands.add((Object)builder.literal((Object)false));
        return builder.call((SqlOperator)SqlStdOperatorTable.CASE, (Iterable)operands.build());
    }

    protected RexNode apply(RelMetadataQuery mq, RexSubQuery e, Set<CorrelationId> variablesSet, RelOptUtil.Logic logic, RelBuilder builder, int inputCount, int offset, boolean isCorrScalarAgg) {
        switch (e.getKind()) {
            case SCALAR_QUERY: {
                return this.rewriteScalar(mq, e, variablesSet, builder, offset, inputCount, isCorrScalarAgg);
            }
            case SOME: {
                return this.rewriteSomeAll(e, variablesSet, builder);
            }
            case IN: 
            case EXISTS: {
                return this.rewriteInExists(e, variablesSet, logic, builder, offset, isCorrScalarAgg);
            }
        }
        throw new AssertionError(e.getKind());
    }

    private RexInputRef field(RelBuilder builder, int inputCount, int offset) {
        int inputOrdinal = 0;
        RelNode r;
        while (offset >= (r = builder.peek(inputCount, inputOrdinal)).getRowType().getFieldCount()) {
            ++inputOrdinal;
            offset -= r.getRowType().getFieldCount();
        }
        return builder.field(inputCount, inputOrdinal, offset);
    }

    private static List<RexNode> fields(RelBuilder builder, int fieldCount) {
        ArrayList<RexNode> projects = new ArrayList<RexNode>();
        for (int i = 0; i < fieldCount; ++i) {
            projects.add((RexNode)builder.field(i));
        }
        return projects;
    }

    public static void subqueryRestriction(RelNode relNode) {
        if (relNode instanceof HiveAggregate) {
            HiveAggregate aggregate = (HiveAggregate)relNode;
            if (!aggregate.getAggCallList().isEmpty() && aggregate.getGroupSet().isEmpty()) {
                throw new CalciteSubqueryRuntimeException("Subquery rewrite: Aggregate without group by is not allowed");
            }
        } else if (relNode instanceof HiveProject || relNode instanceof HiveFilter) {
            HiveSubQueryRemoveRule.subqueryRestriction(relNode.getInput(0));
        }
    }

    private static class ReplaceSubQueryShuttle
    extends RexShuttle {
        private final RexSubQuery subQuery;
        private final RexNode replacement;

        ReplaceSubQueryShuttle(RexSubQuery subQuery, RexNode replacement) {
            this.subQuery = subQuery;
            this.replacement = replacement;
        }

        public RexNode visitSubQuery(RexSubQuery subQuery) {
            return subQuery.equals((Object)this.subQuery) ? this.replacement : subQuery;
        }
    }

    public static class HiveSortLimitRemover
    extends HiveRelShuttleImpl {
        @Override
        public RelNode visit(HiveSortLimit sort) {
            RexLiteral fetchExpr;
            RexLiteral offsetExpr;
            RexNode rexNode = sort.getOffsetExpr();
            if (rexNode != null && rexNode.getKind() == SqlKind.LITERAL && !BigDecimal.ZERO.equals((offsetExpr = (RexLiteral)rexNode).getValue())) {
                throw new RuntimeException(ErrorMsg.OFFSET_NOT_SUPPORTED_IN_SUBQUERY.getMsg());
            }
            rexNode = sort.getFetchExpr();
            if (rexNode != null && rexNode.getKind() == SqlKind.LITERAL && BigDecimal.ZERO.equals((fetchExpr = (RexLiteral)rexNode).getValue())) {
                return super.visit(sort);
            }
            return super.visit(sort.getInput());
        }
    }
}

