PartitionPruner.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.catalog.ListPartitionItem;
import org.apache.doris.catalog.PartitionItem;
import org.apache.doris.catalog.RangePartitionItem;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.SortedPartitionRanges.PartitionItemAndId;
import org.apache.doris.nereids.rules.expression.rules.SortedPartitionRanges.PartitionItemAndRange;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
/**
* PartitionPruner
*/
public class PartitionPruner extends DefaultExpressionRewriter<Void> {
private final List<OnePartitionEvaluator<?>> partitions;
private final Expression partitionPredicate;
/** Different type of table may have different partition prune behavior. */
public enum PartitionTableType {
OLAP,
EXTERNAL
}
private PartitionPruner(List<OnePartitionEvaluator<?>> partitions, Expression partitionPredicate) {
this.partitions = Objects.requireNonNull(partitions, "partitions cannot be null");
this.partitionPredicate = Objects.requireNonNull(partitionPredicate.accept(this, null),
"partitionPredicate cannot be null");
}
@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
// Date cp Date is not supported in BE storage engine. So cast to DateTime in SimplifyComparisonPredicate
// for easy process partition prune, we convert back to date compare date here
// see more info in SimplifyComparisonPredicate
Expression left = cp.left();
Expression right = cp.right();
if (left.getDataType() != DateTimeType.INSTANCE || right.getDataType() != DateTimeType.INSTANCE) {
return cp;
}
if (!(left instanceof DateTimeLiteral) && !(right instanceof DateTimeLiteral)) {
return cp;
}
if (left instanceof DateTimeLiteral && ((DateTimeLiteral) left).isMidnight()
&& right instanceof Cast
&& ((Cast) right).child() instanceof SlotReference
&& ((Cast) right).child().getDataType().isDateType()) {
DateTimeLiteral dt = (DateTimeLiteral) left;
Cast cast = (Cast) right;
return cp.withChildren(
ImmutableList.of(new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay()), cast.child())
);
} else if (right instanceof DateTimeLiteral && ((DateTimeLiteral) right).isMidnight()
&& left instanceof Cast
&& ((Cast) left).child() instanceof SlotReference
&& ((Cast) left).child().getDataType().isDateType()) {
DateTimeLiteral dt = (DateTimeLiteral) right;
Cast cast = (Cast) left;
return cp.withChildren(ImmutableList.of(
cast.child(),
new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay()))
);
} else {
return cp;
}
}
/** prune */
public <K extends Comparable<K>> List<K> prune() {
Builder<K> scanPartitionIdents = ImmutableList.builder();
for (OnePartitionEvaluator partition : partitions) {
if (!canBePrunedOut(partitionPredicate, partition)) {
scanPartitionIdents.add((K) partition.getPartitionIdent());
}
}
return scanPartitionIdents.build();
}
public static <K extends Comparable<K>> List<K> prune(List<Slot> partitionSlots, Expression partitionPredicate,
Map<K, PartitionItem> idToPartitions, CascadesContext cascadesContext,
PartitionTableType partitionTableType) {
return prune(partitionSlots, partitionPredicate, idToPartitions,
cascadesContext, partitionTableType, Optional.empty());
}
/**
* prune partition with `idToPartitions` as parameter.
*/
public static <K extends Comparable<K>> List<K> prune(List<Slot> partitionSlots, Expression partitionPredicate,
Map<K, PartitionItem> idToPartitions, CascadesContext cascadesContext,
PartitionTableType partitionTableType, Optional<SortedPartitionRanges<K>> sortedPartitionRanges) {
partitionPredicate = PartitionPruneExpressionExtractor.extract(
partitionPredicate, ImmutableSet.copyOf(partitionSlots), cascadesContext);
partitionPredicate = PredicateRewriteForPartitionPrune.rewrite(partitionPredicate, cascadesContext);
int expandThreshold = cascadesContext.getAndCacheSessionVariable(
"partitionPruningExpandThreshold",
10, sessionVariable -> sessionVariable.partitionPruningExpandThreshold);
partitionPredicate = OrToIn.EXTRACT_MODE_INSTANCE.rewriteTree(
partitionPredicate, new ExpressionRewriteContext(cascadesContext));
if (BooleanLiteral.TRUE.equals(partitionPredicate)) {
return Utils.fastToImmutableList(idToPartitions.keySet());
} else if (BooleanLiteral.FALSE.equals(partitionPredicate) || partitionPredicate.isNullLiteral()) {
return ImmutableList.of();
}
if (sortedPartitionRanges.isPresent()) {
RangeSet<MultiColumnBound> predicateRanges = partitionPredicate.accept(
new PartitionPredicateToRange(partitionSlots), null);
if (predicateRanges != null) {
return binarySearchFiltering(
sortedPartitionRanges.get(), partitionSlots, partitionPredicate, cascadesContext,
expandThreshold, predicateRanges
);
}
}
return sequentialFiltering(
idToPartitions, partitionSlots, partitionPredicate, cascadesContext, expandThreshold
);
}
/**
* convert partition item to partition evaluator
*/
public static <K> OnePartitionEvaluator<K> toPartitionEvaluator(K id, PartitionItem partitionItem,
List<Slot> partitionSlots, CascadesContext cascadesContext, int expandThreshold) {
if (partitionItem instanceof ListPartitionItem) {
return new OneListPartitionEvaluator<>(
id, partitionSlots, (ListPartitionItem) partitionItem, cascadesContext);
} else if (partitionItem instanceof RangePartitionItem) {
return new OneRangePartitionEvaluator<>(
id, partitionSlots, (RangePartitionItem) partitionItem, cascadesContext, expandThreshold);
} else {
return new UnknownPartitionEvaluator<>(id, partitionItem);
}
}
private static <K extends Comparable<K>> List<K> binarySearchFiltering(
SortedPartitionRanges<K> sortedPartitionRanges, List<Slot> partitionSlots,
Expression partitionPredicate, CascadesContext cascadesContext, int expandThreshold,
RangeSet<MultiColumnBound> predicateRanges) {
List<PartitionItemAndRange<K>> sortedPartitions = sortedPartitionRanges.sortedPartitions;
Set<K> selectedIdSets = Sets.newTreeSet();
int leftIndex = 0;
for (Range<MultiColumnBound> predicateRange : predicateRanges.asRanges()) {
int rightIndex = sortedPartitions.size();
if (leftIndex >= rightIndex) {
break;
}
int midIndex;
MultiColumnBound predicateUpperBound = predicateRange.upperEndpoint();
MultiColumnBound predicateLowerBound = predicateRange.lowerEndpoint();
while (leftIndex + 1 < rightIndex) {
midIndex = (leftIndex + rightIndex) / 2;
PartitionItemAndRange<K> partition = sortedPartitions.get(midIndex);
Range<MultiColumnBound> partitionSpan = partition.range;
if (predicateUpperBound.compareTo(partitionSpan.lowerEndpoint()) < 0) {
rightIndex = midIndex;
} else if (predicateLowerBound.compareTo(partitionSpan.upperEndpoint()) > 0) {
leftIndex = midIndex;
} else {
break;
}
}
for (; leftIndex < sortedPartitions.size(); leftIndex++) {
PartitionItemAndRange<K> partition = sortedPartitions.get(leftIndex);
K partitionId = partition.id;
// list partition will expand to multiple PartitionItemAndRange, we should skip evaluate it again
if (selectedIdSets.contains(partitionId)) {
continue;
}
Range<MultiColumnBound> partitionSpan = partition.range;
if (predicateUpperBound.compareTo(partitionSpan.lowerEndpoint()) < 0) {
break;
}
OnePartitionEvaluator<K> partitionEvaluator = toPartitionEvaluator(
partitionId, partition.partitionItem, partitionSlots, cascadesContext, expandThreshold);
if (!canBePrunedOut(partitionPredicate, partitionEvaluator)) {
selectedIdSets.add(partitionId);
}
}
}
for (PartitionItemAndId<K> defaultPartition : sortedPartitionRanges.defaultPartitions) {
K partitionId = defaultPartition.id;
OnePartitionEvaluator<K> partitionEvaluator = toPartitionEvaluator(
partitionId, defaultPartition.partitionItem, partitionSlots, cascadesContext, expandThreshold);
if (!canBePrunedOut(partitionPredicate, partitionEvaluator)) {
selectedIdSets.add(partitionId);
}
}
return Utils.fastToImmutableList(selectedIdSets);
}
private static <K extends Comparable<K>> List<K> sequentialFiltering(
Map<K, PartitionItem> idToPartitions, List<Slot> partitionSlots,
Expression partitionPredicate, CascadesContext cascadesContext, int expandThreshold) {
List<OnePartitionEvaluator<?>> evaluators = Lists.newArrayListWithCapacity(idToPartitions.size());
for (Entry<K, PartitionItem> kv : idToPartitions.entrySet()) {
evaluators.add(toPartitionEvaluator(
kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, expandThreshold));
}
PartitionPruner partitionPruner = new PartitionPruner(evaluators, partitionPredicate);
//TODO: we keep default partition because it's too hard to prune it, we return false in canPrune().
return partitionPruner.prune();
}
/**
* return true if partition is not qualified. that is, can be pruned out.
*/
private static <K> boolean canBePrunedOut(Expression partitionPredicate, OnePartitionEvaluator<K> evaluator) {
List<Map<Slot, PartitionSlotInput>> onePartitionInputs = evaluator.getOnePartitionInputs();
for (Map<Slot, PartitionSlotInput> currentInputs : onePartitionInputs) {
// evaluate whether there's possible for this partition to accept this predicate
Expression result = evaluator.evaluateWithDefaultPartition(partitionPredicate, currentInputs);
if (!result.equals(BooleanLiteral.FALSE) && !(result instanceof NullLiteral)) {
return false;
}
}
// only have false result: Can be pruned out. have other exprs: CanNot be pruned out
return true;
}
}