ShuffleKeyPruneUtils.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.properties;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.stats.StatsCalculator;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.AggregateUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.util.StatisticsUtil;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
/**AggShuffleKeyOptimize*/
public class ShuffleKeyPruneUtils {
public static final double shuffleKeyHotValueThreshold = 0.05;
private static Optional<List<Expression>> toOptionalIfChanged(
List<? extends Expression> originalKeys, List<Expression> optimizedKeys) {
if (optimizedKeys.equals(originalKeys)) {
return Optional.empty();
}
return Optional.of(optimizedKeys);
}
private static Optional<Pair<List<ExprId>, List<ExprId>>> toOptionalIfChanged(
Pair<List<ExprId>, List<ExprId>> originalKeys, Pair<List<ExprId>, List<ExprId>> optimizedKeys) {
if (originalKeys.first.size() == optimizedKeys.first.size()) {
return Optional.empty();
}
return Optional.of(optimizedKeys);
}
/**
* Scenario 4: When partition expressions are set by rule, optionally reduce shuffle keys.
* Strategy: 1) Try single key (isBalanced); 2) Try numeric+date keys (remove strings);
* 3) Fall back to full partitionExprs.
* Returns the list of expressions to use as shuffle keys, or empty to use full partitionExprs.
*/
public static Optional<List<Expression>> selectBestShuffleKeyForAgg(
PhysicalHashAggregate<? extends Plan> agg, List<Expression> partitionExprs, Statistics childStats,
ConnectContext context) {
double rowCount = childStats.getRowCount();
int instanceNum = context.getTotalInstanceNum();
return selectOptimalShuffleKeys(partitionExprs, childStats, rowCount, instanceNum);
}
/**
* Select optimal shuffle keys with three-step strategy:
* 1. Try single key: sort by type (numeric/date first, string sorted by avg_size), pick first isBalanced key.
* 2. Try remove strings: filter numeric+date keys, if combinedNDV > instanceNum*512 return that list.
* 3. Fall back: return empty (caller uses full partitionExprs).
*/
private static Optional<List<Expression>> selectOptimalShuffleKeys(List<Expression> partitionExprs,
Statistics childStats, double rowCount, int instanceNum) {
List<SlotReference> slotRefs = partitionExprs.stream()
.filter(SlotReference.class::isInstance)
.map(SlotReference.class::cast)
.collect(Collectors.toList());
if (slotRefs.isEmpty()) {
return Optional.empty();
}
// If any partition slot lacks column stats, skip optimization and use original partitionExprs.
for (SlotReference slotRef : slotRefs) {
ColumnStatistic columnStatistic = childStats.findColumnStatistics(slotRef);
if (columnStatistic == null || columnStatistic.isUnKnown) {
return Optional.empty();
}
if (columnStatistic.hotValues == null) {
return Optional.empty();
}
}
// Step 1: Try single key - sort by type priority, pick first isBalanced
List<SlotReference> sortedByType = sortShuffleKeysByTypePriority(slotRefs, childStats);
for (SlotReference slotRef : sortedByType) {
ColumnStatistic colStats = childStats.findColumnStatistics(slotRef);
if (StatisticsUtil.isBalanced(colStats, instanceNum, shuffleKeyHotValueThreshold, rowCount)) {
return toOptionalIfChanged(partitionExprs, ImmutableList.of(slotRef));
}
}
// Step 2: Try remove string types - filter numeric+date, check combined NDV
List<Expression> numericAndDateExprs = slotRefs.stream()
.filter(s -> s.getDataType().isNumericType() || s.getDataType().isDateLikeType())
.collect(Collectors.toList());
if (!numericAndDateExprs.isEmpty()) {
double combinedNdv = StatsCalculator.estimateGroupByRowCount(numericAndDateExprs, childStats);
long ndvThreshold = (long) instanceNum * AggregateUtils.NDV_INSTANCE_BALANCE_MULTIPLIER;
if (combinedNdv > ndvThreshold) {
return toOptionalIfChanged(partitionExprs, ImmutableList.copyOf(numericAndDateExprs));
}
}
// Step 3: Fall back - return empty, caller uses full partitionExprs
return Optional.empty();
}
/**
* Sort shuffle keys: numeric and date first, then string types.
* String types are sorted by column statistics avg size (avgSizeByte) ascending.
*/
private static List<SlotReference> sortShuffleKeysByTypePriority(List<SlotReference> slotRefs,
Statistics childStats) {
List<SlotReference> result = new ArrayList<>(slotRefs);
result.sort(Comparator
.comparingInt((SlotReference s) -> getTypeSortPriority(s.getDataType()))
.thenComparingDouble((SlotReference s) -> getStringAvgSizeForSort(s, childStats)));
return result;
}
/** 0=numeric/date first, 1=string last. */
private static int getTypeSortPriority(DataType dataType) {
if (dataType.isNumericType() || dataType.isDateLikeType()) {
return 0;
}
return 1;
}
/** For string types return avg size from stats; for others return 0 (no secondary sort). */
private static double getStringAvgSizeForSort(Slot slotRef, Statistics childStats) {
DataType dataType = slotRef.getDataType();
if (dataType instanceof CharacterType) {
ColumnStatistic colStats = childStats.findColumnStatistics(slotRef);
if (colStats != null && !colStats.isUnKnown && colStats.avgSizeByte > 0) {
return colStats.avgSizeByte;
}
return ((CharacterType) dataType).getLen();
}
return 0;
}
/**
* Pick optimal shuffle keys for a hash join.
* Uses the same three-step strategy as agg shuffle-key pruning:
* 1) Try single key (isBalanced); 2) Try numeric+date keys (remove strings);
* 3) Fall back (empty).
*/
public static Optional<Pair<List<ExprId>, List<ExprId>>> tryFindOptimalShuffleKeyForJoinWithDistributeColumns(
ConnectContext context, List<Slot> leftOrderedShuffledColumns, List<Slot> rightOrderedShuffledColumns,
List<ExprId> leftOrderedShuffledColumnId, List<ExprId> rightOrderedShuffledColumnId,
Statistics leftStats, Statistics rightStats) {
if (leftStats == null || rightStats == null) {
return Optional.empty();
}
if (leftOrderedShuffledColumns.size() != rightOrderedShuffledColumns.size()) {
return Optional.empty();
}
if (leftOrderedShuffledColumnId.size() != rightOrderedShuffledColumnId.size()) {
return Optional.empty();
}
double leftRows = leftStats.getRowCount();
double rightRows = rightStats.getRowCount();
int instanceNum = context.getTotalInstanceNum();
List<Pair<Slot, Slot>> validPairs = new ArrayList<>();
for (int i = 0; i < leftOrderedShuffledColumns.size(); ++i) {
validPairs.add(Pair.of(leftOrderedShuffledColumns.get(i), rightOrderedShuffledColumns.get(i)));
}
return selectOptimalJoinShuffleKeysFromPairs(validPairs,
Pair.of(leftOrderedShuffledColumnId, rightOrderedShuffledColumnId),
leftStats, rightStats, leftRows, rightRows,
instanceNum);
}
/**
* Three-step join shuffle optimization; compares result to {@code baselineForChange}.
*/
private static Optional<Pair<List<ExprId>, List<ExprId>>> selectOptimalJoinShuffleKeysFromPairs(
List<Pair<Slot, Slot>> validPairs,
Pair<List<ExprId>, List<ExprId>> baselineForChange,
Statistics leftStats, Statistics rightStats,
double leftRows, double rightRows, int instanceNum) {
for (Pair<Slot, Slot> pair : validPairs) {
ColumnStatistic firstStats = leftStats.findColumnStatistics(pair.first);
ColumnStatistic secondStats = rightStats.findColumnStatistics(pair.second);
if (firstStats == null || secondStats == null || firstStats.isUnKnown || secondStats.isUnKnown
|| firstStats.hotValues == null || secondStats.hotValues == null) {
return Optional.empty();
}
}
// Step 1: Try single key - sort by type, pick first where both isBalanced
List<Pair<Slot, Slot>> sortedPairs =
sortJoinKeyPairsByTypePriority(validPairs, leftStats, rightStats);
for (Pair<Slot, Slot> pair : sortedPairs) {
Slot leftSlotRef = pair.first;
Slot rightSlotRef = pair.second;
ColumnStatistic leftColStats = leftStats.findColumnStatistics(leftSlotRef);
ColumnStatistic rightColStats = rightStats.findColumnStatistics(rightSlotRef);
if (StatisticsUtil.isBalanced(leftColStats, instanceNum, shuffleKeyHotValueThreshold, leftRows)
&& StatisticsUtil.isBalanced(rightColStats, instanceNum, shuffleKeyHotValueThreshold, rightRows)) {
return toOptionalIfChanged(baselineForChange, Pair.of(
ImmutableList.of(leftSlotRef.getExprId()),
ImmutableList.of(rightSlotRef.getExprId())));
}
}
// Step 2: Try remove string types - filter numeric+date pairs, check combined NDV
List<Slot> numericDateLeftSlots = new ArrayList<>();
List<Slot> numericDateRightSlots = new ArrayList<>();
for (Pair<Slot, Slot> pair : validPairs) {
if ((pair.first.getDataType().isNumericType() || pair.first.getDataType().isDateLikeType())
&& (pair.second.getDataType().isNumericType() || pair.second.getDataType().isDateLikeType())) {
numericDateLeftSlots.add(pair.first);
numericDateRightSlots.add(pair.second);
}
}
if (!numericDateLeftSlots.isEmpty()) {
double leftCombinedNdv = StatsCalculator.estimateGroupByRowCount(
new ArrayList<>(numericDateLeftSlots), leftStats);
double rightCombinedNdv = StatsCalculator.estimateGroupByRowCount(
new ArrayList<>(numericDateRightSlots), rightStats);
long ndvThreshold = (long) instanceNum * AggregateUtils.NDV_INSTANCE_BALANCE_MULTIPLIER;
if (leftCombinedNdv > ndvThreshold && rightCombinedNdv > ndvThreshold) {
List<ExprId> leftIds = numericDateLeftSlots.stream()
.map(Slot::getExprId)
.collect(Collectors.toList());
List<ExprId> rightIds = numericDateRightSlots.stream()
.map(Slot::getExprId)
.collect(Collectors.toList());
return toOptionalIfChanged(baselineForChange, Pair.of(leftIds, rightIds));
}
}
// Step 3: Fall back
return Optional.empty();
}
/** Sort join key pairs by type priority (numeric/date first, string by avg_size). */
private static List<Pair<Slot, Slot>> sortJoinKeyPairsByTypePriority(
List<Pair<Slot, Slot>> pairs, Statistics leftStats, Statistics rightStats) {
List<Pair<Slot, Slot>> result = new ArrayList<>(pairs);
result.sort(Comparator
.comparingInt((Pair<Slot, Slot> p) ->
getTypeSortPriority(p.first.getDataType()))
.thenComparingDouble((Pair<Slot, Slot> p) ->
getJoinPairStringAvgSizeForSort(p, leftStats, rightStats)));
return result;
}
/** For string join-key pairs, use avg size of both sides for sorting; for others return 0. */
private static double getJoinPairStringAvgSizeForSort(Pair<Slot, Slot> pair,
Statistics leftStats, Statistics rightStats) {
if (pair.first.getDataType() instanceof CharacterType && pair.second.getDataType() instanceof CharacterType) {
return (getStringAvgSizeForSort(pair.first, leftStats) + getStringAvgSizeForSort(pair.second, rightStats));
}
return 0;
}
}