ShuffleKeyPruneUtils.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.properties;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.stats.StatsCalculator;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.AggregateUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.util.StatisticsUtil;

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**AggShuffleKeyOptimize*/
public class ShuffleKeyPruneUtils {
    private static GroupExpression getGroupExpression(Group group) {
        List<GroupExpression> physicalGroupExpressions = group.getPhysicalExpressions();
        if (!physicalGroupExpressions.isEmpty()) {
            return physicalGroupExpressions.get(0);
        } else {
            return group.getLogicalExpressions().get(0);
        }
    }

    /*
     * @param agg is a global aggregate
     * @return the Statistics of the children of the local aggregate corresponding to the global aggregate.
     */
    private static Optional<Statistics> getGlobalAggChildStats(PhysicalHashAggregate<? extends Plan> agg) {
        Optional<GroupExpression> groupExpression = agg.getGroupExpression();
        if (!groupExpression.isPresent()) {
            return Optional.empty();
        }
        Statistics aggChildStats = groupExpression.get().childStatistics(0);
        Group childGroup = groupExpression.get().child(0);
        Plan childExpression = getGroupExpression(childGroup).getPlan();
        if (childExpression instanceof PhysicalHashAggregate
                && ((PhysicalHashAggregate) childExpression).getAggPhase().isLocal()) {
            childGroup = childGroup.getPhysicalExpressions().get(0).child(0);
            aggChildStats = childGroup.getStatistics();
        }
        return Optional.ofNullable(aggChildStats);
    }

    private static boolean canAggShuffleKeyOpt(PhysicalHashAggregate<? extends Plan> agg,
            List<? extends Expression> partitionExprs, ConnectContext connectContext) {
        if (!connectContext.getSessionVariable().enableAggShuffleKeyPrune) {
            return false;
        }
        if (agg.hasSourceRepeat()) {
            return false;
        }
        return true;
    }

    /**
     * When parent sends shuffle request, choose one optimal key from intersection of parent hash
     * columns and agg group-by columns, or use full intersection. Returns list of ExprIds as
     * shuffle keys.
     */
    public static List<ExprId> selectOptimalShuffleKeyForAggWithParentHashRequest(
            PhysicalHashAggregate<? extends Plan> agg, Set<ExprId> intersectIdSet, PlanContext context) {
        List<ExprId> orderedIds = Utils.fastToImmutableList(intersectIdSet);
        if (!context.getConnectContext().getSessionVariable().enableAggShuffleKeyPrune) {
            return orderedIds;
        }
        Optional<Statistics> childStats = getGlobalAggChildStats(agg);
        if (!childStats.isPresent()) {
            return orderedIds;
        }
        List<Expression> intersectExprs = new ArrayList<>();
        for (Expression e : agg.getGroupByExpressions()) {
            if (e instanceof SlotReference) {
                SlotReference slot = (SlotReference) e;
                if (intersectIdSet.contains(slot.getExprId())) {
                    intersectExprs.add(e);
                }
            }
        }
        if (intersectExprs.isEmpty()) {
            return orderedIds;
        }
        double rowCount = childStats.get().getRowCount();
        int instanceNum = ConnectContext.getTotalInstanceNum(context.getConnectContext());
        Optional<List<Expression>> optimalKeys = selectOptimalShuffleKeys(
                intersectExprs, childStats.get(), rowCount, instanceNum);
        if (optimalKeys.isPresent()) {
            return optimalKeys.get().stream()
                    .filter(SlotReference.class::isInstance)
                    .map(SlotReference.class::cast)
                    .map(SlotReference::getExprId)
                    .collect(Collectors.toList());
        }
        return orderedIds;
    }

    /**
     * Scenario 4: When partition expressions are set by rule, optionally reduce shuffle keys.
     * Strategy: 1) Try single key (isBalanced); 2) Try numeric+date keys (remove strings);
     * 3) Fall back to full partitionExprs.
     * Returns the list of expressions to use as shuffle keys, or empty to use full partitionExprs.
     */
    public static Optional<List<Expression>> selectBestShuffleKeyForAgg(
            PhysicalHashAggregate<? extends Plan> agg, List<Expression> partitionExprs, ConnectContext context) {
        if (!canAggShuffleKeyOpt(agg, partitionExprs, context)) {
            return Optional.empty();
        }
        Optional<Statistics> childStats = getGlobalAggChildStats(agg);
        if (!childStats.isPresent()) {
            return Optional.empty();
        }
        double rowCount = childStats.get().getRowCount();
        int instanceNum = ConnectContext.getTotalInstanceNum(context);
        return selectOptimalShuffleKeys(partitionExprs, childStats.get(), rowCount, instanceNum);
    }

