ExpressionRewrite.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.pattern.ExpressionPatternRules;
import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapTableSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/**
 * expression of plan rewrite rule.
 */
public class ExpressionRewrite implements RewriteRuleFactory {
    protected final ExpressionRuleExecutor rewriter;

    public ExpressionRewrite(ExpressionRewriteRule... rules) {
        this.rewriter = new ExpressionRuleExecutor(ImmutableList.copyOf(rules));
    }

    public ExpressionRewrite(ExpressionRuleExecutor rewriter) {
        this.rewriter = Objects.requireNonNull(rewriter, "rewriter is null");
    }

    public Expression rewrite(Expression expression, ExpressionRewriteContext expressionRewriteContext) {
        return rewriter.rewrite(expression, expressionRewriteContext);
    }

    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
                new GenerateExpressionRewrite().build(),
                new OneRowRelationExpressionRewrite().build(),
                new ProjectExpressionRewrite().build(),
                new AggExpressionRewrite().build(),
                new FilterExpressionRewrite().build(),
                new JoinExpressionRewrite().build(),
                new SortExpressionRewrite().build(),
                new LogicalRepeatRewrite().build(),
                new HavingExpressionRewrite().build(),
                new OlapTableSinkExpressionRewrite().build());
    }

    /** GenerateExpressionRewrite */
    public class GenerateExpressionRewrite extends OneRewriteRuleFactory {
        @Override
        public Rule build() {
            return logicalGenerate().thenApply(ctx -> {
                LogicalGenerate<Plan> generate = ctx.root;
                ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
                List<Function> generators = generate.getGenerators();
                List<Function> newGenerators = generators.stream()
                        .map(func -> (Function) rewriter.rewrite(func, context))
                        .collect(ImmutableList.toImmutableList());
                if (generators.equals(newGenerators)) {
                    return generate;
                }
                return generate.withGenerators(newGenerators);
            }).toRule(RuleType.REWRITE_GENERATE_EXPRESSION);
        }
    }

    /** OneRowRelationExpressionRewrite */
    public class OneRowRelationExpressionRewrite extends OneRewriteRuleFactory {
        @Override
        public Rule build() {
            return logicalOneRowRelation().thenApply(ctx -> {
                LogicalOneRowRelation oneRowRelation = ctx.root;
                List<NamedExpression> projects = oneRowRelation.getProjects();
                ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);

                Builder<NamedExpression> rewrittenExprs
                        = ImmutableList.builderWithExpectedSize(projects.size());
                boolean changed = false;
                for (NamedExpression project : projects) {
                    NamedExpression newProject = (NamedExpression) rewriter.rewrite(project, context);
                    if (!changed && !project.deepEquals(newProject)) {
                        changed = true;
                    }
                    rewrittenExprs.add(newProject);
                }
                return changed
                        ? new LogicalOneRowRelation(oneRowRelation.getRelationId(), rewrittenExprs.build())
                        : oneRowRelation;
            }).toRule(RuleType.REWRITE_ONE_ROW_RELATION_EXPRESSION);
        }
    }

    /** ProjectExpressionRewrite */
    public class ProjectExpressionRewrite extends OneRewriteRuleFactory {
        @Override
        public Rule build() {
            return logicalProject().thenApply(ctx -> {
                LogicalProject<Plan> project = ctx.root;
                ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
                List<NamedExpression> projects = project.getProjects();
                List<NamedExpression> newProjects = rewriteAll(projects, rewriter, context);
                if (projects.equals(newProjects)) {
                    return project;
                }
                return project.withProjectsAndChild(newProjects, project.child());
            }).toRule(RuleType.REWRITE_PROJECT_EXPRESSION);
        }
    }

    /** FilterExpressionRewrite */
    public class FilterExpressionRewrite extends OneRewriteRuleFactory {
        @Override
        public Rule build() {
            return logicalFilter().thenApply(ctx -> {
                LogicalFilter<Plan> filter = ctx.root;
                ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
                Set<Expression> newConjuncts = ImmutableSet.copyOf(ExpressionUtils.extractConjunction(
                        rewriter.rewrite(filter.getPredicate(), context)));
                if (newConjuncts.equals(filter.getConjuncts())) {
                    return filter;
                }
                return new LogicalFilter<>(newConjuncts, filter.child());
            }).toRule(RuleType.REWRITE_FILTER_EXPRESSION);
        }
    }

    /** OlapTableSinkExpressionRewrite */
    public class OlapTableSinkExpressionRewrite extends OneRewriteRuleFactory {
        @Override
        public Rule build() {
            return logicalOlapTableSink().thenApply(ctx -> {
                LogicalOlapTableSink<Plan> olapTableSink = ctx.root;
                ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
                List<Expression> partitionExprList = olapTableSink.getPartitionExprList();
                List<Expression> newPartitionExprList = rewriteAll(partitionExprList, rewriter, context);
                Map<Long, Expression> syncMvWhereClauses = olapTableSink.getSyncMvWhereClauses();
                Map<Long, Expression> newSyncMvWhereClauses = new HashMap<>();
                for (Map.Entry<Long, Expression> entry : syncMvWhereClauses.entrySet()) {
                    newSyncMvWhereClauses.put(entry.getKey(), rewriter.rewrite(entry.getValue(), context));
                }
                if (partitionExprList.equals(newPartitionExprList)
                        && syncMvWhereClauses.equals(newSyncMvWhereClauses)) {
                    return olapTableSink;
                }
                return olapTableSink.withPartitionExprAndMvWhereClause(newPartitionExprList, newSyncMvWhereClauses);
            }).toRule(RuleType.REWRITE_OLAP_TABLE_SINK_EXPRESSION);
        }
    }

    /** AggExpressionRewrite */
    public class AggExpressionRewrite extends OneRewriteRuleFactory {
        @Override
        public Rule build() {
            return logicalAggregate().thenApply(ctx -> {
                LogicalAggregate<Plan> agg = ctx.root;
                List<Expression> groupByExprs = agg.getGroupByExpressions();
                ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
                List<Expression> newGroupByExprs = rewriter.rewrite(groupByExprs, context);

                List<NamedExpression> outputExpressions = agg.getOutputExpressions();
                List<NamedExpression> newOutputExpressions = rewriteAll(outputExpressions, rewriter, context);
                if (outputExpressions.equals(newOutputExpressions)) {
                    return agg;
                }
                return new LogicalAggregate<>(newGroupByExprs, newOutputExpressions,
                        agg.isNormalized(), agg.getSourceRepeat(), agg.child());
            }).toRule(RuleType.REWRITE_AGG_EXPRESSION);
        }
    }

    /** JoinExpressionRewrite */
    public class JoinExpressionRewrite extends OneRewriteRuleFactory {
        @Override
        public Rule build() {
            return logicalJoin().thenApply(ctx -> {
                LogicalJoin<Plan, Plan> join = ctx.root;
                List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
                List<Expression> otherJoinConjuncts = join.getOtherJoinConjuncts();
                List<Expression> markJoinConjuncts = join.getMarkJoinConjuncts();
                if (otherJoinConjuncts.isEmpty() && hashJoinConjuncts.isEmpty()
                        && markJoinConjuncts.isEmpty()) {
                    return join;
                }

                ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
                Pair<Boolean, List<Expression>> newHashJoinConjuncts = rewriteConjuncts(hashJoinConjuncts, context);
                Pair<Boolean, List<Expression>> newOtherJoinConjuncts = rewriteConjuncts(otherJoinConjuncts, context);
                Pair<Boolean, List<Expression>> newMarkJoinConjuncts = rewriteConjuncts(markJoinConjuncts, context);

                if (!newHashJoinConjuncts.first && !newOtherJoinConjuncts.first
                        && !newMarkJoinConjuncts.first) {
                    return join;
                }

                return new LogicalJoin<>(join.getJoinType(), newHashJoinConjuncts.second,
                        newOtherJoinConjuncts.second, newMarkJoinConjuncts.second,
                        join.getDistributeHint(), join.getMarkJoinSlotReference(), join.children(),
                        join.getJoinReorderContext());
            }).toRule(RuleType.REWRITE_JOIN_EXPRESSION);
        }

        private Pair<Boolean, List<Expression>> rewriteConjuncts(List<Expression> conjuncts,
                ExpressionRewriteContext context) {
            boolean isChanged = false;
            // some rules will append new conjunct, we need to distinct it
            // for example:
            //   pk = 2 or pk < 0
            // after AddMinMax rule:
            //   (pk = 2 or pk < 0) and pk <= 2
            //
            // if not distinct it, the pk <= 2 will generate repeat forever
            ImmutableSet.Builder<Expression> rewrittenConjuncts = new ImmutableSet.Builder<>();
            for (Expression expr : conjuncts) {
                Expression newExpr = rewriter.rewrite(expr, context);
                newExpr = newExpr.isNullLiteral() && expr instanceof EqualPredicate
                                ? expr.withChildren(rewriter.rewrite(expr.child(0), context),
                                        rewriter.rewrite(expr.child(1), context))
                                : newExpr;
                isChanged = isChanged || !newExpr.equals(expr);
                rewrittenConjuncts.addAll(ExpressionUtils.extractConjunction(newExpr));
            }
            ImmutableList<Expression> newConjuncts = Utils.fastToImmutableList(rewrittenConjuncts.build());
            return Pair.of(isChanged && !newConjuncts.equals(conjuncts), newConjuncts);
        }
    }

    /** SortExpressionRewrite */
    public class SortExpressionRewrite extends OneRewriteRuleFactory {

        @Override
        public Rule build() {
            return logicalSort().thenApply(ctx -> {
                LogicalSort<Plan> sort = ctx.root;
                List<OrderKey> orderKeys = sort.getOrderKeys();
                ImmutableList.Builder<OrderKey> rewrittenOrderKeys
                        = ImmutableList.builderWithExpectedSize(orderKeys.size());
                ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
                boolean changed = false;
                for (OrderKey k : orderKeys) {
                    Expression expression = rewriter.rewrite(k.getExpr(), context);
                    changed |= expression != k.getExpr();
                    rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst()));
                }
                return changed ? sort.withOrderKeys(rewrittenOrderKeys.build()) : sort;
            }).toRule(RuleType.REWRITE_SORT_EXPRESSION);
        }
    }

    /** HavingExpressionRewrite */
    public class HavingExpressionRewrite extends OneRewriteRuleFactory {
        @Override
        public Rule build() {
            return logicalHaving().thenApply(ctx -> {
                LogicalHaving<Plan> having = ctx.root;
                ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
                Set<Expression> newConjuncts = ImmutableSet.copyOf(ExpressionUtils.extractConjunction(
                        rewriter.rewrite(having.getPredicate(), context)));
                if (newConjuncts.equals(having.getConjuncts())) {
                    return having;
                }
                return having.withConjuncts(newConjuncts);
            }).toRule(RuleType.REWRITE_HAVING_EXPRESSION);
        }
    }

    /** LogicalRepeatRewrite */
    public class LogicalRepeatRewrite extends OneRewriteRuleFactory {
        @Override
        public Rule build() {
            return logicalRepeat().thenApply(ctx -> {
                LogicalRepeat<Plan> repeat = ctx.root;
                ImmutableList.Builder<List<Expression>> groupingExprs = ImmutableList.builder();
                ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
                for (List<Expression> expressions : repeat.getGroupingSets()) {
                    groupingExprs.add(expressions.stream()
                            .map(expr -> rewriter.rewrite(expr, context))
                            .collect(ImmutableList.toImmutableList())
                    );
                }
                return repeat.withGroupSetsAndOutput(groupingExprs.build(),
                        repeat.getOutputExpressions().stream()
                                .map(output -> rewriter.rewrite(output, context))
                                .map(e -> (NamedExpression) e)
                                .collect(ImmutableList.toImmutableList()));
            }).toRule(RuleType.REWRITE_REPEAT_EXPRESSION);
        }
    }

    /** bottomUp */
    public static ExpressionBottomUpRewriter bottomUp(ExpressionPatternRuleFactory... ruleFactories) {
        ImmutableList.Builder<ExpressionPatternMatchRule> rules = ImmutableList.builder();
        ImmutableList.Builder<ExpressionTraverseListenerMapping> listeners = ImmutableList.builder();
        for (ExpressionPatternRuleFactory ruleFactory : ruleFactories) {
            if (ruleFactory instanceof ExpressionTraverseListenerFactory) {
                List<ExpressionListenerMatcher<? extends Expression>> listenersMatcher
                        = ((ExpressionTraverseListenerFactory) ruleFactory).buildListeners();
                for (ExpressionListenerMatcher<? extends Expression> listenerMatcher : listenersMatcher) {
                    listeners.add(new ExpressionTraverseListenerMapping(listenerMatcher));
                }
            }
            for (ExpressionPatternMatcher<? extends Expression> patternMatcher : ruleFactory.buildRules()) {
                rules.add(new ExpressionPatternMatchRule(patternMatcher));
            }
        }

        return new ExpressionBottomUpRewriter(
                new ExpressionPatternRules(rules.build()),
                new ExpressionPatternTraverseListeners(listeners.build())
        );
    }

    public static <E extends Expression> List<E> rewriteAll(
            Collection<E> exprs, ExpressionRuleExecutor rewriter, ExpressionRewriteContext context) {
        ImmutableList.Builder<E> result = ImmutableList.builderWithExpectedSize(exprs.size());
        for (E expr : exprs) {
            result.add((E) rewriter.rewrite(expr, context));
        }
        return result.build();
    }
}