ExtractCommonFactorRule.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimaps;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Extract common expr for `CompoundPredicate`.
* for example:
* transform (a or b) and (a or c) to a or (b and c)
* transform (a and b) or (a and c) to a and (b or c)
*/
@Developing
public class ExtractCommonFactorRule implements ExpressionPatternRuleFactory {
public static final ExtractCommonFactorRule INSTANCE = new ExtractCommonFactorRule();
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesTopType(CompoundPredicate.class).then(ExtractCommonFactorRule::extractCommonFactor)
);
}
private static Expression extractCommonFactor(CompoundPredicate originExpr) {
// fast return
boolean canExtract = false;
Set<Expression> childrenSet = new LinkedHashSet<>();
for (Expression child : originExpr.children()) {
if ((child instanceof CompoundPredicate || child instanceof BooleanLiteral)) {
canExtract = true;
}
childrenSet.add(child);
}
if (!canExtract) {
if (childrenSet.size() != originExpr.children().size()) {
if (childrenSet.size() == 1) {
return childrenSet.iterator().next();
} else {
return originExpr.withChildren(childrenSet.stream().collect(Collectors.toList()));
}
}
return originExpr;
}
// flatten same type to a list
// e.g. ((a and (b or c)) and c) -> [a, (b or c), c]
List<Expression> flatten = ExpressionUtils.extract(originExpr);
// combine and delete some boolean literal predicate
// e.g. (a and true) -> true
Expression simplified = ExpressionUtils.combineAsLeftDeepTree(originExpr.getClass(), flatten);
if (!(simplified instanceof CompoundPredicate)) {
return simplified;
}
// separate two levels CompoundPredicate to partitions
// e.g. ((a and (b or c)) and c) -> [[a], [b, c], c]
CompoundPredicate leftDeapTree = (CompoundPredicate) simplified;
ImmutableSet.Builder<List<Expression>> partitionsBuilder
= ImmutableSet.builderWithExpectedSize(flatten.size());
for (Expression onPartition : ExpressionUtils.extract(leftDeapTree)) {
if (onPartition instanceof CompoundPredicate) {
partitionsBuilder.add(ExpressionUtils.extract((CompoundPredicate) onPartition));
} else {
partitionsBuilder.add(ImmutableList.of(onPartition));
}
}
Set<List<Expression>> partitions = partitionsBuilder.build();
Expression result = extractCommonFactors(originExpr, leftDeapTree, Utils.fastToImmutableList(partitions));
return result;
}
private static Expression extractCommonFactors(CompoundPredicate originPredicate,
CompoundPredicate leftDeapTreePredicate, List<List<Expression>> initPartitions) {
// extract factor and fill into commonFactorToPartIds
// e.g.
// originPredicate: (a and (b and c)) and (b or c)
// leftDeapTreePredicate: ((a and b) and c) and (b or c)
// initPartitions: [[a], [b], [c], [b, c]]
//
// -> commonFactorToPartIds = {a: [0], b: [1, 3], c: [2, 3]}.
// so we can know `b` and `c` is a common factors
SetMultimap<Expression, Integer> commonFactorToPartIds = Multimaps.newSetMultimap(
Maps.newLinkedHashMap(), LinkedHashSet::new
);
int originExpressionNum = 0;
int partId = 0;
for (List<Expression> partition : initPartitions) {
for (Expression expression : partition) {
commonFactorToPartIds.put(expression, partId);
originExpressionNum++;
}
partId++;
}
// commonFactorToPartIds = {a: [0], b: [1, 3], c: [2, 3]}
//
// -> reverse key value of commonFactorToPartIds and remove intersecting partition
//
// -> 1. reverse: {[0]: [a], [1, 3]: [b], [2, 3]: [c]}
// -> 2. sort by key size desc: {[1, 3]: [b], [2, 3]: [c], [0]: [a]}
// -> 3. remove intersection partition: {[1, 3]: [b], [2]: [c], [0]: [a]},
// because first part and second part intersect by partition 3
LinkedHashMap<Set<Integer>, Set<Expression>> commonFactorPartitions
= partitionByMostCommonFactors(commonFactorToPartIds);
int extractedExpressionNum = 0;
for (Set<Expression> exprs : commonFactorPartitions.values()) {
extractedExpressionNum += exprs.size();
}
// no any common factor
if (commonFactorPartitions.entrySet().iterator().next().getKey().size() <= 1
&& !(originPredicate.getWidth() > leftDeapTreePredicate.getWidth())
&& originExpressionNum <= extractedExpressionNum) {
// this condition is important because it can avoid deap loop:
// origin originExpr: A = 1 and (B > 0 and B < 10)
// after ExtractCommonFactorRule: (A = 1 and B > 0) and (B < 10) (left deap tree)
// after SimplifyRange: A = 1 and (B > 0 and B < 10) (right deap tree)
return originPredicate;
}
// now we can do extract common factors for each part:
// originPredicate: (a and (b and c)) and (b or c)
// leftDeapTreePredicate: ((a and b) and c) and (b or c)
// initPartitions: [[a], [b], [c], [b, c]]
// commonFactorPartitions: {[1, 3]: [b], [0]: [a]}
//
// -> extractedExprs: [
// b or (false and c) = b,
// a,
// c
// ]
//
// -> result: (b or c) and a and c
ImmutableList.Builder<Expression> extractedExprs
= ImmutableList.builderWithExpectedSize(commonFactorPartitions.size());
for (Entry<Set<Integer>, Set<Expression>> kv : commonFactorPartitions.entrySet()) {
Expression extracted = doExtractCommonFactors(
leftDeapTreePredicate, initPartitions, kv.getKey(), kv.getValue()
);
extractedExprs.add(extracted);
}
// combine and eliminate some boolean literal predicate
return ExpressionUtils.combineAsLeftDeepTree(leftDeapTreePredicate.getClass(), extractedExprs.build());
}
private static Expression doExtractCommonFactors(
CompoundPredicate originPredicate,
List<List<Expression>> partitions, Set<Integer> partitionIds, Set<Expression> commonFactors) {
ImmutableList.Builder<Expression> uncorrelatedExprPartitionsBuilder
= ImmutableList.builderWithExpectedSize(partitionIds.size());
for (Integer partitionId : partitionIds) {
List<Expression> partition = partitions.get(partitionId);
ImmutableSet.Builder<Expression> uncorrelatedBuilder
= ImmutableSet.builderWithExpectedSize(partition.size());
for (Expression exprOfPart : partition) {
if (!commonFactors.contains(exprOfPart)) {
uncorrelatedBuilder.add(exprOfPart);
}
}
Set<Expression> uncorrelated = uncorrelatedBuilder.build();
Expression partitionWithoutCommonFactor
= ExpressionUtils.combineAsLeftDeepTree(originPredicate.flipType(), uncorrelated);
if (partitionWithoutCommonFactor instanceof CompoundPredicate) {
partitionWithoutCommonFactor = extractCommonFactor((CompoundPredicate) partitionWithoutCommonFactor);
}
uncorrelatedExprPartitionsBuilder.add(partitionWithoutCommonFactor);
}
ImmutableList<Expression> uncorrelatedExprPartitions = uncorrelatedExprPartitionsBuilder.build();
ImmutableList.Builder<Expression> allExprs = ImmutableList.builderWithExpectedSize(commonFactors.size() + 1);
allExprs.addAll(commonFactors);
Expression combineUncorrelatedExpr = ExpressionUtils.combineAsLeftDeepTree(
originPredicate.getClass(), uncorrelatedExprPartitions);
allExprs.add(combineUncorrelatedExpr);
Expression result = ExpressionUtils.combineAsLeftDeepTree(originPredicate.flipType(), allExprs.build());
return result;
}
private static LinkedHashMap<Set<Integer>, Set<Expression>> partitionByMostCommonFactors(
SetMultimap<Expression, Integer> commonFactorToPartIds) {
SetMultimap<Set<Integer>, Expression> partWithCommonFactors = Multimaps.newSetMultimap(
Maps.newLinkedHashMap(), LinkedHashSet::new
);
for (Entry<Expression, Collection<Integer>> factorToId : commonFactorToPartIds.asMap().entrySet()) {
partWithCommonFactors.put((Set<Integer>) factorToId.getValue(), factorToId.getKey());
}
List<Set<Integer>> sortedPartitionIdHasCommonFactor = Lists.newArrayList(partWithCommonFactors.keySet());
// place the most common factor at the head of this list
sortedPartitionIdHasCommonFactor.sort((p1, p2) -> p2.size() - p1.size());
LinkedHashMap<Set<Integer>, Set<Expression>> shouldExtractFactors = Maps.newLinkedHashMap();
Set<Integer> allocatedPartitions = Sets.newLinkedHashSet();
for (Set<Integer> originMostCommonFactorPartitions : sortedPartitionIdHasCommonFactor) {
ImmutableSet.Builder<Integer> notAllocatePartitions = ImmutableSet.builderWithExpectedSize(
originMostCommonFactorPartitions.size());
for (Integer partId : originMostCommonFactorPartitions) {
if (allocatedPartitions.add(partId)) {
notAllocatePartitions.add(partId);
}
}
Set<Integer> mostCommonFactorPartitions = notAllocatePartitions.build();
if (!mostCommonFactorPartitions.isEmpty()) {
Set<Expression> commonFactors = partWithCommonFactors.get(originMostCommonFactorPartitions);
shouldExtractFactors.put(mostCommonFactorPartitions, commonFactors);
}
}
return shouldExtractFactors;
}
}