ExtractCommonFactorRule.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimaps;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Sets;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Extract common expr for `CompoundPredicate`.
 * for example:
 * transform (a or b) and (a or c) to a or (b and c)
 * transform (a and b) or (a and c) to a and (b or c)
 */
@Developing
public class ExtractCommonFactorRule implements ExpressionPatternRuleFactory {
    public static final ExtractCommonFactorRule INSTANCE = new ExtractCommonFactorRule();

    @Override
    public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
        return ImmutableList.of(
                 matchesTopType(CompoundPredicate.class).then(ExtractCommonFactorRule::extractCommonFactor)
                        .toRule(ExpressionRuleType.EXTRACT_COMMON_FACTOR)
        );
    }

    private static Expression extractCommonFactor(CompoundPredicate originExpr) {
        // fast return
        boolean canExtract = false;
        Set<Expression> childrenSet = new LinkedHashSet<>();
        for (Expression child : originExpr.children()) {
            if ((child instanceof CompoundPredicate || child instanceof BooleanLiteral)) {
                canExtract = true;
            }
            childrenSet.add(child);
        }
        if (!canExtract) {
            if (childrenSet.size() != originExpr.children().size()) {
                if (childrenSet.size() == 1) {
                    return childrenSet.iterator().next();
                } else {
                    return originExpr.withChildren(childrenSet.stream().collect(Collectors.toList()));
                }
            }
            return originExpr;
        }
        // flatten same type to a list
        // e.g. ((a and (b or c)) and c) -> [a, (b or c), c]
        List<Expression> flatten = ExpressionUtils.extract(originExpr);

        // combine and delete some boolean literal predicate
        // e.g. (a and true) -> a
        Expression simplified = ExpressionUtils.compound(originExpr instanceof And, flatten);
        if (!(simplified instanceof CompoundPredicate)) {
            return simplified;
        }

        // separate two levels CompoundPredicate to partitions
        // e.g. ((a and (b or c)) and c) -> [[a], [b, c], c]
        ImmutableSet.Builder<List<Expression>> partitionsBuilder
                = ImmutableSet.builderWithExpectedSize(simplified.children().size());
        for (Expression onPartition : simplified.children()) {
            if (onPartition instanceof CompoundPredicate) {
                partitionsBuilder.add(ExpressionUtils.extract((CompoundPredicate) onPartition));
            } else {
                partitionsBuilder.add(ImmutableList.of(onPartition));
            }
        }
        Set<List<Expression>> partitions = partitionsBuilder.build();

        return extractCommonFactors(simplified instanceof And, Utils.fastToImmutableList(partitions));
    }

    private static Expression extractCommonFactors(boolean isAnd, List<List<Expression>> initPartitions) {
        // extract factor and fill into commonFactorToPartIds
        // e.g.
        //      originPredicate:         (a and (b and c)) and (b or c)
        //      initPartitions: [[a], [b], [c], [b, c]]
        //
        //   -> commonFactorToPartIds = {a: [0], b: [1, 3], c: [2, 3]}.
        //      so we can know `b` and `c` is a common factors
        SetMultimap<Expression, Integer> commonFactorToPartIds = Multimaps.newSetMultimap(
                Maps.newLinkedHashMap(), LinkedHashSet::new
        );
        for (int i = 0; i < initPartitions.size(); i++) {
            List<Expression> partition = initPartitions.get(i);
            for (Expression expression : partition) {
                commonFactorToPartIds.put(expression, i);
            }
        }

        //     commonFactorToPartIds = {a: [0], b: [1, 3], c: [2, 3]}
        //
        //  -> reverse key value of commonFactorToPartIds and remove intersecting partition
        //
        //  -> 1. reverse: {[0]: [a], [1, 3]: [b], [2, 3]: [c]}
        //  -> 2. sort by key size desc: {[1, 3]: [b], [2, 3]: [c], [0]: [a]}
        //  -> 3. remove intersection partition: {[1, 3]: [b], [2]: [c], [0]: [a]},
        //        because first part and second part intersect by partition 3
        LinkedHashMap<Set<Integer>, Set<Expression>> commonFactorPartitions
                = partitionByMostCommonFactors(commonFactorToPartIds);

        // now we can do extract common factors for each part:
        //    originPredicate:         (a and (b and c)) and (b or c)
        //    initPartitions:          [[a], [b], [c], [b, c]]
        //    commonFactorPartitions:  {[1, 3]: [b], [0]: [a]}
        //
        // -> extractedExprs: [
        //                       b or (false and c) = b,
        //                       a,
        //                       c
        //                    ]
        //
        // -> result: (b or c) and a and c
        ImmutableList.Builder<Expression> extractedExprs
                = ImmutableList.builderWithExpectedSize(commonFactorPartitions.size());
        for (Entry<Set<Integer>, Set<Expression>> kv : commonFactorPartitions.entrySet()) {
            Expression extracted = doExtractCommonFactors(isAnd, initPartitions, kv.getKey(), kv.getValue());
            extractedExprs.add(extracted);
        }

        // combine and eliminate some boolean literal predicate
        return ExpressionUtils.compound(isAnd, extractedExprs.build());
    }

    private static Expression doExtractCommonFactors(boolean isAnd,
            List<List<Expression>> partitions, Set<Integer> partitionIds, Set<Expression> commonFactors) {
        ImmutableList.Builder<Expression> uncorrelatedExprPartitionsBuilder
                = ImmutableList.builderWithExpectedSize(partitionIds.size());
        for (Integer partitionId : partitionIds) {
            List<Expression> partition = partitions.get(partitionId);
            ImmutableSet.Builder<Expression> uncorrelatedBuilder
                    = ImmutableSet.builderWithExpectedSize(partition.size());
            for (Expression exprOfPart : partition) {
                if (!commonFactors.contains(exprOfPart)) {
                    uncorrelatedBuilder.add(exprOfPart);
                }
            }

            Set<Expression> uncorrelated = uncorrelatedBuilder.build();
            Expression partitionWithoutCommonFactor = ExpressionUtils.compound(!isAnd, uncorrelated);
            if (partitionWithoutCommonFactor instanceof CompoundPredicate) {
                partitionWithoutCommonFactor = extractCommonFactor((CompoundPredicate) partitionWithoutCommonFactor);
            }
            uncorrelatedExprPartitionsBuilder.add(partitionWithoutCommonFactor);
        }

        // common factors should be flip of isAnd
        Expression combineCommonFactor = ExpressionUtils.compound(!isAnd, commonFactors);
        if (combineCommonFactor instanceof CompoundPredicate) {
            combineCommonFactor = extractCommonFactor((CompoundPredicate) combineCommonFactor);
        }
        List<Expression> rewriteCommonFactors = isAnd ? ExpressionUtils.extractDisjunction(combineCommonFactor)
                : ExpressionUtils.extractConjunction(combineCommonFactor);

        ImmutableList<Expression> uncorrelatedExprPartitions = uncorrelatedExprPartitionsBuilder.build();
        Expression combineUncorrelatedExpr = ExpressionUtils.compound(isAnd, uncorrelatedExprPartitions);

        ImmutableList.Builder<Expression> allExprs = ImmutableList.builderWithExpectedSize(
                rewriteCommonFactors.size() + 1);
        allExprs.addAll(rewriteCommonFactors);
        allExprs.add(combineUncorrelatedExpr);

        return ExpressionUtils.compound(!isAnd, allExprs.build());
    }

    private static LinkedHashMap<Set<Integer>, Set<Expression>> partitionByMostCommonFactors(
            SetMultimap<Expression, Integer> commonFactorToPartIds) {
        SetMultimap<Set<Integer>, Expression> partWithCommonFactors = Multimaps.newSetMultimap(
                Maps.newLinkedHashMap(), LinkedHashSet::new
        );

        for (Entry<Expression, Collection<Integer>> factorToId : commonFactorToPartIds.asMap().entrySet()) {
            partWithCommonFactors.put((Set<Integer>) factorToId.getValue(), factorToId.getKey());
        }

        List<Set<Integer>> sortedPartitionIdHasCommonFactor = Lists.newArrayList(partWithCommonFactors.keySet());
        // place the most common factor at the head of this list
        sortedPartitionIdHasCommonFactor.sort((p1, p2) -> p2.size() - p1.size());

        LinkedHashMap<Set<Integer>, Set<Expression>> shouldExtractFactors = Maps.newLinkedHashMap();

        Set<Integer> allocatedPartitions = Sets.newLinkedHashSet();
        for (Set<Integer> originMostCommonFactorPartitions : sortedPartitionIdHasCommonFactor) {
            ImmutableSet.Builder<Integer> notAllocatePartitions = ImmutableSet.builderWithExpectedSize(
                    originMostCommonFactorPartitions.size());
            for (Integer partId : originMostCommonFactorPartitions) {
                if (allocatedPartitions.add(partId)) {
                    notAllocatePartitions.add(partId);
                }
            }

            Set<Integer> mostCommonFactorPartitions = notAllocatePartitions.build();
            if (!mostCommonFactorPartitions.isEmpty()) {
                Set<Expression> commonFactors = partWithCommonFactors.get(originMostCommonFactorPartitions);
                shouldExtractFactors.put(mostCommonFactorPartitions, commonFactors);
            }
        }

        return shouldExtractFactors;
    }
}