PullUpProjectExprUnderTopN.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.NoneMovableFunction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Pull up non-trivial expressions from Projects below TopN to above TopN,
 * exposing their input base columns as lazy materialization candidates.
 *
 * <p>Two-pass CustomRewriter:
 * <ol>
 * <li><b>Collector (top-down)</b>: walk the plan tree, find qualifying TopNs,
 *     walk into their descendants to find Projects with pull-able expressions.
 *     Any operator that references a slot blocks pulling up expressions that
 *     output that slot past it. Boundary nodes (Aggregate, Window, Repeat,
 *     Relation, CTEProducer) stop the walk.
 *     Set operators are treated as blockers for the current TopN but their
 *     children are still traversed so nested TopNs inside them are visited.</li>
 * <li><b>Replacer (bottom-up)</b>: simplify found Projects and add upper
 *     Projects above TopN to restore pulled-up expressions.</li>
 * </ol>
 */
public class PullUpProjectExprUnderTopN implements CustomRewriter {

    @Override
    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
        ConnectContext ctx = jobContext.getCascadesContext()
                .getStatementContext().getConnectContext();
        if (ctx != null && !ctx.getSessionVariable().enableTopnExprPullup) {
            return plan;
        }

        // Pass 1: Collect pull-up info
        CollectorContext collectorCtx = new CollectorContext();
        plan.accept(new Collector(), collectorCtx);

        if (collectorCtx.topNToPullUpInfo.isEmpty()) {
            return plan;
        }

        // Deduplicate: when nested TopNs both try to pull up the same expression
        // from the same Project, keep it only in the outermost TopN.
        deduplicatePullUps(collectorCtx);