    /**
     * Select optimal shuffle keys with three-step strategy:
     * 1. Try single key: sort by type (numeric/date first, string sorted by avg_size), pick first isBalanced key.
     * 2. Try remove strings: filter numeric+date keys, if combinedNDV > instanceNum*512 return that list.
     * 3. Fall back: return empty (caller uses full partitionExprs).
     */
    private static Optional<List<Expression>> selectOptimalShuffleKeys(List<Expression> partitionExprs,
            Statistics childStats, double rowCount, int instanceNum) {
        List<SlotReference> slotRefs = partitionExprs.stream()
                .filter(SlotReference.class::isInstance)
                .map(SlotReference.class::cast)
                .collect(Collectors.toList());
        if (slotRefs.isEmpty()) {
            return Optional.empty();
        }
        // If any partition slot lacks column stats, skip optimization and use original partitionExprs.
        for (SlotReference slotRef : slotRefs) {
            if (childStats.findColumnStatistics(slotRef) == null) {
                return Optional.empty();
            }
        }

        // Step 1: Try single key - sort by type priority, pick first isBalanced
        List<SlotReference> sortedByType = sortShuffleKeysByTypePriority(slotRefs, childStats);
        for (SlotReference slotRef : sortedByType) {
            ColumnStatistic colStats = childStats.findColumnStatistics(slotRef);
            if (StatisticsUtil.isBalanced(colStats, rowCount, instanceNum)) {
                return Optional.of(ImmutableList.of(slotRef));
            }
        }

        // Step 2: Try remove string types - filter numeric+date, check combined NDV
        List<Expression> numericAndDateExprs = slotRefs.stream()
                .filter(s -> s.getDataType().isNumericType() || s.getDataType().isDateLikeType())
                .collect(Collectors.toList());
        if (!numericAndDateExprs.isEmpty()) {
            double combinedNdv = StatsCalculator.estimateGroupByRowCount(numericAndDateExprs, childStats);
            long ndvThreshold = (long) instanceNum * AggregateUtils.NDV_INSTANCE_BALANCE_MULTIPLIER;
            if (combinedNdv > ndvThreshold) {
                return Optional.of(ImmutableList.copyOf(numericAndDateExprs));
            }
        }

        // Step 3: Fall back - return empty, caller uses full partitionExprs
        return Optional.empty();
    }

    /**
     * Sort shuffle keys: numeric and date first, then string types.
     * String types are sorted by column statistics avg size (avgSizeByte) ascending.
     */
    private static List<SlotReference> sortShuffleKeysByTypePriority(List<SlotReference> slotRefs,
            Statistics childStats) {
        List<SlotReference> result = new ArrayList<>(slotRefs);
        result.sort(Comparator
                .comparingInt((SlotReference s) -> getTypeSortPriority(s.getDataType()))
                .thenComparingDouble((SlotReference s) -> getStringAvgSizeForSort(s, childStats)));
        return result;
    }

    /** 0=numeric/date first, 1=string last. */
    private static int getTypeSortPriority(DataType dataType) {
        if (dataType.isNumericType() || dataType.isDateLikeType()) {
            return 0;
        }
        return 1;
    }

    /** For string types return avg size from stats; for others return 0 (no secondary sort). */
    private static double getStringAvgSizeForSort(SlotReference slotRef, Statistics childStats) {
        DataType dataType = slotRef.getDataType();
        if (dataType instanceof CharacterType) {
            ColumnStatistic colStats = childStats.findColumnStatistics(slotRef);
            if (colStats != null && !colStats.isUnKnown && colStats.avgSizeByte > 0) {
                return colStats.avgSizeByte;
            }
            return ((CharacterType) dataType).getLen();
        }
        return 0;
    }

