ShuffleKeyPruner.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.processor.post;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.ShuffleKeyPruneUtils;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalBlackholeSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeResultSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDictionarySink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHiveTableSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalIcebergTableSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcTableSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
import org.apache.doris.nereids.trees.plans.physical.PhysicalMaxComputeTableSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapTableSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRecursiveUnion;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRepeat;
import org.apache.doris.nereids.trees.plans.physical.PhysicalResultSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSetOperation;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTVFTableSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnary;
import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.statistics.Statistics;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* Post-process shuffle key pruning on final physical plan. {@link PruneCtx#allowShuffleKeyPrune} marks
* subtrees where shuffle key pruning is allowed; when false, pruning is skipped (e.g. join/setop alignment,
* sinks).
*/
public class ShuffleKeyPruner extends PlanPostProcessor {
@Override
public Plan processRoot(Plan plan, CascadesContext ctx) {
if (!ctx.getConnectContext().getSessionVariable().enableAggShuffleKeyPrune) {
return plan;
}
return plan.accept(new ShuffleKeyPruneRewriter(), new PruneCtx(true, ctx));
}
private static final class PruneCtx {
/** When true, shuffle key pruning may run on descendants; when false, skip pruning. */
final boolean allowShuffleKeyPrune;
final CascadesContext cascadesContext;
private PruneCtx(boolean allowShuffleKeyPrune, CascadesContext cascadesContext) {
this.allowShuffleKeyPrune = allowShuffleKeyPrune;
this.cascadesContext = cascadesContext;
}
PruneCtx withAllowShuffleKeyPrune(boolean allowShuffleKeyPrune) {
return new PruneCtx(allowShuffleKeyPrune, cascadesContext);
}
}
private static class ShuffleKeyPruneRewriter extends DefaultPlanRewriter<PruneCtx> {
@Override
public Plan visitPhysicalDistribute(PhysicalDistribute<? extends Plan> distribute, PruneCtx ctx) {
return rewriteUnary(distribute, ctx.withAllowShuffleKeyPrune(true));
}
@Override
public Plan visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? extends Plan> join, PruneCtx ctx) {
Pair<Boolean, Boolean> lr = deriveHashJoinChildAllowShuffleKeyPrune(join, ctx.allowShuffleKeyPrune);
PhysicalHashJoin<? extends Plan, ? extends Plan> current = rewriteChildren(join,
ImmutableList.of(ctx.withAllowShuffleKeyPrune(lr.first),
ctx.withAllowShuffleKeyPrune(lr.second)));
if (ctx.allowShuffleKeyPrune) {
return maybePruneShuffleJoin(current, ctx.cascadesContext).orElse(current);
}
return current;
}
@Override
public Plan visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg, PruneCtx ctx) {
PhysicalHashAggregate<? extends Plan> current;
if (agg.getAggPhase().isLocal()) {
current = rewriteUnary(agg, ctx.withAllowShuffleKeyPrune(true));
} else {
current = rewriteUnary(agg, ctx);
}
if (ctx.allowShuffleKeyPrune && agg.getAggPhase().isGlobal()) {
return tryPruneGlobalAgg(current, ctx.cascadesContext);
}
return current;
}
@Override
public Plan visitPhysicalCTEAnchor(PhysicalCTEAnchor<? extends Plan, ? extends Plan> anchor, PruneCtx ctx) {
return rewriteChildren(anchor, ImmutableList.of(ctx.withAllowShuffleKeyPrune(true), ctx));
}
@Override
public Plan visitPhysicalWindow(PhysicalWindow<? extends Plan> window, PruneCtx ctx) {
return rewriteUnary(window, ctx);
}
@Override
public Plan visitPhysicalSetOperation(PhysicalSetOperation setOperation, PruneCtx ctx) {
return visitChildren(this, setOperation, ctx.withAllowShuffleKeyPrune(false));
}
@Override
public Plan visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows, PruneCtx ctx) {
return rewriteUnary(assertNumRows, ctx.withAllowShuffleKeyPrune(false));
}
@Override
public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, PruneCtx ctx) {
boolean childAllowShuffleKeyPrune = !limit.isGlobal();
return rewriteUnary(limit, ctx.withAllowShuffleKeyPrune(childAllowShuffleKeyPrune));
}
@Override
public Plan visitAbstractPhysicalSort(AbstractPhysicalSort<? extends Plan> sort, PruneCtx ctx) {
boolean childAllowShuffleKeyPrune = sort.getSortPhase().isLocal();
return rewriteUnary(sort, ctx.withAllowShuffleKeyPrune(childAllowShuffleKeyPrune));
}
@Override
public Plan visitPhysicalNestedLoopJoin(
PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin, PruneCtx ctx) {
boolean leftAllowShuffleKeyPrune;
if (nestedLoopJoin.getJoinType().isCrossJoin() || nestedLoopJoin.getJoinType().isInnerJoin()
|| nestedLoopJoin.getJoinType().isLeftJoin()) {
leftAllowShuffleKeyPrune = true;
} else {
leftAllowShuffleKeyPrune = false;
}
return rewriteChildren(nestedLoopJoin,
ImmutableList.of(ctx.withAllowShuffleKeyPrune(leftAllowShuffleKeyPrune),
ctx.withAllowShuffleKeyPrune(false)));
}
@Override
public Plan visitPhysicalPartitionTopN(PhysicalPartitionTopN<? extends Plan> partitionTopN, PruneCtx ctx) {
if (partitionTopN.getPhase().isTwoPhaseLocal()) {
return rewriteUnary(partitionTopN, ctx.withAllowShuffleKeyPrune(true));
} else {
return rewriteUnary(partitionTopN, ctx);
}
}
@Override
public Plan visitPhysicalRecursiveUnion(
PhysicalRecursiveUnion<? extends Plan, ? extends Plan> recursiveUnion, PruneCtx ctx) {
return rewriteChildren(recursiveUnion, ImmutableList.of(ctx.withAllowShuffleKeyPrune(false),
ctx.withAllowShuffleKeyPrune(false)));
}
@Override
public Plan visitPhysicalCTEProducer(PhysicalCTEProducer<? extends Plan> producer, PruneCtx ctx) {
return rewriteUnary(producer, ctx.withAllowShuffleKeyPrune(true));
}
@Override
public Plan visitPhysicalGenerate(PhysicalGenerate<? extends Plan> generate, PruneCtx ctx) {
return rewriteUnary(generate, ctx.withAllowShuffleKeyPrune(true));
}
@Override
public Plan visitPhysicalBlackholeSink(PhysicalBlackholeSink<? extends Plan> sink, PruneCtx ctx) {
return rewriteUnary(sink, ctx.withAllowShuffleKeyPrune(true));
}
@Override
public Plan visitPhysicalRepeat(PhysicalRepeat<? extends Plan> repeat, PruneCtx ctx) {
return rewriteUnary(repeat, ctx.withAllowShuffleKeyPrune(true));
}
@Override
public Plan visitPhysicalOlapTableSink(PhysicalOlapTableSink<? extends Plan> sink, PruneCtx ctx) {
boolean childAllowShuffleKeyPrune;
if (ctx.cascadesContext.getConnectContext() != null
&& !ctx.cascadesContext.getConnectContext().getSessionVariable().enableStrictConsistencyDml) {
childAllowShuffleKeyPrune = true;
} else {
childAllowShuffleKeyPrune = sink.getRequirePhysicalProperties().equals(PhysicalProperties.ANY);
}
return rewriteUnary(sink, ctx.withAllowShuffleKeyPrune(childAllowShuffleKeyPrune));
}
@Override
public Plan visitPhysicalResultSink(PhysicalResultSink<? extends Plan> sink, PruneCtx ctx) {
boolean childAllowShuffleKeyPrune = false;
if (ctx.cascadesContext.getConnectContext() != null
&& ctx.cascadesContext.getConnectContext().getSessionVariable().enableParallelResultSink()
&& !ctx.cascadesContext.getStatementContext().isShortCircuitQuery()) {
childAllowShuffleKeyPrune = true;
}
return rewriteUnary(sink, ctx.withAllowShuffleKeyPrune(childAllowShuffleKeyPrune));
}
@Override
public Plan visitPhysicalHiveTableSink(PhysicalHiveTableSink<? extends Plan> hiveTableSink, PruneCtx ctx) {
boolean childAllowShuffleKeyPrune;
if (ctx.cascadesContext.getConnectContext() != null
&& !ctx.cascadesContext.getConnectContext().getSessionVariable().enableStrictConsistencyDml) {
childAllowShuffleKeyPrune = true;
} else {
childAllowShuffleKeyPrune =
hiveTableSink.getRequirePhysicalProperties().equals(PhysicalProperties.ANY);
}
return rewriteUnary(hiveTableSink, ctx.withAllowShuffleKeyPrune(childAllowShuffleKeyPrune));
}
@Override
public Plan visitPhysicalIcebergTableSink(
PhysicalIcebergTableSink<? extends Plan> icebergTableSink, PruneCtx ctx) {
boolean childAllowShuffleKeyPrune;
if (ctx.cascadesContext.getConnectContext() != null
&& !ctx.cascadesContext.getConnectContext().getSessionVariable().enableStrictConsistencyDml) {
childAllowShuffleKeyPrune = true;
} else {
childAllowShuffleKeyPrune =
icebergTableSink.getRequirePhysicalProperties().equals(PhysicalProperties.ANY);
}
return rewriteUnary(icebergTableSink, ctx.withAllowShuffleKeyPrune(childAllowShuffleKeyPrune));
}
@Override
public Plan visitPhysicalMaxComputeTableSink(
PhysicalMaxComputeTableSink<? extends Plan> mcTableSink, PruneCtx ctx) {
boolean childAllowShuffleKeyPrune;
if (ctx.cascadesContext.getConnectContext() != null
&& !ctx.cascadesContext.getConnectContext().getSessionVariable().enableStrictConsistencyDml) {
childAllowShuffleKeyPrune = true;
} else {
childAllowShuffleKeyPrune = mcTableSink.getRequirePhysicalProperties().equals(
PhysicalProperties.ANY);
}
return rewriteUnary(mcTableSink, ctx.withAllowShuffleKeyPrune(childAllowShuffleKeyPrune));
}
@Override
public Plan visitPhysicalJdbcTableSink(
PhysicalJdbcTableSink<? extends Plan> jdbcTableSink, PruneCtx ctx) {
return rewriteUnary(jdbcTableSink, ctx.withAllowShuffleKeyPrune(false));
}
@Override
public Plan visitPhysicalTVFTableSink(
PhysicalTVFTableSink<? extends Plan> tvfTableSink, PruneCtx ctx) {
return rewriteUnary(tvfTableSink, ctx.withAllowShuffleKeyPrune(false));
}
@Override
public Plan visitPhysicalDictionarySink(PhysicalDictionarySink<? extends Plan> dictionarySink,
PruneCtx ctx) {
boolean childAllowShuffleKeyPrune = dictionarySink.getRequirePhysicalProperties()
.equals(PhysicalProperties.ANY);
return rewriteUnary(dictionarySink, ctx.withAllowShuffleKeyPrune(childAllowShuffleKeyPrune));
}
@Override
public Plan visitPhysicalDeferMaterializeResultSink(
PhysicalDeferMaterializeResultSink<? extends Plan> sink,
PruneCtx ctx) {
return rewriteUnary(sink, ctx.withAllowShuffleKeyPrune(false));
}
private <P extends PhysicalUnary<?>> P rewriteUnary(P plan, PruneCtx ctx) {
Plan oldChild = plan.child();
Plan newChild = oldChild.accept(this, ctx);
if (newChild == oldChild) {
return plan;
}
AbstractPhysicalPlan rewritten = (AbstractPhysicalPlan) plan.withChildren(ImmutableList.of(newChild));
return (P) rewritten.copyStatsAndGroupIdFrom((AbstractPhysicalPlan) plan);
}
private <P extends Plan> P rewriteChildren(P plan, List<PruneCtx> context) {
ImmutableList.Builder<Plan> newChildren = ImmutableList.builderWithExpectedSize(plan.arity());
boolean hasNewChildren = false;
for (int i = 0; i < plan.arity(); ++i) {
Plan child = plan.child(i);
Plan newChild = child.accept(this, context.get(i));
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
if (hasNewChildren) {
P originPlan = plan;
plan = (P) plan.withChildren(newChildren.build());
plan = (P) ((AbstractPhysicalPlan) plan).copyStatsAndGroupIdFrom((AbstractPhysicalPlan) originPlan);
}
return plan;
}
}
private static Pair<Boolean, Boolean> deriveHashJoinChildAllowShuffleKeyPrune(
PhysicalHashJoin<? extends Plan, ? extends Plan> join,
boolean parentAllowShuffleKeyPrune) {
if (join.isBroadCastJoin()) {
return Pair.of(parentAllowShuffleKeyPrune, false);
}
if (join.shuffleType() == org.apache.doris.nereids.trees.plans.algebra.ShuffleType.shuffle) {
return Pair.of(false, false);
}
return Pair.of(false, false);
}
private static Optional<PhysicalHashJoin<? extends Plan, ? extends Plan>> maybePruneShuffleJoin(
PhysicalHashJoin<? extends Plan, ? extends Plan> join, CascadesContext cascadesContext) {
if (join.isMarkJoin() && join.getHashJoinConjuncts().isEmpty()) {
return Optional.empty();
}
if (join.isBroadCastJoin()) {
return Optional.empty();
}
Optional<PhysicalDistribute<Plan>> leftDistOpt = findHashDistributeUnderJoinChild(join.left());
Optional<PhysicalDistribute<Plan>> rightDistOpt = findHashDistributeUnderJoinChild(join.right());
if (!leftDistOpt.isPresent() || !rightDistOpt.isPresent()) {
return Optional.empty();
}
if (join.getDistributeHint().getSkewInfo() != null) {
return Optional.empty();
}
PhysicalDistribute<Plan> leftDist = leftDistOpt.get();
PhysicalDistribute<Plan> rightDist = rightDistOpt.get();
if (!(leftDist.getDistributionSpec() instanceof DistributionSpecHash)
|| !(rightDist.getDistributionSpec() instanceof DistributionSpecHash)) {
return Optional.empty();
}
Statistics leftStats = statisticsForShuffleKeyPruneBelowDistribute(leftDist);
Statistics rightStats = statisticsForShuffleKeyPruneBelowDistribute(rightDist);
DistributionSpecHash leftSpec = (DistributionSpecHash) leftDist.getDistributionSpec();
DistributionSpecHash rightSpec = (DistributionSpecHash) rightDist.getDistributionSpec();
Optional<Pair<List<ExprId>, List<ExprId>>> optimal =
ShuffleKeyPruneUtils.tryFindOptimalShuffleKeyForJoinWithDistributeColumns(
cascadesContext.getConnectContext(),
leftDist.getOrderedShuffledSlots(), rightDist.getOrderedShuffledSlots(),
leftSpec.getOrderedShuffledColumns(), rightSpec.getOrderedShuffledColumns(),
leftStats, rightStats);
if (!optimal.isPresent()) {
return Optional.empty();
}
Pair<List<ExprId>, List<ExprId>> keys = optimal.get();
DistributionSpecHash newLeftSpec = sliceHashSpec(leftSpec, keys.first);
DistributionSpecHash newRightSpec = sliceHashSpec(rightSpec, keys.second);
Plan rebuiltLeftDist = rebuildDistribute(leftDist, newLeftSpec, leftDist.child());
Plan rebuiltRightDist = rebuildDistribute(rightDist, newRightSpec, rightDist.child());
Plan replacedLeft = replaceDistributeUnderJoinChild(join.left(), rebuiltLeftDist);
Plan replacedRight = replaceDistributeUnderJoinChild(join.right(), rebuiltRightDist);
PhysicalHashJoin<Plan, Plan> rewritten =
asHashJoin(join.withChildren(ImmutableList.of(replacedLeft, replacedRight)));
return Optional.of((PhysicalHashJoin<? extends Plan, ? extends Plan>) rewritten.copyStatsAndGroupIdFrom(join));
}
/**
* Join child is either {@link PhysicalDistribute}, or {@link PhysicalHashAggregate} (GLOBAL) whose
* child is {@link PhysicalDistribute}, optionally wrapped by {@link PhysicalProject} and/or
* {@link PhysicalFilter}; otherwise empty.
*/
static Optional<PhysicalDistribute<Plan>> findHashDistributeUnderJoinChild(Plan joinChild) {
if (joinChild instanceof PhysicalDistribute) {
return Optional.of((PhysicalDistribute<Plan>) joinChild);
}
if (joinChild instanceof PhysicalProject || joinChild instanceof PhysicalFilter) {
return findHashDistributeUnderJoinChild(joinChild.child(0));
}
if (joinChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<?> agg = (PhysicalHashAggregate<?>) joinChild;
if (agg.getAggPhase().isGlobal() && agg.child() instanceof PhysicalDistribute) {
return Optional.of((PhysicalDistribute<Plan>) agg.child());
}
}
return Optional.empty();
}
/**
* Replaces the target {@link PhysicalDistribute} under a join child with {@code newDistributeRoot}
* (typically from {@link #rebuildDistribute}). If the join child is the distribute itself, returns
* {@code newDistributeRoot}; if the join child is GLOBAL agg over that distribute, returns agg with
* updated child. {@link PhysicalProject} / {@link PhysicalFilter} wrappers are preserved.
*/
static Plan replaceDistributeUnderJoinChild(Plan joinChild, Plan newDistributeRoot) {
if (joinChild instanceof PhysicalDistribute) {
return newDistributeRoot;
}
if (joinChild instanceof PhysicalProject) {
PhysicalProject<?> project = (PhysicalProject<?>) joinChild;
Plan innerNew = replaceDistributeUnderJoinChild(project.child(), newDistributeRoot);
if (innerNew == project.child()) {
return project;
}
PhysicalProject<Plan> rewritten = (PhysicalProject<Plan>) project.withChildren(
ImmutableList.of(innerNew));
return rewritten.copyStatsAndGroupIdFrom((AbstractPhysicalPlan) project);
}
if (joinChild instanceof PhysicalFilter) {
PhysicalFilter<?> filter = (PhysicalFilter<?>) joinChild;
Plan innerNew = replaceDistributeUnderJoinChild(filter.child(), newDistributeRoot);
if (innerNew == filter.child()) {
return filter;
}
PhysicalFilter<Plan> rewritten = (PhysicalFilter<Plan>) filter.withChildren(
ImmutableList.of(innerNew));
return rewritten.copyStatsAndGroupIdFrom((AbstractPhysicalPlan) filter);
}
PhysicalHashAggregate<?> agg = (PhysicalHashAggregate<?>) joinChild;
PhysicalHashAggregate<Plan> rewritten = (PhysicalHashAggregate<Plan>) agg.withChildren(
ImmutableList.of(newDistributeRoot));
return rewritten.copyStatsAndGroupIdFrom((AbstractPhysicalPlan) agg);
}
/**
* Statistics below a shuffle {@link PhysicalDistribute}, aligned with {@link #tryPruneGlobalAgg} for
* column balance / NDV when the child is local agg or otherwise.
*/
private static Statistics statisticsForShuffleKeyPruneBelowDistribute(PhysicalDistribute<Plan> dist) {
if (dist.child() instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<Plan> childAgg = (PhysicalHashAggregate<Plan>) dist.child();
if (childAgg.getAggPhase().isLocal()) {
return childAgg.child().getStats();
} else {
return childAgg.getStats();
}
} else {
return dist.child().getStats();
}
}
private static PhysicalHashAggregate<? extends Plan> tryPruneGlobalAgg(PhysicalHashAggregate<? extends Plan> agg,
CascadesContext cascadesContext) {
if (!(agg.child() instanceof PhysicalDistribute)) {
return agg;
}
PhysicalDistribute<Plan> dist = (PhysicalDistribute<Plan>) agg.child();
if (!(dist.getDistributionSpec() instanceof DistributionSpecHash)) {
return agg;
}
if (agg.hasSourceRepeat()) {
return agg;
}
DistributionSpecHash hashSpec = (DistributionSpecHash) dist.getDistributionSpec();
Statistics childStats = statisticsForShuffleKeyPruneBelowDistribute(dist);
List<Expression> evalExprs;
if (agg.getPartitionExpressions().isPresent() && !agg.getPartitionExpressions().get().isEmpty()) {
evalExprs = agg.getPartitionExpressions().get();
} else {
evalExprs = agg.getGroupByExpressions();
}
if (evalExprs.isEmpty()) {
return agg;
}
List<ExprId> shuffleExprIds = hashSpec.getOrderedShuffledColumns();
if (shuffleExprIds.size() < 2) {
return agg;
}
List<Expression> shuffleExprs = new ArrayList<>(hashSpec.getOrderedShuffledColumns().size());
Map<ExprId, Slot> exprIdSlotMap = new HashMap<>();
for (Slot slot : agg.getOutput()) {
exprIdSlotMap.put(slot.getExprId(), slot);
}
for (ExprId exprId : hashSpec.getOrderedShuffledColumns()) {
if (!exprIdSlotMap.containsKey(exprId)) {
return agg;
}
shuffleExprs.add(exprIdSlotMap.get(exprId));
}
Optional<List<Expression>> best = ShuffleKeyPruneUtils.selectBestShuffleKeyForAgg(agg, shuffleExprs,
childStats, cascadesContext.getConnectContext());
if (!best.isPresent()) {
return agg;
}
List<ExprId> newIds = best.get().stream()
.filter(SlotReference.class::isInstance)
.map(SlotReference.class::cast)
.map(SlotReference::getExprId)
.collect(Collectors.toList());
if (newIds.size() != best.get().size() || newIds.size() >= hashSpec.getOrderedShuffledColumns().size()) {
return agg;
}
DistributionSpecHash newSpec = hashSpec.withShuffleExprs(newIds);
Plan replaced = rebuildDistribute(dist, newSpec, dist.child());
PhysicalHashAggregate<Plan> rewritten = asAgg(agg.withChildren(ImmutableList.of(replaced)));
return (PhysicalHashAggregate<Plan>) rewritten.copyStatsAndGroupIdFrom(agg);
}
private static DistributionSpecHash sliceHashSpec(DistributionSpecHash origin, List<ExprId> newOrderedKeys) {
return new DistributionSpecHash(newOrderedKeys, origin.getShuffleType(),
origin.getTableId(), origin.getSelectedIndexId(), origin.getPartitionIds());
}
private static PhysicalDistribute<Plan> rebuildDistribute(PhysicalDistribute<Plan> origin,
DistributionSpecHash newHashSpec, Plan newChild) {
PhysicalProperties props = PhysicalProperties.createHash(newHashSpec)
.withOrderSpec(origin.getPhysicalProperties().getOrderSpec());
return AbstractPlan.copyWithSameId(origin,
() -> new PhysicalDistribute<>(newHashSpec, origin.getGroupExpression(),
origin.getLogicalProperties(), props, origin.getStats(), newChild));
}
@SuppressWarnings("unchecked")
private static PhysicalHashJoin<Plan, Plan> asHashJoin(Plan join) {
return (PhysicalHashJoin<Plan, Plan>) join;
}
@SuppressWarnings("unchecked")
private static PhysicalHashAggregate<Plan> asAgg(Plan agg) {
return (PhysicalHashAggregate<Plan>) agg;
}
}