ShuffleKeyPruneUtils.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.properties;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.util.StatisticsUtil;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**AggShuffleKeyOptimize*/
public class ShuffleKeyPruneUtils {
private static GroupExpression getGroupExpression(Group group) {
List<GroupExpression> physicalGroupExpressions = group.getPhysicalExpressions();
if (!physicalGroupExpressions.isEmpty()) {
return physicalGroupExpressions.get(0);
} else {
return group.getLogicalExpressions().get(0);
}
}
/*
* @param agg is a global aggregate
* @return the Statistics of the children of the local aggregate corresponding to the global aggregate.
*/
private static Optional<Statistics> getGlobalAggChildStats(PhysicalHashAggregate<? extends Plan> agg) {
Optional<GroupExpression> groupExpression = agg.getGroupExpression();
if (!groupExpression.isPresent()) {
return Optional.empty();
}
Statistics aggChildStats = groupExpression.get().childStatistics(0);
Group childGroup = groupExpression.get().child(0);
Plan childExpression = getGroupExpression(childGroup).getPlan();
if (childExpression instanceof PhysicalHashAggregate
&& ((PhysicalHashAggregate) childExpression).getAggPhase().isLocal()) {
childGroup = childGroup.getPhysicalExpressions().get(0).child(0);
aggChildStats = childGroup.getStatistics();
}
return Optional.ofNullable(aggChildStats);
}
private static boolean canAggShuffleKeyOpt(PhysicalHashAggregate<? extends Plan> agg,
ConnectContext connectContext) {
if (!connectContext.getSessionVariable().chooseOneAggShuffleKey) {
return false;
}
if (agg.getGroupByExpressions().size() <= connectContext.getSessionVariable().shuffleKeyPruneThreshold) {
return false;
}
if (agg.hasSourceRepeat()) {
return false;
}
return true;
}
/**
* When parent sends shuffle request, choose one optimal key from intersection of parent hash
* columns and agg group-by columns, or use full intersection. Returns list of ExprIds as
* shuffle keys.
*/
public static List<ExprId> selectOptimalShuffleKeyForAggWithParentHashRequest(
PhysicalHashAggregate<? extends Plan> agg, Set<ExprId> intersectIdSet, PlanContext context) {
List<ExprId> orderedIds = Utils.fastToImmutableList(intersectIdSet);
if (!context.getConnectContext().getSessionVariable().chooseOneAggShuffleKey
|| intersectIdSet.size() <= context.getConnectContext().getSessionVariable().shuffleKeyPruneThreshold) {
return orderedIds;
}
Optional<Statistics> childStats = getGlobalAggChildStats(agg);
if (!childStats.isPresent()) {
return orderedIds;
}
List<Expression> intersectExprs = new ArrayList<>();
for (Expression e : agg.getGroupByExpressions()) {
if (e instanceof SlotReference) {
SlotReference slot = (SlotReference) e;
if (intersectIdSet.contains(slot.getExprId())) {
intersectExprs.add(e);
}
}
}
if (intersectExprs.isEmpty()) {
return orderedIds;
}
double rowCount = childStats.get().getRowCount();
int instanceNum = ConnectContext.getTotalInstanceNum(context.getConnectContext());
Optional<Expression> best = chooseBestShuffleKeyFromPartitionExpressions(
intersectExprs, childStats.get(), rowCount, instanceNum);
if (best.isPresent()) {
return ImmutableList.of(((SlotReference) best.get()).getExprId());
}
return orderedIds;
}
/**
* Scenario 4: When partition expressions are set by rule, optionally reduce to 1 key (or 2 for 2+2).
* Returns the list of expressions to use as shuffle keys.
*/
public static Optional<Expression> selectBestShuffleKeyForAgg(
PhysicalHashAggregate<? extends Plan> agg, List<Expression> partitionExprs, ConnectContext context) {
if (!canAggShuffleKeyOpt(agg, context)) {
return Optional.empty();
}
Optional<Statistics> childStats = getGlobalAggChildStats(agg);
if (!childStats.isPresent()) {
return Optional.empty();
}
double rowCount = childStats.get().getRowCount();
int instanceNum = ConnectContext.getTotalInstanceNum(context);
return chooseBestShuffleKeyFromPartitionExpressions(
partitionExprs, childStats.get(), rowCount, instanceNum);
}
/**
* Choose best shuffle key (one SlotReference) from partition expressions by score formula.
*/
private static Optional<Expression> chooseBestShuffleKeyFromPartitionExpressions(List<Expression> expressions,
Statistics childStats, double rowCount, int instanceNum) {
Expression bestExpr = null;
double bestScore = Double.NEGATIVE_INFINITY;
final double w1 = 1.0;
final double w2 = 1.0;
final double w3 = 0.5;
for (Expression expr : expressions) {
if (!(expr instanceof SlotReference)) {
continue;
}
SlotReference slotRef = (SlotReference) expr;
ColumnStatistic colStats = childStats.findColumnStatistics(slotRef);
if (colStats == null) {
continue;
}
if (!StatisticsUtil.isBalanced(colStats, rowCount, instanceNum)) {
continue;
}
double skewScore = StatisticsUtil.computeShuffleKeySkewScore(colStats, rowCount, instanceNum);
if (skewScore == Double.NEGATIVE_INFINITY) {
continue;
}
double normalizeNdv = rowCount <= 0 ? 0 : Math.min(1.0, colStats.ndv / rowCount);
double normalizeDataTypeCost = slotRef.getDataType().isNumericType() ? 0.0 : 1.0;
double score = w1 * normalizeNdv + w2 * skewScore - w3 * normalizeDataTypeCost;
if (score > bestScore) {
bestScore = score;
bestExpr = slotRef;
}
}
return Optional.ofNullable(bestExpr);
}
/**
* Get Global AGG plan and its input statistics from a Group (if the group's best plan is Global AGG).
*/
private static Optional<Pair<PhysicalHashAggregate<? extends Plan>, Statistics>> getGlobalAggInputStatsFromGroup(
Group group) {
for (GroupExpression ge : group.getPhysicalExpressions()) {
Plan p = ge.getPlan();
if (p instanceof PhysicalHashAggregate && ((PhysicalHashAggregate<?>) p).getAggPhase().isGlobal()) {
Optional<Statistics> inputStats = getGlobalAggChildStats((PhysicalHashAggregate<? extends Plan>) p);
return inputStats.map(statistics -> Pair.of((PhysicalHashAggregate<? extends Plan>) p, statistics));
}
}
return Optional.empty();
}
/**
* Scenario 3.3: when both join children are Global AGG, find one unified shuffle key from
* join key ��� left_agg.gby ��� right_agg.gby with best combined score. Returns (leftKey, rightKey).
*/
public static Optional<Pair<ExprId, ExprId>> tryFindOptimalShuffleKeyForBothAggChildren(
PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin, PlanContext context) {
if (hashJoin.getHashJoinConjuncts().size()
<= context.getConnectContext().getSessionVariable().shuffleKeyPruneThreshold) {
return Optional.empty();
}
GroupExpression joinGroupExpr = context.getGroupExpression();
if (joinGroupExpr == null) {
return Optional.empty();
}
Group leftGroup = joinGroupExpr.child(0);
Group rightGroup = joinGroupExpr.child(1);
Optional<Pair<PhysicalHashAggregate<? extends Plan>, Statistics>> leftOpt =
getGlobalAggInputStatsFromGroup(leftGroup);
Optional<Pair<PhysicalHashAggregate<? extends Plan>, Statistics>> rightOpt =
getGlobalAggInputStatsFromGroup(rightGroup);
if (!leftOpt.isPresent() || !rightOpt.isPresent()) {
return Optional.empty();
}
PhysicalHashAggregate<? extends Plan> leftAgg = leftOpt.get().first;
PhysicalHashAggregate<? extends Plan> rightAgg = rightOpt.get().first;
Statistics leftStats = leftOpt.get().second;
Statistics rightStats = rightOpt.get().second;
Pair<List<ExprId>, List<ExprId>> joinKeys = hashJoin.getHashConjunctsExprIds();
if (joinKeys.first.isEmpty() || joinKeys.second.size() != joinKeys.first.size()) {
return Optional.empty();
}
Set<ExprId> leftGbyIds = leftAgg.getGroupByExpressions().stream()
.filter(SlotReference.class::isInstance)
.map(SlotReference.class::cast)
.map(SlotReference::getExprId)
.collect(Collectors.toSet());
Set<ExprId> rightGbyIds = rightAgg.getGroupByExpressions().stream()
.filter(SlotReference.class::isInstance)
.map(SlotReference.class::cast)
.map(SlotReference::getExprId)
.collect(Collectors.toSet());
double leftRows = leftStats.getRowCount();
double rightRows = rightStats.getRowCount();
int instanceNum = ConnectContext.getTotalInstanceNum(context.getConnectContext());
ExprId bestLeftKey = null;
ExprId bestRightKey = null;
double bestScore = Double.NEGATIVE_INFINITY;
for (int i = 0; i < joinKeys.first.size(); i++) {
ExprId leftId = joinKeys.first.get(i);
ExprId rightId = joinKeys.second.get(i);
if (!leftGbyIds.contains(leftId) || !rightGbyIds.contains(rightId)) {
continue;
}
SlotReference leftSlotRef = leftAgg.getGroupByExpressions().stream()
.filter(e -> e instanceof SlotReference && ((SlotReference) e).getExprId().equals(leftId))
.map(SlotReference.class::cast)
.findFirst()
.orElse(null);
SlotReference rightSlotRef = rightAgg.getGroupByExpressions().stream()
.filter(e -> e instanceof SlotReference && ((SlotReference) e).getExprId().equals(rightId))
.map(SlotReference.class::cast)
.findFirst()
.orElse(null);
if (leftSlotRef == null || rightSlotRef == null) {
continue;
}
ColumnStatistic leftColStats = leftStats.findColumnStatistics(leftSlotRef);
ColumnStatistic rightColStats = rightStats.findColumnStatistics(rightSlotRef);
if (leftColStats == null || rightColStats == null) {
continue;
}
if (!StatisticsUtil.isBalanced(leftColStats, leftRows, instanceNum)
|| !StatisticsUtil.isBalanced(rightColStats, rightRows, instanceNum)) {
continue;
}
double leftScore = computeShuffleKeyScore(leftColStats, leftRows, instanceNum, leftSlotRef.getDataType());
double rightScore = computeShuffleKeyScore(rightColStats, rightRows, instanceNum,
rightSlotRef.getDataType());
if (leftScore == Double.NEGATIVE_INFINITY || rightScore == Double.NEGATIVE_INFINITY) {
continue;
}
double avgScore = (leftScore + rightScore) / 2.0;
if (avgScore > bestScore) {
bestScore = avgScore;
bestLeftKey = leftId;
bestRightKey = rightId;
}
}
if (bestLeftKey == null || bestRightKey == null) {
return Optional.empty();
}
return Optional.of(Pair.of(bestLeftKey, bestRightKey));
}
/**
* Compute shuffle key score for one column: w1*normalize_ndv + w2*skew_score - w3*data_type_cost.
* Returns NEGATIVE_INFINITY if not balanced or skew too high.
*/
private static double computeShuffleKeyScore(ColumnStatistic colStats, double rowCount, int instanceNum,
DataType dataType) {
final double w1 = 1.0;
final double w2 = 1.0;
final double w3 = 0.5;
if (!StatisticsUtil.isBalanced(colStats, rowCount, instanceNum)) {
return Double.NEGATIVE_INFINITY;
}
double skewScore = StatisticsUtil.computeShuffleKeySkewScore(colStats, rowCount, instanceNum);
if (skewScore == Double.NEGATIVE_INFINITY) {
return Double.NEGATIVE_INFINITY;
}
double normalizeNdv = rowCount <= 0 ? 0 : Math.min(1.0, colStats.ndv / rowCount);
double normalizeDataTypeCost = dataType.isNumericType() ? 0.0 : 1.0;
return w1 * normalizeNdv + w2 * skewScore - w3 * normalizeDataTypeCost;
}
}