OrToIn.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionBottomUpRewriter;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Used to convert multi equalTo which has same slot and compare to a literal of disjunction to a InPredicate so that
 * it could be push down to storage engine.
 * example:
 * col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)
 * col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4
 * (col1 = 1 or col1 = 2) and  (col2 = 3 or col2 = 4)
 * <p>
 * would be converted to:
 * col1 in (1, 2) or col1 = 3 and (col2 = 4)
 * col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4
 * (col1 in (1, 2) and (col2 in (3, 4)))
 * The generic type declaration and the overridden 'rewrite' function in this class may appear unconventional
 * because we need to maintain a map passed between methods in this class. But the owner of this module prohibits
 * adding any additional rule-specific fields to the default ExpressionRewriteContext. However, the entire expression
 * rewrite framework always passes an ExpressionRewriteContext of type context to all rules.
 */
public class OrToIn implements ExpressionPatternRuleFactory {

    public static final OrToIn INSTANCE = new OrToIn();

    public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2;

    @Override
    public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
        return ImmutableList.of(
                matchesTopType(Or.class).then(OrToIn::rewrite)
        );
    }

    public Expression rewriteTree(Expression expr, ExpressionRewriteContext context) {
        ExpressionBottomUpRewriter bottomUpRewriter = ExpressionRewrite.bottomUp(this);
        return bottomUpRewriter.rewrite(expr, context);
    }

    private static Expression rewrite(Or or) {
        // NOTICE: use linked hash map to avoid unstable order or entry.
        //  unstable order entry lead to dead loop since return expression always un-equals to original one.
        Map<NamedExpression, Set<Literal>> slotNameToLiteral = Maps.newLinkedHashMap();
        Map<Expression, NamedExpression> disConjunctToSlot = Maps.newLinkedHashMap();
        List<Expression> expressions = ExpressionUtils.extractDisjunction(or);
        for (Expression expression : expressions) {
            if (expression instanceof EqualTo) {
                handleEqualTo((EqualTo) expression, slotNameToLiteral, disConjunctToSlot);
            } else if (expression instanceof InPredicate) {
                handleInPredicate((InPredicate) expression, slotNameToLiteral, disConjunctToSlot);
            }
        }
        if (disConjunctToSlot.isEmpty()) {
            return or;
        }

        List<Expression> rewrittenOr = new ArrayList<>();
        for (Map.Entry<NamedExpression, Set<Literal>> entry : slotNameToLiteral.entrySet()) {
            Set<Literal> literals = entry.getValue();
            if (literals.size() >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
                InPredicate inPredicate = new InPredicate(entry.getKey(), ImmutableList.copyOf(entry.getValue()));
                rewrittenOr.add(inPredicate);
            }
        }
        for (Expression expression : expressions) {
            if (disConjunctToSlot.get(expression) == null) {
                rewrittenOr.add(expression);
            } else {
                Set<Literal> literals = slotNameToLiteral.get(disConjunctToSlot.get(expression));
                if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
                    rewrittenOr.add(expression);
                }
            }
        }

        return ExpressionUtils.or(rewrittenOr);
    }

    private static void handleEqualTo(EqualTo equal, Map<NamedExpression, Set<Literal>> slotNameToLiteral,
                               Map<Expression, NamedExpression> disConjunctToSlot) {
        Expression left = equal.left();
        Expression right = equal.right();
        if (left instanceof NamedExpression && right instanceof Literal) {
            addSlotToLiteral((NamedExpression) left, (Literal) right, slotNameToLiteral);
            disConjunctToSlot.put(equal, (NamedExpression) left);
        } else if (right instanceof NamedExpression && left instanceof Literal) {
            addSlotToLiteral((NamedExpression) right, (Literal) left, slotNameToLiteral);
            disConjunctToSlot.put(equal, (NamedExpression) right);
        }
    }

    private static void handleInPredicate(InPredicate inPredicate, Map<NamedExpression, Set<Literal>> slotNameToLiteral,
                                   Map<Expression, NamedExpression> disConjunctToSlot) {
        // TODO a+b in (1,2,3...) is not supported now
        if (inPredicate.getCompareExpr() instanceof NamedExpression
                && inPredicate.getOptions().stream().allMatch(opt -> opt instanceof Literal)) {
            for (Expression opt : inPredicate.getOptions()) {
                addSlotToLiteral((NamedExpression) inPredicate.getCompareExpr(), (Literal) opt, slotNameToLiteral);
            }
            disConjunctToSlot.put(inPredicate, (NamedExpression) inPredicate.getCompareExpr());
        }
    }

    private static void addSlotToLiteral(NamedExpression namedExpression, Literal literal,
            Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
        Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new LinkedHashSet<>());
        literals.add(literal);
    }
}