FillUpQualifyMissingSlot.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalQualify;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
/**
* We don't fill the missing slots in FillUpMissingSlots.
* Because for distinct queries,
* for example:
* select distinct year,country from sales having year > 2000 qualify row_number() over (order by year + 1) > 1;
* It would be converted into the form of agg.
* before logical plan:
* qualify
* |
* project(distinct)
* |
* scan
* apply ProjectWithDistinctToAggregate rule
* after logical plan:
* qualify
* |
* agg
* |
* scan
* if fill the missing slots in FillUpMissingSlots(after ProjectWithDistinctToAggregate). qualify could hardly be
* pushed under the agg of distinct.
* But apply FillUpQualifyMissingSlot rule before ProjectWithDistinctToAggregate
* logical plan:
* project(distinct)
* |
* qualify
* |
* project
* |
* scan
* and then apply ProjectWithDistinctToAggregate rule
* logical plan:
* agg
* |
* qualify
* |
* project
* |
* scan
* So it is easy to handle.
*/
public class FillUpQualifyMissingSlot extends FillUpMissingSlots {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
/*
qualify -> project
qualify -> project(distinct)
qualify -> project(distinct) -> agg
qualify -> project(distinct) -> having -> agg
*/
RuleType.FILL_UP_QUALIFY_PROJECT.build(
logicalQualify(logicalProject())
.then(qualify -> {
checkWindow(qualify);
LogicalProject<Plan> project = qualify.child();
return createPlan(project, qualify.getConjuncts(), (newConjuncts, projects) -> {
LogicalProject<Plan> bottomProject = new LogicalProject<>(projects, project.child());
LogicalQualify<Plan> logicalQualify = new LogicalQualify<>(newConjuncts, bottomProject);
ImmutableList<NamedExpression> copyOutput = ImmutableList.copyOf(project.getOutput());
return new LogicalProject<>(copyOutput, project.isDistinct(), logicalQualify);
});
})
),
/*
qualify -> agg
*/
RuleType.FILL_UP_QUALIFY_AGGREGATE.build(
logicalQualify(aggregate()).then(qualify -> {
checkWindow(qualify);
Aggregate<Plan> agg = qualify.child();
Resolver resolver = new Resolver(agg);
qualify.getConjuncts().forEach(resolver::resolve);
return createPlan(resolver, agg, (r, a) -> {
Set<Expression> newConjuncts = ExpressionUtils.replace(
qualify.getConjuncts(), r.getSubstitution());
boolean notChanged = newConjuncts.equals(qualify.getConjuncts());
if (notChanged && a.equals(agg)) {
return null;
}
return notChanged ? qualify.withChildren(a) : new LogicalQualify<>(newConjuncts, a);
});
})
),
/*
qualify -> having -> agg
*/
RuleType.FILL_UP_QUALIFY_HAVING_AGGREGATE.build(
logicalQualify(logicalHaving(aggregate())).then(qualify -> {
checkWindow(qualify);
LogicalHaving<Aggregate<Plan>> having = qualify.child();
Aggregate<Plan> agg = qualify.child().child();
Resolver resolver = new Resolver(agg);
qualify.getConjuncts().forEach(resolver::resolve);
return createPlan(resolver, agg, (r, a) -> {
Set<Expression> newConjuncts = ExpressionUtils.replace(
qualify.getConjuncts(), r.getSubstitution());
boolean notChanged = newConjuncts.equals(qualify.getConjuncts());
if (notChanged && a.equals(agg)) {
return null;
}
return notChanged ? qualify.withChildren(having.withChildren(a)) :
new LogicalQualify<>(newConjuncts, having.withChildren(a));
});
})
),
/*
qualify -> having -> project
qualify -> having -> project(distinct)
*/
RuleType.FILL_UP_QUALIFY_HAVING_PROJECT.build(
logicalQualify(logicalHaving(logicalProject())).then(qualify -> {
checkWindow(qualify);
LogicalHaving<LogicalProject<Plan>> having = qualify.child();
LogicalProject<Plan> project = qualify.child().child();
return createPlan(project, qualify.getConjuncts(), (newConjuncts, projects) -> {
ImmutableList<NamedExpression> copyOutput = ImmutableList.copyOf(project.getOutput());
if (project.isDistinct()) {
Set<Slot> missingSlots = having.getExpressions().stream()
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(s -> !projects.contains(s))
.collect(Collectors.toSet());
List<NamedExpression> output = ImmutableList.<NamedExpression>builder()
.addAll(projects).addAll(missingSlots).build();
LogicalQualify<LogicalProject<Plan>> logicalQualify =
new LogicalQualify<>(newConjuncts, new LogicalProject<>(output, project.child()));
return having.withChildren(project.withProjects(copyOutput).withChildren(logicalQualify));
} else {
return new LogicalProject<>(copyOutput, new LogicalQualify<>(newConjuncts,
having.withChildren(project.withProjects(projects))));
}
});
})
)
);
}
interface PlanGenerator {
Plan apply(Set<Expression> newConjuncts, List<NamedExpression> projects);
}
private Plan createPlan(LogicalProject<Plan> project, Set<Expression> conjuncts, PlanGenerator planGenerator) {
Set<Slot> projectOutputSet = project.getOutputSet();
List<NamedExpression> newOutputSlots = Lists.newArrayList();
Set<Expression> newConjuncts = new HashSet<>();
for (Expression conjunct : conjuncts) {
conjunct = conjunct.accept(new DefaultExpressionRewriter<List<NamedExpression>>() {
@Override
public Expression visitWindow(WindowExpression window, List<NamedExpression> context) {
Alias alias = new Alias(window);
context.add(alias);
return alias.toSlot();
}
}, newOutputSlots);
newConjuncts.add(conjunct);
}
Set<Slot> notExistedInProject = conjuncts.stream()
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(s -> !projectOutputSet.contains(s))
.collect(Collectors.toSet());
newOutputSlots.addAll(notExistedInProject);
if (newOutputSlots.isEmpty()) {
return null;
}
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()
.addAll(project.getProjects())
.addAll(newOutputSlots).build();
return planGenerator.apply(newConjuncts, projects);
}
private void checkWindow(LogicalQualify<? extends Plan> qualify) throws AnalysisException {
Set<SlotReference> inputSlots = new HashSet<>();
AtomicBoolean hasWindow = new AtomicBoolean(false);
for (Expression conjunct : qualify.getConjuncts()) {
conjunct.accept(new DefaultExpressionVisitor<Void, Set<SlotReference>>() {
@Override
public Void visitWindow(WindowExpression windowExpression, Set<SlotReference> context) {
hasWindow.set(true);
return null;
}
@Override
public Void visitSlotReference(SlotReference slotReference, Set<SlotReference> context) {
context.add(slotReference);
return null;
}
}, inputSlots);
}
if (hasWindow.get()) {
return;
}
qualify.accept(new DefaultPlanVisitor<Void, Void>() {
private void findWindow(List<NamedExpression> namedExpressions) {
for (NamedExpression slot : namedExpressions) {
if (slot instanceof Alias && slot.child(0) instanceof WindowExpression) {
if (inputSlots.contains(slot.toSlot())) {
hasWindow.set(true);
}
}
}
}
@Override
public Void visitLogicalProject(LogicalProject<? extends Plan> project, Void context) {
findWindow(project.getProjects());
return visit(project, context);
}
@Override
public Void visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
findWindow(aggregate.getOutputExpressions());
return visit(aggregate, context);
}
}, null);
if (!hasWindow.get()) {
throw new AnalysisException("qualify only used for window expression");
}
}
}