FillUpMissingSlots.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.SqlModeHelper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Resolve having clause to the aggregation/repeat.
* need Top to Down to traverse plan,
* because we need to process FILL_UP_SORT_HAVING_AGGREGATE before FILL_UP_HAVING_AGGREGATE.
* be aware that when filling up the missing slots, we should exclude outer query's correlated slots.
* because these correlated slots belong to outer query, so should not try to find them in child node.
*/
public class FillUpMissingSlots implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.FILL_UP_SORT_PROJECT.build(
logicalSort(logicalProject())
.thenApply(ctx -> {
LogicalSort<LogicalProject<Plan>> sort = ctx.root;
Optional<Scope> outerScope = ctx.cascadesContext.getOuterScope();
LogicalProject<Plan> project = sort.child();
Set<Slot> projectOutputSet = project.getOutputSet();
Set<Slot> notExistedInProject = sort.getOrderKeys().stream()
.map(OrderKey::getExpr)
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(s -> !projectOutputSet.contains(s)
&& (!outerScope.isPresent() || !outerScope.get()
.getCorrelatedSlots().contains(s)))
.collect(Collectors.toSet());
if (notExistedInProject.isEmpty()) {
return null;
}
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()
.addAll(project.getProjects()).addAll(notExistedInProject).build();
return new LogicalProject<>(ImmutableList.copyOf(project.getOutput()),
sort.withChildren(new LogicalProject<>(projects, project.child())));
})
),
RuleType.FILL_UP_SORT_AGGREGATE_HAVING_AGGREGATE.build(
logicalSort(
aggregate(logicalHaving(aggregate()))
.when(a -> a.getOutputExpressions().stream().allMatch(SlotReference.class::isInstance))
).when(this::checkSort)
.thenApply(ctx -> processDistinctProjectWithAggregate(ctx.root,
ctx.root.child(), ctx.root.child().child().child(),
ctx.cascadesContext.getOuterScope()))
),
// ATTN: process aggregate with distinct project, must run this rule before FILL_UP_SORT_AGGREGATE
// because this pattern will always fail in FILL_UP_SORT_AGGREGATE
RuleType.FILL_UP_SORT_AGGREGATE_AGGREGATE.build(
logicalSort(
aggregate(aggregate())
.when(a -> a.getOutputExpressions().stream().allMatch(SlotReference.class::isInstance))
).when(this::checkSort)
.thenApply(ctx -> processDistinctProjectWithAggregate(ctx.root,
ctx.root.child(), ctx.root.child().child(),
ctx.cascadesContext.getOuterScope()))
),
RuleType.FILL_UP_SORT_AGGREGATE.build(
logicalSort(aggregate())
.when(this::checkSort)
.thenApply(ctx -> {
LogicalSort<Aggregate<Plan>> sort = ctx.root;
Aggregate<Plan> agg = sort.child();
Resolver resolver = new Resolver(agg, ctx.cascadesContext.getOuterScope());
sort.getExpressions().forEach(resolver::resolve);
return createPlan(resolver, agg, (r, a) -> {
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
.map(ok -> new OrderKey(
ExpressionUtils.replace(ok.getExpr(), r.getSubstitution()),
ok.isAsc(),
ok.isNullFirst()))
.collect(ImmutableList.toImmutableList());
boolean notChanged = newOrderKeys.equals(sort.getOrderKeys());
if (notChanged && a.equals(agg)) {
return null;
}
return notChanged ? sort.withChildren(a) : new LogicalSort<>(newOrderKeys, a);
});
})
),
RuleType.FILL_UP_SORT_HAVING_AGGREGATE.build(
logicalSort(logicalHaving(aggregate()))
.when(this::checkSort)
.thenApply(ctx -> {
LogicalSort<LogicalHaving<Aggregate<Plan>>> sort = ctx.root;
LogicalHaving<Aggregate<Plan>> having = sort.child();
Aggregate<Plan> agg = having.child();
Resolver resolver = new Resolver(agg, ctx.cascadesContext.getOuterScope());
sort.getExpressions().forEach(resolver::resolve);
return createPlan(resolver, agg, (r, a) -> {
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
.map(key -> key.withExpression(
ExpressionUtils.replace(key.getExpr(), r.getSubstitution())))
.collect(ImmutableList.toImmutableList());
boolean notChanged = newOrderKeys.equals(sort.getOrderKeys());
if (notChanged && a.equals(agg)) {
return null;
}
return notChanged ? sort.withChildren(sort.child().withChildren(a))
: new LogicalSort<>(newOrderKeys, sort.child().withChildren(a));
});
})
),
RuleType.FILL_UP_SORT_HAVING_PROJECT.build(
logicalSort(logicalHaving(logicalProject())).thenApply(ctx -> {
LogicalSort<LogicalHaving<LogicalProject<Plan>>> sort = ctx.root;
Optional<Scope> outerScope = ctx.cascadesContext.getOuterScope();
Set<Slot> childOutput = sort.child().getOutputSet();
Set<Slot> notExistedInProject = sort.getOrderKeys().stream()
.map(OrderKey::getExpr)
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(s -> !childOutput.contains(s)
&& (!outerScope.isPresent() || !outerScope.get()
.getCorrelatedSlots().contains(s)))
.collect(Collectors.toSet());
if (notExistedInProject.isEmpty()) {
return null;
}
LogicalProject<?> project = sort.child().child();
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()
.addAll(project.getProjects())
.addAll(notExistedInProject).build();
Plan child = sort.withChildren(sort.child().withChildren(project.withProjects(projects)));
return new LogicalProject<>(ImmutableList.copyOf(project.getOutput()), child);
})
),
RuleType.FILL_UP_HAVING_AGGREGATE.build(
logicalHaving(aggregate()).thenApply(ctx -> {
LogicalHaving<Aggregate<Plan>> having = ctx.root;
Aggregate<Plan> agg = having.child();
Resolver resolver = new Resolver(agg, ctx.cascadesContext.getOuterScope());
having.getConjuncts().forEach(resolver::resolve);
return createPlan(resolver, agg, (r, a) -> {
Set<Expression> newConjuncts = ExpressionUtils.replace(
having.getConjuncts(), r.getSubstitution());
boolean notChanged = newConjuncts.equals(having.getConjuncts());
if (notChanged && a.equals(agg)) {
return null;
}
return notChanged ? having.withChildren(a) : new LogicalHaving<>(newConjuncts, a);
});
})
),
// Convert having to filter
RuleType.FILL_UP_HAVING_PROJECT.build(
logicalHaving(logicalProject()).thenApply(ctx -> {
LogicalHaving<LogicalProject<Plan>> having = ctx.root;
Optional<Scope> outerScope = ctx.cascadesContext.getOuterScope();
if (having.getExpressions().stream().anyMatch(e -> e.containsType(AggregateFunction.class))) {
// This is very weird pattern.
// There are some aggregate functions in having, but its child is project.
// There are some slot from project in having too.
// Having should execute after project.
// But aggregate function should execute before project.
// Since no aggregate here, we should add an empty aggregate before project.
// We should push aggregate function into aggregate node first.
// Then put aggregate result slots and original project slots into new project.
// The final plan should be
// Having
// +-- Project
// +-- Aggregate
// Since aggregate node have no group by key.
// So project should not contain any slot from its original child.
// Or otherwise slot cannot find will be thrown.
LogicalProject<Plan> project = having.child();
// new an empty agg here
LogicalAggregate<Plan> agg = new LogicalAggregate<>(
ImmutableList.of(), ImmutableList.of(), project.child());
// avoid throw exception even if having have slot from its child.
// because we will add a project between having and project.
Resolver resolver = new Resolver(agg, false, outerScope);
having.getConjuncts().forEach(resolver::resolve);
agg = agg.withAggOutput(resolver.getNewOutputSlots());
Set<Expression> newConjuncts = ExpressionUtils.replace(
having.getConjuncts(), resolver.getSubstitution());
ImmutableList.Builder<NamedExpression> projects = ImmutableList.builder();
projects.addAll(project.getOutputs()).addAll(agg.getOutput());
return new LogicalHaving<>(newConjuncts, new LogicalProject<>(projects.build(), agg));
} else {
LogicalProject<Plan> project = having.child();
Set<Slot> projectOutputSet = project.getOutputSet();
Set<Slot> notExistedInProject = having.getExpressions().stream()
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(s -> !projectOutputSet.contains(s)
&& (!outerScope.isPresent() || !outerScope.get()
.getCorrelatedSlots().contains(s)))
.collect(Collectors.toSet());
if (notExistedInProject.isEmpty()) {
return null;
}
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()
.addAll(project.getProjects()).addAll(notExistedInProject).build();
return new LogicalProject<>(ImmutableList.copyOf(project.getOutput()),
having.withChildren(new LogicalProject<>(projects, project.child())));
}
})
)
);
}
static class Resolver {
private final List<NamedExpression> outputExpressions;
private final List<Expression> groupByExpressions;
private final Map<Expression, Slot> substitution = Maps.newHashMap();
private final List<NamedExpression> newOutputSlots = Lists.newArrayList();
private final Map<Slot, Expression> outputSubstitutionMap;
private final boolean checkSlot;
private final Optional<Scope> outerScope;
Resolver(Aggregate<?> aggregate, boolean checkSlot, Optional<Scope> outerScope) {
outputExpressions = aggregate.getOutputExpressions();
groupByExpressions = aggregate.getGroupByExpressions();
outputSubstitutionMap = outputExpressions.stream().filter(Alias.class::isInstance)
.collect(Collectors.toMap(NamedExpression::toSlot, alias -> alias.child(0),
(k1, k2) -> k1));
this.checkSlot = checkSlot;
this.outerScope = outerScope;
}
Resolver(Aggregate<?> aggregate, boolean checkSlot) {
this(aggregate, checkSlot, Optional.empty());
}
Resolver(Aggregate<?> aggregate) {
this(aggregate, true, Optional.empty());
}
Resolver(Aggregate<?> aggregate, Optional<Scope> outerScope) {
this(aggregate, true, outerScope);
}
public void resolve(Expression expression) {
Pair<Optional<Expression>, Boolean> result = lookUp(expression);
Optional<Expression> found = result.first;
boolean isFoundInOutputExpressions = result.second;
if (found.isPresent()) {
// If we found the equivalent slot or alias in the output expressions or group-by expressions,
// We should replace the expression in having clause with the one in aggregation.
if (found.get() instanceof NamedExpression) {
substitution.put(expression, ((NamedExpression) found.get()).toSlot());
if (!isFoundInOutputExpressions) {
// If the equivalent expression wasn't found in the output expressions, we should
// push it down to the aggregation.
newOutputSlots.add(((NamedExpression) found.get()).toSlot());
}
} else {
// If the equivalent expression (neither slot nor alias) was found in group-by expressions (
// E.g. group by (a + 1) having (a + 1)), we should generate an alias for it and
// push it down to the aggregation.
generateAliasForNewOutputSlots(expression);
}
} else {
// We couldn't find the equivalent expression in output expressions and group-by expressions,
// so we should check whether the expression is valid.
if (expression instanceof SlotReference) {
if ((!outerScope.isPresent()
|| !outerScope.get().getCorrelatedSlots().contains(expression))) {
if (!SqlModeHelper.hasOnlyFullGroupBy()) {
// ATTN: we should add any_value to agg's output here, but not add slot directly.
// because normalize agg cannot replace upper slot with new output.
Alias alias = new Alias(new AnyValue(expression));
newOutputSlots.add(alias);
substitution.put(expression, alias.toSlot());
} else if (checkSlot) {
throw new AnalysisException(expression.toSql() + " should be grouped by.");
}
}
} else if (expression instanceof AggregateFunction) {
if (checkWhetherNestedAggregateFunctionsExist((AggregateFunction) expression)) {
throw new AnalysisException("Aggregate functions in having clause can't be nested: "
+ expression.toSql() + ".");
}
generateAliasForNewOutputSlots(expression);
} else if (expression instanceof WindowExpression) {
generateAliasForNewOutputSlots(expression);
} else {
// Try to resolve the children.
for (Expression child : expression.children()) {
resolve(child);
}
}
}
}
/**
* Look up the expression in aggregation.
* @param expression Expression in predicates of having clause.
* @return {@code Pair<Optional<Expression>, Boolean>}
* first: the expression in aggregation which is equivalent to input expression.
* second: whether the expression is found in output expressions of aggregation.
*/
private Pair<Optional<Expression>, Boolean> lookUp(Expression expression) {
Optional<Expression> found = outputExpressions.stream()
.filter(source -> isEquivalent(source, expression))
.map(source -> (Expression) source)
.findFirst();
if (found.isPresent()) {
return Pair.of(found, true);
}
found = groupByExpressions.stream().filter(source -> isEquivalent(source, expression)).findFirst();
return Pair.of(found, false);
}
/**
* Check whether the two expressions are equivalent.
* @param source The expression in aggregation.
* @param expression The expression used to compared to the one in aggregation.
* @return true if the expressions are equivalent.
*/
private boolean isEquivalent(Expression source, Expression expression) {
if (source.equals(expression)) {
return true;
} else if (source instanceof Alias) {
Alias alias = (Alias) source;
return alias.toSlot().equals(expression) || alias.child().equals(expression);
}
return false;
}
private boolean checkWhetherNestedAggregateFunctionsExist(AggregateFunction aggregateFunction) {
return aggregateFunction.children()
.stream()
.anyMatch(child -> child.anyMatch(AggregateFunction.class::isInstance));
}
private void generateAliasForNewOutputSlots(Expression expression) {
Expression replacedExpr = ExpressionUtils.replace(expression, outputSubstitutionMap);
Alias alias = new Alias(replacedExpr);
newOutputSlots.add(alias);
substitution.put(expression, alias.toSlot());
}
public Map<Expression, Slot> getSubstitution() {
return substitution;
}
public List<NamedExpression> getNewOutputSlots() {
return newOutputSlots;
}
}
interface PlanGenerator {
Plan apply(Resolver resolver, Aggregate<?> aggregate);
}
protected Plan createPlan(Resolver resolver, Aggregate<? extends Plan> aggregate, PlanGenerator planGenerator) {
Aggregate<? extends Plan> newAggregate;
if (resolver.getNewOutputSlots().isEmpty()) {
newAggregate = aggregate;
} else {
List<NamedExpression> newOutputExpressions = Streams
.concat(aggregate.getOutputExpressions().stream(), resolver.getNewOutputSlots().stream())
.collect(ImmutableList.toImmutableList());
newAggregate = aggregate.withAggOutput(newOutputExpressions);
}
Plan plan = planGenerator.apply(resolver, newAggregate);
if (plan == null) {
return null;
}
List<NamedExpression> projections = aggregate.getOutputExpressions().stream()
.map(NamedExpression::toSlot).collect(ImmutableList.toImmutableList());
return new LogicalProject<>(projections, plan);
}
private boolean checkSort(LogicalSort<? extends Plan> logicalSort) {
Plan child = logicalSort.child();
for (OrderKey orderKey : logicalSort.getOrderKeys()) {
Expression expr = orderKey.getExpr();
if (expr.anyMatch(AggregateFunction.class::isInstance)) {
return true;
}
for (Slot inputSlot : expr.getInputSlots()) {
if (!child.getOutputSet().contains(inputSlot)) {
return true;
}
}
}
return false;
}
/**
* for sql like SELECT DISTINCT a FROM t GROUP BY a HAVING b > 0 ORDER BY a.
* there order by need to bind with bottom aggregate's output and bottom aggregate's child's output.
* this function used to fill up missing slot for these situations correctly.
*
* @param sort top sort
* @param upperAggregate upper aggregate used to check slot in order by should be in select list
* @param bottomAggregate bottom aggregate used to bind with its and its child's output
*
* @return filled up plan
*/
private Plan processDistinctProjectWithAggregate(LogicalSort<?> sort,
Aggregate<?> upperAggregate, Aggregate<Plan> bottomAggregate, Optional<Scope> outerScope) {
Resolver resolver = new Resolver(bottomAggregate, outerScope);
sort.getExpressions().forEach(resolver::resolve);
return createPlan(resolver, bottomAggregate, (r, a) -> {
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
.map(ok -> new OrderKey(
ExpressionUtils.replace(ok.getExpr(), r.getSubstitution()),
ok.isAsc(),
ok.isNullFirst()))
.collect(ImmutableList.toImmutableList());
boolean sortNotChanged = newOrderKeys.equals(sort.getOrderKeys());
boolean aggNotChanged = a.equals(bottomAggregate);
if (sortNotChanged && aggNotChanged) {
return null;
}
if (aggNotChanged) {
// since sort expr must in select list, we should not change agg at all.
return new LogicalSort<>(newOrderKeys, sort.child());
} else {
Set<NamedExpression> upperAggOutputs = Sets.newHashSet(upperAggregate.getOutputExpressions());
for (int i = 0; i < newOrderKeys.size(); i++) {
OrderKey orderKey = newOrderKeys.get(i);
Expression expression = orderKey.getExpr();
if (!upperAggOutputs.containsAll(expression.getInputSlots())) {
throw new AnalysisException(sort.getOrderKeys().get(i).getExpr().toSql()
+ " of ORDER BY clause is not in SELECT list");
}
}
throw new AnalysisException("Expression of ORDER BY clause is not in SELECT list");
}
});
}
}