    /**
     * Get Global AGG plan and its input statistics from a Group (if the group's best plan is Global AGG).
     */
    private static Optional<Pair<PhysicalHashAggregate<? extends Plan>, Statistics>> getGlobalAggInputStatsFromGroup(
            Group group) {
        for (GroupExpression ge : group.getPhysicalExpressions()) {
            Plan p = ge.getPlan();
            if (p instanceof PhysicalHashAggregate && ((PhysicalHashAggregate<?>) p).getAggPhase().isGlobal()) {
                Optional<Statistics> inputStats = getGlobalAggChildStats((PhysicalHashAggregate<? extends Plan>) p);
                return inputStats.map(statistics -> Pair.of((PhysicalHashAggregate<? extends Plan>) p, statistics));
            }
        }
        return Optional.empty();
    }

    /**
     * Scenario 3.3: when both join children are Global AGG, find optimal shuffle keys from
     * join key ��� left_agg.gby ��� right_agg.gby. Same three-step strategy as agg:
     * 1) Try single key (isBalanced); 2) Try numeric+date keys (remove strings);
     * 3) Fall back. Returns (leftKeys, rightKeys) or empty.
     */
    public static Optional<Pair<List<ExprId>, List<ExprId>>> tryFindOptimalShuffleKeyForBothAggChildren(
            PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin, PlanContext context) {
        GroupExpression joinGroupExpr = context.getGroupExpression();
        if (joinGroupExpr == null) {
            return Optional.empty();
        }
        Group leftGroup = joinGroupExpr.child(0);
        Group rightGroup = joinGroupExpr.child(1);
        Optional<Pair<PhysicalHashAggregate<? extends Plan>, Statistics>> leftOpt =
                getGlobalAggInputStatsFromGroup(leftGroup);
        Optional<Pair<PhysicalHashAggregate<? extends Plan>, Statistics>> rightOpt =
                getGlobalAggInputStatsFromGroup(rightGroup);
        if (!leftOpt.isPresent() || !rightOpt.isPresent()) {
            return Optional.empty();
        }

        PhysicalHashAggregate<? extends Plan> leftAgg = leftOpt.get().first;
        PhysicalHashAggregate<? extends Plan> rightAgg = rightOpt.get().first;
        if (leftAgg.hasSourceRepeat() || rightAgg.hasSourceRepeat()) {
            return Optional.empty();
        }
        Statistics leftStats = leftOpt.get().second;
        Statistics rightStats = rightOpt.get().second;

        Pair<List<ExprId>, List<ExprId>> joinKeys = hashJoin.getHashConjunctsExprIds();
        if (joinKeys.first.isEmpty() || joinKeys.second.size() != joinKeys.first.size()) {
            return Optional.empty();
        }

        Set<ExprId> leftGbyIds = leftAgg.getGroupByExpressions().stream()
                .filter(SlotReference.class::isInstance)
                .map(SlotReference.class::cast)
                .map(SlotReference::getExprId)
                .collect(Collectors.toSet());
        Set<ExprId> rightGbyIds = rightAgg.getGroupByExpressions().stream()
                .filter(SlotReference.class::isInstance)
                .map(SlotReference.class::cast)
                .map(SlotReference::getExprId)
                .collect(Collectors.toSet());

        double leftRows = leftStats.getRowCount();
        double rightRows = rightStats.getRowCount();
        int instanceNum = ConnectContext.getTotalInstanceNum(context.getConnectContext());

        // Build (leftSlotRef, rightSlotRef) pairs for join keys in both gby sets
        List<Pair<SlotReference, SlotReference>> validPairs = new ArrayList<>();
        for (int i = 0; i < joinKeys.first.size(); i++) {
            ExprId leftId = joinKeys.first.get(i);
            ExprId rightId = joinKeys.second.get(i);
            if (!leftGbyIds.contains(leftId) || !rightGbyIds.contains(rightId)) {
                continue;
            }
            SlotReference leftSlotRef = leftAgg.getGroupByExpressions().stream()
                    .filter(e -> e instanceof SlotReference && ((SlotReference) e).getExprId().equals(leftId))
                    .map(SlotReference.class::cast)
                    .findFirst()
                    .orElse(null);
            SlotReference rightSlotRef = rightAgg.getGroupByExpressions().stream()
                    .filter(e -> e instanceof SlotReference && ((SlotReference) e).getExprId().equals(rightId))
                    .map(SlotReference.class::cast)
                    .findFirst()
                    .orElse(null);
            if (leftSlotRef != null && rightSlotRef != null) {
                validPairs.add(Pair.of(leftSlotRef, rightSlotRef));
            }
        }
        if (validPairs.isEmpty()) {
            return Optional.empty();
        }
        // If any join key pair lacks column stats on either side, skip optimization.
        for (Pair<SlotReference, SlotReference> pair : validPairs) {
            if (leftStats.findColumnStatistics(pair.first) == null
                    || rightStats.findColumnStatistics(pair.second) == null) {
                return Optional.empty();
            }
        }

        // Step 1: Try single key - sort by type, pick first where both isBalanced
        List<Pair<SlotReference, SlotReference>> sortedPairs =
                sortJoinKeyPairsByTypePriority(validPairs, leftStats, rightStats);
        for (Pair<SlotReference, SlotReference> pair : sortedPairs) {
            SlotReference leftSlotRef = pair.first;
            SlotReference rightSlotRef = pair.second;
            ColumnStatistic leftColStats = leftStats.findColumnStatistics(leftSlotRef);
            ColumnStatistic rightColStats = rightStats.findColumnStatistics(rightSlotRef);
            if (StatisticsUtil.isBalanced(leftColStats, leftRows, instanceNum)
                    && StatisticsUtil.isBalanced(rightColStats, rightRows, instanceNum)) {
                return Optional.of(Pair.of(
                        ImmutableList.of(leftSlotRef.getExprId()),
                        ImmutableList.of(rightSlotRef.getExprId())));
            }
        }

        // Step 2: Try remove string types - filter numeric+date pairs, check combined NDV
        List<SlotReference> numericDateLeftSlots = new ArrayList<>();
        List<SlotReference> numericDateRightSlots = new ArrayList<>();
        for (Pair<SlotReference, SlotReference> pair : validPairs) {
            if ((pair.first.getDataType().isNumericType() || pair.first.getDataType().isDateLikeType())
                    && (pair.second.getDataType().isNumericType() || pair.second.getDataType().isDateLikeType())) {
                numericDateLeftSlots.add(pair.first);
                numericDateRightSlots.add(pair.second);
            }
        }
        if (!numericDateLeftSlots.isEmpty()) {
            double leftCombinedNdv = StatsCalculator.estimateGroupByRowCount(
                    new ArrayList<>(numericDateLeftSlots), leftStats);
            double rightCombinedNdv = StatsCalculator.estimateGroupByRowCount(
                    new ArrayList<>(numericDateRightSlots), rightStats);
            long ndvThreshold = (long) instanceNum * AggregateUtils.NDV_INSTANCE_BALANCE_MULTIPLIER;
            if (leftCombinedNdv > ndvThreshold && rightCombinedNdv > ndvThreshold) {
                List<ExprId> leftIds = numericDateLeftSlots.stream()
                        .map(SlotReference::getExprId)
                        .collect(Collectors.toList());
                List<ExprId> rightIds = numericDateRightSlots.stream()
                        .map(SlotReference::getExprId)
                        .collect(Collectors.toList());
                return Optional.of(Pair.of(leftIds, rightIds));
            }
        }

        // Step 3: Fall back
        return Optional.empty();
    }

    /** Sort join key pairs by type priority (numeric/date first, string by avg_size). */
    private static List<Pair<SlotReference, SlotReference>> sortJoinKeyPairsByTypePriority(
            List<Pair<SlotReference, SlotReference>> pairs, Statistics leftStats, Statistics rightStats) {
        List<Pair<SlotReference, SlotReference>> result = new ArrayList<>(pairs);
        result.sort(Comparator
                .comparingInt((Pair<SlotReference, SlotReference> p) ->
                        getTypeSortPriority(p.first.getDataType()))
                .thenComparingDouble((Pair<SlotReference, SlotReference> p) ->
                        getJoinPairStringAvgSizeForSort(p, leftStats, rightStats)));
        return result;
    }

    /** For string join-key pairs, use avg size of both sides for sorting; for others return 0. */
    private static double getJoinPairStringAvgSizeForSort(Pair<SlotReference, SlotReference> pair,
            Statistics leftStats, Statistics rightStats) {
        if (pair.first.getDataType() instanceof CharacterType && pair.second.getDataType() instanceof CharacterType) {
            return (getStringAvgSizeForSort(pair.first, leftStats) + getStringAvgSizeForSort(pair.second, rightStats));
        }
        return 0;
    }
}