ExpressionBottomUpRewriter.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression;

import org.apache.doris.nereids.pattern.ExpressionPatternRules;
import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners;
import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners.CombinedListener;
import org.apache.doris.nereids.trees.expressions.Expression;

import com.google.common.collect.ImmutableList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;

/** ExpressionBottomUpRewriter */
public class ExpressionBottomUpRewriter implements ExpressionRewriteRule<ExpressionRewriteContext> {
    public static final String BATCH_ID_KEY = "batch_id";
    private static final Logger LOG = LogManager.getLogger(ExpressionBottomUpRewriter.class);
    private static final AtomicInteger rewriteBatchId = new AtomicInteger();
    private final ExpressionPatternRules rules;
    private final ExpressionPatternTraverseListeners listeners;

    public ExpressionBottomUpRewriter(ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) {
        this.rules = rules;
        this.listeners = listeners;
    }

    // entrance
    @Override
    public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
        int currentBatch = rewriteBatchId.incrementAndGet();
        return rewriteBottomUp(expr, ctx, currentBatch, null, rules, listeners);
    }

    private static Expression rewriteBottomUp(
            Expression expression, ExpressionRewriteContext context, int currentBatch, @Nullable Expression parent,
            ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) {

        Optional<Integer> rewriteState = expression.getMutableState(BATCH_ID_KEY);
        if (!rewriteState.isPresent() || rewriteState.get() != currentBatch) {
            CombinedListener listener = null;
            boolean hasChildren = expression.arity() > 0;
            if (hasChildren) {
                listener = listeners.matchesAndCombineListeners(expression, context, parent);
                if (listener != null) {
                    listener.onEnter();
                }
            }

            Expression afterRewrite = expression;
            try {
                Expression beforeRewrite;
                afterRewrite = rewriteChildren(expression, context, currentBatch, rules, listeners);
                // use rewriteTimes to avoid dead loop
                int rewriteTimes = 0;
                boolean changed;
                do {
                    beforeRewrite = afterRewrite;

                    // rewrite this
                    Optional<Expression> applied = rules.matchesAndApply(beforeRewrite, context, parent);

                    changed = applied.isPresent();
                    if (changed) {
                        afterRewrite = applied.get();
                        // ensure children are rewritten
                        afterRewrite = rewriteChildren(afterRewrite, context, currentBatch, rules, listeners);
                    }
                    rewriteTimes++;
                } while (changed && rewriteTimes < 100);

                // set rewritten
                afterRewrite.setMutableState(BATCH_ID_KEY, currentBatch);
            } finally {
                if (hasChildren && listener != null) {
                    listener.onExit(afterRewrite);
                }
            }

            return afterRewrite;
        }

        // already rewritten
        return expression;
    }

    private static Expression rewriteChildren(Expression parent, ExpressionRewriteContext context, int currentBatch,
            ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) {
        boolean changed = false;
        ImmutableList.Builder<Expression> newChildren = ImmutableList.builderWithExpectedSize(parent.arity());
        for (Expression child : parent.children()) {
            Expression newChild = rewriteBottomUp(child, context, currentBatch, parent, rules, listeners);
            changed |= !child.equals(newChild);
            newChildren.add(newChild);
        }

        Expression result = parent;
        if (changed) {
            result = parent.withChildren(newChildren.build());
        }
        if (changed && context.cascadesContext.isEnableExprTrace()) {
            LOG.info("WithChildren: \nbefore: " + parent + "\nafter: " + result);
        }
        return result;
    }
}