EagerAggRewriter.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.eageraggregation;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.StatsDerive;
import org.apache.doris.nereids.stats.ExpressionEstimation;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* eager aggregation
* agg[sum(t1.A) group by t1.B]
* ->join(t1.C=t2.D)
* ->T1(A, B, C)
* ->T2(D)
*
* =>
* agg[sum(x) group by t1.B]
* ->join(t1.C=t2.D)
* ->agg[sum(A) as x, group by B]
* ->T1(A, B, C)
* ->T2(D)
*/
public class EagerAggRewriter extends DefaultPlanRewriter<PushDownAggContext> {
private static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000;
private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000;
private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100;
private final StatsDerive derive = new StatsDerive(false);
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, PushDownAggContext context) {
boolean toLeft = false;
boolean toRight = false;
boolean pushHere = false;
if (context.getAggFunctions().isEmpty()) {
// example: select x from T group by x
// if no agg function, try to push to large child
Statistics leftStats = join.left().getStats();
if (leftStats == null) {
leftStats = join.left().accept(derive, new StatsDerive.DeriveContext());
}
Statistics rightStats = join.right().getStats();
if (rightStats == null) {
rightStats = join.right().accept(derive, new StatsDerive.DeriveContext());
}
if (leftStats.getRowCount() > rightStats.getRowCount()) {
toLeft = true;
} else {
toRight = true;
}
} else {
for (AggregateFunction aggFunc : context.getAggFunctions()) {
if (join.left().getOutputSet().containsAll(aggFunc.getInputSlots())) {
toLeft = true;
} else if (join.right().getOutputSet().containsAll(aggFunc.getInputSlots())) {
toRight = true;
} else {
pushHere = true;
}
}
}
if (pushHere || (toLeft && toRight)) {
if (SessionVariable.isEagerAggregationOnJoin()) {
return genAggregate(join, context);
} else {
return join;
}
}
List<SlotReference> joinConditionSlots;
List<SlotReference> childGroupByKeys = new ArrayList<>();
if (toLeft) {
joinConditionSlots = getJoinConditionsInputSlotsFromOneSide(join, join.left());
for (SlotReference key : context.getGroupKeys()) {
if (join.left().getOutputSet().containsAll(key.getInputSlots())) {
childGroupByKeys.add(key);
}
}
} else {
joinConditionSlots = getJoinConditionsInputSlotsFromOneSide(join, join.right());
for (SlotReference key : context.getGroupKeys()) {
if (join.right().getOutputSet().containsAll(key.getInputSlots())) {
childGroupByKeys.add(key);
}
}
}
for (SlotReference slot : joinConditionSlots) {
if (!childGroupByKeys.contains(slot)) {
childGroupByKeys.add(slot);
}
}
PushDownAggContext childContext = context.withGroupKeys(childGroupByKeys);
Statistics stats = join.right().getStats();
if (stats == null) {
stats = join.right().accept(derive, new StatsDerive.DeriveContext());
}
if (stats.getRowCount() > PushDownAggContext.BIG_JOIN_BUILD_SIZE
|| SessionVariable.getEagerAggregationMode() > 0) {
childContext = childContext.passThroughBigJoin();
}
if (toLeft) {
Plan newLeft = join.left().accept(this, childContext);
if (newLeft != join.left()) {
return join.withChildren(newLeft, join.right());
}
} else {
Plan newRight = join.right().accept(this, childContext);
if (newRight != join.right()) {
return join.withChildren(join.left(), newRight);
}
}
return join;
}
private List<SlotReference> getJoinConditionsInputSlotsFromOneSide(
LogicalJoin<? extends Plan, ? extends Plan> join,
Plan side) {
List<SlotReference> oneSideSlots = new ArrayList<>();
for (Expression condition : join.getHashJoinConjuncts()) {
for (Slot slot : condition.getInputSlots()) {
if (side.getOutputSet().contains(slot)) {
oneSideSlots.add((SlotReference) slot);
}
}
}
for (Expression condition : join.getOtherJoinConjuncts()) {
for (Slot slot : condition.getInputSlots()) {
if (side.getOutputSet().contains(slot)) {
oneSideSlots.add((SlotReference) slot);
}
}
}
return oneSideSlots;
}
private PushDownAggContext createContextFromProject(
LogicalProject<? extends Plan> project,
PushDownAggContext context) {
/*
* context: sum(a) groupBy(y+z as x, l)
* proj: b+c as a, u+v as y, m+n as l
* newContext: sum(b+c), groupBy((u+v)+z as x, m+n as l)
*/
List<SlotReference> groupKeys = new ArrayList<>();
for (SlotReference key : context.getGroupKeys()) {
groupKeys.addAll(
project.pushDownExpressionPastProject(key).getInputSlots()
.stream().map(slot -> (SlotReference) slot).collect(Collectors.toList()));
}
List<AggregateFunction> aggFunctions = new ArrayList<>();
Map<AggregateFunction, Alias> aliasMap = new IdentityHashMap<>();
for (AggregateFunction aggFunc : context.getAggFunctions()) {
AggregateFunction newAggFunc = (AggregateFunction) project.pushDownExpressionPastProject(aggFunc);
Alias alias = context.getAliasMap().get(aggFunc);
aliasMap.put(newAggFunc, (Alias) alias.withChildren(newAggFunc));
aggFunctions.add(newAggFunc);
}
return new PushDownAggContext(aggFunctions, groupKeys, aliasMap,
context.getCascadesContext(), context.isPassThroughBigJoin());
}
private boolean canPushThroughProject(LogicalProject<? extends Plan> project, PushDownAggContext context) {
for (SlotReference slot : context.getGroupKeys()) {
if (!project.getOutputSet().contains(slot)) {
SessionVariable.throwRuntimeExceptionWhenFeDebug("eager agg failed: can not find group key("
+ slot + ") in " + project);
return false;
}
}
for (Slot slot : context.getAggFunctionsInputSlots()) {
if (!project.getOutputSet().contains(slot)) {
SessionVariable.throwRuntimeExceptionWhenFeDebug("eager agg failed: can not find aggFunc slot("
+ slot + ") in " + project);
return false;
}
}
// push sum(A) through project(x, x+y as A)
// if x is not used as group key, do not push through
for (Slot slot : context.getAggFunctionsInputSlots()) {
for (NamedExpression prj : project.getProjects()) {
if (prj instanceof Alias && prj.getExprId().equals(slot.getExprId())) {
if (prj.getInputSlots().stream()
.anyMatch(
s -> project.getOutputSet().contains(s)
&& !context.getGroupKeys().contains(s))) {
return false;
}
}
}
}
return true;
}
private Plan alignUnionChildrenDataType(Plan child, PushDownAggContext context) {
int outputSize = child.getOutput().size();
List<DataType> outputDataType = Lists.newArrayListWithExpectedSize(outputSize);
outputDataType.addAll(context.getAggFunctions().stream()
.map(func -> context.getAliasMap().get(func).getDataType()).collect(Collectors.toList()));
outputDataType.addAll(context.getGroupKeys().stream().map(s -> s.getDataType()).collect(Collectors.toList()));
List<NamedExpression> projection = Lists.newArrayListWithExpectedSize(outputSize);
boolean needProject = false;
for (int colIdx = 0; colIdx < outputSize; colIdx++) {
SlotReference slot = (SlotReference) child.getOutput().get(colIdx);
if (!slot.getDataType().equals(outputDataType.get(colIdx))) {
projection.add(new Alias(new Cast(slot, outputDataType.get(colIdx))));
needProject = true;
} else {
projection.add(slot);
}
}
if (needProject) {
return new LogicalProject<Plan>(projection, child);
} else {
return child;
}
}
@Override
public Plan visitLogicalUnion(LogicalUnion union, PushDownAggContext context) {
if (!union.getConstantExprsList().isEmpty()) {
return union;
}
if (!union.getOutputs().stream().allMatch(e -> e instanceof SlotReference)) {
return union;
}
List<Plan> newChildren = Lists.newArrayList();
List<PushDownAggContext> childrenContext = new ArrayList<>();
boolean changed = false;
for (int idx = 0; idx < union.children().size(); idx++) {
Plan child = union.children().get(idx);
final int childIdx = idx;
List<AggregateFunction> aggFunctionsForChild = new ArrayList<>();
IdentityHashMap<AggregateFunction, Alias> aliasMapForChild = new IdentityHashMap<>();
for (AggregateFunction func : context.getAggFunctions()) {
AggregateFunction newFunc = (AggregateFunction) union.pushDownExpressionPastSetOperator(func, childIdx);
aggFunctionsForChild.add(newFunc);
Alias alias = context.getAliasMap().get(func);
// aliasForChild should have its own ExprId
Alias aliasForChild = new Alias(newFunc, alias.getName(), alias.getQualifier());
aliasMapForChild.put(newFunc, aliasForChild);
}
List<SlotReference> groupKeysForChild = context.getGroupKeys().stream()
.map(slot -> (SlotReference) union.pushDownExpressionPastSetOperator(slot, childIdx))
.collect(Collectors.toList());
PushDownAggContext contextForChild = new PushDownAggContext(aggFunctionsForChild, groupKeysForChild,
aliasMapForChild, context.getCascadesContext(), context.isPassThroughBigJoin());
childrenContext.add(contextForChild);
Plan newChild = child.accept(this, contextForChild);
if (newChild != child) {
changed = true;
}
// all children need align data type, even if it is not rewritten
newChild = alignUnionChildrenDataType(newChild, context);
newChildren.add(newChild);
}
if (changed) {
List<List<SlotReference>> newRegularChildrenOutputs = Lists.newArrayListWithExpectedSize(union.arity());
for (int childIdx = 0; childIdx < union.arity(); childIdx++) {
newRegularChildrenOutputs.add(
newChildren.get(childIdx).getOutput().stream()
.map(s -> (SlotReference) s).collect(Collectors.toList()));
}
List<NamedExpression> newOutput = Lists.newArrayList();
for (AggregateFunction func : context.getAggFunctions()) {
Alias alias = context.getAliasMap().get(func);
if (alias == null) {
SessionVariable.throwRuntimeExceptionWhenFeDebug("push down agg failed. union: " + union
+ " context: " + context);
return union;
}
newOutput.add(alias.toSlot());
}
newOutput.addAll(context.getGroupKeys());
LogicalUnion newUnion = (LogicalUnion) union
.withChildrenAndOutputs(newChildren, newOutput, newRegularChildrenOutputs);
return newUnion;
} else {
return union;
}
}
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> project, PushDownAggContext context) {
if (project.child() instanceof LogicalCatalogRelation
|| (project.child() instanceof LogicalFilter
&& project.child().child(0) instanceof LogicalCatalogRelation)) {
// project
// --> scan
// =>
// aggregate
// --> project
// --> scan
return genAggregate(project, context);
}
if (!canPushThroughProject(project, context)) {
return genAggregate(project, context);
}
PushDownAggContext newContext = createContextFromProject(project, context);
Plan newChild = project.child().accept(this, newContext);
if (newChild != project.child()) {
/*
* agg[sum(a), groupBy(b)]
* -> proj(a, b1+b2 as b)
* -> join(c = d)
* -> any(a, b1, b2, c,...)
* -> any(d, ...)
* =>
* agg[sum(x), groupBy(b)]
* -> proj(x, b1+b2 as b)
* -> join(c=d)
* ->agg[sum(a) as x, groupBy(b1, b2, c)]
* ->proj(a, b1, b2, c, ...)
* -> any(a, b1, b2, c)
* -> any(d, ...)
*/
List<NamedExpression> newProjections = new ArrayList<>();
for (Alias alias : context.getAliasMap().values()) {
newProjections.add(alias.toSlot());
}
for (SlotReference slot : context.getGroupKeys()) {
boolean valid = false;
for (NamedExpression ne : project.getProjects()) {
if (ne.toSlot().getExprId().equals(slot.getExprId())) {
valid = true;
newProjections.add(ne);
break;
}
}
if (!valid) {
SessionVariable.throwRuntimeExceptionWhenFeDebug(
"push agg failed. slot: " + "not found in " + project);
return project;
}
}
LogicalProject result = new LogicalProject(newProjections, newChild);
return result;
}
return project;
}
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, PushDownAggContext context) {
return agg;
}
@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, PushDownAggContext context) {
return genAggregate(filter, context);
}
@Override
public Plan visitLogicalRelation(LogicalRelation relation, PushDownAggContext context) {
return genAggregate(relation, context);
}
private Plan genAggregate(Plan child, PushDownAggContext context) {
if (checkStats(child, context)) {
List<NamedExpression> aggOutputExpressions = new ArrayList<>();
for (AggregateFunction func : context.getAggFunctions()) {
aggOutputExpressions.add(context.getAliasMap().get(func));
}
aggOutputExpressions.addAll(context.getGroupKeys());
LogicalAggregate genAgg = new LogicalAggregate(context.getGroupKeys(), aggOutputExpressions, child);
NormalizeAggregate normalizeAggregate = new NormalizeAggregate();
return normalizeAggregate.normalizeAgg(genAgg, Optional.empty(),
context.getCascadesContext());
} else {
return child;
}
}
private boolean checkStats(Plan plan, PushDownAggContext context) {
int mode = SessionVariable.getEagerAggregationMode();
if (mode < 0) {
return false;
}
if (mode > 0) {
// when mode=1, any join is regarded as big join in order to
// push down aggregation through at least one join
return context.isPassThroughBigJoin();
}
if (!context.isPassThroughBigJoin()) {
return false;
}
Statistics stats = plan.getStats();
if (stats == null) {
stats = plan.accept(derive, new StatsDerive.DeriveContext());
}
if (stats.getRowCount() == 0) {
return false;
}
List<ColumnStatistic> groupKeysStats = new ArrayList<>();
List<ColumnStatistic> lower = Lists.newArrayList();
List<ColumnStatistic> medium = Lists.newArrayList();
List<ColumnStatistic> high = Lists.newArrayList();
List<ColumnStatistic>[] cards = new List[] { lower, medium, high };
for (NamedExpression key : context.getGroupKeys()) {
ColumnStatistic colStats = ExpressionEstimation.INSTANCE.estimate(key, stats);
if (colStats.isUnKnown) {
return false;
}
groupKeysStats.add(colStats);
cards[groupByCardinality(colStats, stats.getRowCount())].add(colStats);
}
double lowerCartesian = 1.0;
for (ColumnStatistic colStats : lower) {
lowerCartesian = lowerCartesian * colStats.ndv;
}
// pow(row_count/20, a half of lower column size)
double lowerUpper = Math.max(stats.getRowCount() / 20, 1);
lowerUpper = Math.pow(lowerUpper, Math.max(lower.size() / 2, 1));
if (high.isEmpty() && (lower.size() + medium.size()) <= 2) {
return true;
}
if (high.isEmpty() && medium.isEmpty()) {
if (lower.size() == 1 && lowerCartesian * 20 <= stats.getRowCount()) {
return true;
} else if (lower.size() == 2 && lowerCartesian * 7 <= stats.getRowCount()) {
return true;
} else if (lower.size() <= 3 && lowerCartesian * 20 <= stats.getRowCount() && lowerCartesian < lowerUpper) {
return true;
} else {
return false;
}
}
if (high.size() >= 2 || medium.size() > 2 || (high.size() == 1 && !medium.isEmpty())) {
return false;
}
// 3. Extremely low cardinality for lower with at most one medium or high.
double lowerCartesianLowerBound = stats.getRowCount() / LOWER_AGGREGATE_EFFECT_COEFFICIENT;
if (high.size() + medium.size() == 1 && lower.size() <= 2 && lowerCartesian <= lowerCartesianLowerBound) {
return true;
}
return false;
}
// high(2): row_count / cardinality < MEDIUM_AGGREGATE_EFFECT_COEFFICIENT
// medium(1): row_count / cardinality >= MEDIUM_AGGREGATE_EFFECT_COEFFICIENT and
// < LOW_AGGREGATE_EFFECT_COEFFICIENT
// lower(0): row_count / cardinality >= LOW_AGGREGATE_EFFECT_COEFFICIENT
private int groupByCardinality(ColumnStatistic colStats, double rowCount) {
if (rowCount == 0 || colStats.ndv * MEDIUM_AGGREGATE_EFFECT_COEFFICIENT > rowCount) {
return 2;
} else if (colStats.ndv * MEDIUM_AGGREGATE_EFFECT_COEFFICIENT <= rowCount
&& colStats.ndv * LOW_AGGREGATE_EFFECT_COEFFICIENT > rowCount) {
return 1;
} else if (colStats.ndv * LOW_AGGREGATE_EFFECT_COEFFICIENT <= rowCount) {
return 0;
}
return 2;
}
}