AggregateStrategies.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.implementation;
import org.apache.doris.analysis.IndexDef.IndexType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.MaterializedIndexMeta;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.pattern.PatternDescriptor;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.RequireProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregatePhase;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
import org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.Project;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHudiScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate.PushDownAggOp;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/** AggregateStrategies */
@DependsRules({
NormalizeAggregate.class,
FoldConstantRuleOnFE.class
})
public class AggregateStrategies implements ImplementationRuleFactory {
@Override
public List<Rule> buildRules() {
PatternDescriptor<LogicalAggregate<GroupPlan>> basePattern = logicalAggregate();
return ImmutableList.of(
RuleType.COUNT_ON_INDEX_WITHOUT_PROJECT.build(
logicalAggregate(
logicalFilter(
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
)
)
.when(agg -> enablePushDownCountOnIndex())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
if (funcs.isEmpty() || !funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot))) {
return false;
}
Set<Expression> conjuncts = agg.child().getConjuncts();
if (conjuncts.isEmpty()) {
return false;
}
Set<Slot> aggSlots = funcs.stream()
.flatMap(f -> f.getInputSlots().stream())
.collect(Collectors.toSet());
return aggSlots.isEmpty() || conjuncts.stream().allMatch(expr ->
checkSlotInOrExpression(expr, aggSlots) && checkIsNullExpr(expr, aggSlots));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
LogicalFilter<LogicalOlapScan> filter = agg.child();
LogicalOlapScan olapScan = filter.child();
return pushdownCountOnIndex(agg, null, filter, olapScan, ctx.cascadesContext);
})
),
RuleType.COUNT_ON_INDEX.build(
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
)
)
)
.when(agg -> enablePushDownCountOnIndex())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
if (funcs.isEmpty() || !funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot))) {
return false;
}
Set<Expression> conjuncts = agg.child().child().getConjuncts();
if (conjuncts.isEmpty()) {
return false;
}
Set<Slot> aggSlots = funcs.stream()
.flatMap(f -> f.getInputSlots().stream())
.collect(Collectors.toSet());
return aggSlots.isEmpty() || conjuncts.stream().allMatch(expr ->
checkSlotInOrExpression(expr, aggSlots) && checkIsNullExpr(expr, aggSlots));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
LogicalFilter<LogicalOlapScan> filter = project.child();
LogicalOlapScan olapScan = filter.child();
return pushdownCountOnIndex(agg, project, filter, olapScan, ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE_WITHOUT_PROJECT.build(
logicalAggregate(
logicalFilter(
logicalOlapScan().when(this::isUniqueKeyTable))
.when(filter -> {
if (filter.getConjuncts().size() != 1) {
return false;
}
Expression childExpr = filter.getConjuncts().iterator().next().children().get(0);
if (childExpr instanceof SlotReference) {
Optional<Column> column = ((SlotReference) childExpr).getOriginalColumn();
return column.map(Column::isDeleteSignColumn).orElse(false);
}
return false;
})
)
.when(agg -> enablePushDownMinMaxOnUnique())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min) || (f instanceof Max));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
LogicalFilter<LogicalOlapScan> filter = agg.child();
LogicalOlapScan olapScan = filter.child();
return pushdownMinMaxOnUniqueTable(agg, null, filter, olapScan,
ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE.build(
logicalAggregate(logicalProject(logicalFilter(logicalOlapScan().when(this::isUniqueKeyTable))
.when(filter -> {
if (filter.getConjuncts().size() != 1) {
return false;
}
Expression childExpr = filter.getConjuncts().iterator().next()
.children().get(0);
if (childExpr instanceof SlotReference) {
Optional<Column> column = ((SlotReference) childExpr).getOriginalColumn();
return column.map(Column::isDeleteSignColumn).orElse(false);
}
return false;
})))
.when(agg -> enablePushDownMinMaxOnUnique())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty()
&& funcs.stream().allMatch(f -> (f instanceof Min) || (f instanceof Max));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
LogicalFilter<LogicalOlapScan> filter = project.child();
LogicalOlapScan olapScan = filter.child();
return pushdownMinMaxOnUniqueTable(agg, project, filter, olapScan,
ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT.build(
logicalAggregate(
logicalOlapScan()
)
.when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
.thenApply(ctx -> storageLayerAggregate(ctx.root, null, ctx.root.child(), ctx.cascadesContext))
),
RuleType.STORAGE_LAYER_WITH_PROJECT_NO_SLOT_REF.build(
logicalProject(
logicalOlapScan()
)
.thenApply(ctx -> {
LogicalProject<LogicalOlapScan> project = ctx.root;
LogicalOlapScan olapScan = project.child();
return pushDownCountWithoutSlotRef(project, olapScan, ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_WITH_PROJECT.build(
logicalAggregate(
logicalProject(
logicalOlapScan()
)
)
.when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalOlapScan>> agg = ctx.root;
LogicalProject<LogicalOlapScan> project = agg.child();
LogicalOlapScan olapScan = project.child();
return storageLayerAggregate(agg, project, olapScan, ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT_FOR_FILE_SCAN.build(
logicalAggregate(
logicalFileScan()
)
.when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
.thenApply(ctx -> storageLayerAggregate(ctx.root, null, ctx.root.child(), ctx.cascadesContext))
),
RuleType.STORAGE_LAYER_AGGREGATE_WITH_PROJECT_FOR_FILE_SCAN.build(
logicalAggregate(
logicalProject(
logicalFileScan()
)
).when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalFileScan>> agg = ctx.root;
LogicalProject<LogicalFileScan> project = agg.child();
LogicalFileScan fileScan = project.child();
return storageLayerAggregate(agg, project, fileScan, ctx.cascadesContext);
})
),
RuleType.ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().isEmpty())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.ONE))
.thenApplyMulti(ctx -> onePhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().isEmpty())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
.thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
// RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
// basePattern
// .when(this::containsCountDistinctMultiExpr)
// .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
// .thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
// ),
RuleType.THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
basePattern
.when(this::containsCountDistinctMultiExpr)
.when(agg -> agg.supportAggregatePhase(AggregatePhase.THREE))
.thenApplyMulti(ctx -> threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
),
RuleType.ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg))
.when(agg -> agg.supportAggregatePhase(AggregatePhase.ONE))
.thenApplyMulti(ctx -> onePhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg))
.when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
.thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() > 1
&& !containsCountDistinctMultiExpr(agg)
&& couldConvertToMulti(agg))
.when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
.thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
// RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build(
// basePattern
// .when(agg -> agg.getDistinctArguments().size() == 1)
// .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
// .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
// ),
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.THREE))
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
/*
* sql:
* select count(distinct name), sum(age) from student;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(),
* output[count(partial_count(name)), sum(partial_sum(partial_sum(age)))],
* GATHER)
* +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(),
* output(partial_count(name), partial_sum(partial_sum(age))),
* hash distribute by name)
* +--GLOBAL(BUFFER_TO_BUFFER, groupBy(name),
* output(name, partial_sum(age)),
* hash_distribute by name)
* +--LOCAL(INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age)))
* +--scan(name, age)
*/
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.FOUR))
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireDistinctHash =
groupByAndDistinct -> RequireProperties.of(
PhysicalProperties.createHash(
ctx.root.getDistinctArguments(), ShuffleType.REQUIRE
)
);
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGather =
agg -> RequireProperties.of(PhysicalProperties.GATHER);
return fourPhaseAggregateWithDistinct(
ctx.root, ctx.connectContext,
secondPhaseRequireDistinctHash, fourPhaseRequireGather
);
})
),
/*
* sql:
* select age, count(distinct name) from student group by age;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(age),
* output[age, sum(partial_count(name))],
* hash distribute by name)
* +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(age),
* output(age, partial_count(name)),
* hash distribute by age, name)
* +--GLOBAL(BUFFER_TO_BUFFER, groupBy(age, name),
* output(age, name),
* hash_distribute by age, name)
* +--LOCAL(INPUT_TO_BUFFER, groupBy(age, name), output(age, name))
* +--scan(age, name)
*/
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE.build(
basePattern
.when(agg -> agg.everyDistinctArgumentNumIsOne() && !agg.getGroupByExpressions().isEmpty())
.when(agg ->
ImmutableSet.builder()
.addAll(agg.getGroupByExpressions())
.addAll(agg.getDistinctArguments())
.build().size() > agg.getGroupByExpressions().size()
)
.when(agg -> {
if (agg.getDistinctArguments().size() == 1) {
return true;
}
return couldConvertToMulti(agg);
})
.when(agg -> agg.supportAggregatePhase(AggregatePhase.FOUR))
.whenNot(Aggregate::mustUseMultiDistinctAgg)
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireGroupByAndDistinctHash =
groupByAndDistinct -> RequireProperties.of(
PhysicalProperties.createHash(groupByAndDistinct, ShuffleType.REQUIRE)
);
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGroupByHash =
agg -> RequireProperties.of(
PhysicalProperties.createHash(
agg.getGroupByExpressions(), ShuffleType.REQUIRE
)
);
return fourPhaseAggregateWithDistinct(
ctx.root, ctx.connectContext,
secondPhaseRequireGroupByAndDistinctHash, fourPhaseRequireGroupByHash
);
})
)
);
}
/*
* select 66 from baseall_dup; could use pushAggOp=COUNT to not scan real data.
*/
private LogicalProject<? extends Plan> pushDownCountWithoutSlotRef(
LogicalProject<? extends Plan> project,
LogicalOlapScan logicalScan,
CascadesContext cascadesContext) {
final LogicalProject<? extends Plan> canNotPush = project;
if (!enablePushDownNoGroupAgg()) {
return canNotPush;
}
if (logicalScan != null) {
KeysType keysType = logicalScan.getTable().getKeysType();
if (keysType != KeysType.DUP_KEYS) {
return canNotPush;
}
}
for (Expression e : project.getProjects()) {
if (e.anyMatch(SlotReference.class::isInstance)) {
return canNotPush;
}
}
PhysicalOlapScan physicalOlapScan
= (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan()
.build()
.transform(logicalScan, cascadesContext)
.get(0);
return project.withChildren(ImmutableList.of(new PhysicalStorageLayerAggregate(
physicalOlapScan, PushDownAggOp.COUNT)));
}
private boolean enablePushDownMinMaxOnUnique() {
ConnectContext connectContext = ConnectContext.get();
return connectContext != null && connectContext.getSessionVariable().isEnablePushDownMinMaxOnUnique();
}
private boolean isUniqueKeyTable(LogicalOlapScan logicalScan) {
if (logicalScan != null) {
KeysType keysType = logicalScan.getTable().getKeysType();
return keysType == KeysType.UNIQUE_KEYS;
}
return false;
}
private boolean enablePushDownCountOnIndex() {
ConnectContext connectContext = ConnectContext.get();
return connectContext != null && connectContext.getSessionVariable().isEnablePushDownCountOnIndex();
}
private boolean checkSlotInOrExpression(Expression expr, Set<Slot> aggSlots) {
if (expr instanceof Or) {
Set<Slot> slots = expr.getInputSlots();
if (!slots.stream().allMatch(aggSlots::contains)) {
return false;
}
} else {
for (Expression child : expr.children()) {
if (!checkSlotInOrExpression(child, aggSlots)) {
return false;
}
}
}
return true;
}
private boolean checkIsNullExpr(Expression expr, Set<Slot> aggSlots) {
if (expr instanceof IsNull) {
Set<Slot> slots = expr.getInputSlots();
if (slots.stream().anyMatch(aggSlots::contains)) {
return false;
}
} else {
for (Expression child : expr.children()) {
if (!checkIsNullExpr(child, aggSlots)) {
return false;
}
}
}
return true;
}
private boolean isDupOrMowKeyTable(LogicalOlapScan logicalScan) {
if (logicalScan != null) {
KeysType keysType = logicalScan.getTable().getKeysType();
return (keysType == KeysType.DUP_KEYS)
|| (keysType == KeysType.UNIQUE_KEYS && logicalScan.getTable().getEnableUniqueKeyMergeOnWrite());
}
return false;
}
private boolean isInvertedIndexEnabledOnTable(LogicalOlapScan logicalScan) {
if (logicalScan == null) {
return false;
}
OlapTable olapTable = logicalScan.getTable();
Map<Long, MaterializedIndexMeta> indexIdToMeta = olapTable.getIndexIdToMeta();
return indexIdToMeta.values().stream()
.anyMatch(indexMeta -> indexMeta.getIndexes().stream()
.anyMatch(index -> index.getIndexType() == IndexType.INVERTED
|| index.getIndexType() == IndexType.BITMAP));
}
/**
* sql: select count(*) from tbl where column match 'token'
* <p>
* before:
* <p>
* LogicalAggregate(groupBy=[], output=[count(*)])
* |
* LogicalFilter(column match 'token')
* |
* LogicalOlapScan(table=tbl)
* <p>
* after:
* <p>
* LogicalAggregate(groupBy=[], output=[count(*)])
* |
* LogicalFilter(column match 'token')
* |
* PhysicalStorageLayerAggregate(pushAggOp=COUNT_ON_INDEX, table=PhysicalOlapScan(table=tbl))
*
*/
private LogicalAggregate<? extends Plan> pushdownCountOnIndex(
LogicalAggregate<? extends Plan> agg,
@Nullable LogicalProject<? extends Plan> project,
LogicalFilter<? extends Plan> filter,
LogicalOlapScan olapScan,
CascadesContext cascadesContext) {
PhysicalOlapScan physicalOlapScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan()
.build()
.transform(olapScan, cascadesContext)
.get(0);
List<Expression> argumentsOfAggregateFunction = normalizeArguments(agg.getAggregateFunctions(), project);
if (!onlyContainsSlot(argumentsOfAggregateFunction)) {
return agg;
}
return agg.withChildren(ImmutableList.of(
project != null
? project.withChildren(ImmutableList.of(
filter.withChildren(ImmutableList.of(
new PhysicalStorageLayerAggregate(
physicalOlapScan, PushDownAggOp.COUNT_ON_MATCH)))))
: filter.withChildren(ImmutableList.of(
new PhysicalStorageLayerAggregate(
physicalOlapScan, PushDownAggOp.COUNT_ON_MATCH)))
));
}
private List<Expression> normalizeArguments(Set<AggregateFunction> aggregateFunctions,
@Nullable LogicalProject<? extends Plan> project) {
List<Expression> arguments = aggregateFunctions.stream()
.flatMap(aggregateFunction -> aggregateFunction.getArguments().stream())
.collect(ImmutableList.toImmutableList());
if (project != null) {
arguments = Project.findProject(arguments, project.getProjects())
.stream()
.map(p -> p instanceof Alias ? p.child(0) : p)
.collect(ImmutableList.toImmutableList());
}
return arguments;
}
private boolean onlyContainsSlot(List<Expression> arguments) {
return arguments.stream().allMatch(argument -> {
if (argument instanceof SlotReference) {
return true;
}
return false;
});
}
//select /*+SET_VAR(enable_pushdown_minmax_on_unique=true) */min(user_id) from table_unique;
//push pushAggOp=MINMAX to scan node
private LogicalAggregate<? extends Plan> pushdownMinMaxOnUniqueTable(
LogicalAggregate<? extends Plan> aggregate,
@Nullable LogicalProject<? extends Plan> project,
LogicalFilter<? extends Plan> filter,
LogicalOlapScan olapScan,
CascadesContext cascadesContext) {
final LogicalAggregate<? extends Plan> canNotPush = aggregate;
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
if (checkWhetherPushDownMinMax(aggregateFunctions, project, olapScan.getOutput())) {
PhysicalOlapScan physicalOlapScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan()
.build()
.transform(olapScan, cascadesContext)
.get(0);
if (project != null) {
return aggregate.withChildren(ImmutableList.of(
project.withChildren(ImmutableList.of(
filter.withChildren(ImmutableList.of(
new PhysicalStorageLayerAggregate(
physicalOlapScan,
PushDownAggOp.MIN_MAX)))))));
} else {
return aggregate.withChildren(ImmutableList.of(
filter.withChildren(ImmutableList.of(
new PhysicalStorageLayerAggregate(
physicalOlapScan,
PushDownAggOp.MIN_MAX)))));
}
} else {
return canNotPush;
}
}
private boolean checkWhetherPushDownMinMax(Set<AggregateFunction> aggregateFunctions,
@Nullable LogicalProject<? extends Plan> project, List<Slot> outPutSlots) {
boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream()
.map(ExpressionTrait::getArguments)
.flatMap(List::stream)
.allMatch(argument -> argument instanceof SlotReference);
if (!onlyContainsSlotOrNumericCastSlot) {
return false;
}
List<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(aggregateFunction -> aggregateFunction.getArguments().stream())
.collect(ImmutableList.toImmutableList());
if (project != null) {
argumentsOfAggregateFunction = Project.findProject(
argumentsOfAggregateFunction, project.getProjects())
.stream()
.map(p -> p instanceof Alias ? p.child(0) : p)
.collect(ImmutableList.toImmutableList());
}
onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
.stream()
.allMatch(argument -> argument instanceof SlotReference);
if (!onlyContainsSlotOrNumericCastSlot) {
return false;
}
Set<SlotReference> aggUsedSlots = ExpressionUtils.collect(argumentsOfAggregateFunction,
SlotReference.class::isInstance);
List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots, outPutSlots);
for (SlotReference slot : usedSlotInTable) {
Column column = slot.getOriginalColumn().get();
PrimitiveType colType = column.getType().getPrimitiveType();
if (colType.isComplexType() || colType.isHllType() || colType.isBitmapType()) {
return false;
}
}
return true;
}
/**
* sql: select count(*) from tbl
* <p>
* before:
* <p>
* LogicalAggregate(groupBy=[], output=[count(*)])
* |
* LogicalOlapScan(table=tbl)
* <p>
* after:
* <p>
* LogicalAggregate(groupBy=[], output=[count(*)])
* |
* PhysicalStorageLayerAggregate(pushAggOp=COUNT, table=PhysicalOlapScan(table=tbl))
*
*/
private LogicalAggregate<? extends Plan> storageLayerAggregate(
LogicalAggregate<? extends Plan> aggregate,
@Nullable LogicalProject<? extends Plan> project,
LogicalRelation logicalScan, CascadesContext cascadesContext) {
final LogicalAggregate<? extends Plan> canNotPush = aggregate;
if (!(logicalScan instanceof LogicalOlapScan) && !(logicalScan instanceof LogicalFileScan)) {
return canNotPush;
}
if (logicalScan instanceof LogicalOlapScan) {
KeysType keysType = ((LogicalOlapScan) logicalScan).getTable().getKeysType();
if (keysType != KeysType.AGG_KEYS && keysType != KeysType.DUP_KEYS) {
return canNotPush;
}
}
List<Expression> groupByExpressions = aggregate.getGroupByExpressions();
if (!groupByExpressions.isEmpty() || !aggregate.getDistinctArguments().isEmpty()) {
return canNotPush;
}
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
Set<Class<? extends AggregateFunction>> functionClasses = aggregateFunctions
.stream()
.map(AggregateFunction::getClass)
.collect(Collectors.toSet());
Map<Class<? extends AggregateFunction>, PushDownAggOp> supportedAgg = PushDownAggOp.supportedFunctions();
if (!supportedAgg.keySet().containsAll(functionClasses)) {
return canNotPush;
}
if (logicalScan instanceof LogicalOlapScan) {
LogicalOlapScan logicalOlapScan = (LogicalOlapScan) logicalScan;
KeysType keysType = logicalOlapScan.getTable().getKeysType();
if (functionClasses.contains(Count.class) && keysType != KeysType.DUP_KEYS) {
return canNotPush;
}
if (functionClasses.contains(Count.class) && logicalOlapScan.isDirectMvScan()) {
return canNotPush;
}
}
if (aggregateFunctions.stream().anyMatch(fun -> fun.arity() > 1)) {
return canNotPush;
}
// TODO: refactor this to process slot reference or expression together
boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream()
.map(ExpressionTrait::getArguments)
.flatMap(List::stream)
.allMatch(argument -> {
if (argument instanceof SlotReference) {
return true;
}
if (argument instanceof Cast) {
return argument.child(0) instanceof SlotReference
&& argument.getDataType().isNumericType()
&& argument.child(0).getDataType().isNumericType();
}
return false;
});
if (!onlyContainsSlotOrNumericCastSlot) {
return canNotPush;
}
// we already normalize the arguments to slotReference
List<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(aggregateFunction -> aggregateFunction.getArguments().stream())
.collect(ImmutableList.toImmutableList());
if (project != null) {
argumentsOfAggregateFunction = Project.findProject(
argumentsOfAggregateFunction, project.getProjects())
.stream()
.map(p -> p instanceof Alias ? p.child(0) : p)
.collect(ImmutableList.toImmutableList());
}
onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
.stream()
.allMatch(argument -> {
if (argument instanceof SlotReference) {
return true;
}
if (argument instanceof Cast) {
return argument.child(0) instanceof SlotReference
&& argument.getDataType().isNumericType()
&& argument.child(0).getDataType().isNumericType();
}
return false;
});
if (!onlyContainsSlotOrNumericCastSlot) {
return canNotPush;
}
Set<PushDownAggOp> pushDownAggOps = functionClasses.stream()
.map(supportedAgg::get)
.collect(Collectors.toSet());
PushDownAggOp mergeOp = pushDownAggOps.size() == 1
? pushDownAggOps.iterator().next()
: PushDownAggOp.MIX;
Set<SlotReference> aggUsedSlots =
ExpressionUtils.collect(argumentsOfAggregateFunction, SlotReference.class::isInstance);
List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots,
logicalScan.getOutput());
for (SlotReference slot : usedSlotInTable) {
Column column = slot.getOriginalColumn().get();
if (column.isAggregated()) {
return canNotPush;
}
// The zone map max length of CharFamily is 512, do not
// over the length: https://github.com/apache/doris/pull/6293
if (mergeOp == PushDownAggOp.MIN_MAX || mergeOp == PushDownAggOp.MIX) {
PrimitiveType colType = column.getType().getPrimitiveType();
if (colType.isComplexType() || colType.isHllType() || colType.isBitmapType()
|| (colType == PrimitiveType.STRING && !enablePushDownStringMinMax())) {
return canNotPush;
}
if (colType.isCharFamily() && column.getType().getLength() > 512 && !enablePushDownStringMinMax()) {
return canNotPush;
}
}
if (mergeOp == PushDownAggOp.COUNT || mergeOp == PushDownAggOp.MIX) {
// NULL value behavior in `count` function is zero, so
// we should not use row_count to speed up query. the col
// must be not null
if (column.isAllowNull()) {
return canNotPush;
}
}
}
if (logicalScan instanceof LogicalOlapScan) {
PhysicalOlapScan physicalScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan()
.build()
.transform(logicalScan, cascadesContext)
.get(0);
if (project != null) {
return aggregate.withChildren(ImmutableList.of(
project.withChildren(
ImmutableList.of(new PhysicalStorageLayerAggregate(physicalScan, mergeOp)))
));
} else {
return aggregate.withChildren(ImmutableList.of(
new PhysicalStorageLayerAggregate(physicalScan, mergeOp)
));
}
} else if (logicalScan instanceof LogicalFileScan) {
Rule rule = (logicalScan instanceof LogicalHudiScan) ? new LogicalHudiScanToPhysicalHudiScan().build()
: new LogicalFileScanToPhysicalFileScan().build();
PhysicalFileScan physicalScan = (PhysicalFileScan) rule.transform(logicalScan, cascadesContext)
.get(0);
if (project != null) {
return aggregate.withChildren(ImmutableList.of(
project.withChildren(
ImmutableList.of(new PhysicalStorageLayerAggregate(physicalScan, mergeOp)))
));
} else {
return aggregate.withChildren(ImmutableList.of(
new PhysicalStorageLayerAggregate(physicalScan, mergeOp)
));
}
} else {
return canNotPush;
}
}
private boolean enablePushDownStringMinMax() {
ConnectContext connectContext = ConnectContext.get();
return connectContext != null && connectContext.getSessionVariable().isEnablePushDownStringMinMax();
}
/**
* sql: select count(*) from tbl group by id
* <p>
* before:
* <p>
* LogicalAggregate(groupBy=[id], output=[count(*)])
* |
* LogicalOlapScan(table=tbl)
* <p>
* after:
* <p>
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[id], output=[count(*)])
* |
* PhysicalDistribute(distributionSpec=GATHER)
* |
* LogicalOlapScan(table=tbl)
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[id], output=[count(*)])
* |
* LogicalOlapScan(table=tbl, **already distribute by id**)
*
*/
private List<PhysicalHashAggregate<Plan>> onePhaseAggregateWithoutDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
AggregateParam inputToResultParam = AggregateParam.LOCAL_RESULT;
List<NamedExpression> newOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), outputChild -> {
if (outputChild instanceof AggregateFunction) {
return new AggregateExpression((AggregateFunction) outputChild, inputToResultParam);
}
return outputChild;
});
PhysicalHashAggregate<Plan> gatherLocalAgg = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), newOutput, Optional.empty(),
inputToResultParam, false,
logicalAgg.getLogicalProperties(),
requireGather, logicalAgg.child());
if (logicalAgg.getGroupByExpressions().isEmpty()) {
// TODO: usually bad, disable it until we could do better cost computation.
// return ImmutableList.of(gatherLocalAgg);
return ImmutableList.of();
} else {
RequireProperties requireHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
PhysicalHashAggregate<Plan> hashLocalAgg = gatherLocalAgg
.withRequire(requireHash)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<Plan>>builder()
// TODO: usually bad, disable it until we could do better cost computation.
//.add(gatherLocalAgg)
.add(hashLocalAgg)
.build();
}
}
/**
* sql: select count(distinct id, name) from tbl group by name
* <p>
* before:
* <p>
* LogicalAggregate(groupBy=[name], output=[count(distinct id, name)])
* |
* LogicalOlapScan(table=tbl)
* <p>
* after:
* <p>
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))])
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id])
* |
* PhysicalDistribute(distributionSpec=GATHER)
* |
* LogicalOlapScan(table=tbl)
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))])
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id])
* |
* PhysicalDistribute(distributionSpec=HASH(name))
* |
* LogicalOlapScan(table=tbl, **already distribute by name**)
*
*/
private List<PhysicalHashAggregate<Plan>> twoPhaseAggregateWithCountDistinctMulti(
LogicalAggregate<? extends Plan> logicalAgg, CascadesContext cascadesContext) {
AggregateParam inputToBufferParam = AggregateParam.LOCAL_BUFFER;
Collection<Expression> countDistinctArguments = logicalAgg.getDistinctArguments();
List<Expression> localAggGroupBy = ImmutableList.copyOf(ImmutableSet.<Expression>builder()
.addAll(logicalAgg.getGroupByExpressions())
.addAll(countDistinctArguments)
.build());
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream()
.filter(aggregateFunction -> !aggregateFunction.isDistinct())
.collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam);
return new Alias(localAggExpr);
}));
List<Expression> partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg);
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
List<NamedExpression> localOutput = ImmutableList.<NamedExpression>builder()
.addAll((List<NamedExpression>) (List) localAggGroupBy.stream()
.filter(g -> !(g instanceof Literal))
.collect(ImmutableList.toImmutableList()))
.addAll(nonDistinctAggFunctionToAliasPhase1.values())
.build();
PhysicalHashAggregate<Plan> gatherLocalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, localOutput, Optional.of(partitionExpressions),
new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER),
maybeUsingStreamAgg(cascadesContext.getConnectContext(), logicalAgg),
logicalAgg.getLogicalProperties(), requireGather, logicalAgg.child()
);
List<Expression> distinctGroupBy = logicalAgg.getGroupByExpressions();
LogicalAggregate<? extends Plan> countIfAgg = countDistinctMultiExprToCountIf(
logicalAgg, cascadesContext).first;
AggregateParam distinctInputToResultParam
= new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT);
AggregateParam globalBufferToResultParam
= new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
List<NamedExpression> distinctOutput = ExpressionUtils.rewriteDownShortCircuit(
countIfAgg.getOutputExpressions(), outputChild -> {
if (outputChild instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) outputChild;
Alias alias = nonDistinctAggFunctionToAliasPhase1.get(aggregateFunction);
if (alias == null) {
return new AggregateExpression(aggregateFunction, distinctInputToResultParam);
} else {
return new AggregateExpression(aggregateFunction,
globalBufferToResultParam, alias.toSlot());
}
} else {
return outputChild;
}
});
PhysicalHashAggregate<Plan> gatherLocalGatherDistinctAgg = new PhysicalHashAggregate<>(
distinctGroupBy, distinctOutput, Optional.of(partitionExpressions),
distinctInputToResultParam, false,
logicalAgg.getLogicalProperties(), requireGather, gatherLocalAgg
);
if (logicalAgg.getGroupByExpressions().isEmpty()) {
return ImmutableList.of(gatherLocalGatherDistinctAgg);
} else {
RequireProperties requireHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
PhysicalHashAggregate<Plan> hashLocalHashGlobalAgg = gatherLocalGatherDistinctAgg
.withRequireTree(requireHash.withChildren(requireHash))
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<Plan>>builder()
// TODO: usually bad, disable it until we could do better cost computation.
//.add(gatherLocalGatherDistinctAgg)
.add(hashLocalHashGlobalAgg)
.build();
}
}
/**
* sql: select count(distinct id, name) from tbl group by name
* <p>
* before:
* <p>
* LogicalAggregate(groupBy=[name], output=[count(distinct id, name)])
* |
* LogicalOlapScan(table=tbl)
* <p>
* after:
* <p>
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))])
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
* |
* PhysicalDistribute(distributionSpec=GATHER)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))])
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
* |
* PhysicalDistribute(distributionSpec=HASH(name))
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
*
*/
private List<PhysicalHashAggregate<? extends Plan>> threePhaseAggregateWithCountDistinctMulti(
LogicalAggregate<? extends Plan> logicalAgg, CascadesContext cascadesContext) {
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
Collection<Expression> countDistinctArguments = logicalAgg.getDistinctArguments();
List<Expression> localAggGroupBy = ImmutableList.copyOf(ImmutableSet.<Expression>builder()
.addAll(logicalAgg.getGroupByExpressions())
.addAll(countDistinctArguments)
.build());
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream()
.filter(aggregateFunction -> !aggregateFunction.isDistinct())
.collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam);
return new Alias(localAggExpr);
}));
List<Expression> partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg);
RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY);
List<NamedExpression> localOutput = ImmutableList.<NamedExpression>builder()
.addAll((List<NamedExpression>) (List) localAggGroupBy.stream()
.filter(g -> !(g instanceof Literal))
.collect(ImmutableList.toImmutableList()))
.addAll(nonDistinctAggFunctionToAliasPhase1.values())
.build();
PhysicalHashAggregate<Plan> anyLocalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, localOutput, Optional.of(partitionExpressions),
new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER),
maybeUsingStreamAgg(cascadesContext.getConnectContext(), logicalAgg),
logicalAgg.getLogicalProperties(), requireAny, logicalAgg.child()
);
List<Expression> globalAggGroupBy = localAggGroupBy;
boolean hasCountDistinctMulti = logicalAgg.getAggregateFunctions().stream()
.filter(AggregateFunction::isDistinct)
.filter(Count.class::isInstance)
.anyMatch(c -> c.arity() > 1);
AggregateParam bufferToBufferParam = new AggregateParam(
AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, !hasCountDistinctMulti);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
nonDistinctAggFunctionToAliasPhase1.entrySet()
.stream()
.collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> {
AggregateFunction originFunction = kv.getKey();
Alias localOutputAlias = kv.getValue();
AggregateExpression globalAggExpr = new AggregateExpression(
originFunction, bufferToBufferParam, localOutputAlias.toSlot());
return new Alias(globalAggExpr);
}));
Set<SlotReference> slotInCountDistinct = ExpressionUtils.collect(
ImmutableList.copyOf(countDistinctArguments), SlotReference.class::isInstance);
List<NamedExpression> globalAggOutput = ImmutableList.copyOf(ImmutableSet.<NamedExpression>builder()
.addAll((List<NamedExpression>) (List) logicalAgg.getGroupByExpressions())
.addAll(slotInCountDistinct)
.addAll(nonDistinctAggFunctionToAliasPhase2.values())
.build());
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
PhysicalHashAggregate<Plan> anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>(
globalAggGroupBy, globalAggOutput, Optional.of(partitionExpressions),
bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
requireGather, anyLocalAgg);
LogicalAggregate<? extends Plan> countIfAgg = countDistinctMultiExprToCountIf(
logicalAgg, cascadesContext).first;
AggregateParam distinctInputToResultParam
= new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT, !hasCountDistinctMulti);
AggregateParam globalBufferToResultParam
= new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
List<NamedExpression> distinctOutput = ExpressionUtils.rewriteDownShortCircuit(
countIfAgg.getOutputExpressions(), outputChild -> {
if (outputChild instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) outputChild;
Alias alias = nonDistinctAggFunctionToAliasPhase2.get(aggregateFunction);
if (alias == null) {
return new AggregateExpression(aggregateFunction, distinctInputToResultParam);
} else {
return new AggregateExpression(aggregateFunction,
globalBufferToResultParam, alias.toSlot());
}
} else {
return outputChild;
}
});
PhysicalHashAggregate<Plan> anyLocalGatherGlobalGatherAgg = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), distinctOutput, Optional.empty(),
distinctInputToResultParam, false,
logicalAgg.getLogicalProperties(), requireGather, anyLocalGatherGlobalAgg
);
// RequireProperties requireDistinctHash = RequireProperties.of(
// PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
// PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalGatherDistinctAgg
// = anyLocalGatherGlobalGatherAgg.withChildren(ImmutableList.of(
// anyLocalGatherGlobalAgg
// .withRequire(requireDistinctHash)
// .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
// ));
if (logicalAgg.getGroupByExpressions().isEmpty()) {
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(anyLocalGatherGlobalGatherAgg)
//.add(anyLocalHashGlobalGatherDistinctAgg)
.build();
} else {
RequireProperties requireGroupByHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
PhysicalHashAggregate<PhysicalHashAggregate<Plan>> anyLocalHashGlobalHashDistinctAgg
= anyLocalGatherGlobalGatherAgg.withRequirePropertiesAndChild(requireGroupByHash,
anyLocalGatherGlobalAgg
.withRequire(requireGroupByHash)
.withPartitionExpressions(logicalAgg.getGroupByExpressions())
)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
// .add(anyLocalGatherGlobalGatherAgg)
// .add(anyLocalHashGlobalGatherDistinctAgg)
.add(anyLocalHashGlobalHashDistinctAgg)
.build();
}
}
/**
* sql: select name, count(value) from tbl group by name
* <p>
* before:
* <p>
* LogicalAggregate(groupBy=[name], output=[name, count(value)])
* |
* LogicalOlapScan(table=tbl)
* <p>
* after:
* <p>
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=BUFFER_TO_RESULT)
* |
* PhysicalDistribute(distributionSpec=GATHER)
* |
* PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=BUFFER_TO_RESULT)
* |
* PhysicalDistribute(distributionSpec=HASH(name))
* |
* PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
*
*/
private List<PhysicalHashAggregate<Plan>> twoPhaseAggregateWithoutDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
Map<AggregateFunction, Alias> inputToBufferAliases = logicalAgg.getAggregateFunctions()
.stream()
.collect(ImmutableMap.toImmutableMap(function -> function, function -> {
AggregateExpression inputToBuffer = new AggregateExpression(function, inputToBufferParam);
return new Alias(inputToBuffer);
}));
List<Expression> localAggGroupBy = logicalAgg.getGroupByExpressions();
List<Expression> partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg);
List<NamedExpression> localAggOutput = ImmutableList.<NamedExpression>builder()
// we already normalized the group by expressions to List<Slot> by the NormalizeAggregate rule
.addAll((List) localAggGroupBy)
.addAll(inputToBufferAliases.values())
.build();
RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY);
PhysicalHashAggregate<? extends Plan> anyLocalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, localAggOutput, Optional.of(partitionExpressions),
inputToBufferParam, maybeUsingStreamAgg(connectContext, logicalAgg),
logicalAgg.getLogicalProperties(), requireAny,
logicalAgg.child());
AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
List<NamedExpression> globalAggOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), outputChild -> {
if (!(outputChild instanceof AggregateFunction)) {
return outputChild;
}
Alias inputToBufferAlias = inputToBufferAliases.get(outputChild);
if (inputToBufferAlias == null) {
return outputChild;
}
AggregateFunction function = (AggregateFunction) outputChild;
return new AggregateExpression(function, bufferToResultParam, inputToBufferAlias.toSlot());
});
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
PhysicalHashAggregate<Plan> anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, globalAggOutput, Optional.of(partitionExpressions),
bufferToResultParam, false, anyLocalAgg.getLogicalProperties(),
requireGather, anyLocalAgg);
if (logicalAgg.getGroupByExpressions().isEmpty()) {
return ImmutableList.of(anyLocalGatherGlobalAgg);
} else {
RequireProperties requireHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
PhysicalHashAggregate<Plan> anyLocalHashGlobalAgg = anyLocalGatherGlobalAgg
.withRequire(requireHash)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<Plan>>builder()
// TODO: usually bad, disable it until we could do better cost computation.
// .add(anyLocalGatherGlobalAgg)
.add(anyLocalHashGlobalAgg)
.build();
}
}
/**
* sql: select count(distinct id) from tbl group by name
* <p>
* before:
* <p>
* LogicalAggregate(groupBy=[name], output=[name, count(distinct id)])
* |
* LogicalOlapScan(table=tbl)
* <p>
* after:
* <p>
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
* PhysicalDistribute(distributionSpec=GATHER)
* |
* LogicalOlapScan(table=tbl)
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl, **if distribute by name**)
*
*/
private List<PhysicalHashAggregate<? extends Plan>> twoPhaseAggregateWithDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
.filter(AggregateFunction::isDistinct)
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
.filter(NamedExpression.class::isInstance)
.map(NamedExpression.class::cast)
.collect(ImmutableSet.toImmutableSet());
Set<NamedExpression> localAggGroupBy = ImmutableSet.<NamedExpression>builder()
.addAll((List<NamedExpression>) (List) logicalAgg.getGroupByExpressions())
.addAll(distinctArguments)
.build();
AggregateParam inputToBufferParam = AggregateParam.LOCAL_BUFFER;
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream()
.filter(aggregateFunction -> !aggregateFunction.isDistinct())
.collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam);
return new Alias(localAggExpr);
}));
List<NamedExpression> localAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(localAggGroupBy)
.addAll(nonDistinctAggFunctionToAliasPhase1.values())
.build();
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
List<Expression> partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg);
PhysicalHashAggregate<Plan> gatherLocalAgg = new PhysicalHashAggregate<>(ImmutableList.copyOf(localAggGroupBy),
localAggOutput, Optional.of(partitionExpressions), inputToBufferParam,
/*
* should not use streaming, there has some bug in be will compute wrong result,
* see aggregate_strategies.groovy
*/
false, Optional.empty(), logicalAgg.getLogicalProperties(),
requireGather, logicalAgg.child());
AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
List<NamedExpression> globalOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), outputChild -> {
if (outputChild instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) outputChild;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
return new AggregateExpression(nonDistinct, AggregateParam.LOCAL_RESULT);
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild);
return new AggregateExpression(aggregateFunction, bufferToResultParam, alias.toSlot());
}
} else {
return outputChild;
}
});
PhysicalHashAggregate<Plan> gatherLocalGatherGlobalAgg
= new PhysicalHashAggregate<>(logicalAgg.getGroupByExpressions(), globalOutput,
Optional.empty(), bufferToResultParam, false,
logicalAgg.getLogicalProperties(), requireGather, gatherLocalAgg);
if (logicalAgg.getGroupByExpressions().isEmpty()) {
RequireProperties requireDistinctHash = RequireProperties.of(PhysicalProperties.createHash(
distinctArguments, ShuffleType.REQUIRE));
PhysicalHashAggregate<? extends Plan> hashLocalGatherGlobalAgg = gatherLocalGatherGlobalAgg
.withChildren(ImmutableList.of(gatherLocalAgg
.withRequire(requireDistinctHash)
.withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
));
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
//.add(gatherLocalGatherGlobalAgg)
.add(hashLocalGatherGlobalAgg)
.build();
} else {
RequireProperties requireGroupByHash = RequireProperties.of(PhysicalProperties.createHash(
logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
PhysicalHashAggregate<PhysicalHashAggregate<Plan>> hashLocalHashGlobalAgg = gatherLocalGatherGlobalAgg
.withRequirePropertiesAndChild(requireGroupByHash, gatherLocalAgg
.withRequire(requireGroupByHash)
.withPartitionExpressions(logicalAgg.getGroupByExpressions())
)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
// .add(gatherLocalGatherGlobalAgg)
.add(hashLocalHashGlobalAgg)
.build();
}
}
/**
* sql: select count(distinct id) from tbl group by name
* <p>
* before:
* <p>
* LogicalAggregate(groupBy=[name], output=[name, count(distinct id)])
* |
* LogicalOlapScan(table=tbl)
* <p>
* after:
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
* |
* PhysicalDistribute(distributionSpec=GATHER)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
* |
* PhysicalDistribute(distributionSpec=HASH(name))
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
*
*/
// TODO: support one phase aggregate(group by columns + distinct columns) + two phase distinct aggregate
private List<PhysicalHashAggregate<? extends Plan>> threePhaseAggregateWithDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
boolean couldBanned = couldConvertToMulti(logicalAgg);
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
.filter(AggregateFunction::isDistinct)
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
.filter(NamedExpression.class::isInstance)
.map(NamedExpression.class::cast)
.collect(ImmutableSet.toImmutableSet());
Set<NamedExpression> localAggGroupBySet = ImmutableSet.<NamedExpression>builder()
.addAll((List<NamedExpression>) (List) logicalAgg.getGroupByExpressions())
.addAll(distinctArguments)
.build();
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream()
.filter(aggregateFunction -> !aggregateFunction.isDistinct())
.collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam);
return new Alias(localAggExpr);
}));
List<NamedExpression> localAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(localAggGroupBySet)
.addAll(nonDistinctAggFunctionToAliasPhase1.values())
.build();
List<Expression> localAggGroupBy = ImmutableList.copyOf(localAggGroupBySet);
boolean isGroupByEmptySelectEmpty = localAggGroupBy.isEmpty() && localAggOutput.isEmpty();
// be not recommend generate an aggregate node with empty group by and empty output,
// so add a null int slot to group by slot and output
if (isGroupByEmptySelectEmpty) {
localAggGroupBy = ImmutableList.of(new NullLiteral(TinyIntType.INSTANCE));
localAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}
boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, localAggGroupBy);
List<Expression> partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg);
RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY);
PhysicalHashAggregate<Plan> anyLocalAgg = new PhysicalHashAggregate<>(localAggGroupBy,
localAggOutput, Optional.of(partitionExpressions), inputToBufferParam,
maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(),
requireAny, logicalAgg.child());
AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
nonDistinctAggFunctionToAliasPhase1.entrySet()
.stream()
.collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> {
AggregateFunction originFunction = kv.getKey();
Alias localOutput = kv.getValue();
AggregateExpression globalAggExpr = new AggregateExpression(
originFunction, bufferToBufferParam, localOutput.toSlot());
return new Alias(globalAggExpr);
}));
List<NamedExpression> globalAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(localAggGroupBySet)
.addAll(nonDistinctAggFunctionToAliasPhase2.values())
.build();
// be not recommend generate an aggregate node with empty group by and empty output,
// so add a null int slot to group by slot and output
if (isGroupByEmptySelectEmpty) {
globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
PhysicalHashAggregate<Plan> anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, globalAggOutput, Optional.of(partitionExpressions),
bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
requireGather, anyLocalAgg);
AggregateParam bufferToResultParam = new AggregateParam(
AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT, couldBanned);
List<NamedExpression> distinctOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), expr -> {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
return new AggregateExpression(nonDistinct, bufferToResultParam, aggregateFunction);
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr);
return new AggregateExpression(aggregateFunction,
new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_RESULT),
alias.toSlot());
}
}
return expr;
});
PhysicalHashAggregate<Plan> anyLocalGatherGlobalGatherDistinctAgg = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), distinctOutput, Optional.empty(),
bufferToResultParam, false, logicalAgg.getLogicalProperties(),
requireGather, anyLocalGatherGlobalAgg);
RequireProperties requireDistinctHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getDistinctArguments(), ShuffleType.REQUIRE));
PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalGatherDistinctAgg
= anyLocalGatherGlobalGatherDistinctAgg
.withChildren(ImmutableList.of(anyLocalGatherGlobalAgg
.withRequire(requireDistinctHash)
.withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
));
if (logicalAgg.getGroupByExpressions().isEmpty()) {
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
// TODO: this plan pattern is not good usually, we remove it temporary.
// .add(anyLocalGatherGlobalGatherDistinctAgg)
.add(anyLocalHashGlobalGatherDistinctAgg)
.build();
} else {
RequireProperties requireGroupByHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalHashDistinctAgg
= anyLocalGatherGlobalGatherDistinctAgg
.withRequirePropertiesAndChild(requireGroupByHash, anyLocalGatherGlobalAgg
.withRequire(requireGroupByHash)
.withPartitionExpressions(logicalAgg.getGroupByExpressions())
)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
// TODO: this plan pattern is not good usually, we remove it temporary.
//.add(anyLocalGatherGlobalGatherDistinctAgg)
//.add(anyLocalHashGlobalGatherDistinctAgg)
.add(anyLocalHashGlobalHashDistinctAgg)
.build();
}
}
/**
* sql: select count(distinct id) from (...) group by name
* <p>
* before:
* <p>
* LogicalAggregate(groupBy=[name], output=[count(distinct id)])
* |
* any plan
* <p>
* after:
* <p>
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[multi_distinct_count(id)])
* |
* PhysicalDistribute(distributionSpec=GATHER)
* |
* any plan
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[multi_distinct_count(id)])
* |
* any plan(**already distribute by name**)
*
*/
private List<PhysicalHashAggregate<? extends Plan>> onePhaseAggregateWithMultiDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
AggregateParam inputToResultParam = AggregateParam.LOCAL_RESULT;
List<NamedExpression> newOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), outputChild -> {
if (outputChild instanceof AggregateFunction) {
AggregateFunction function = tryConvertToMultiDistinct((AggregateFunction) outputChild);
return new AggregateExpression(function, inputToResultParam);
}
return outputChild;
});
RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
PhysicalHashAggregate<? extends Plan> gatherLocalAgg = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), newOutput, inputToResultParam,
maybeUsingStreamAgg(connectContext, logicalAgg),
logicalAgg.getLogicalProperties(), requireGather, logicalAgg.child());
if (logicalAgg.getGroupByExpressions().isEmpty()) {
// TODO: usually bad, disable it until we could do better cost computation.
// return ImmutableList.of(gatherLocalAgg);
return ImmutableList.of();
} else {
RequireProperties requireHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
PhysicalHashAggregate<? extends Plan> hashLocalAgg = gatherLocalAgg
.withRequire(requireHash)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
// TODO: usually bad, disable it until we could do better cost computation.
// .add(gatherLocalAgg)
.add(hashLocalAgg)
.build();
}
}
/**
* sql: select count(distinct id) from tbl group by name
* <p>
* before:
* <p>
* LogicalAggregate(groupBy=[name], output=[name, count(distinct id)])
* |
* LogicalOlapScan(table=tbl)
* <p>
* after:
* <p>
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=BUFFER_TO_RESULT)
* |
* PhysicalDistribute(distributionSpec=GATHER)
* |
* PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=BUFFER_TO_RESULT)
* |
* PhysicalDistribute(distributionSpec=HASH(name))
* |
* PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=INPUT_TO_BUFFER)
* |
* LogicalOlapScan(table=tbl)
*
*/
private List<PhysicalHashAggregate<? extends Plan>> twoPhaseAggregateWithMultiDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
Map<AggregateFunction, Alias> aggFunctionToAliasPhase1 = aggregateFunctions.stream()
.collect(ImmutableMap.toImmutableMap(function -> function, function -> {
AggregateFunction multiDistinct = tryConvertToMultiDistinct(function);
AggregateExpression localAggExpr = new AggregateExpression(multiDistinct, inputToBufferParam);
return new Alias(localAggExpr);
}));
List<NamedExpression> localAggOutput = ImmutableList.<NamedExpression>builder()
// already normalize group by expression to List<SlotReference>
.addAll((List<NamedExpression>) (List) logicalAgg.getGroupByExpressions())
.addAll(aggFunctionToAliasPhase1.values())
.build();
RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY);
PhysicalHashAggregate<? extends Plan> anyLocalAgg = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), localAggOutput,
inputToBufferParam, maybeUsingStreamAgg(connectContext, logicalAgg),
logicalAgg.getLogicalProperties(), requireAny, logicalAgg.child());
AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
List<NamedExpression> globalOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), outputChild -> {
if (outputChild instanceof AggregateFunction) {
Alias alias = aggFunctionToAliasPhase1.get(outputChild);
AggregateExpression localAggExpr = (AggregateExpression) alias.child();
return new AggregateExpression(localAggExpr.getFunction(),
bufferToResultParam, alias.toSlot());
} else {
return outputChild;
}
});
PhysicalHashAggregate<? extends Plan> anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), globalOutput, Optional.empty(),
bufferToResultParam, false, logicalAgg.getLogicalProperties(),
RequireProperties.of(PhysicalProperties.GATHER), anyLocalAgg);
if (logicalAgg.getGroupByExpressions().isEmpty()) {
// Collection<Expression> distinctArguments = logicalAgg.getDistinctArguments();
// RequireProperties requireDistinctHash = RequireProperties.of(PhysicalProperties.createHash(
// distinctArguments, ShuffleType.REQUIRE));
// PhysicalHashAggregate<? extends Plan> hashLocalGatherGlobalAgg = anyLocalGatherGlobalAgg
// .withChildren(ImmutableList.of(anyLocalAgg
// .withRequire(requireDistinctHash)
// .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
// ));
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(anyLocalGatherGlobalAgg)
.build();
} else {
RequireProperties requireHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalAgg = anyLocalGatherGlobalAgg
.withRequire(requireHash)
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
// TODO: usually bad, disable it until we could do better cost computation.
// .add(anyLocalGatherGlobalAgg)
.add(anyLocalHashGlobalAgg)
.build();
}
}
private boolean maybeUsingStreamAgg(
ConnectContext connectContext, LogicalAggregate<? extends Plan> logicalAggregate) {
return !connectContext.getSessionVariable().disableStreamPreaggregations
&& !logicalAggregate.getGroupByExpressions().isEmpty();
}
private boolean maybeUsingStreamAgg(
ConnectContext connectContext, List<? extends Expression> groupByExpressions) {
return !connectContext.getSessionVariable().disableStreamPreaggregations
&& !groupByExpressions.isEmpty();
}
private List<Expression> getHashAggregatePartitionExpressions(
LogicalAggregate<? extends Plan> logicalAggregate) {
return logicalAggregate.getGroupByExpressions().isEmpty()
? ImmutableList.copyOf(logicalAggregate.getDistinctArguments())
: logicalAggregate.getGroupByExpressions();
}
private AggregateFunction tryConvertToMultiDistinct(AggregateFunction function) {
if (function instanceof SupportMultiDistinct && function.isDistinct()) {
return ((SupportMultiDistinct) function).convertToMultiDistinct();
}
return function;
}
/**
* countDistinctMultiExprToCountIf.
* <p>
* NOTE: this function will break the normalized output, e.g. from `count(distinct slot1, slot2)` to
* `count(if(slot1 is null, null, slot2))`. So if you invoke this method, and separate the
* phase of aggregate, please normalize to slot and create a bottom project like NormalizeAggregate.
*/
private Pair<LogicalAggregate<? extends Plan>, List<Count>> countDistinctMultiExprToCountIf(
LogicalAggregate<? extends Plan> aggregate, CascadesContext cascadesContext) {
ImmutableList.Builder<Count> countIfList = ImmutableList.builder();
List<NamedExpression> newOutput = ExpressionUtils.rewriteDownShortCircuit(
aggregate.getOutputExpressions(), outputChild -> {
if (outputChild instanceof Count) {
Count count = (Count) outputChild;
if (count.isDistinct() && count.arity() > 1) {
Set<Expression> arguments = ImmutableSet.copyOf(count.getArguments());
Expression countExpr = count.getArgument(arguments.size() - 1);
for (int i = arguments.size() - 2; i >= 0; --i) {
Expression argument = count.getArgument(i);
If ifNull = new If(new IsNull(argument), NullLiteral.INSTANCE, countExpr);
countExpr = assignNullType(ifNull, cascadesContext);
}
Count countIf = new Count(countExpr);
countIfList.add(countIf);
return countIf;
}
}
return outputChild;
});
return Pair.of(aggregate.withAggOutput(newOutput), countIfList.build());
}
private boolean containsCountDistinctMultiExpr(LogicalAggregate<? extends Plan> aggregate) {
return ExpressionUtils.deapAnyMatch(aggregate.getOutputExpressions(), expr ->
expr instanceof Count && ((Count) expr).isDistinct() && expr.arity() > 1);
}
// don't invoke the ExpressionNormalization, because the expression maybe simplified and get rid of some slots
private If assignNullType(If ifExpr, CascadesContext cascadesContext) {
If ifWithCoercion = (If) TypeCoercionUtils.processBoundFunction(ifExpr);
Expression trueValue = ifWithCoercion.getArgument(1);
if (trueValue instanceof Cast && trueValue.child(0) instanceof NullLiteral) {
List<Expression> newArgs = Lists.newArrayList(ifWithCoercion.getArguments());
// backend don't support null type, so we should set the type
newArgs.set(1, new NullLiteral(((Cast) trueValue).getDataType()));
return ifWithCoercion.withChildren(newArgs);
}
return ifWithCoercion;
}
private boolean enablePushDownNoGroupAgg() {
ConnectContext connectContext = ConnectContext.get();
return connectContext == null || connectContext.getSessionVariable().enablePushDownNoGroupAgg();
}
private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext,
Function<List<Expression>, RequireProperties> secondPhaseRequireSupplier,
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireSupplier) {
boolean couldBanned = couldConvertToMulti(logicalAgg);
Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
.filter(AggregateFunction::isDistinct)
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
.filter(NamedExpression.class::isInstance)
.map(NamedExpression.class::cast)
.collect(ImmutableSet.toImmutableSet());
Set<NamedExpression> localAggGroupBySet = ImmutableSet.<NamedExpression>builder()
.addAll((List<NamedExpression>) (List) logicalAgg.getGroupByExpressions())
.addAll(distinctArguments)
.build();
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream()
.filter(aggregateFunction -> !aggregateFunction.isDistinct())
.collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam);
return new Alias(localAggExpr);
}, (oldValue, newValue) -> newValue));
List<NamedExpression> localAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(localAggGroupBySet)
.addAll(nonDistinctAggFunctionToAliasPhase1.values())
.build();
List<Expression> localAggGroupBy = ImmutableList.copyOf(localAggGroupBySet);
boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, localAggGroupBy);
List<Expression> partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg);
RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY);
boolean isGroupByEmptySelectEmpty = localAggGroupBy.isEmpty() && localAggOutput.isEmpty();
// be not recommend generate an aggregate node with empty group by and empty output,
// so add a null int slot to group by slot and output
if (isGroupByEmptySelectEmpty) {
localAggGroupBy = ImmutableList.of(new NullLiteral(TinyIntType.INSTANCE));
localAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}
PhysicalHashAggregate<Plan> anyLocalAgg = new PhysicalHashAggregate<>(localAggGroupBy,
localAggOutput, Optional.of(partitionExpressions), inputToBufferParam,
maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(),
requireAny, logicalAgg.child());
AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
nonDistinctAggFunctionToAliasPhase1.entrySet()
.stream()
.collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> {
AggregateFunction originFunction = kv.getKey();
Alias localOutput = kv.getValue();
AggregateExpression globalAggExpr = new AggregateExpression(
originFunction, bufferToBufferParam, localOutput.toSlot());
return new Alias(globalAggExpr);
}));
List<NamedExpression> globalAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(localAggGroupBySet)
.addAll(nonDistinctAggFunctionToAliasPhase2.values())
.build();
// be not recommend generate an aggregate node with empty group by and empty output,
// so add a null int slot to group by slot and output
if (isGroupByEmptySelectEmpty) {
globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}
RequireProperties secondPhaseRequire = secondPhaseRequireSupplier.apply(localAggGroupBy);
//phase 2
PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, globalAggOutput, Optional.of(ImmutableList.copyOf(logicalAgg.getDistinctArguments())),
bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
secondPhaseRequire, anyLocalAgg);
boolean shouldDistinctAfterPhase2 = distinctArguments.size() > 1;
// phase 3
AggregateParam distinctLocalParam = new AggregateParam(
AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase3 = new HashMap<>();
List<NamedExpression> localDistinctOutput = Lists.newArrayList();
for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
NamedExpression outputExpr = logicalAgg.getOutputExpressions().get(i);
List<AggregateFunction> needUpdateSlot = Lists.newArrayList();
NamedExpression outputExprPhase3 = (NamedExpression) outputExpr
.rewriteDownShortCircuit(expr -> {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction newDistinct;
if (shouldDistinctAfterPhase2) {
// we use aggregate function to process distinct,
// so need to change to multi distinct function
newDistinct = tryConvertToMultiDistinct(
aggregateFunction.withDistinctAndChildren(
true, ImmutableList.copyOf(aggChild))
);
} else {
// we use group by to process distinct,
// so no distinct param in the aggregate function
newDistinct = aggregateFunction.withDistinctAndChildren(
false, ImmutableList.copyOf(aggChild));
}
AggregateExpression newDistinctAggExpr = new AggregateExpression(
newDistinct, distinctLocalParam, newDistinct);
return newDistinctAggExpr;
} else {
needUpdateSlot.add(aggregateFunction);
Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr);
return new AggregateExpression(aggregateFunction,
new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_BUFFER),
alias.toSlot());
}
}
return expr;
});
for (AggregateFunction originFunction : needUpdateSlot) {
nonDistinctAggFunctionToAliasPhase3.put(originFunction, (Alias) outputExprPhase3);
}
localDistinctOutput.add(outputExprPhase3);
}
PhysicalHashAggregate<? extends Plan> distinctLocal = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), localDistinctOutput, Optional.empty(),
distinctLocalParam, false, logicalAgg.getLogicalProperties(),
secondPhaseRequire, anyLocalHashGlobalAgg);
//phase 4
AggregateParam distinctGlobalParam = new AggregateParam(
AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT, couldBanned);
List<NamedExpression> globalDistinctOutput = Lists.newArrayList();
for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
NamedExpression outputExpr = logicalAgg.getOutputExpressions().get(i);
NamedExpression outputExprPhase4 = (NamedExpression) outputExpr.rewriteDownShortCircuit(expr -> {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction newDistinct;
if (shouldDistinctAfterPhase2) {
newDistinct = tryConvertToMultiDistinct(
aggregateFunction.withDistinctAndChildren(
true, ImmutableList.copyOf(aggChild))
);
} else {
newDistinct = aggregateFunction
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
}
int idx = logicalAgg.getOutputExpressions().indexOf(outputExpr);
Alias localDistinctAlias = (Alias) (localDistinctOutput.get(idx));
return new AggregateExpression(newDistinct,
distinctGlobalParam, localDistinctAlias.toSlot());
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase3.get(expr);
return new AggregateExpression(aggregateFunction,
new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_RESULT),
alias.toSlot());
}
}
return expr;
});
globalDistinctOutput.add(outputExprPhase4);
}
RequireProperties fourPhaseRequire = fourPhaseRequireSupplier.apply(logicalAgg);
PhysicalHashAggregate<? extends Plan> distinctGlobal = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), globalDistinctOutput, Optional.empty(),
distinctGlobalParam, false, logicalAgg.getLogicalProperties(),
fourPhaseRequire, distinctLocal);
return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(distinctGlobal)
.build();
}
private boolean couldConvertToMulti(LogicalAggregate<? extends Plan> aggregate) {
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
for (AggregateFunction func : aggregateFunctions) {
if (!func.isDistinct()) {
continue;
}
if (!(func instanceof Count || func instanceof Sum || func instanceof GroupConcat
|| func instanceof Sum0)) {
return false;
}
if (func.arity() <= 1) {
continue;
}
for (int i = 1; i < func.arity(); i++) {
// think about group_concat(distinct col_1, ',')
if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) {
return false;
}
}
}
return true;
}
}