NestedCaseWhenCondToLiteral.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import java.util.List;
import java.util.Map;

/**
 * For nested CaseWhen/IF expression, replace the inner CaseWhen/IF condition with TRUE/FALSE literal
 * when the condition also exists in the outer CaseWhen/IF conditions.
 *
 * on the nested CASE/IF path, a condition may exist in multiple CASE/IF branches,
 * for any inner case when or if condition, its boolean value is determined by the outermost CASE/IF branch,
 * that is the first occurrence of the condition on the nested CASE/IF path.
 *
 * <br>
 *  1. if it exists in outer case's current branch condition, replace it with TRUE
 *    e.g.
 *      case when A then
 *                  (case when A then 1 else 2 end)
 *          ...
 *      end
 *     then inner case condition A will replace with TRUE:
 *      case when A then
 *                  (case when TRUE then 1 else 2 end)
 *          ...
 *      end
 * <br>
 *  2. if it exists in outer case's previous branch condition, replace it with FALSE
 *    e.g.
 *      case when A then ...
 *           when B then
 *                  (case when A then 1 else 2 end)
 *          ...
 *      end
 *     then inner case condition A will replace with FALSE:
 *      case when A then ...
 *           when B then
 *                  (case when FALSE then 1 else 2 end)
 *          ...
 *      end
 * <br>
 */
public class NestedCaseWhenCondToLiteral implements ExpressionPatternRuleFactory {

    public static final NestedCaseWhenCondToLiteral INSTANCE = new NestedCaseWhenCondToLiteral();

    @Override
    public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
        return ImmutableList.of(
                root(Expression.class)
                        .when(this::needRewrite)
                        .thenApply(ctx -> rewrite(ctx.expr, ctx.rewriteContext))
                        .toRule(ExpressionRuleType.NESTED_CASE_WHEN_COND_TO_LITERAL)
        );
    }

    private boolean needRewrite(Expression expression) {
        return expression.containsType(CaseWhen.class, If.class);
    }

    private Expression rewrite(Expression expression, ExpressionRewriteContext context) {
        return expression.accept(new NestedCondReplacer(), null);
    }

    /** NestedCondReplacer */
    @VisibleForTesting
    public static class NestedCondReplacer extends DefaultExpressionRewriter<Void> {

        // condition literals is used to record the boolean literal for a condition expression,
        // 1. if a condition, if it exists in outer case/if conditions, it will be replaced with the literal.
        // 2. otherwise it's the first time occur, then:
        //    a) when enter a case/if branch, set this condition to TRUE literal
        //    b) when leave a case/if branch, set this condition to FALSE literal
        //    c) when leave the whole case/if statement, remove this condition literal
        protected final Map<Expression, BooleanLiteral> conditionLiterals = Maps.newHashMap();

        @Override
        public Expression visit(Expression expr, Void context) {
            if (INSTANCE.needRewrite(expr)) {
                return super.visit(expr, context);
            } else {
                return expr;
            }
        }

        @Override
        public Expression visitCaseWhen(CaseWhen caseWhen, Void context) {
            ImmutableList.Builder<WhenClause> newWhenClausesBuilder
                    = ImmutableList.builderWithExpectedSize(caseWhen.arity());
            List<Expression> firstOccurConds = Lists.newArrayListWithExpectedSize(caseWhen.arity());
            for (WhenClause whenClause : caseWhen.getWhenClauses()) {
                Expression oldCondition = whenClause.getOperand();
                Pair<Expression, Boolean> replaceResult = replaceCondition(oldCondition, context);
                Expression newCondition = replaceResult.first;
                boolean condFirstOccur = replaceResult.second;
                if (condFirstOccur) {
                    firstOccurConds.add(oldCondition);
                    conditionLiterals.put(oldCondition, BooleanLiteral.TRUE);
                }
                Expression newResult = whenClause.getResult().accept(this, context);
                if (condFirstOccur) {
                    conditionLiterals.put(oldCondition, BooleanLiteral.FALSE);
                }
                if (whenClause.getOperand() != newCondition || whenClause.getResult() != newResult) {
                    newWhenClausesBuilder.add(new WhenClause(newCondition, newResult));
                } else {
                    newWhenClausesBuilder.add(whenClause);
                }
            }
            Expression oldDefaultValue = caseWhen.getDefaultValue().orElse(null);
            Expression newDefaultValue = oldDefaultValue;
            if (newDefaultValue != null) {
                newDefaultValue = newDefaultValue.accept(this, context);
            }
            for (Expression cond : firstOccurConds) {
                conditionLiterals.remove(cond);
            }
            List<WhenClause> newWhenClauses = newWhenClausesBuilder.build();
            boolean hasNewChildren = false;
            if (newWhenClauses.size() != caseWhen.getWhenClauses().size()) {
                hasNewChildren = true;
            } else {
                for (int i = 0; i < newWhenClauses.size(); i++) {
                    if (newWhenClauses.get(i) != caseWhen.getWhenClauses().get(i)) {
                        hasNewChildren = true;
                        break;
                    }
                }
            }
            if (newDefaultValue != oldDefaultValue) {
                hasNewChildren = true;
            }
            if (hasNewChildren) {
                return newDefaultValue != null
                        ? new CaseWhen(newWhenClauses, newDefaultValue)
                        : new CaseWhen(newWhenClauses);
            } else {
                return caseWhen;
            }
        }

        @Override
        public Expression visitIf(If ifExpr, Void context) {
            Expression oldCondition = ifExpr.getCondition();
            Pair<Expression, Boolean> replaceResult = replaceCondition(oldCondition, context);
            Expression newCondition = replaceResult.first;
            boolean condFirstOccur = replaceResult.second;
            if (condFirstOccur) {
                conditionLiterals.put(oldCondition, BooleanLiteral.TRUE);
            }
            Expression newTrueValue = ifExpr.getTrueValue().accept(this, context);
            if (condFirstOccur) {
                conditionLiterals.put(oldCondition, BooleanLiteral.FALSE);
            }
            Expression newFalseValue = ifExpr.getFalseValue().accept(this, context);
            if (condFirstOccur) {
                conditionLiterals.remove(oldCondition);
            }
            if (newCondition != oldCondition
                    || newTrueValue != ifExpr.getTrueValue()
                    || newFalseValue != ifExpr.getFalseValue()) {
                return new If(newCondition, newTrueValue, newFalseValue);
            } else {
                return ifExpr;
            }
        }

        // return newCondition + condition first occur flag
        private Pair<Expression, Boolean> replaceCondition(Expression condition, Void context) {
            if (condition.isLiteral()) {
                // literal condition do not need to replace, and do not record it
                return Pair.of(condition, false);
            } else if (conditionLiterals.containsKey(condition)) {
                return Pair.of(conditionLiterals.get(condition), false);
            } else if (condition instanceof CompoundPredicate) {
                ImmutableList.Builder<Expression> newChildrenBuilder
                        = ImmutableList.builderWithExpectedSize(condition.arity());
                boolean hasNewChildren = false;
                for (Expression child : condition.children()) {
                    Expression newChild = replaceCondition(child, context).first;
                    hasNewChildren = hasNewChildren || newChild != child;
                    newChildrenBuilder.add(newChild);
                }
                Expression newCondition = hasNewChildren
                        ? condition.withChildren(newChildrenBuilder.build()) : condition;
                return Pair.of(newCondition, true);
            } else {
                return Pair.of(condition.accept(this, context), true);
            }
        }
    }
}