MergeAggregate.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**MergeAggregate*/
@DependsRules({
NormalizeAggregate.class
})
public class MergeAggregate implements RewriteRuleFactory {
private static final ImmutableSet<String> ALLOW_MERGE_AGGREGATE_FUNCTIONS =
ImmutableSet.of("min", "max", "sum", "any_value");
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(logicalAggregate()).when(this::canMergeAggregateWithoutProject)
.then(this::mergeTwoAggregate)
.toRule(RuleType.MERGE_AGGREGATE),
logicalAggregate(logicalProject(logicalAggregate()))
.when(this::canMergeAggregateWithProject)
.then(this::mergeAggProjectAgg)
.toRule(RuleType.MERGE_AGGREGATE));
}
/**
* before:
* LogicalAggregate
* +--LogicalAggregate
* after:
* LogicalAggregate
*/
private Plan mergeTwoAggregate(LogicalAggregate<LogicalAggregate<Plan>> outerAgg) {
LogicalAggregate<Plan> innerAgg = outerAgg.child();
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
List<NamedExpression> newOutputExpressions = outerAgg.getOutputExpressions().stream()
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
.collect(Collectors.toList());
return outerAgg.withAggOutput(newOutputExpressions).withChildren(innerAgg.children());
}
/**
* before:
* LogicalAggregate (outputExpressions = [col2, sum(col1)], groupByKeys = [col2])
* +--LogicalProject (projects = [a as col2, col1])
* +--LogicalAggregate (outputExpressions = [a, b, sum(c) as col1], groupByKeys = [a,b])
* after:
* LogicalProject (projects = [a as col2, sum(col1) as sum(col1)]
* +--LogicalAggregate (outputExpression = [a, sum(c) as sum(col1)], groupByKeys = [a])
*/
private Plan mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
LogicalProject<LogicalAggregate<Plan>> project = outerAgg.child();
LogicalAggregate<Plan> innerAgg = project.child();
List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions();
List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections(
project.getProjects(), (List) outputExpressions);
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
// rewrite agg function. e.g. max(max)
List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
.collect(Collectors.toList());
// replace groupByKeys directly refer to the slot below the project
List<Expression> replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(),
outerAgg.getGroupByExpressions());
List<NamedExpression> newOutputExpressions = ImmutableList.<NamedExpression>builder()
.addAll(replacedGroupBy.stream().map(slot -> (NamedExpression) slot).iterator())
.addAll(replacedAggFunc.stream().map(alias -> (NamedExpression) alias).iterator()).build();
// construct agg
LogicalAggregate<Plan> resAgg = outerAgg.withGroupByAndOutput(replacedGroupBy, newOutputExpressions)
.withChildren(innerAgg.children());
// construct upper project
Map<ExprId, NamedExpression> exprIdToNameExpressionMap = new HashMap<>();
for (NamedExpression pro : project.getProjects()) {
exprIdToNameExpressionMap.put(pro.getExprId(), pro);
}
List<Expression> originOuterAggGroupBy = outerAgg.getGroupByExpressions();
List<Expression> projectGroupBy = new ArrayList<>();
for (Expression expression : originOuterAggGroupBy) {
ExprId exprId = ((NamedExpression) expression).getExprId();
NamedExpression namedExpression = exprIdToNameExpressionMap.get(exprId);
projectGroupBy.add(namedExpression);
}
List<NamedExpression> upperProjects = ImmutableList.<NamedExpression>builder()
.addAll(projectGroupBy.stream().map(namedExpr -> (NamedExpression) namedExpr).iterator())
.addAll(replacedAggFunc.stream().map(expr -> ((NamedExpression) expr).toSlot()).iterator())
.build();
return new LogicalProject<Plan>(upperProjects, resAgg);
}
private NamedExpression rewriteAggregateFunction(NamedExpression e,
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc) {
return (NamedExpression) e.rewriteDownShortCircuit(expr -> {
if (expr instanceof Alias && ((Alias) expr).child() instanceof AggregateFunction) {
Alias alias = (Alias) expr;
AggregateFunction aggFunc = (AggregateFunction) alias.child();
ExprId childExprId = ((SlotReference) aggFunc.child(0)).getExprId();
if (innerAggExprIdToAggFunc.containsKey(childExprId)) {
return new Alias(alias.getExprId(), innerAggExprIdToAggFunc.get(childExprId),
alias.getName());
} else {
return expr;
}
} else {
return expr;
}
});
}
private boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg,
boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions();
List<AggregateFunction> replacedAggFunctions = projectOptional.map(project ->
(List<AggregateFunction>) PlanUtils.replaceExpressionByProjections(
projectOptional.get().getProjects(), new ArrayList<>(aggregateFunctions)))
.orElse(new ArrayList<>(aggregateFunctions));
for (AggregateFunction outerFunc : replacedAggFunctions) {
if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) {
return false;
}
if (outerFunc.isDistinct() && !sameGroupBy) {
return false;
}
// not support outerAggFunc: sum(a+1),sum(a+b)
if (!(outerFunc.child(0) instanceof SlotReference)) {
return false;
}
ExprId childExprId = ((SlotReference) outerFunc.child(0)).getExprId();
if (innerAggExprIdToAggFunc.containsKey(childExprId)) {
AggregateFunction innerFunc = innerAggExprIdToAggFunc.get(childExprId);
if (innerFunc.isDistinct() && !sameGroupBy) {
return false;
}
// support sum(sum),min(min),max(max),any_value(any_value),sum(count)
// sum(count) -> count() need outerAgg having group by keys (reason: nullable)
if (!(outerFunc.getName().equals("sum") && innerFunc.getName().equals("count")
&& !outerAgg.getGroupByExpressions().isEmpty())
&& !innerFunc.getName().equals(outerFunc.getName())) {
return false;
}
} else {
// select a, max(b), min(b), any_value(b) from (select a,b from t1 group by a, b) group by a;
// equals select a, max(b), min(b), any_value(b) from t1 group by a;
if (!outerFunc.getName().equals("max")
&& !outerFunc.getName().equals("min")
&& !outerFunc.getName().equals("any_value")) {
return false;
}
}
}
return true;
}
private boolean canMergeAggregateWithoutProject(LogicalAggregate<LogicalAggregate<Plan>> outerAgg) {
LogicalAggregate<Plan> innerAgg = outerAgg.child();
if (!new HashSet<>(innerAgg.getGroupByExpressions()).containsAll(outerAgg.getGroupByExpressions())) {
return false;
}
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.empty());
}
private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
LogicalProject<LogicalAggregate<Plan>> project = outerAgg.child();
LogicalAggregate<Plan> innerAgg = project.child();
List<Expression> outerAggGroupByKeys = PlanUtils.replaceExpressionByProjections(project.getProjects(),
outerAgg.getGroupByExpressions());
if (!new HashSet<>(innerAgg.getGroupByExpressions()).containsAll(outerAggGroupByKeys)) {
return false;
}
// project cannot have expressions like a+1
if (ExpressionUtils.deapAnyMatch(project.getProjects(),
expr -> !(expr instanceof SlotReference) && !(expr instanceof Alias))) {
return false;
}
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
}
private Map<ExprId, AggregateFunction> getInnerAggExprIdToAggFuncMap(LogicalAggregate<Plan> innerAgg) {
return innerAgg.getOutputExpressions().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
(existValue, newValue) -> existValue));
}
}