ShuffleKeyPruneUtils.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.properties;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.util.StatisticsUtil;

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**AggShuffleKeyOptimize*/
public class ShuffleKeyPruneUtils {
    private static GroupExpression getGroupExpression(Group group) {
        List<GroupExpression> physicalGroupExpressions = group.getPhysicalExpressions();
        if (!physicalGroupExpressions.isEmpty()) {
            return physicalGroupExpressions.get(0);
        } else {
            return group.getLogicalExpressions().get(0);
        }
    }

    /*
     * @param agg is a global aggregate
     * @return the Statistics of the children of the local aggregate corresponding to the global aggregate.
     */
    private static Optional<Statistics> getGlobalAggChildStats(PhysicalHashAggregate<? extends Plan> agg) {
        Optional<GroupExpression> groupExpression = agg.getGroupExpression();
        if (!groupExpression.isPresent()) {
            return Optional.empty();
        }
        Statistics aggChildStats = groupExpression.get().childStatistics(0);
        Group childGroup = groupExpression.get().child(0);
        Plan childExpression = getGroupExpression(childGroup).getPlan();
        if (childExpression instanceof PhysicalHashAggregate
                && ((PhysicalHashAggregate) childExpression).getAggPhase().isLocal()) {
            childGroup = childGroup.getPhysicalExpressions().get(0).child(0);
            aggChildStats = childGroup.getStatistics();
        }
        return Optional.ofNullable(aggChildStats);
    }

    private static boolean canAggShuffleKeyOpt(PhysicalHashAggregate<? extends Plan> agg,
            ConnectContext connectContext) {
        if (!connectContext.getSessionVariable().chooseOneAggShuffleKey) {
            return false;
        }
        if (agg.getGroupByExpressions().size() <= connectContext.getSessionVariable().shuffleKeyPruneThreshold) {
            return false;
        }
        if (agg.hasSourceRepeat()) {
            return false;
        }
        return true;
    }

    /**
     * When parent sends shuffle request, choose one optimal key from intersection of parent hash
     * columns and agg group-by columns, or use full intersection. Returns list of ExprIds as
     * shuffle keys.
     */
    public static List<ExprId> selectOptimalShuffleKeyForAggWithParentHashRequest(
            PhysicalHashAggregate<? extends Plan> agg, Set<ExprId> intersectIdSet, PlanContext context) {
        List<ExprId> orderedIds = Utils.fastToImmutableList(intersectIdSet);
        if (!context.getConnectContext().getSessionVariable().chooseOneAggShuffleKey
                || intersectIdSet.size() <= context.getConnectContext().getSessionVariable().shuffleKeyPruneThreshold) {
            return orderedIds;
        }
        Optional<Statistics> childStats = getGlobalAggChildStats(agg);
        if (!childStats.isPresent()) {
            return orderedIds;
        }
        List<Expression> intersectExprs = new ArrayList<>();
        for (Expression e : agg.getGroupByExpressions()) {
            if (e instanceof SlotReference) {
                SlotReference slot = (SlotReference) e;
                if (intersectIdSet.contains(slot.getExprId())) {
                    intersectExprs.add(e);
                }
            }
        }
        if (intersectExprs.isEmpty()) {
            return orderedIds;
        }
        double rowCount = childStats.get().getRowCount();
        int instanceNum = ConnectContext.getTotalInstanceNum(context.getConnectContext());
        Optional<Expression> best = chooseBestShuffleKeyFromPartitionExpressions(
                intersectExprs, childStats.get(), rowCount, instanceNum);
        if (best.isPresent()) {
            return ImmutableList.of(((SlotReference) best.get()).getExprId());
        }
        return orderedIds;
    }

    /**
     * Scenario 4: When partition expressions are set by rule, optionally reduce to 1 key (or 2 for 2+2).
     * Returns the list of expressions to use as shuffle keys.
     */
    public static Optional<Expression> selectBestShuffleKeyForAgg(
            PhysicalHashAggregate<? extends Plan> agg, List<Expression> partitionExprs, ConnectContext context) {
        if (!canAggShuffleKeyOpt(agg, context)) {
            return Optional.empty();
        }
        Optional<Statistics> childStats = getGlobalAggChildStats(agg);
        if (!childStats.isPresent()) {
            return Optional.empty();
        }
        double rowCount = childStats.get().getRowCount();
        int instanceNum = ConnectContext.getTotalInstanceNum(context);
        return chooseBestShuffleKeyFromPartitionExpressions(
                partitionExprs, childStats.get(), rowCount, instanceNum);
    }

