SetPreAggStatus.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.MaterializedIndexMeta;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
/**
* SetPreAggStatus
* bottom-up tranverse the plan tree and collect required info into PreAggInfoContext
* when get to the bottom LogicalOlapScan node, we set the preagg status using info in PreAggInfoContext
*/
public class SetPreAggStatus extends DefaultPlanRewriter<Stack<SetPreAggStatus.PreAggInfoContext>>
implements CustomRewriter {
private Map<RelationId, PreAggInfoContext> olapScanPreAggContexts = new HashMap<>();
/**
* PreAggInfoContext
*/
public static class PreAggInfoContext {
private List<Expression> filterConjuncts = new ArrayList<>();
private List<Expression> joinConjuncts = new ArrayList<>();
private List<Expression> groupByExpresssions = new ArrayList<>();
private Set<AggregateFunction> aggregateFunctions = new HashSet<>();
private Set<RelationId> olapScanIds = new HashSet<>();
private Map<Slot, Expression> replaceMap = new HashMap<>();
private void setReplaceMap(Map<Slot, Expression> replaceMap) {
this.replaceMap = replaceMap;
}
private void addRelationId(RelationId id) {
olapScanIds.add(id);
}
private void addJoinInfo(LogicalJoin logicalJoin) {
joinConjuncts.addAll(logicalJoin.getExpressions());
joinConjuncts = Lists.newArrayList(ExpressionUtils.replace(joinConjuncts, replaceMap));
}
private void addFilterConjuncts(List<Expression> conjuncts) {
filterConjuncts.addAll(conjuncts);
filterConjuncts = Lists.newArrayList(ExpressionUtils.replace(filterConjuncts, replaceMap));
}
private void addGroupByExpresssions(List<Expression> expressions) {
groupByExpresssions.addAll(expressions);
groupByExpresssions = Lists.newArrayList(ExpressionUtils.replace(groupByExpresssions, replaceMap));
}
private void addAggregateFunctions(Set<AggregateFunction> functions) {
aggregateFunctions.addAll(functions);
Set<AggregateFunction> newAggregateFunctions = Sets.newHashSet();
for (AggregateFunction aggregateFunction : aggregateFunctions) {
newAggregateFunctions
.add((AggregateFunction) ExpressionUtils.replace(aggregateFunction, replaceMap));
}
aggregateFunctions = newAggregateFunctions;
}
}
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
Plan newPlan = plan.accept(this, new Stack<>());
return newPlan.accept(SetOlapScanPreAgg.INSTANCE, olapScanPreAggContexts);
}
@Override
public Plan visit(Plan plan, Stack<PreAggInfoContext> context) {
Plan newPlan = super.visit(plan, context);
context.clear();
return newPlan;
}
@Override
public Plan visitLogicalOlapScan(LogicalOlapScan logicalOlapScan, Stack<PreAggInfoContext> context) {
if (logicalOlapScan.isPreAggStatusUnSet()) {
long selectIndexId = logicalOlapScan.getSelectedIndexId();
MaterializedIndexMeta meta = logicalOlapScan.getTable().getIndexMetaByIndexId(selectIndexId);
if (meta.getKeysType() == KeysType.DUP_KEYS || (meta.getKeysType() == KeysType.UNIQUE_KEYS
&& logicalOlapScan.getTable().getEnableUniqueKeyMergeOnWrite())) {
return logicalOlapScan.withPreAggStatus(PreAggStatus.on());
} else {
if (context.empty()) {
context.push(new PreAggInfoContext());
}
context.peek().addRelationId(logicalOlapScan.getRelationId());
return logicalOlapScan;
}
} else {
return logicalOlapScan;
}
}
@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> logicalFilter, Stack<PreAggInfoContext> context) {
LogicalFilter plan = (LogicalFilter) super.visit(logicalFilter, context);
if (!context.empty()) {
context.peek().addFilterConjuncts(plan.getExpressions());
}
return plan;
}
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> logicalJoin,
Stack<PreAggInfoContext> context) {
LogicalJoin plan = (LogicalJoin) super.visit(logicalJoin, context);
if (!context.empty()) {
context.peek().addJoinInfo(plan);
}
return plan;
}
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> logicalProject,
Stack<PreAggInfoContext> context) {
LogicalProject plan = (LogicalProject) super.visit(logicalProject, context);
if (!context.empty()) {
context.peek().setReplaceMap(plan.getAliasToProducer());
}
return plan;
}
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> logicalAggregate,
Stack<PreAggInfoContext> context) {
Plan plan = super.visit(logicalAggregate, context);
if (!context.isEmpty()) {
PreAggInfoContext preAggInfoContext = context.pop();
preAggInfoContext.addAggregateFunctions(logicalAggregate.getAggregateFunctions());
preAggInfoContext.addGroupByExpresssions(nonVirtualGroupByExprs(logicalAggregate));
for (RelationId id : preAggInfoContext.olapScanIds) {
olapScanPreAggContexts.put(id, preAggInfoContext);
}
}
return plan;
}
@Override
public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> logicalRepeat, Stack<PreAggInfoContext> context) {
return super.visit(logicalRepeat, context);
}
private List<Expression> nonVirtualGroupByExprs(LogicalAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.filter(expr -> !(expr instanceof VirtualSlotReference))
.collect(ImmutableList.toImmutableList());
}
private static class SetOlapScanPreAgg extends DefaultPlanRewriter<Map<RelationId, PreAggInfoContext>> {
private static SetOlapScanPreAgg INSTANCE = new SetOlapScanPreAgg();
@Override
public Plan visitLogicalOlapScan(LogicalOlapScan olapScan, Map<RelationId, PreAggInfoContext> context) {
if (olapScan.isPreAggStatusUnSet()) {
PreAggStatus preAggStatus = PreAggStatus.off("No valid aggregate on scan.");
PreAggInfoContext preAggInfoContext = context.get(olapScan.getRelationId());
if (preAggInfoContext != null) {
preAggStatus = createPreAggStatus(olapScan, preAggInfoContext);
}
return olapScan.withPreAggStatus(preAggStatus);
} else {
return olapScan;
}
}
private PreAggStatus createPreAggStatus(LogicalOlapScan logicalOlapScan, PreAggInfoContext context) {
List<Expression> filterConjuncts = context.filterConjuncts;
List<Expression> joinConjuncts = context.joinConjuncts;
Set<AggregateFunction> aggregateFuncs = context.aggregateFunctions;
List<Expression> groupingExprs = context.groupByExpresssions;
Set<Slot> outputSlots = logicalOlapScan.getOutputSet();
Pair<Set<SlotReference>, Set<SlotReference>> splittedSlots = splitKeyValueSlots(outputSlots);
Set<SlotReference> keySlots = splittedSlots.first;
Set<SlotReference> valueSlots = splittedSlots.second;
Preconditions.checkState(outputSlots.size() == keySlots.size() + valueSlots.size(),
"output slots contains no key or value slots");
Set<Slot> groupingExprsInputSlots = ExpressionUtils.getInputSlotSet(groupingExprs);
if (!Sets.intersection(groupingExprsInputSlots, valueSlots).isEmpty()) {
return PreAggStatus
.off(String.format("Grouping expression %s contains non-key column %s",
groupingExprs, groupingExprsInputSlots));
}
Set<Slot> filterInputSlots = ExpressionUtils.getInputSlotSet(filterConjuncts);
if (!Sets.intersection(filterInputSlots, valueSlots).isEmpty()) {
return PreAggStatus.off(String.format("Filter conjuncts %s contains non-key column %s",
filterConjuncts, filterInputSlots));
}
Set<Slot> joinInputSlots = ExpressionUtils.getInputSlotSet(joinConjuncts);
if (!Sets.intersection(joinInputSlots, valueSlots).isEmpty()) {
return PreAggStatus.off(String.format("Join conjuncts %s contains non-key column %s",
joinConjuncts, joinInputSlots));
}
Set<AggregateFunction> candidateAggFuncs = Sets.newHashSet();
for (AggregateFunction aggregateFunction : aggregateFuncs) {
if (!Sets.intersection(aggregateFunction.getInputSlots(), outputSlots).isEmpty()) {
candidateAggFuncs.add(aggregateFunction);
} else {
if (!(aggregateFunction instanceof Max || aggregateFunction instanceof Min
|| (aggregateFunction instanceof Count && aggregateFunction.isDistinct()))) {
return PreAggStatus.off(
String.format("can't turn preAgg on because aggregate function %s in other table",
aggregateFunction));
}
}
}
Set<Slot> candidateGroupByInputSlots = Sets.newHashSet();
candidateGroupByInputSlots.addAll(groupingExprsInputSlots);
candidateGroupByInputSlots.retainAll(outputSlots);
if (candidateAggFuncs.isEmpty() && candidateGroupByInputSlots.isEmpty()) {
return !aggregateFuncs.isEmpty() || !groupingExprs.isEmpty() ? PreAggStatus.on()
: PreAggStatus.off("No aggregate on scan.");
} else {
return checkAggregateFunctions(candidateAggFuncs, candidateGroupByInputSlots);
}
}
private PreAggStatus checkAggregateFunctions(Set<AggregateFunction> aggregateFuncs,
Set<Slot> groupingExprsInputSlots) {
if (aggregateFuncs.isEmpty() && groupingExprsInputSlots.isEmpty()) {
return PreAggStatus.off("No aggregate on scan.");
}
PreAggStatus preAggStatus = PreAggStatus.on();
for (AggregateFunction aggFunc : aggregateFuncs) {
if (aggFunc.children().isEmpty()) {
preAggStatus = PreAggStatus.off(
String.format("can't turn preAgg on for aggregate function %s", aggFunc));
} else if (aggFunc.children().size() == 1 && aggFunc.child(0) instanceof Slot) {
Slot aggSlot = (Slot) aggFunc.child(0);
if (aggSlot instanceof SlotReference
&& ((SlotReference) aggSlot).getOriginalColumn().isPresent()) {
if (((SlotReference) aggSlot).getOriginalColumn().get().isKey()) {
preAggStatus = OneKeySlotAggChecker.INSTANCE.check(aggFunc);
} else {
preAggStatus = OneValueSlotAggChecker.INSTANCE.check(aggFunc,
((SlotReference) aggSlot).getOriginalColumn().get().getAggregationType());
}
} else {
preAggStatus = PreAggStatus.off(
String.format("aggregate function %s use unknown slot %s from scan",
aggFunc, aggSlot));
}
} else {
Set<Slot> aggSlots = aggFunc.getInputSlots();
Pair<Set<SlotReference>, Set<SlotReference>> splitSlots = splitKeyValueSlots(aggSlots);
preAggStatus = checkAggWithKeyAndValueSlots(aggFunc, splitSlots.first, splitSlots.second);
}
if (preAggStatus.isOff()) {
return preAggStatus;
}
}
return preAggStatus;
}
private Pair<Set<SlotReference>, Set<SlotReference>> splitKeyValueSlots(Set<Slot> slots) {
Set<SlotReference> keySlots = com.google.common.collect.Sets.newHashSetWithExpectedSize(slots.size());
Set<SlotReference> valueSlots = com.google.common.collect.Sets.newHashSetWithExpectedSize(slots.size());
for (Slot slot : slots) {
if (slot instanceof SlotReference && ((SlotReference) slot).getOriginalColumn().isPresent()) {
if (((SlotReference) slot).getOriginalColumn().get().isKey()) {
keySlots.add((SlotReference) slot);
} else {
valueSlots.add((SlotReference) slot);
}
}
}
return Pair.of(keySlots, valueSlots);
}
private PreAggStatus checkAggWithKeyAndValueSlots(AggregateFunction aggFunc,
Set<SlotReference> keySlots, Set<SlotReference> valueSlots) {
Expression child = aggFunc.child(0);
List<Expression> conditionExps = new ArrayList<>();
List<Expression> returnExps = new ArrayList<>();
// ignore cast
while (child instanceof Cast) {
if (!((Cast) child).getDataType().isNumericType()) {
return PreAggStatus.off(String.format("%s is not numeric CAST.", child.toSql()));
}
child = child.child(0);
}
// step 1: extract all condition exprs and return exprs
if (child instanceof If) {
conditionExps.add(child.child(0));
returnExps.add(removeCast(child.child(1)));
returnExps.add(removeCast(child.child(2)));
} else if (child instanceof CaseWhen) {
CaseWhen caseWhen = (CaseWhen) child;
// WHEN THEN
for (WhenClause whenClause : caseWhen.getWhenClauses()) {
conditionExps.add(whenClause.getOperand());
returnExps.add(removeCast(whenClause.getResult()));
}
// ELSE
returnExps.add(removeCast(caseWhen.getDefaultValue().orElse(new NullLiteral())));
} else {
// currently, only IF and CASE WHEN are supported
returnExps.add(removeCast(child));
}
// step 2: check condition expressions
Set<Slot> inputSlots = ExpressionUtils.getInputSlotSet(conditionExps);
if (!keySlots.containsAll(inputSlots)) {
return PreAggStatus
.off(String.format("some columns in condition %s is not key.", conditionExps));
}
return KeyAndValueSlotsAggChecker.INSTANCE.check(aggFunc, returnExps);
}
private static Expression removeCast(Expression expression) {
while (expression instanceof Cast) {
expression = ((Cast) expression).child();
}
return expression;
}
private static class OneValueSlotAggChecker
extends ExpressionVisitor<PreAggStatus, AggregateType> {
public static final OneValueSlotAggChecker INSTANCE = new OneValueSlotAggChecker();
public PreAggStatus check(AggregateFunction aggFun, AggregateType aggregateType) {
return aggFun.accept(INSTANCE, aggregateType);
}
@Override
public PreAggStatus visit(Expression expr, AggregateType aggregateType) {
return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql()));
}
@Override
public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction,
AggregateType aggregateType) {
return PreAggStatus
.off(String.format("%s is not supported.", aggregateFunction.toSql()));
}
@Override
public PreAggStatus visitMax(Max max, AggregateType aggregateType) {
if (aggregateType == AggregateType.MAX && !max.isDistinct()) {
return PreAggStatus.on();
} else {
return PreAggStatus
.off(String.format("%s is not match agg mode %s or has distinct param",
max.toSql(), aggregateType));
}
}
@Override
public PreAggStatus visitMin(Min min, AggregateType aggregateType) {
if (aggregateType == AggregateType.MIN && !min.isDistinct()) {
return PreAggStatus.on();
} else {
return PreAggStatus
.off(String.format("%s is not match agg mode %s or has distinct param",
min.toSql(), aggregateType));
}
}
@Override
public PreAggStatus visitSum(Sum sum, AggregateType aggregateType) {
if (aggregateType == AggregateType.SUM && !sum.isDistinct()) {
return PreAggStatus.on();
} else {
return PreAggStatus
.off(String.format("%s is not match agg mode %s or has distinct param",
sum.toSql(), aggregateType));
}
}
@Override
public PreAggStatus visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount,
AggregateType aggregateType) {
if (aggregateType == AggregateType.BITMAP_UNION) {
return PreAggStatus.on();
} else {
return PreAggStatus.off("invalid bitmap_union_count: " + bitmapUnionCount.toSql());
}
}
@Override
public PreAggStatus visitBitmapUnion(BitmapUnion bitmapUnion, AggregateType aggregateType) {
if (aggregateType == AggregateType.BITMAP_UNION) {
return PreAggStatus.on();
} else {
return PreAggStatus.off("invalid bitmapUnion: " + bitmapUnion.toSql());
}
}
@Override
public PreAggStatus visitHllUnionAgg(HllUnionAgg hllUnionAgg, AggregateType aggregateType) {
if (aggregateType == AggregateType.HLL_UNION) {
return PreAggStatus.on();
} else {
return PreAggStatus.off("invalid hllUnionAgg: " + hllUnionAgg.toSql());
}
}
@Override
public PreAggStatus visitHllUnion(HllUnion hllUnion, AggregateType aggregateType) {
if (aggregateType == AggregateType.HLL_UNION) {
return PreAggStatus.on();
} else {
return PreAggStatus.off("invalid hllUnion: " + hllUnion.toSql());
}
}
}
private static class OneKeySlotAggChecker extends ExpressionVisitor<PreAggStatus, Void> {
public static final OneKeySlotAggChecker INSTANCE = new OneKeySlotAggChecker();
public PreAggStatus check(AggregateFunction aggFun) {
return aggFun.accept(INSTANCE, null);
}
@Override
public PreAggStatus visit(Expression expr, Void context) {
return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql()));
}
@Override
public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction,
Void context) {
if (aggregateFunction.isDistinct()) {
return PreAggStatus.on();
} else {
return PreAggStatus.off(String.format("%s is not distinct.", aggregateFunction.toSql()));
}
}
@Override
public PreAggStatus visitMax(Max max, Void context) {
return PreAggStatus.on();
}
@Override
public PreAggStatus visitMin(Min min, Void context) {
return PreAggStatus.on();
}
}
private static class KeyAndValueSlotsAggChecker
extends ExpressionVisitor<PreAggStatus, List<Expression>> {
public static final KeyAndValueSlotsAggChecker INSTANCE = new KeyAndValueSlotsAggChecker();
public PreAggStatus check(AggregateFunction aggFun, List<Expression> returnValues) {
return aggFun.accept(INSTANCE, returnValues);
}
@Override
public PreAggStatus visit(Expression expr, List<Expression> returnValues) {
return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql()));
}
@Override
public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction,
List<Expression> returnValues) {
return PreAggStatus
.off(String.format("%s is not supported.", aggregateFunction.toSql()));
}
@Override
public PreAggStatus visitSum(Sum sum, List<Expression> returnValues) {
for (Expression value : returnValues) {
if (!(isAggTypeMatched(value, AggregateType.SUM) || value.isZeroLiteral()
|| value.isNullLiteral())) {
return PreAggStatus.off(String.format("%s is not supported.", sum.toSql()));
}
}
return PreAggStatus.on();
}
@Override
public PreAggStatus visitMax(Max max, List<Expression> returnValues) {
for (Expression value : returnValues) {
if (!(isAggTypeMatched(value, AggregateType.MAX) || isKeySlot(value)
|| value.isNullLiteral())) {
return PreAggStatus.off(String.format("%s is not supported.", max.toSql()));
}
}
return PreAggStatus.on();
}
@Override
public PreAggStatus visitMin(Min min, List<Expression> returnValues) {
for (Expression value : returnValues) {
if (!(isAggTypeMatched(value, AggregateType.MIN) || isKeySlot(value)
|| value.isNullLiteral())) {
return PreAggStatus.off(String.format("%s is not supported.", min.toSql()));
}
}
return PreAggStatus.on();
}
@Override
public PreAggStatus visitCount(Count count, List<Expression> returnValues) {
if (count.isDistinct()) {
for (Expression value : returnValues) {
if (!(isKeySlot(value) || value.isZeroLiteral() || value.isNullLiteral())) {
return PreAggStatus
.off(String.format("%s is not supported.", count.toSql()));
}
}
return PreAggStatus.on();
} else {
return PreAggStatus.off(String.format("%s is not supported.", count.toSql()));
}
}
private boolean isKeySlot(Expression expression) {
return expression instanceof SlotReference
&& ((SlotReference) expression).getOriginalColumn().isPresent()
&& ((SlotReference) expression).getOriginalColumn().get().isKey();
}
private boolean isAggTypeMatched(Expression expression, AggregateType aggregateType) {
return expression instanceof SlotReference
&& ((SlotReference) expression).getOriginalColumn().isPresent()
&& ((SlotReference) expression).getOriginalColumn().get()
.getAggregationType() == aggregateType;
}
}
}
}