        // Pass 2: Replace/restructure
        return plan.accept(new Replacer(), collectorCtx);
    }

    // =========================================================================
    // Data structures
    // =========================================================================

    /** Info collected per TopN about which expressions to pull up from which Projects. */
    static class PullUpInfo {
        final LogicalTopN topN;
        final List<Slot> originalTopNOutput;
        final List<NamedExpression> allPulledUpExprs = new ArrayList<>();
        final Map<LogicalProject<? extends Plan>, List<NamedExpression>> projectToPulledUpExprs
                = new LinkedHashMap<>();
        final Map<ExprId, List<Slot>> baseSlotsByExpr = new HashMap<>();

        PullUpInfo(LogicalTopN topN) {
            this.topN = topN;
            this.originalTopNOutput = ImmutableList.copyOf(topN.getOutput());
        }

        void addPulledUpExpr(LogicalProject<? extends Plan> project, NamedExpression expr) {
            allPulledUpExprs.add(expr);
            projectToPulledUpExprs.computeIfAbsent(project, k -> new ArrayList<>()).add(expr);
            baseSlotsByExpr.put(expr.getExprId(), ImmutableList.copyOf(expr.getInputSlots()));
        }
    }

    /** Context shared between collector and replacer passes. */
    static class CollectorContext {
        final Map<LogicalTopN, PullUpInfo> topNToPullUpInfo = new LinkedHashMap<>();
        int cteProducerDepth = 0;

        boolean hasPullUpInfo(LogicalTopN topN) {
            return topNToPullUpInfo.containsKey(topN);
        }

        PullUpInfo getPullUpInfo(LogicalTopN topN) {
            return topNToPullUpInfo.get(topN);
        }

    }

    // =========================================================================
    // Pass 1: Collector (top-down)
    // =========================================================================

    private static boolean qualifiesForLazyMatThreshold(LogicalTopN topN) {
        long limit = topN.getLimit();
        if (limit <= 0) {
            return false;
        }
        long threshold = SessionVariable.getTopNLazyMaterializationThreshold();
        return threshold >= limit;
    }

    static class Collector extends DefaultPlanRewriter<CollectorContext> {

        @Override
        public Plan visitLogicalCTEProducer(
                LogicalCTEProducer<? extends Plan> cteProducer, CollectorContext context) {
            context.cteProducerDepth++;
            try {
                return visit(cteProducer, context);
            } finally {
                context.cteProducerDepth--;
            }
        }

        @Override
        public Plan visitLogicalTopN(LogicalTopN topN, CollectorContext context) {
            if (context.cteProducerDepth > 0
                    || !qualifiesForLazyMatThreshold(topN)) {
                return visit(topN, context);
            }
            PullUpInfo info = new PullUpInfo(topN);
            // Seed blockedExprIds with this TopN's order key ExprIds so that
            // expressions used by order keys are not pulled up past this TopN.
            Set<ExprId> blockedExprIds = buildOrderKeyExprIds(topN);
            collectFromNode((Plan) topN.child(0), info, blockedExprIds);
            if (!info.allPulledUpExprs.isEmpty()) {
                context.topNToPullUpInfo.put(topN, info);
            }
            return visit(topN, context);
        }
    }

    /**
     * Recursively walk down from a TopN's child to find Projects with pull-able expressions.
     *
     * <p>{@code blockedExprIds} contains ExprIds of slots that are referenced by operators
     * along the path from the TopN to the current node. An expression whose output ExprId
     * is in this set cannot be pulled up past the operators that reference it.
     */
    private static void collectFromNode(Plan node, PullUpInfo info, Set<ExprId> blockedExprIds) {
        if (node instanceof LogicalProject) {
            LogicalProject<? extends Plan> project = (LogicalProject<? extends Plan>) node;
            for (NamedExpression ne : project.getProjects()) {
                if (canPullUp(ne) && !blockedExprIds.contains(ne.getExprId())) {
                    info.addPulledUpExpr(project, ne);
                }
            }
            // Continue into the project's child. Chained projects are all visited.
            collectFromNode((Plan) project.child(0), info, blockedExprIds);
            return;
        }

        if (node instanceof LogicalTopN) {
            LogicalTopN inner = (LogicalTopN) node;
            // TopN preserves all input columns, so it doesn't block by itself.
            // However, its order keys consume slots, so add them to blocked set.
            // Do NOT reset blockedExprIds ��� intermediate operators between the
            // outer and inner TopN must still block expressions.
            Set<ExprId> newBlocked = new HashSet<>(blockedExprIds);
            newBlocked.addAll(buildOrderKeyExprIds(inner));
            collectFromNode((Plan) inner.child(0), info, newBlocked);
            return;
        }

        // Stop at boundary nodes that transform the schema or are data sources.
        if (node instanceof LogicalRelation || node instanceof LogicalCTEProducer
                || isBlockingNode(node)) {
            return;
        }

        // Set operations are a boundary for the current TopN: do NOT collect
        // expressions from below them. UNION ALL children may compute the same
        // output column with different expressions (e.g. a+1 vs a+2), and a
        // single pull-up Project above the TopN cannot represent branch-specific
        // semantics. The normal visitor will still traverse into the children,
        // so nested TopNs inside set operations are handled independently.
        if (node instanceof LogicalSetOperation) {
            return;
        }

        // For all other nodes, add their input slot ExprIds to the blocked set.
        // Any operator that references a slot in its expressions prevents
        // expressions that output that slot from being pulled up past it.
        Set<ExprId> newBlocked = new HashSet<>(blockedExprIds);
        for (Expression expr : node.getExpressions()) {
            newBlocked.addAll(expr.getInputSlotExprIds());
            if (expr instanceof NamedExpression) {
                newBlocked.add(((NamedExpression) expr).getExprId());
            }
        }

        for (Plan child : node.children()) {
            collectFromNode(child, info, newBlocked);
        }
    }

    // =========================================================================
    // Pull-up eligibility
    // =========================================================================

    /**
     * Check if a named expression can be pulled up above TopN.
     * Eligible: Alias with non-trivial child, not blocked, no NoneMovableFunction.
     */
    static boolean canPullUp(NamedExpression ne) {
        if (!(ne instanceof Alias)) {
            return false;
        }
        Expression child = ((Alias) ne).child();
        if (child instanceof Slot || child instanceof Literal) {
            return false;
        }
        if (ne.anyMatch(e -> e instanceof NoneMovableFunction)) {
            return false;
        }
        return true;
    }

    private static boolean isBlockingNode(Plan node) {
        return node instanceof LogicalAggregate
                || node instanceof LogicalWindow
                || node instanceof LogicalRepeat;
    }

    private static Set<ExprId> buildOrderKeyExprIds(LogicalTopN<?> topN) {
        Set<ExprId> orderKeyExprIds = new HashSet<>();
        for (OrderKey orderKey : topN.getOrderKeys()) {
            Expression keyExpr = orderKey.getExpr();
            orderKeyExprIds.addAll(keyExpr.getInputSlotExprIds());
            if (keyExpr instanceof NamedExpression) {
                orderKeyExprIds.add(((NamedExpression) keyExpr).getExprId());
            }
        }
        return orderKeyExprIds;
    }

    /**
     * Deduplicate pull-up expressions so that each expression in a Project is only
     * pulled up to the outermost TopN that collects it.
     *
     * <p>Since {@link CollectorContext#topNToPullUpInfo} is a {@link LinkedHashMap}
     * and the Collector visits the plan top-down, iteration order is outer-to-inner.
     * We keep the first occurrence of each (project-reference, exprId) pair and
     * remove duplicates from inner TopNs.
     */
    private static void deduplicatePullUps(CollectorContext context) {
        // Use IdentityHashMap because we need to distinguish Project nodes by object
        // reference, not by content equality.
        Map<LogicalProject<? extends Plan>, Set<ExprId>> handled = new IdentityHashMap<>();

        for (PullUpInfo info : context.topNToPullUpInfo.values()) {
            List<NamedExpression> toRemove = new ArrayList<>();
            for (Map.Entry<LogicalProject<? extends Plan>, List<NamedExpression>> entry
                    : info.projectToPulledUpExprs.entrySet()) {
                LogicalProject<? extends Plan> project = entry.getKey();
                Set<ExprId> projectHandled = handled.computeIfAbsent(project, k -> new HashSet<>());
                for (NamedExpression expr : entry.getValue()) {
                    if (projectHandled.contains(expr.getExprId())) {
                        toRemove.add(expr);
                    } else {
                        projectHandled.add(expr.getExprId());
                    }
                }
            }
            for (NamedExpression expr : toRemove) {
                info.allPulledUpExprs.remove(expr);
                for (List<NamedExpression> list : info.projectToPulledUpExprs.values()) {
                    list.removeIf(e -> e == expr);
                }
                info.projectToPulledUpExprs.entrySet().removeIf(e -> e.getValue().isEmpty());
                info.baseSlotsByExpr.remove(expr.getExprId());
            }
        }
    }

    // =========================================================================
    // Pass 2: Replacer (bottom-up)
    // =========================================================================

    static class Replacer extends DefaultPlanRewriter<CollectorContext> {

        @Override
        public Plan visitLogicalProject(LogicalProject<? extends Plan> project, CollectorContext context) {
            LogicalProject<? extends Plan> rewritten = (LogicalProject<? extends Plan>) visit(project, context);

            // Collect ALL pulled-up expressions across ALL PullUpInfos for this
            // project. After dedup, each expression belongs to exactly one TopN
            // (the outermost one that can pull it up). The project needs to be
            // simplified by removing all of them, exposing their base slots once.
            List<NamedExpression> allPulledUpExprs = collectAllPulledUpExprs(context, rewritten);
            if (allPulledUpExprs.isEmpty() && rewritten != project
                    && rewritten.getProjects().equals(project.getProjects())) {
                allPulledUpExprs = collectAllPulledUpExprs(context, project);
            }
            if (allPulledUpExprs.isEmpty()) {
                return rewritten;
            }
            return simplifyProject(rewritten, allPulledUpExprs, context);
        }

        @Override
        public Plan visitLogicalTopN(LogicalTopN topN, CollectorContext context) {
            LogicalTopN rewritten = (LogicalTopN) visit(topN, context);
            PullUpInfo info = context.getPullUpInfo(rewritten);
            if (info == null && rewritten != topN) {
                info = context.getPullUpInfo(topN);
            }
            if (info == null || info.allPulledUpExprs.isEmpty()) {
                return rewritten;
            }
            return addUpperProject(rewritten, info);
        }
    }

    /**
     * Collect all pulled-up expressions across all PullUpInfos for a project.
     * After dedup each expression belongs to exactly one TopN, but the project
     * must be simplified by removing all of them at once.
     */
    private static List<NamedExpression> collectAllPulledUpExprs(
            CollectorContext context, LogicalProject<?> project) {
        List<NamedExpression> result = new ArrayList<>();
        for (PullUpInfo info : context.topNToPullUpInfo.values()) {
            List<NamedExpression> exprs = info.projectToPulledUpExprs.get(project);
            if (exprs != null) {
                result.addAll(exprs);
            }
        }
        return result;
    }

    /** Remove pulled-up expressions from project and add their base input slots. */
    private static LogicalProject<? extends Plan> simplifyProject(
            LogicalProject<? extends Plan> project,
            List<NamedExpression> pulledUpExprs,
            CollectorContext context) {
        if (pulledUpExprs.isEmpty()) {
            return project;
        }

        Set<ExprId> pulledUpExprIds = new HashSet<>();
        for (NamedExpression ne : pulledUpExprs) {
            pulledUpExprIds.add(ne.getExprId());
        }

        List<NamedExpression> simplified = new ArrayList<>();
        Set<ExprId> existingExprIds = new HashSet<>();
        for (NamedExpression ne : project.getProjects()) {
            if (!pulledUpExprIds.contains(ne.getExprId())) {
                simplified.add(ne);
                existingExprIds.add(ne.getExprId());
            }
        }

        for (NamedExpression pulledUpExpr : pulledUpExprs) {
            for (PullUpInfo info : context.topNToPullUpInfo.values()) {
                List<Slot> baseSlots = info.baseSlotsByExpr.get(pulledUpExpr.getExprId());
                if (baseSlots != null) {
                    for (Slot baseSlot : baseSlots) {
                        if (!existingExprIds.contains(baseSlot.getExprId())) {
                            simplified.add(baseSlot);
                            existingExprIds.add(baseSlot.getExprId());
                        }
                    }
                    break; // found, no need to check other PullUpInfos
                }
            }
        }

        if (simplified.equals(project.getProjects())) {
            return project;
        }
        return (LogicalProject<? extends Plan>) project.withProjects(simplified);
    }

    /** Create a new Project above the TopN that restores pulled-up expressions. */
    private static LogicalProject<Plan> addUpperProject(LogicalTopN topN, PullUpInfo info) {
        Map<ExprId, NamedExpression> pulledUpBySlotExprId = new HashMap<>();
        for (NamedExpression e : info.allPulledUpExprs) {
            pulledUpBySlotExprId.put(e.toSlot().getExprId(), e);
        }

        // Use the current (possibly rewritten) TopN's output so that slots
        // whose expressions were deduplicated to an outer TopN reference
        // the correct post-simplification ExprIds instead of stale ones.
        List<Slot> currentOutput = topN.getOutput();
        List<NamedExpression> upperOutput = new ArrayList<>();
        for (int i = 0; i < info.originalTopNOutput.size(); i++) {
            Slot origSlot = info.originalTopNOutput.get(i);
            NamedExpression pulledUpExpr = pulledUpBySlotExprId.get(origSlot.getExprId());
            if (pulledUpExpr != null) {
                upperOutput.add(pulledUpExpr);
            } else if (i < currentOutput.size()) {
                // Slot was not pulled up: use the current slot at the same
                // position so the ExprId matches the rewritten child subtree.
                upperOutput.add(currentOutput.get(i));
            }
            // else: slot was deduplicated to an outer TopN and its base slot
            // was already present in the simplified lower project ��� skip.
        }

        return new LogicalProject<>(ImmutableList.copyOf(upperOutput), topN);
    }
}