    /**
     * Choose best shuffle key (one SlotReference) from partition expressions by score formula.
     */
    private static Optional<Expression> chooseBestShuffleKeyFromPartitionExpressions(List<Expression> expressions,
            Statistics childStats, double rowCount, int instanceNum) {
        Expression bestExpr = null;
        double bestScore = Double.NEGATIVE_INFINITY;
        final double w1 = 1.0;
        final double w2 = 1.0;
        final double w3 = 0.5;
        for (Expression expr : expressions) {
            if (!(expr instanceof SlotReference)) {
                continue;
            }
            SlotReference slotRef = (SlotReference) expr;
            ColumnStatistic colStats = childStats.findColumnStatistics(slotRef);
            if (colStats == null) {
                continue;
            }
            if (!StatisticsUtil.isBalanced(colStats, rowCount, instanceNum)) {
                continue;
            }
            double skewScore = StatisticsUtil.computeShuffleKeySkewScore(colStats, rowCount, instanceNum);
            if (skewScore == Double.NEGATIVE_INFINITY) {
                continue;
            }
            double normalizeNdv = rowCount <= 0 ? 0 : Math.min(1.0, colStats.ndv / rowCount);
            double normalizeDataTypeCost = slotRef.getDataType().isNumericType() ? 0.0 : 1.0;
            double score = w1 * normalizeNdv + w2 * skewScore - w3 * normalizeDataTypeCost;
            if (score > bestScore) {
                bestScore = score;
                bestExpr = slotRef;
            }
        }
        return Optional.ofNullable(bestExpr);
    }

    /**
     * Get Global AGG plan and its input statistics from a Group (if the group's best plan is Global AGG).
     */
    private static Optional<Pair<PhysicalHashAggregate<? extends Plan>, Statistics>> getGlobalAggInputStatsFromGroup(
            Group group) {
        for (GroupExpression ge : group.getPhysicalExpressions()) {
            Plan p = ge.getPlan();
            if (p instanceof PhysicalHashAggregate && ((PhysicalHashAggregate<?>) p).getAggPhase().isGlobal()) {
                Optional<Statistics> inputStats = getGlobalAggChildStats((PhysicalHashAggregate<? extends Plan>) p);
                return inputStats.map(statistics -> Pair.of((PhysicalHashAggregate<? extends Plan>) p, statistics));
            }
        }
        return Optional.empty();
    }

