OrToIn.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.expression.ExpressionBottomUpRewriter;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Do NOT use this rule in ExpressionOptimization
 * apply this rule on filter expressions in extract mode,
 * on other expressions in replace mode
 *
 */
public class OrToIn {
    /**
     * case 1: from (a=1 and b=1) or (a=2), "a in (1, 2)" is inferred,
     * inferred expr is not equivalent to the original expr
     * - replaceMode: output origin expr
     * - extractMode: output a in (1, 2) and  (a=1 and b=1) or (a=2)
     *
     * case 2: from (a=1) or (a=2), "a in (1,2)" is inferred, the inferred expr is equivalent to the original expr
     * - replaceMode/extractMode: output a in (1, 2)
     *
     * extractMode only used for filter, the inferred In-predicate could be pushed down.
     */
    public enum Mode {
        replaceMode,
        extractMode
    }

    public static final OrToIn EXTRACT_MODE_INSTANCE = new OrToIn(Mode.extractMode);
    public static final OrToIn REPLACE_MODE_INSTANCE = new OrToIn(Mode.replaceMode);

    private final Mode mode;

    public OrToIn(Mode mode) {
        this.mode = mode;
    }

    /**
     * simplify and then rewrite
     */
    public Expression rewriteTree(Expression expr, ExpressionRewriteContext context) {
        ExpressionBottomUpRewriter simplify = ExpressionRewrite.bottomUp(SimplifyRange.INSTANCE);
        expr = simplify.rewrite(expr, context);
        return rewriteTree(expr);
    }

    /**
     * rewrite tree
     */
    public Expression rewriteTree(Expression expr) {
        List<Expression> children = expr.children();
        if (children.isEmpty()) {
            return expr;
        }
        List<Expression> newChildren = children.stream()
                .map(this::rewriteTree).collect(Collectors.toList());
        if (expr instanceof And) {
            // filter out duplicated conjunct
            // example: OrToInTest.testDeDup()
            Set<Expression> dedupSet = new LinkedHashSet<>();
            for (Expression newChild : newChildren) {
                dedupSet.addAll(ExpressionUtils.extractConjunction(newChild));
            }
            newChildren = Lists.newArrayList(dedupSet);
        }
        if (expr instanceof CompoundPredicate && newChildren.size() == 1) {
            // (a=1) and (a=1)
            // after rewrite, newChildren=[(a=1)]
            expr = newChildren.get(0);
        } else {
            expr = expr.withChildren(newChildren);
        }
        if (expr instanceof Or) {
            expr = rewrite((Or) expr);
        }
        return expr;
    }

    private Expression rewrite(Or or) {
        Pair<Expression, Expression> pair = extractCommonConjunct(or);
        Expression result = tryToRewriteIn(pair.second);
        if (pair.first != null) {
            result = new And(pair.first, result);
        }
        return result;
    }

    private Expression tryToRewriteIn(Expression or) {
        List<Expression> disjuncts = ExpressionUtils.extractDisjunction(or);
        for (Expression disjunct : disjuncts) {
            if (!hasInOrEqualChildren(disjunct)) {
                return or;
            }
        }

        Map<Expression, Set<Literal>> candidates = getCandidates(disjuncts.get(0));
        if (candidates.isEmpty()) {
            return or;
        }

        // verify each candidate
        for (int i = 1; i < disjuncts.size(); i++) {
            Map<Expression, Set<Literal>> otherCandidates = getCandidates(disjuncts.get(i));
            if (otherCandidates.isEmpty()) {
                return or;
            }
            candidates = mergeCandidates(candidates, otherCandidates);
            if (candidates.isEmpty()) {
                return or;
            }
        }
        if (!candidates.isEmpty()) {
            Expression conjunct = candidatesToFinalResult(candidates);
            boolean keep = keepOriginalOrExpression(disjuncts);
            if (keep) {
                if (mode == Mode.extractMode) {
                    return new And(conjunct, or);
                } else {
                    return or;
                }
            } else {
                return conjunct;
            }
        }
        return or;
    }

    private boolean keepOriginalOrExpression(List<Expression> disjuncts) {
        for (Expression disjunct : disjuncts) {
            List<Expression> conjuncts = ExpressionUtils.extractConjunction(disjunct);
            if (conjuncts.size() > 1) {
                return true;
            }
        }
        return false;
    }

    private Map<Expression, Set<Literal>> mergeCandidates(
            Map<Expression, Set<Literal>> a,
            Map<Expression, Set<Literal>> b) {
        Map<Expression, Set<Literal>> result = new LinkedHashMap<>();
        for (Expression expr : a.keySet()) {
            Set<Literal> otherLiterals = b.get(expr);
            if (otherLiterals != null) {
                Set<Literal> literals = a.get(expr);
                literals.addAll(otherLiterals);
                if (!literals.isEmpty()) {
                    result.put(expr, literals);
                }
            }
        }
        return result;
    }

    private Expression candidatesToFinalResult(Map<Expression, Set<Literal>> candidates) {
        return ExpressionUtils.and(candidates.entrySet().stream()
                .map(entry -> ExpressionUtils.toInPredicateOrEqualTo(entry.getKey(), entry.getValue()))
                .collect(Collectors.toList()));
    }

