MergePercentileToArray.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Percentile;
import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileArray;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/**MergePercentileToArray
* LogicalAggregate (outputExpression:[percentile(a,0.1) as c1, percentile(a,0.22) as c2])
* ->
* LogicalProject (projects: [element_at(percentile(a,[0.1,0.22])#1, 1) as c1,
* element_at(percentile(a,[0.1,0.22], 2)#1 as c2])
* --+LogicalAggregate(outputExpression: percentile_array(a, [0.1, 0.22]) as percentile_array(a, [0.1, 0.22])#1)
* */
@DependsRules({
NormalizeAggregate.class
})
public class MergePercentileToArray extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate(any())
.then(this::doMerge)
.toRule(RuleType.MERGE_PERCENTILE_TO_ARRAY);
}
// Merge percentile into percentile_array according to funcMap
private List<AggregateFunction> getPercentileArrays(Map<DistinctAndExpr, List<AggregateFunction>> funcMap) {
List<AggregateFunction> newPercentileArrays = Lists.newArrayList();
for (Map.Entry<DistinctAndExpr, List<AggregateFunction>> entry : funcMap.entrySet()) {
List<Expression> percentList = new ArrayList<>();
boolean allPercentIsLiteral = true;
for (AggregateFunction aggFunc : entry.getValue()) {
Expression percent = aggFunc.child(1);
percentList.add(percent);
if (allPercentIsLiteral && !(percent instanceof Literal)) {
allPercentIsLiteral = false;
}
}
ArrayLiteral percentArrayLiteral = null;
Array percentArray = null;
if (allPercentIsLiteral) {
percentArrayLiteral = new ArrayLiteral((List) percentList);
} else {
percentArray = new Array(percentList.toArray(new Expression[0]));
}
PercentileArray percentileArray;
Expression secondArg = allPercentIsLiteral
? TypeCoercionUtils.castIfNotSameType(percentArrayLiteral, ArrayType.of(DoubleType.INSTANCE))
: TypeCoercionUtils.castIfNotSameType(percentArray, ArrayType.of(DoubleType.INSTANCE));
if (entry.getKey().isDistinct) {
percentileArray = new PercentileArray(true, entry.getKey().getExpression(), secondArg);
} else {
percentileArray = new PercentileArray(entry.getKey().getExpression(), secondArg);
}
newPercentileArrays.add(percentileArray);
}
return newPercentileArrays;
}
// Find all the percentile functions and place them in the map
// with the first parameter of the percentile as the key
private Map<DistinctAndExpr, List<AggregateFunction>> collectFuncMap(LogicalAggregate<Plan> aggregate) {
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
Map<DistinctAndExpr, List<AggregateFunction>> funcMap = new HashMap<>();
for (AggregateFunction func : aggregateFunctions) {
if (!(func instanceof Percentile)) {
continue;
}
DistinctAndExpr distictAndExpr = new DistinctAndExpr(func.child(0), func.isDistinct());
funcMap.computeIfAbsent(distictAndExpr, k -> new ArrayList<>()).add(func);
}
funcMap.entrySet().removeIf(entry -> entry.getValue().size() == 1);
return funcMap;
}
private Plan doMerge(LogicalAggregate<Plan> aggregate) {
Map<DistinctAndExpr, List<AggregateFunction>> funcMap = collectFuncMap(aggregate);
if (funcMap.isEmpty()) {
return aggregate;
}
Set<AggregateFunction> canMergePercentiles = Sets.newHashSet();
for (Map.Entry<DistinctAndExpr, List<AggregateFunction>> entry : funcMap.entrySet()) {
canMergePercentiles.addAll(entry.getValue());
}
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
SetView<AggregateFunction> aggFuncsNotChange = Sets.difference(aggregateFunctions, canMergePercentiles);
// construct new Aggregate
List<AggregateFunction> newPercentileArrays = getPercentileArrays(funcMap);
ImmutableList.Builder<NamedExpression> normalizedAggOutputBuilder =
ImmutableList.builderWithExpectedSize(aggregate.getGroupByExpressions().size()
+ aggFuncsNotChange.size() + newPercentileArrays.size());
List<NamedExpression> groupBySlots = new ArrayList<>();
for (Expression groupBy : aggregate.getGroupByExpressions()) {
groupBySlots.add(((NamedExpression) groupBy).toSlot());
}
normalizedAggOutputBuilder.addAll(groupBySlots);
Set<Alias> existsAliases =
ExpressionUtils.mutableCollect(aggregate.getOutputExpressions(), Alias.class::isInstance);
NormalizeToSlotContext notChangeFuncContext = NormalizeToSlotContext.buildContext(existsAliases,
aggFuncsNotChange);
NormalizeToSlotContext percentileArrayContext = NormalizeToSlotContext.buildContext(new HashSet<>(),
newPercentileArrays);
normalizedAggOutputBuilder.addAll(notChangeFuncContext.pushDownToNamedExpression(aggFuncsNotChange));
normalizedAggOutputBuilder.addAll(percentileArrayContext.pushDownToNamedExpression(newPercentileArrays));
LogicalAggregate<Plan> newAggregate = aggregate.withAggOutput(normalizedAggOutputBuilder.build());
// construct new Project
List<Expression> notChangeForProject = notChangeFuncContext.normalizeToUseSlotRef(
(Set<Expression>) (Set) aggFuncsNotChange);
List<Expression> newPercentileArrayForProject = percentileArrayContext.normalizeToUseSlotRef(
(List<Expression>) (List) newPercentileArrays);
ImmutableList.Builder<NamedExpression> newProjectOutputExpressions = ImmutableList.builder();
newProjectOutputExpressions.addAll((List<NamedExpression>) (List) notChangeForProject);
Map<Expression, List<Alias>> existsAliasMap = Maps.newHashMap();
// existsAliasMap is used to keep upper plan refer the same expr
for (Alias alias : existsAliases) {
existsAliasMap.computeIfAbsent(alias.child(), k -> new ArrayList<>()).add(alias);
}
Map<DistinctAndExpr, Slot> slotMap = Maps.newHashMap();
// slotMap is used to find the correspondence
// between LogicalProject's element_at(percentile_array_slot_reference, i) which replaces the old percentile()
// and the merged percentile_array() in LogicalAggregate
for (int i = 0; i < newPercentileArrays.size(); i++) {
DistinctAndExpr distinctAndExpr = new DistinctAndExpr(newPercentileArrays.get(i)
.child(0), newPercentileArrays.get(i).isDistinct());
slotMap.put(distinctAndExpr, (Slot) newPercentileArrayForProject.get(i));
}
for (Map.Entry<DistinctAndExpr, List<AggregateFunction>> entry : funcMap.entrySet()) {
for (int i = 0; i < entry.getValue().size(); i++) {
AggregateFunction aggFunc = entry.getValue().get(i);
List<Alias> originAliases = existsAliasMap.get(aggFunc);
for (Alias originAlias : originAliases) {
DistinctAndExpr distinctAndExpr = new DistinctAndExpr(aggFunc.child(0), aggFunc.isDistinct());
Alias newAlias = new Alias(originAlias.getExprId(), new ElementAt(slotMap.get(distinctAndExpr),
new IntegerLiteral(i + 1)), originAlias.getName());
newProjectOutputExpressions.add(newAlias);
}
}
}
newProjectOutputExpressions.addAll(groupBySlots);
return new LogicalProject<>(newProjectOutputExpressions.build(), newAggregate);
}
private static class DistinctAndExpr {
private final Expression expression;
private final boolean isDistinct;
public DistinctAndExpr(Expression expression, boolean isDistinct) {
this.expression = expression;
this.isDistinct = isDistinct;
}
public Expression getExpression() {
return expression;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
DistinctAndExpr a = (DistinctAndExpr) o;
return isDistinct == a.isDistinct
&& Objects.equals(expression, a.expression);
}
@Override
public int hashCode() {
return Objects.hash(expression, isDistinct);
}
}
}