SumLiteralRewrite.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import org.apache.thrift.annotation.Nullable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
/**
* sum(expr +/- literal) ==> sum(expr) +/- literal * count(expr)
*/
public class SumLiteralRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate()
.whenNot(agg -> agg.getSourceRepeat().isPresent())
.then(agg -> {
Map<NamedExpression, Pair<SumInfo, Literal>> sumLiteralMap = new HashMap<>();
for (NamedExpression namedExpression : agg.getOutputs()) {
Pair<NamedExpression, Pair<SumInfo, Literal>> pel = extractSumLiteral(namedExpression);
if (pel == null) {
continue;
}
sumLiteralMap.put(pel.first, pel.second);
}
Map<NamedExpression, Pair<SumInfo, Literal>> validSumLiteralMap =
removeOneSumLiteral(sumLiteralMap);
if (validSumLiteralMap.isEmpty()) {
return null;
}
return rewriteSumLiteral(agg, validSumLiteralMap);
}).toRule(RuleType.SUM_LITERAL_REWRITE);
}
// when there only one sum literal like select count(id1 + 1), count(id2 + 1) from t, we don't rewrite them.
private Map<NamedExpression, Pair<SumInfo, Literal>> removeOneSumLiteral(
Map<NamedExpression, Pair<SumInfo, Literal>> sumLiteralMap) {
Map<Expression, Integer> countSum = new HashMap<>();
for (Entry<NamedExpression, Pair<SumInfo, Literal>> e : sumLiteralMap.entrySet()) {
Expression expr = e.getValue().first.expr;
countSum.merge(expr, 1, Integer::sum);
}
Map<NamedExpression, Pair<SumInfo, Literal>> validSumLiteralMap = new HashMap<>();
for (Entry<NamedExpression, Pair<SumInfo, Literal>> e : sumLiteralMap.entrySet()) {
Expression expr = e.getValue().first.expr;
if (countSum.get(expr) > 1) {
validSumLiteralMap.put(e.getKey(), e.getValue());
}
}
return validSumLiteralMap;
}
private Plan rewriteSumLiteral(
LogicalAggregate<?> agg, Map<NamedExpression, Pair<SumInfo, Literal>> sumLiteralMap) {
Set<NamedExpression> newAggOutput = new HashSet<>();
for (NamedExpression expr : agg.getOutputExpressions()) {
if (!sumLiteralMap.containsKey(expr)) {
newAggOutput.add(expr);
}
}
Map<SumInfo, Slot> exprToSum = new HashMap<>();
Map<SumInfo, Slot> exprToCount = new HashMap<>();
Map<AggregateFunction, NamedExpression> existedAggFunc = new HashMap<>();
for (NamedExpression e : agg.getOutputExpressions()) {
if (e.children().size() == 1 && e.child(0) instanceof AggregateFunction) {
existedAggFunc.put((AggregateFunction) e.child(0), e);
}
}
Set<SumInfo> countSumExpr = new HashSet<>();
for (Pair<SumInfo, Literal> pair : sumLiteralMap.values()) {
countSumExpr.add(pair.first);
}
for (SumInfo info : countSumExpr) {
NamedExpression namedSum = constructSum(info, existedAggFunc);
NamedExpression namedCount = constructCount(info, existedAggFunc);
exprToSum.put(info, namedSum.toSlot());
exprToCount.put(info, namedCount.toSlot());
newAggOutput.add(namedSum);
newAggOutput.add(namedCount);
}
LogicalAggregate<?> newAgg = agg.withAggOutput(ImmutableList.copyOf(newAggOutput));
List<NamedExpression> newProjects = constructProjectExpression(agg, sumLiteralMap, exprToSum, exprToCount);
return new LogicalProject<>(newProjects, newAgg);
}
private List<NamedExpression> constructProjectExpression(
LogicalAggregate<?> agg, Map<NamedExpression, Pair<SumInfo, Literal>> sumLiteralMap,
Map<SumInfo, Slot> exprToSum, Map<SumInfo, Slot> exprToCount) {
List<NamedExpression> newProjects = new ArrayList<>();
for (NamedExpression namedExpr : agg.getOutputExpressions()) {
if (!sumLiteralMap.containsKey(namedExpr)) {
newProjects.add(namedExpr.toSlot());
continue;
}
SumInfo originExpr = sumLiteralMap.get(namedExpr).first;
Literal literal = sumLiteralMap.get(namedExpr).second;
Expression newExpr;
if (namedExpr.child(0).child(0) instanceof Add) {
newExpr = new Add(exprToSum.get(originExpr),
new Multiply(literal, exprToCount.get(originExpr)));
} else {
newExpr = new Subtract(exprToSum.get(originExpr),
new Multiply(literal, exprToCount.get(originExpr)));
}
newProjects.add(new Alias(namedExpr.getExprId(), newExpr, namedExpr.getName()));
}
return newProjects;
}
private NamedExpression constructSum(SumInfo info, Map<AggregateFunction, NamedExpression> existedAggFunc) {
Sum sum = new Sum(info.isDistinct, info.isAlwaysNullable, info.expr);
NamedExpression namedSum;
if (existedAggFunc.containsKey(sum)) {
namedSum = existedAggFunc.get(sum);
} else {
namedSum = new Alias(sum);
}
return namedSum;
}
private NamedExpression constructCount(SumInfo info, Map<AggregateFunction, NamedExpression> existedAggFunc) {
Count count = new Count(info.isDistinct, info.expr);
NamedExpression namedCount;
if (existedAggFunc.containsKey(count)) {
namedCount = existedAggFunc.get(count);
} else {
namedCount = new Alias(count);
}
return namedCount;
}
private @Nullable Pair<NamedExpression, Pair<SumInfo, Literal>> extractSumLiteral(
NamedExpression namedExpression) {
if (namedExpression.children().size() != 1) {
return null;
}
Expression func = namedExpression.child(0);
if (!(func instanceof Sum)) {
return null;
}
Expression child = func.child(0);
if (!(child instanceof Add) && !(child instanceof Subtract)) {
return null;
}
Expression left = ((BinaryArithmetic) child).left();
Expression right = ((BinaryArithmetic) child).right();
if (!(right.isLiteral() && left instanceof Slot)) {
// right now, only support slot +/- literal
return null;
}
if (!(right.getDataType().isIntegerLikeType() || right.getDataType().isFloatLikeType())) {
// only support integer or float types
return null;
}
SumInfo info = new SumInfo(left, ((Sum) func).isDistinct(), ((Sum) func).isAlwaysNullable());
return Pair.of(namedExpression, Pair.of(info, (Literal) right));
}
static class SumInfo {
Expression expr;
boolean isDistinct;
boolean isAlwaysNullable;
SumInfo(Expression expr, boolean isDistinct, boolean isAlwaysNullable) {
this.expr = expr;
this.isDistinct = isDistinct;
this.isAlwaysNullable = isAlwaysNullable;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SumInfo sumInfo = (SumInfo) o;
if (isDistinct != sumInfo.isDistinct) {
return false;
}
if (isAlwaysNullable != sumInfo.isAlwaysNullable) {
return false;
}
return Objects.equals(expr, sumInfo.expr);
}
@Override
public int hashCode() {
return Objects.hash(expr, isDistinct, isAlwaysNullable);
}
}
}