    /*
       it is not necessary to rewrite "a like 'xyz' or a=1 or a=2" to "a like 'xyz' or a in (1, 2)",
       because we cannot push "a in (1, 2)" into storage layer
     */
    private boolean hasInOrEqualChildren(Expression disjunct) {
        List<Expression> conjuncts = ExpressionUtils.extractConjunction(disjunct);
        for (Expression conjunct : conjuncts) {
            if (conjunct instanceof EqualTo || conjunct instanceof InPredicate) {
                return true;
            }
        }
        return false;
    }

    // conjuncts.get(idx) has different input slots
    private boolean independentConjunct(int idx, List<Expression> conjuncts) {
        Expression conjunct = conjuncts.get(idx);
        Set<Slot> targetSlots = conjunct.getInputSlots();
        if (conjuncts.size() == 1) {
            return true;
        }
        for (int i = 0; i < conjuncts.size(); i++) {
            if (i != idx) {
                Set<Slot> otherInput = Sets.newHashSet();
                otherInput.addAll(conjuncts.get(i).getInputSlots());
                otherInput.retainAll(targetSlots);
                if (!otherInput.isEmpty()) {
                    return false;
                }
            }
        }
        return true;
    }

    private Map<Expression, Set<Literal>> getCandidates(Expression disjunct) {
        List<Expression> conjuncts = ExpressionUtils.extractConjunction(disjunct);
        Map<Expression, Set<Literal>> candidates = new LinkedHashMap<>();
        // collect candidates from the first disjunction
        for (int idx = 0; idx < conjuncts.size(); idx++) {
            if (!independentConjunct(idx, conjuncts)) {
                continue;
            }
            // find pattern: A=1 / A in (1, 2, 3 ...)
            // candidates: A->[1] / A -> [1, 2, 3, ...]
            Expression conjunct = conjuncts.get(idx);
            Expression compareExpr = null;
            if (conjunct instanceof EqualTo) {
                EqualTo eq = (EqualTo) conjunct;
                Literal literal = null;
                if (!(eq.left() instanceof Literal) && eq.right() instanceof Literal) {
                    compareExpr = eq.left();
                    literal = (Literal) eq.right();
                } else if (!(eq.right() instanceof Literal) && eq.left() instanceof Literal) {
                    compareExpr = eq.right();
                    literal = (Literal) eq.left();
                }
                if (compareExpr != null) {
                    Set<Literal> literals = candidates.get(compareExpr);
                    if (literals == null) {
                        literals = Sets.newHashSet();
                        literals.add(literal);
                        candidates.put(compareExpr, literals);
                    } else {
                        // pattern like (A=1 and A=2) should be processed by SimplifyRange rule
                        // OrToIn rule does apply to this expression
                        candidates.clear();
                        break;

                    }
                }
            } else if (conjunct instanceof InPredicate) {
                InPredicate inPredicate = (InPredicate) conjunct;
                Set<Literal> literalOptions = new LinkedHashSet<>();
                boolean allLiteralOpts = true;
                for (Expression opt : inPredicate.getOptions()) {
                    if (opt instanceof Literal) {
                        literalOptions.add((Literal) opt);
                    } else {
                        allLiteralOpts = false;
                        break;
                    }
                }

                if (allLiteralOpts) {
                    Set<Literal> alreadyMappedLiterals = candidates.get(inPredicate.getCompareExpr());
                    if (alreadyMappedLiterals == null) {
                        candidates.put(inPredicate.getCompareExpr(), literalOptions);
                    } else {
                        // pattern like (A=1 and A in (1, 2)) should be processed by SimplifyRange rule
                        // OrToIn rule does apply to this expression
                        candidates.clear();
                        break;
                    }
                }
            }
        }
        return candidates;
    }

    /**
     * (a and b and ...) or (a and c and ...)
     * =>
     * a and [(b and ...) or (c and ...)]
     * extract the common part: a
     * and remaining part (b and ...) or (c and ...)
     * @returns Pair (common, remaining)
     */
    private Pair<Expression, Expression> extractCommonConjunct(Or or) {
        List<Expression> disjuncts = ExpressionUtils.extractDisjunction(or);
        List<List<Expression>> conjunctsList = Lists.newArrayList();
        for (Expression disjunct : disjuncts) {
            conjunctsList.add(ExpressionUtils.extractConjunction(disjunct));
        }
        List<Expression> commons = Lists.newArrayList();
        for (Expression a : conjunctsList.get(0)) {
            boolean isCommon = true;
            for (int i = 1; i < disjuncts.size(); i++) {
                if (!conjunctsList.get(i).contains(a)) {
                    isCommon = false;
                    break;
                }
            }
            if (isCommon) {
                commons.add(a);
            }
        }
        if (!commons.isEmpty()) {
            List<Expression> remainPart = Lists.newArrayList();
            for (int i = 0; i < disjuncts.size(); i++) {
                conjunctsList.get(i).removeAll(commons);
                remainPart.add(ExpressionUtils.and(conjunctsList.get(i)));
            }
            Expression remainOr = ExpressionUtils.or(remainPart);
            return Pair.of(ExpressionUtils.and(commons), remainOr);
        } else {
            return Pair.of(null, or);
        }
    }
}