ExtractSingleTableExpressionFromDisjunction.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Paper: Quantifying TPC-H Choke Points and Their Optimizations
 * 4.4 Join-Dependent Predicate Duplication
 * Example:
 * Two queries, Q7 and Q19, include predicates that operate
 * on multiple tables without being a join predicate. In Q17,
 * (n1.name = ’NATION1’ AND n2.name = ’NATION2’) OR (n1.name = ’NATION2’ AND n2.name = ’NATION1’)
 * =>
 * (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')
 * and (n1.n_name = 'FRANCE' or n1.n_name='GERMANY') and (n2.n_name='GERMANY' or n2.n_name='FRANCE')
 * <p>
 * new generated expr is redundant, but they could be pushed down to reduce the cardinality of children output tuples.
 * <p>
 * Implementation note:
 * 1. This rule should only be applied ONCE to avoid generate same redundant expression.
 * 2. A redundant expression only contains slots from a single table.
 * 3. In old optimizer, there is `InferFilterRule` generates redundancy expressions. Its Nereid counterpart also need
 * `RemoveRedundantExpression`.
 * <p>
 */
public class ExtractSingleTableExpressionFromDisjunction implements RewriteRuleFactory {
    private static final ImmutableSet<JoinType> ALLOW_JOIN_TYPE = ImmutableSet.of(JoinType.INNER_JOIN,
            JoinType.LEFT_OUTER_JOIN, JoinType.RIGHT_OUTER_JOIN, JoinType.LEFT_SEMI_JOIN, JoinType.RIGHT_SEMI_JOIN,
            JoinType.LEFT_ANTI_JOIN, JoinType.RIGHT_ANTI_JOIN, JoinType.CROSS_JOIN, JoinType.FULL_OUTER_JOIN);

    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
                logicalFilter().then(filter -> {
                    List<Expression> dependentPredicates = extractDependentConjuncts(filter.getConjuncts());
                    if (dependentPredicates.isEmpty()) {
                        return null;
                    }
                    Set<Expression> newPredicates = ImmutableSet.<Expression>builder()
                            .addAll(filter.getConjuncts())
                            .addAll(dependentPredicates).build();
                    if (newPredicates.size() == filter.getConjuncts().size()) {
                        return null;
                    }
                    return new LogicalFilter<>(newPredicates, filter.child());
                }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION),
                logicalJoin().when(join -> ALLOW_JOIN_TYPE.contains(join.getJoinType())).then(join -> {
                    List<Expression> dependentOtherPredicates = extractDependentConjuncts(
                            ImmutableSet.copyOf(join.getOtherJoinConjuncts()));
                    if (dependentOtherPredicates.isEmpty()) {
                        return null;
                    }
                    Set<Expression> newOtherPredicates = ImmutableSet.<Expression>builder()
                            .addAll(join.getOtherJoinConjuncts())
                            .addAll(dependentOtherPredicates).build();
                    if (newOtherPredicates.size() == join.getOtherJoinConjuncts().size()) {
                        return null;
                    }
                    return join.withJoinConjuncts(join.getHashJoinConjuncts(),
                            ImmutableList.copyOf(newOtherPredicates),
                            join.getMarkJoinConjuncts(), join.getJoinReorderContext());
                }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION));
    }

    private List<Expression> extractDependentConjuncts(Set<Expression> conjuncts) {
        List<Expression> dependentPredicates = Lists.newArrayList();
        for (Expression conjunct : conjuncts) {
            // conjunct=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY')
            //          or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')
            List<Expression> disjuncts = ExpressionUtils.extractDisjunction(conjunct);
            // disjuncts={ (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'),
            //            (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')}
            if (disjuncts.size() == 1) {
                continue;
            }
            // only check table in first disjunct.
            // In our example, qualifiers = { n1, n2 }
            // try to extract
            Set<String> qualifiers = disjuncts.get(0).getInputSlots().stream()
                    .map(slot -> String.join(".", slot.getQualifier()))
                    .collect(Collectors.toCollection(Sets::newLinkedHashSet));
            for (String qualifier : qualifiers) {
                List<Expression> extractForAll = Lists.newArrayList();
                boolean success = true;
                for (Expression expr : disjuncts) {
                    Optional<Expression> extracted = extractSingleTableExpression(expr, qualifier);
                    if (!extracted.isPresent()) {
                        // extract failed
                        success = false;
                        break;
                    } else {
                        extractForAll.addAll(ExpressionUtils.extractDisjunction(extracted.get()));
                    }
                }
                if (success) {
                    dependentPredicates.add(ExpressionUtils.or(extractForAll));
                }
            }
        }
        return dependentPredicates;
    }

    // extract some conjucts from expr, all slots of the extracted conjunct comes from the table referred by qualifier.
    // example: expr=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'), qualifier="n1."
    // output: n1.n_name = 'FRANCE'
    private Optional<Expression> extractSingleTableExpression(Expression expr, String qualifier) {
        // suppose the qualifier is table T, then the process steps are as follow:
        // 1. split the expression into conjunctions: c1 and c2 and c3 and ...
        // 2. for each conjunction ci, suppose its extract is Ei:
        //    a) if ci's all slots come from T, then the whole ci is extracted, then Ei = ci;
        //    b) if ci is an OR expression, then split ci into disjunctions:  ci => d1 or d2 or d3 or ...,
        //       for each disjunction, extract it recuirsely, suppose after extract dj, we get ej,
        //       if all the dj can extracted ej, then extract ci succ, which is Ei = e1 or e2 or e3 or ...,
        //       if any dj extract failed, then extract ci fail
        // 3. collect all the succ extracted Ei, and the result for table T is `E1 and E2 and E3 and ...`
        //
        // for example:
        // suppose expr = (t1.a = 1 or (t2.b = 2 and t1.c = 3)) and (t1.d = 4 or t2.e = 5), qualifier = t1, then
        // c1 = (t1.a = 1 or (t2.b = 2 and t1.c = 3)),
        // because the whole c1 contains slot t2.b not belong to t1, so cannot extract the whole c1,
        // but c1 is an OR expression, so split c1 into disjunctions:
        // d1 => t1.a = 1, d2 => (t2.b = 2 and t1.c = 3)
        // then after extract on d1, we get e1 = t1.a = 1, extract on d2, we get t1.c = 3,
        // so we can extract E1 for c1:   t1.a = 1 or t1.c = 3
        List<Expression> output = Lists.newArrayList();
        List<Expression> conjuncts = ExpressionUtils.extractConjunction(expr);
        for (Expression conjunct : conjuncts) {
            if (isSingleTableExpression(conjunct, qualifier)) {
                output.add(conjunct);
            } else if (conjunct instanceof Or) {
                List<Expression> disjuncts = ExpressionUtils.extractDisjunction(conjunct);
                List<Expression> extracted = Lists.newArrayListWithExpectedSize(disjuncts.size());
                boolean success = true;
                for (Expression disjunct : disjuncts) {
                    Optional<Expression> extractedDisjunct = extractSingleTableExpression(disjunct, qualifier);
                    if (extractedDisjunct.isPresent()) {
                        extracted.addAll(ExpressionUtils.extractDisjunction(extractedDisjunct.get()));
                    } else {
                        // extract failed
                        success = false;
                        break;
                    }
                }
                if (success) {
                    output.add(ExpressionUtils.or(extracted));
                }
            }
        }
        if (output.isEmpty()) {
            return Optional.empty();
        } else {
            return Optional.of(ExpressionUtils.and(output));
        }
    }

    private boolean isSingleTableExpression(Expression expr, String qualifier) {
        //TODO: cache getSlotQualifierAsString() result.
        for (Slot slot : expr.getInputSlots()) {
            String slotQualifier = String.join(".", slot.getQualifier());
            if (!slotQualifier.equals(qualifier)) {
                return false;
            }
        }
        return true;
    }
}