EliminateGroupByKey.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.properties.FuncDeps;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
/**
* Eliminate group by key based on fd item information.
* such as:
* for a -> b, we can get:
* group by a, b, c => group by a, c
*
* When a group-by key is FD-redundant but still needed in the output,
* it is wrapped with any_value() and assigned a fresh ExprId.
* Upper plan references are rewritten via ExprIdRewriter so that
* all ancestor nodes see the new ExprIds.
*/
public class EliminateGroupByKey extends DefaultPlanRewriter<Map<ExprId, ExprId>> implements CustomRewriter {
private ExprIdRewriter exprIdReplacer;
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
if (!plan.containsType(Aggregate.class)) {
return plan;
}
Map<ExprId, ExprId> replaceMap = new HashMap<>();
ExprIdRewriter.ReplaceRule replaceRule = new ExprIdRewriter.ReplaceRule(replaceMap, false);
exprIdReplacer = new ExprIdRewriter(replaceRule, jobContext);
return plan.accept(this, replaceMap);
}
@Override
public Plan visit(Plan plan, Map<ExprId, ExprId> replaceMap) {
plan = visitChildren(this, plan, replaceMap);
plan = exprIdReplacer.rewriteExpr(plan, replaceMap);
return plan;
}
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> proj, Map<ExprId, ExprId> replaceMap) {
proj = visitChildren(this, proj, replaceMap);
// Find the Aggregate child, possibly through a Filter
Plan child = proj.child(0);
LogicalAggregate<? extends Plan> agg;
boolean hasFilter = child instanceof LogicalFilter;
if (hasFilter && child.child(0) instanceof LogicalAggregate) {
agg = (LogicalAggregate<? extends Plan>) child.child(0);
} else if (child instanceof LogicalAggregate) {
agg = (LogicalAggregate<? extends Plan>) child;
} else {
return exprIdReplacer.rewriteExpr(proj, replaceMap);
}
// Don't transform if source repeat is present
if (agg.getSourceRepeat().isPresent()) {
return exprIdReplacer.rewriteExpr(proj, replaceMap);
}
// Compute requireOutput: slots needed by the Project (and Filter, if present)
Set<Slot> requireOutput = new HashSet<>(proj.getInputSlots());
if (hasFilter) {
requireOutput.addAll(child.getInputSlots());
}
// Transform the aggregate
EliminateResult result = eliminateGroupByKeyWithMap(agg, requireOutput);
if (!result.changed) {
return exprIdReplacer.rewriteExpr(proj, replaceMap);
}
// Merge into the global replaceMap so that all ancestor nodes get rewritten
replaceMap.putAll(result.replaceMap);
// Rebuild the child chain with the new aggregate,
// and rewrite the Filter (if present) and Project expressions
Plan newChild;
if (hasFilter) {
Plan updatedFilter = child.withChildren(result.newAgg);
newChild = exprIdReplacer.rewriteExpr(updatedFilter, replaceMap);
} else {
newChild = result.newAgg;
}
Plan newProj = exprIdReplacer.rewriteExpr(proj.withChildren(newChild), replaceMap);
return newProj;
}
/** Result of eliminateGroupByKey: the new aggregate and a map of old->new ExprIds. */
private static class EliminateResult {
final LogicalAggregate<Plan> newAgg;
final Map<ExprId, ExprId> replaceMap;
final boolean changed;
EliminateResult(LogicalAggregate<Plan> newAgg, Map<ExprId, ExprId> replaceMap, boolean changed) {
this.newAgg = newAgg;
this.replaceMap = replaceMap;
this.changed = changed;
}
}
EliminateResult eliminateGroupByKeyWithMap(LogicalAggregate<? extends Plan> agg, Set<Slot> requireOutput) {
FindResult result = findCanBeRemovedExpressionsInternal(agg, requireOutput,
agg.child().getLogicalProperties().getTrait());
Set<Expression> removeExpression = result.removeExpression;
Set<Expression> wrapWithAnyValue = result.wrapWithAnyValue;
List<Expression> newGroupExpression = new ArrayList<>();
for (Expression expression : agg.getGroupByExpressions()) {
if (!removeExpression.contains(expression)
&& !wrapWithAnyValue.contains(expression)) {
newGroupExpression.add(expression);
}
}
List<NamedExpression> newOutput = new ArrayList<>();
Map<ExprId, ExprId> replaceMap = new HashMap<>();
boolean changed = !removeExpression.isEmpty() || !wrapWithAnyValue.isEmpty();
for (NamedExpression expression : agg.getOutputExpressions()) {
if (removeExpression.contains(expression)) {
continue;
}
if (wrapWithAnyValue.contains(expression)) {
// expression is FD-redundant but needed in output: wrap with any_value
// Use fresh ExprId (auto-generated by Alias) to avoid ExprId collision,
// and record the mapping for rewriting upper plan references.
Alias newAlias = new Alias(new AnyValue(expression.toSlot()), expression.getName());
replaceMap.put(expression.getExprId(), newAlias.getExprId());
expression = newAlias;
}
newOutput.add(expression);
}
return new EliminateResult(agg.withGroupByAndOutput(newGroupExpression, newOutput), replaceMap, changed);
}
/**
* Return expressions that can be completely removed from both group-by and output.
* Kept for backward compatibility with external callers (e.g. PushDownAggThroughJoinOnPkFk).
*/
public static Set<Expression> findCanBeRemovedExpressions(LogicalAggregate<? extends Plan> agg,
Set<Slot> requireOutput, DataTrait dataTrait) {
FindResult result = findCanBeRemovedExpressionsInternal(agg, requireOutput, dataTrait);
return new HashSet<>(result.removeExpression);
}
/** Result of findCanBeRemovedExpressionsInternal: two sets of expressions. */
private static class FindResult {
final Set<Expression> removeExpression; // remove from group-by and output
final Set<Expression> wrapWithAnyValue; // remove from group-by, wrap with ANY_VALUE in output
FindResult(Set<Expression> removeExpression, Set<Expression> wrapWithAnyValue) {
this.removeExpression = removeExpression;
this.wrapWithAnyValue = wrapWithAnyValue;
}
}
private static FindResult findCanBeRemovedExpressionsInternal(LogicalAggregate<? extends Plan> agg,
Set<Slot> requireOutput, DataTrait dataTrait) {
Map<Expression, Set<Slot>> groupBySlots = new HashMap<>();
Set<Slot> validSlots = new HashSet<>();
for (Expression expression : agg.getGroupByExpressions()) {
groupBySlots.put(expression, expression.getInputSlots());
validSlots.addAll(expression.getInputSlots());
}
FuncDeps funcDeps = dataTrait.getAllValidFuncDeps(validSlots);
if (funcDeps.isEmpty()) {
return new FindResult(new HashSet<>(), new HashSet<>());
}
Set<Set<Slot>> minGroupBySlots = funcDeps.eliminateDeps(new HashSet<>(groupBySlots.values()), requireOutput);
Set<Expression> removeExpression = new HashSet<>();
Set<Expression> wrapWithAnyValue = new HashSet<>();
for (Entry<Expression, Set<Slot>> entry : groupBySlots.entrySet()) {
if (!minGroupBySlots.contains(entry.getValue())) {
// FD redundant: can remove from group-by
if (!requireOutput.containsAll(entry.getValue())) {
// Not needed in output either: remove completely
removeExpression.add(entry.getKey());
} else {
// Still needed in output: remove from group-by, wrap with ANY_VALUE in output
wrapWithAnyValue.add(entry.getKey());
}
}
}
return new FindResult(removeExpression, wrapWithAnyValue);
}
}