InferPredicates.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.PredicateInferUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
 * infer additional predicates for `LogicalFilter` and `LogicalJoin`.
 * <pre>
 * The logic is as follows:
 * 1. poll up bottom predicate then infer additional predicates
 *   for example:
 *   select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id
 *   1. poll up bottom predicate
 *      select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t.id = 1
 *   2. infer
 *      select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t.id = 1 and t2.id = 1
 *   finally transformed sql:
 *      select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
 * 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
 *   round of predicate push-down
 * </pre>
 */
public class InferPredicates extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
    private PullUpPredicates pullUpPredicates;
    // The role of pullUpAllPredicates is to prevent inference of redundant predicates
    private PullUpPredicates pullUpAllPredicates;

    @Override
    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
        // Preparing stmt requires that the predicate cannot be changed, so no predicate inference is performed.
        ConnectContext connectContext = jobContext.getCascadesContext().getConnectContext();
        if (connectContext != null && connectContext.getCommand() == MysqlCommand.COM_STMT_PREPARE) {
            return plan;
        }
        pullUpPredicates = new PullUpPredicates(false, jobContext.getCascadesContext());
        pullUpAllPredicates = new PullUpPredicates(true, jobContext.getCascadesContext());
        return plan.accept(this, jobContext);
    }

    @Override
    public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, JobContext context) {
        join = visitChildren(this, join, context);
        if (join.isMarkJoin()) {
            return join;
        }
        Plan left = join.left();
        Plan right = join.right();
        Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition());
        switch (join.getJoinType()) {
            case INNER_JOIN:
            case CROSS_JOIN:
            case LEFT_SEMI_JOIN:
            case RIGHT_SEMI_JOIN:
                left = inferNewPredicate(left, expressions);
                right = inferNewPredicate(right, expressions);
                break;
            case LEFT_OUTER_JOIN:
            case LEFT_ANTI_JOIN:
                right = inferNewPredicate(right, expressions);
                break;
            case RIGHT_OUTER_JOIN:
            case RIGHT_ANTI_JOIN:
                left = inferNewPredicate(left, expressions);
                break;
            default:
                break;
        }
        if (left != join.left() || right != join.right()) {
            return join.withChildren(left, right);
        } else {
            return join;
        }
    }

    @Override
    public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, JobContext context) {
        filter = visitChildren(this, filter, context);
        Set<Expression> filterPredicates = pullUpPredicates(filter);
        filterPredicates.removeAll(pullUpAllPredicates(filter.child()));
        return new LogicalFilter<>(ImmutableSet.copyOf(filterPredicates), filter.child());
    }

    @Override
    public Plan visitLogicalExcept(LogicalExcept except, JobContext context) {
        except = visitChildren(this, except, context);
        Set<Expression> baseExpressions = pullUpPredicates(except);
        if (baseExpressions.isEmpty()) {
            return except;
        }
        ImmutableList.Builder<Plan> builder = ImmutableList.builder();
        builder.add(except.child(0));
        for (int i = 1; i < except.arity(); ++i) {
            Map<Expression, Expression> replaceMap = new HashMap<>();
            for (int j = 0; j < except.getOutput().size(); ++j) {
                NamedExpression output = except.getOutput().get(j);
                replaceMap.put(output, except.getRegularChildOutput(i).get(j));
            }
            builder.add(inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)));
        }
        return except.withChildren(builder.build());
    }

    @Override
    public Plan visitLogicalIntersect(LogicalIntersect intersect, JobContext context) {
        intersect = visitChildren(this, intersect, context);
        Set<Expression> baseExpressions = pullUpPredicates(intersect);
        if (baseExpressions.isEmpty()) {
            return intersect;
        }
        ImmutableList.Builder<Plan> builder = ImmutableList.builder();
        for (int i = 0; i < intersect.arity(); ++i) {
            Map<Expression, Expression> replaceMap = new HashMap<>();
            for (int j = 0; j < intersect.getOutput().size(); ++j) {
                NamedExpression output = intersect.getOutput().get(j);
                replaceMap.put(output, intersect.getRegularChildOutput(i).get(j));
            }
            builder.add(inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)));
        }
        return intersect.withChildren(builder.build());
    }

    private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expression> condition) {
        Set<Expression> baseExpressions = pullUpPredicates(left);
        baseExpressions.addAll(pullUpPredicates(right));
        condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
        return PredicateInferUtils.inferPredicate(baseExpressions);
    }

    private Set<Expression> pullUpPredicates(Plan plan) {
        return Sets.newLinkedHashSet(plan.accept(pullUpPredicates, null));
    }

    private Set<Expression> pullUpAllPredicates(Plan plan) {
        return Sets.newLinkedHashSet(plan.accept(pullUpAllPredicates, null));
    }

    private Plan inferNewPredicate(Plan plan, Set<Expression> expressions) {
        Set<Expression> predicates = new LinkedHashSet<>();
        Set<Slot> planOutputs = plan.getOutputSet();
        for (Expression expr : expressions) {
            Set<Slot> slots = expr.getInputSlots();
            if (!slots.isEmpty() && planOutputs.containsAll(slots)) {
                predicates.add(expr);
            }
        }
        predicates.removeAll(plan.accept(pullUpAllPredicates, null));
        return PlanUtils.filterOrSelf(predicates, plan);
    }
}