PullUpProjectExprUnderTopN.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.NoneMovableFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.L2DistanceApproximate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Score;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Pull up non-trivial expressions from Projects below TopN to above TopN,
* exposing their input base columns as lazy materialization candidates.
*
* <p>The rewriter runs bottom-up. Each LogicalTopN is treated as the current
* target TopN after its child has already been rewritten. The target TopN then
* collects only Projects in its own child subtree, stops at nested TopNs, and
* adds one upper Project to restore the original TopN output. This lets an
* upper TopN pull expressions that were just restored above a lower TopN.
*/
public class PullUpProjectExprUnderTopN implements CustomRewriter {
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
ConnectContext ctx = jobContext.getCascadesContext()
.getStatementContext().getConnectContext();
if (ctx != null && !ctx.getSessionVariable().enableTopnExprPullup) {
return plan;
}
return plan.accept(new Rewriter(), new RewriteContext());
}
// =========================================================================
// Data structures
// =========================================================================
/** Info collected per TopN about which expressions to pull up from which Projects. */
static class PullUpInfo {
final LogicalTopN topN;
final List<Slot> originalTopNOutput;
final List<NamedExpression> allPulledUpExprs = new ArrayList<>();
final Map<LogicalProject<? extends Plan>, List<NamedExpression>> projectToPulledUpExprs
= new LinkedHashMap<>();
final Map<Slot, Expression> pullUpExprReplaceMap = new LinkedHashMap<>();
PullUpInfo(LogicalTopN topN) {
this.topN = topN;
this.originalTopNOutput = ImmutableList.copyOf(topN.getOutput());
}
void addPulledUpExpr(LogicalProject<? extends Plan> project, NamedExpression expr) {
allPulledUpExprs.add(expr);
projectToPulledUpExprs.computeIfAbsent(project, k -> new ArrayList<>()).add(expr);
addPullUpExprReplace(expr);
}
void addPullUpExprReplace(NamedExpression expr) {
if (expr instanceof Alias) {
pullUpExprReplaceMap.putIfAbsent(expr.toSlot(), expr.child(0));
}
}
}
/** Context for the bottom-up traversal. */
static class RewriteContext {
int cteProducerDepth = 0;
}
// =========================================================================
// Bottom-up TopN rewriter
// =========================================================================
private static boolean qualifiesForLazyMatThreshold(LogicalTopN topN) {
long limit = topN.getLimit();
if (limit <= 0) {
return false;
}
long threshold = SessionVariable.getTopNLazyMaterializationThreshold();
return threshold >= limit;
}
static class Rewriter extends DefaultPlanRewriter<RewriteContext> {
@Override
public Plan visitLogicalCTEProducer(
LogicalCTEProducer<? extends Plan> cteProducer, RewriteContext context) {
context.cteProducerDepth++;
try {
return visit(cteProducer, context);
} finally {
context.cteProducerDepth--;
}
}
@Override
public Plan visitLogicalTopN(LogicalTopN topN, RewriteContext context) {
LogicalTopN rewritten = (LogicalTopN) visit(topN, context);
if (context.cteProducerDepth > 0 || !qualifiesForLazyMatThreshold(rewritten)) {
return rewritten;
}
PullUpInfo info = new PullUpInfo(rewritten);
// Seed blockedExprIds with this TopN's order key ExprIds so that
// expressions used by order keys are not pulled up past this TopN.
collectFromNode((Plan) rewritten.child(0), info, buildOrderKeyExprIds(rewritten));
if (info.allPulledUpExprs.isEmpty()) {
return rewritten;
}
Plan simplifiedChild = ((Plan) rewritten.child(0)).accept(new ProjectSimplifier(), info);
if (simplifiedChild == rewritten.child(0)) {
return rewritten;
}
LogicalTopN topNWithSimplifiedChild = rewritten.withChildren(ImmutableList.of(simplifiedChild));
return addUpperProject(topNWithSimplifiedChild, info);
}
}
/**
* Recursively walk down from a TopN's child to find Projects with pull-able expressions.
*
* <p>{@code blockedExprIds} contains ExprIds of slots that are referenced by operators
* along the path from the TopN to the current node. An expression whose output ExprId
* is in this set cannot be pulled up past the operators that reference it.
*/
private static void collectFromNode(Plan node, PullUpInfo info, Set<ExprId> blockedExprIds) {
if (node instanceof LogicalProject) {
LogicalProject<? extends Plan> project = (LogicalProject<? extends Plan>) node;
Set<ExprId> childBlockedExprIds = new HashSet<>(blockedExprIds);
for (NamedExpression ne : project.getProjects()) {
info.addPullUpExprReplace(ne);
if (canPullUp(ne) && !blockedExprIds.contains(ne.getExprId())) {
info.addPulledUpExpr(project, ne);
}
if (blockedExprIds.contains(ne.getExprId())) {
childBlockedExprIds.addAll(ne.getInputSlotExprIds());
}
}
// Continue into the project's child. Chained projects are all visited.
collectFromNode((Plan) project.child(0), info, childBlockedExprIds);
return;
}
if (node instanceof LogicalTopN) {
// The bottom-up rewriter has already handled this nested TopN.
// The current target TopN only collects Projects above it.
return;
}
// Stop at boundary nodes that transform the schema or are data sources.
if (node instanceof LogicalRelation || node instanceof LogicalCTEProducer
|| isBlockingNode(node)) {
return;
}
// Set operations are a boundary for the current TopN: do NOT collect
// expressions from below them. UNION ALL children may compute the same
// output column with different expressions (e.g. a+1 vs a+2), and a
// single pull-up Project above the TopN cannot represent branch-specific
// semantics. The normal visitor will still traverse into the children,
// so nested TopNs inside set operations are handled independently.
if (node instanceof LogicalSetOperation) {
return;
}
// For null-generating outer joins, block all output slots from the
// nullable side(s). Expressions inside a nullable side are protected
// by join null-extension: when there is no match, the entire nullable
// tuple is set to NULL. Pulling such an expression above the join
// would break this, e.g. ifnull(r.b, 0) inside the right side of a
// LEFT JOIN would see individual column NULLs and convert them to 0,
// changing the NULL that null-extension produced.
// Example: SELECT l.id, sub.x FROM l LEFT JOIN (
// SELECT id, ifnull(b, 0) AS x FROM r) sub ON l.id = sub.id
// ORDER BY l.id LIMIT 3;
// Here x=ifnull(b,0) is in a Project on the nullable (right) side.
// Pulling it above the join turns unmatched-row x from NULL to 0.
if (node instanceof LogicalJoin) {
LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) node;
JoinType joinType = join.getJoinType();
Set<ExprId> newBlocked = new HashSet<>(blockedExprIds);
// add join expression slots (same as default branch)
for (Expression expr : node.getExpressions()) {
newBlocked.addAll(expr.getInputSlotExprIds());
if (expr instanceof NamedExpression) {
newBlocked.add(((NamedExpression) expr).getExprId());
}
}
// block all output slots from the nullable side(s)
if (joinType.isLeftOuterJoin() || joinType.isAsofLeftOuterJoin()
|| joinType.isFullOuterJoin()) {
for (Slot s : join.right().getOutput()) {
newBlocked.add(s.getExprId());
}
}
if (joinType.isRightOuterJoin() || joinType.isAsofRightOuterJoin()
|| joinType.isFullOuterJoin()) {
for (Slot s : join.left().getOutput()) {
newBlocked.add(s.getExprId());
}
}
for (Plan child : node.children()) {
collectFromNode(child, info, newBlocked);
}
return;
}
// For all other nodes, add their input slot ExprIds to the blocked set.
// Any operator that references a slot in its expressions prevents
// expressions that output that slot from being pulled up past it.
Set<ExprId> newBlocked = new HashSet<>(blockedExprIds);
for (Expression expr : node.getExpressions()) {
newBlocked.addAll(expr.getInputSlotExprIds());
if (expr instanceof NamedExpression) {
newBlocked.add(((NamedExpression) expr).getExprId());
}
}
for (Plan child : node.children()) {
collectFromNode(child, info, newBlocked);
}
}
// =========================================================================
// Pull-up eligibility
// =========================================================================
/**
* Check if a named expression can be pulled up above TopN.
* Eligible: Alias with non-trivial child, not blocked, no NoneMovableFunction.
*/
static boolean canPullUp(NamedExpression ne) {
if (!(ne instanceof Alias)) {
return false;
}
Expression child = ((Alias) ne).child();
if (child instanceof Slot || child instanceof Literal) {
return false;
}
if (ne.anyMatch(e -> e instanceof NoneMovableFunction)) {
return false;
}
if (ne.containsVolatileExpression()) {
return false;
}
if (ne.anyMatch(e -> e instanceof Score)) {
return false;
}
if (ne.anyMatch(e -> e instanceof L2DistanceApproximate)) {
return false;
}
return true;
}
private static boolean isBlockingNode(Plan node) {
return node instanceof LogicalAggregate
|| node instanceof LogicalWindow
|| node instanceof LogicalRepeat;
}
private static Set<ExprId> buildOrderKeyExprIds(LogicalTopN<?> topN) {
Set<ExprId> orderKeyExprIds = new HashSet<>();
for (OrderKey orderKey : topN.getOrderKeys()) {
Expression keyExpr = orderKey.getExpr();
orderKeyExprIds.addAll(keyExpr.getInputSlotExprIds());
if (keyExpr instanceof NamedExpression) {
orderKeyExprIds.add(((NamedExpression) keyExpr).getExprId());
}
}
return orderKeyExprIds;
}
static class ProjectSimplifier extends DefaultPlanRewriter<PullUpInfo> {
@Override
public Plan visitLogicalTopN(LogicalTopN topN, PullUpInfo info) {
return topN;
}
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> project, PullUpInfo info) {
LogicalProject<? extends Plan> rewritten = (LogicalProject<? extends Plan>) visit(project, info);
List<NamedExpression> exprs = info.projectToPulledUpExprs.get(rewritten);
if (exprs != null) {
return simplifyProject(rewritten, exprs, info);
}
if (rewritten != project && rewritten.getProjects().equals(project.getProjects())) {
exprs = info.projectToPulledUpExprs.get(project);
if (exprs != null) {
return simplifyProject(rewritten, exprs, info);
}
}
return simplifyProject(rewritten, ImmutableList.of(), info);
}
}
/**
* Remove pulled-up expressions from this Project and expose their base input slots.
*
* <p>For example, pulling up {@code x = a + 1} from cascaded Projects:
*
* <pre>
* Before:
* TopN
* Project(id, x) -- forwards x from child
* Project(id, a + 1 AS x)
* Scan(id, a)
*
* After simplifyProject (both Projects lose x, gain a):
* TopN
* Project(id, a) -- x removed because child no longer outputs it
* Project(id, a) -- a+1 removed, base slot a exposed
* Scan(id, a)
*
* {@code addUpperProject} then restores the computation above TopN:
* Project(id, a + 1 AS x) -- new upper Project
* TopN
* Project(id, a)
* Project(id, a)
* Scan(id, a)
* </pre>
*/
private static LogicalProject<? extends Plan> simplifyProject(
LogicalProject<? extends Plan> project,
List<NamedExpression> pulledUpExprs,
PullUpInfo info) {
Set<ExprId> childOutputExprIds = ((Plan) project.child(0)).getOutputExprIdSet();
List<Expression> passThroughExprs = collectUnavailablePullUpExprs(project, info, childOutputExprIds);
if (pulledUpExprs.isEmpty() && passThroughExprs.isEmpty()) {
return project;
}
Set<ExprId> pulledUpExprIds = new HashSet<>();
for (NamedExpression ne : pulledUpExprs) {
pulledUpExprIds.add(ne.getExprId());
}
List<NamedExpression> simplified = new ArrayList<>();
Set<ExprId> existingExprIds = new HashSet<>();
for (NamedExpression ne : project.getProjects()) {
if (!pulledUpExprIds.contains(ne.getExprId())) {
Expression replaceExpr = getPullUpReplaceExpression(ne.toSlot(), info);
if (replaceExpr == null || !isUnavailableExpression(ne, childOutputExprIds)) {
NamedExpression resolved = resolveAliasChildIfNeeded(ne, info, childOutputExprIds);
simplified.add(resolved);
existingExprIds.add(resolved.getExprId());
}
}
}
for (NamedExpression pulledUpExpr : pulledUpExprs) {
for (Slot baseSlot : resolveInputSlots(pulledUpExpr.child(0), info, childOutputExprIds)) {
if (!existingExprIds.contains(baseSlot.getExprId())) {
simplified.add(baseSlot);
existingExprIds.add(baseSlot.getExprId());
}
}
}
for (Expression passThroughExpr : passThroughExprs) {
for (Slot baseSlot : resolveInputSlots(passThroughExpr, info, childOutputExprIds)) {
if (!existingExprIds.contains(baseSlot.getExprId())) {
simplified.add(baseSlot);
existingExprIds.add(baseSlot.getExprId());
}
}
}
if (simplified.equals(project.getProjects())) {
return project;
}
return (LogicalProject<? extends Plan>) project.withProjects(simplified);
}
private static List<Expression> collectUnavailablePullUpExprs(
LogicalProject<? extends Plan> project, PullUpInfo info, Set<ExprId> childOutputExprIds) {
List<Expression> passThroughExprs = new ArrayList<>();
for (NamedExpression ne : project.getProjects()) {
Expression replaceExpr = getPullUpReplaceExpression(ne.toSlot(), info);
if (replaceExpr != null && isUnavailableExpression(ne, childOutputExprIds)) {
passThroughExprs.add(replaceExpr);
}
}
return passThroughExprs;
}
/** Check the non-replaceExpr conditions for unavailability.
* Caller must have already verified {@code getPullUpReplaceExpression(ne.toSlot()) != null}. */
private static boolean isUnavailableExpression(NamedExpression ne, Set<ExprId> childOutputExprIds) {
if (ne instanceof Slot) {
return !childOutputExprIds.contains(ne.getExprId());
}
return ne instanceof Alias
&& ne.getInputSlots().stream().anyMatch(slot -> !childOutputExprIds.contains(slot.getExprId()));
}
private static Expression getPullUpReplaceExpression(Slot slot, PullUpInfo info) {
Expression expression = info.pullUpExprReplaceMap.get(slot);
while (expression instanceof Slot) {
Expression next = info.pullUpExprReplaceMap.get((Slot) expression);
if (next == null) {
return expression;
}
expression = next;
}
return expression;
}
/** Create a new Project above the TopN that restores pulled-up expressions. */
private static LogicalProject<Plan> addUpperProject(LogicalTopN topN, PullUpInfo info) {
Map<ExprId, NamedExpression> pulledUpBySlotExprId = new HashMap<>();
Set<ExprId> currentOutputExprIds = topN.getOutputExprIdSet();
for (NamedExpression e : info.allPulledUpExprs) {
pulledUpBySlotExprId.put(e.toSlot().getExprId(), resolveAliasChildIfNeeded(e, info, currentOutputExprIds));
}
List<NamedExpression> upperOutput = new ArrayList<>();
Set<ExprId> upperOutputExprIds = new HashSet<>();
for (int i = 0; i < info.originalTopNOutput.size(); i++) {
Slot origSlot = info.originalTopNOutput.get(i);
NamedExpression pulledUpExpr = pulledUpBySlotExprId.get(origSlot.getExprId());
if (pulledUpExpr != null) {
upperOutput.add(pulledUpExpr);
upperOutputExprIds.add(pulledUpExpr.getExprId());
} else if (currentOutputExprIds.contains(origSlot.getExprId())) {
upperOutput.add(origSlot);
upperOutputExprIds.add(origSlot.getExprId());
} else {
Expression resolvedExpr = resolveExpression(origSlot, info, currentOutputExprIds);
upperOutput.add(new Alias(origSlot.getExprId(), resolvedExpr, origSlot.getName()));
upperOutputExprIds.add(origSlot.getExprId());
}
}
return new LogicalProject<>(ImmutableList.copyOf(upperOutput), topN);
}
private static NamedExpression resolveAliasChildIfNeeded(NamedExpression expr, PullUpInfo info,
Set<ExprId> availableExprIds) {
if (!(expr instanceof Alias)) {
return expr;
}
Expression resolvedChild = resolveExpression(expr.child(0), info, availableExprIds);
if (resolvedChild.equals(expr.child(0))) {
return expr;
}
return new Alias(expr.getExprId(), resolvedChild, expr.getName());
}
private static List<Slot> resolveInputSlots(Expression expr, PullUpInfo info,
Set<ExprId> availableExprIds) {
return ImmutableList.copyOf(resolveExpression(expr, info, availableExprIds).getInputSlots());
}
private static Expression resolveExpression(Expression expression, PullUpInfo info,
Set<ExprId> availableExprIds) {
Expression resolved = replaceUnavailableSlots(expression, info, availableExprIds);
while (!resolved.equals(expression)) {
expression = resolved;
resolved = replaceUnavailableSlots(expression, info, availableExprIds);
}
return resolved;
}
private static Expression replaceUnavailableSlots(Expression expression, PullUpInfo info,
Set<ExprId> availableExprIds) {
Map<Slot, Expression> replaceMap = new LinkedHashMap<>();
for (Map.Entry<Slot, Expression> entry : info.pullUpExprReplaceMap.entrySet()) {
if (!availableExprIds.contains(entry.getKey().getExprId())) {
replaceMap.put(entry.getKey(), entry.getValue());
}
}
return ExpressionUtils.replace(expression, replaceMap);
}
}