    /**
     * Scenario 3.3: when both join children are Global AGG, find one unified shuffle key from
     * join key ��� left_agg.gby ��� right_agg.gby with best combined score. Returns (leftKey, rightKey).
     */
    public static Optional<Pair<ExprId, ExprId>> tryFindOptimalShuffleKeyForBothAggChildren(
            PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin, PlanContext context) {
        if (hashJoin.getHashJoinConjuncts().size()
                <= context.getConnectContext().getSessionVariable().shuffleKeyPruneThreshold) {
            return Optional.empty();
        }
        GroupExpression joinGroupExpr = context.getGroupExpression();
        if (joinGroupExpr == null) {
            return Optional.empty();
        }
        Group leftGroup = joinGroupExpr.child(0);
        Group rightGroup = joinGroupExpr.child(1);
        Optional<Pair<PhysicalHashAggregate<? extends Plan>, Statistics>> leftOpt =
                getGlobalAggInputStatsFromGroup(leftGroup);
        Optional<Pair<PhysicalHashAggregate<? extends Plan>, Statistics>> rightOpt =
                getGlobalAggInputStatsFromGroup(rightGroup);
        if (!leftOpt.isPresent() || !rightOpt.isPresent()) {
            return Optional.empty();
        }

        PhysicalHashAggregate<? extends Plan> leftAgg = leftOpt.get().first;
        PhysicalHashAggregate<? extends Plan> rightAgg = rightOpt.get().first;
        Statistics leftStats = leftOpt.get().second;
        Statistics rightStats = rightOpt.get().second;

        Pair<List<ExprId>, List<ExprId>> joinKeys = hashJoin.getHashConjunctsExprIds();
        if (joinKeys.first.isEmpty() || joinKeys.second.size() != joinKeys.first.size()) {
            return Optional.empty();
        }

        Set<ExprId> leftGbyIds = leftAgg.getGroupByExpressions().stream()
                .filter(SlotReference.class::isInstance)
                .map(SlotReference.class::cast)
                .map(SlotReference::getExprId)
                .collect(Collectors.toSet());
        Set<ExprId> rightGbyIds = rightAgg.getGroupByExpressions().stream()
                .filter(SlotReference.class::isInstance)
                .map(SlotReference.class::cast)
                .map(SlotReference::getExprId)
                .collect(Collectors.toSet());

        double leftRows = leftStats.getRowCount();
        double rightRows = rightStats.getRowCount();
        int instanceNum = ConnectContext.getTotalInstanceNum(context.getConnectContext());

        ExprId bestLeftKey = null;
        ExprId bestRightKey = null;
        double bestScore = Double.NEGATIVE_INFINITY;

        for (int i = 0; i < joinKeys.first.size(); i++) {
            ExprId leftId = joinKeys.first.get(i);
            ExprId rightId = joinKeys.second.get(i);
            if (!leftGbyIds.contains(leftId) || !rightGbyIds.contains(rightId)) {
                continue;
            }
            SlotReference leftSlotRef = leftAgg.getGroupByExpressions().stream()
                    .filter(e -> e instanceof SlotReference && ((SlotReference) e).getExprId().equals(leftId))
                    .map(SlotReference.class::cast)
                    .findFirst()
                    .orElse(null);
            SlotReference rightSlotRef = rightAgg.getGroupByExpressions().stream()
                    .filter(e -> e instanceof SlotReference && ((SlotReference) e).getExprId().equals(rightId))
                    .map(SlotReference.class::cast)
                    .findFirst()
                    .orElse(null);
            if (leftSlotRef == null || rightSlotRef == null) {
                continue;
            }
            ColumnStatistic leftColStats = leftStats.findColumnStatistics(leftSlotRef);
            ColumnStatistic rightColStats = rightStats.findColumnStatistics(rightSlotRef);
            if (leftColStats == null || rightColStats == null) {
                continue;
            }
            if (!StatisticsUtil.isBalanced(leftColStats, leftRows, instanceNum)
                    || !StatisticsUtil.isBalanced(rightColStats, rightRows, instanceNum)) {
                continue;
            }
            double leftScore = computeShuffleKeyScore(leftColStats, leftRows, instanceNum, leftSlotRef.getDataType());
            double rightScore = computeShuffleKeyScore(rightColStats, rightRows, instanceNum,
                    rightSlotRef.getDataType());
            if (leftScore == Double.NEGATIVE_INFINITY || rightScore == Double.NEGATIVE_INFINITY) {
                continue;
            }
            double avgScore = (leftScore + rightScore) / 2.0;
            if (avgScore > bestScore) {
                bestScore = avgScore;
                bestLeftKey = leftId;
                bestRightKey = rightId;
            }
        }
        if (bestLeftKey == null || bestRightKey == null) {
            return Optional.empty();
        }
        return Optional.of(Pair.of(bestLeftKey, bestRightKey));
    }

    /**
     * Compute shuffle key score for one column: w1*normalize_ndv + w2*skew_score - w3*data_type_cost.
     * Returns NEGATIVE_INFINITY if not balanced or skew too high.
     */
    private static double computeShuffleKeyScore(ColumnStatistic colStats, double rowCount, int instanceNum,
            DataType dataType) {
        final double w1 = 1.0;
        final double w2 = 1.0;
        final double w3 = 0.5;
        if (!StatisticsUtil.isBalanced(colStats, rowCount, instanceNum)) {
            return Double.NEGATIVE_INFINITY;
        }
        double skewScore = StatisticsUtil.computeShuffleKeySkewScore(colStats, rowCount, instanceNum);
        if (skewScore == Double.NEGATIVE_INFINITY) {
            return Double.NEGATIVE_INFINITY;
        }
        double normalizeNdv = rowCount <= 0 ? 0 : Math.min(1.0, colStats.ndv / rowCount);
        double normalizeDataTypeCost = dataType.isNumericType() ? 0.0 : 1.0;
        return w1 * normalizeNdv + w2 * skewScore - w3 * normalizeDataTypeCost;
    }
}