PullUpProjectExprUnderTopN.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.NoneMovableFunction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Pull up non-trivial expressions from Projects below TopN to above TopN,
 * exposing their input base columns as lazy materialization candidates.
 *
 * <p>Two-pass CustomRewriter:
 * <ol>
 * <li><b>Collector (top-down)</b>: walk the plan tree, find qualifying TopNs,
 *     walk into their descendants to find Projects with pull-able expressions.</li>
 * <li><b>Replacer (bottom-up)</b>: simplify found Projects and add upper
 *     Projects above TopN to restore pulled-up expressions.</li>
 * </ol>
 */
public class PullUpProjectExprUnderTopN implements CustomRewriter {

    @Override
    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
        ConnectContext ctx = jobContext.getCascadesContext()
                .getStatementContext().getConnectContext();
        if (ctx != null && !ctx.getSessionVariable().enableTopnExprPullup) {
            return plan;
        }

        // Pass 1: Collect pull-up info
        CollectorContext collectorCtx = new CollectorContext();
        plan.accept(new Collector(), collectorCtx);

        if (collectorCtx.topNToPullUpInfo.isEmpty()) {
            return plan;
        }

        // Pass 2: Replace/restructure
        return plan.accept(new Replacer(), collectorCtx);
    }

    // =========================================================================
    // Data structures
    // =========================================================================

    /** Info collected per TopN about which expressions to pull up from which Projects. */
    static class PullUpInfo {
        final LogicalTopN topN;
        final List<Slot> originalTopNOutput;
        final List<NamedExpression> allPulledUpExprs = new ArrayList<>();
        final Map<LogicalProject<? extends Plan>, List<NamedExpression>> projectToPulledUpExprs
                = new LinkedHashMap<>();
        final Map<ExprId, List<Slot>> baseSlotsByExpr = new HashMap<>();

        PullUpInfo(LogicalTopN topN) {
            this.topN = topN;
            this.originalTopNOutput = ImmutableList.copyOf(topN.getOutput());
        }

        void addPulledUpExpr(LogicalProject<? extends Plan> project, NamedExpression expr) {
            allPulledUpExprs.add(expr);
            projectToPulledUpExprs.computeIfAbsent(project, k -> new ArrayList<>()).add(expr);
            baseSlotsByExpr.put(expr.getExprId(), ImmutableList.copyOf(expr.getInputSlots()));
        }
    }

    /** Context shared between collector and replacer passes. */
    static class CollectorContext {
        final Map<LogicalTopN, PullUpInfo> topNToPullUpInfo = new LinkedHashMap<>();
        boolean insideQualifyingTopN = false;
        int cteProducerDepth = 0;

        boolean hasPullUpInfo(LogicalTopN topN) {
            return topNToPullUpInfo.containsKey(topN);
        }

        PullUpInfo getPullUpInfo(LogicalTopN topN) {
            return topNToPullUpInfo.get(topN);
        }

        PullUpInfo getPullUpInfoForProject(LogicalProject<? extends Plan> project) {
            for (PullUpInfo info : topNToPullUpInfo.values()) {
                if (info.projectToPulledUpExprs.containsKey(project)) {
                    return info;
                }
            }
            return null;
        }
    }

    // =========================================================================
    // Pass 1: Collector (top-down)
    // =========================================================================

    static class Collector extends DefaultPlanRewriter<CollectorContext> {

        private static boolean qualifiesForLazyMat(LogicalTopN topN) {
            long limit = topN.getLimit();
            if (limit <= 0) {
                return false;
            }
            long threshold = SessionVariable.getTopNLazyMaterializationThreshold();
            return threshold >= limit;
        }

        @Override
        public Plan visitLogicalCTEProducer(
                LogicalCTEProducer<? extends Plan> cteProducer, CollectorContext context) {
            context.cteProducerDepth++;
            try {
                return visit(cteProducer, context);
            } finally {
                context.cteProducerDepth--;
            }
        }

        @Override
        public Plan visitLogicalTopN(LogicalTopN topN, CollectorContext context) {
            if (context.cteProducerDepth > 0
                    || !qualifiesForLazyMat(topN)
                    || context.insideQualifyingTopN) {
                return visit(topN, context);
            }
            PullUpInfo info = new PullUpInfo(topN);
            Set<ExprId> collectedOutputExprIds = new HashSet<>();
            boolean blocked = walkAndCollect((Plan) topN.child(0), topN, info, collectedOutputExprIds);
            if (!blocked && !info.allPulledUpExprs.isEmpty()) {
                context.topNToPullUpInfo.put(topN, info);
            }
            context.insideQualifyingTopN = true;
            try {
                return visit(topN, context);
            } finally {
                context.insideQualifyingTopN = false;
            }
        }
    }

    /**
     * Walk down from a qualifying TopN's child to find Projects with pull-able expressions.
     *
     * @return true if walk was blocked (expression outputs used by intermediate operator)
     */
    private static boolean walkAndCollect(Plan node, LogicalTopN topN, PullUpInfo info,
            Set<ExprId> collectedOutputExprIds) {
        if (node instanceof LogicalProject) {
            LogicalProject<? extends Plan> project = (LogicalProject<? extends Plan>) node;
            Set<ExprId> orderKeyExprIds = buildOrderKeyExprIds(topN);
            for (NamedExpression ne : project.getProjects()) {
                if (canPullUp(ne, orderKeyExprIds)) {
                    info.addPulledUpExpr(project, ne);
                    collectedOutputExprIds.add(ne.getExprId());
                }
            }
            Plan child = (Plan) project.child(0);
            if (child instanceof LogicalProject) {
                return false;
            }
            if (!isScanNode(child)) {
                return walkAndCollect(child, topN, info, collectedOutputExprIds);
            }
            return false;
        }

        if (node instanceof LogicalJoin) {
            LogicalJoin<? extends Plan, ? extends Plan> join = (LogicalJoin<? extends Plan, ? extends Plan>) node;
            // Walk children first, then check if join conditions reference ANY
            // collected expressions (both from children and from ancestors).
            boolean leftBlocked = walkAndCollect((Plan) join.left(), topN, info, collectedOutputExprIds);
            boolean rightBlocked = walkAndCollect((Plan) join.right(), topN, info, collectedOutputExprIds);
            if (leftBlocked || rightBlocked) {
                return true;
            }
            // Check join conditions against ALL collected outputs (not just new ones),
            // because expressions may have been collected from a Project above this Join.
            if (operatorUsesPulledUpOutputs(node, collectedOutputExprIds)) {
                // Remove only what's referenced by the join condition
                Set<ExprId> referenced = collectReferencedExprIds(join);
                removePulledUpExprsByOutput(info, referenced, collectedOutputExprIds);
            }
            return false;
        }

        if (node instanceof LogicalFilter) {
            if (operatorUsesPulledUpOutputs(node, collectedOutputExprIds)) {
                info.allPulledUpExprs.clear();
                info.projectToPulledUpExprs.clear();
                info.baseSlotsByExpr.clear();
                collectedOutputExprIds.clear();
                return true;
            }
            return walkAndCollect((Plan) ((LogicalFilter<? extends Plan>) node).child(0), topN,
                    info, collectedOutputExprIds);
        }

        // Scan, Agg, Sort, Window, Union, CTE Producer: stop walking
        if (node instanceof LogicalCTEProducer) {
            return false;
        }
        return false;
    }

    private static boolean isScanNode(Plan node) {
        return node instanceof LogicalRelation;
    }

    /**
     * Check whether any expression of an intermediate operator references
     * the output slots of already-collected pull-up expressions.
     */
    private static boolean operatorUsesPulledUpOutputs(Plan node, Set<ExprId> pulledUpOutputExprIds) {
        if (pulledUpOutputExprIds.isEmpty()) {
            return false;
        }
        List<Expression> expressions = new ArrayList<>();
        if (node instanceof LogicalFilter) {
            expressions.addAll(((LogicalFilter<? extends Plan>) node).getConjuncts());
        } else if (node instanceof LogicalJoin) {
            LogicalJoin<? extends Plan, ? extends Plan> join
                    = (LogicalJoin<? extends Plan, ? extends Plan>) node;
            expressions.addAll(join.getHashJoinConjuncts());
            expressions.addAll(join.getOtherJoinConjuncts());
            expressions.addAll(join.getMarkJoinConjuncts());
        }
        for (Expression expr : expressions) {
            for (Slot slot : expr.getInputSlots()) {
                if (pulledUpOutputExprIds.contains(slot.getExprId())) {
                    return true;
                }
            }
        }
        return false;
    }

    private static Set<ExprId> collectReferencedExprIds(LogicalJoin<?, ?> join) {
        Set<ExprId> result = new HashSet<>();
        List<Expression> allConjuncts = new ArrayList<>();
        allConjuncts.addAll(join.getHashJoinConjuncts());
        allConjuncts.addAll(join.getOtherJoinConjuncts());
        allConjuncts.addAll(join.getMarkJoinConjuncts());
        for (Expression expr : allConjuncts) {
            // Collect the expression's own ExprId if it's a NamedExpression
            // (e.g., extracted Alias like expr_cast(d_week_seq1 as BIGINT)#389)
            if (expr instanceof NamedExpression) {
                result.add(((NamedExpression) expr).getExprId());
            }
            // Also collect all input slot ExprIds
            for (Slot slot : expr.getInputSlots()) {
                result.add(slot.getExprId());
            }
        }
        return result;
    }

    private static void removePulledUpExprsByOutput(PullUpInfo info, Set<ExprId> toRemove,
            Set<ExprId> collectedOutputExprIds) {
        Set<ExprId> removed = new HashSet<>();
        List<NamedExpression> remaining = new ArrayList<>();
        for (NamedExpression ne : info.allPulledUpExprs) {
            if (toRemove.contains(ne.getExprId())) {
                removed.add(ne.getExprId());
            } else {
                remaining.add(ne);
            }
        }
        info.allPulledUpExprs.clear();
        info.allPulledUpExprs.addAll(remaining);
        info.projectToPulledUpExprs.values().forEach(
                list -> list.removeIf(e -> toRemove.contains(e.getExprId())));
        info.projectToPulledUpExprs.entrySet().removeIf(
                e -> e.getValue().isEmpty());
        removed.forEach(info.baseSlotsByExpr::remove);
        collectedOutputExprIds.removeAll(toRemove);
    }

    // =========================================================================
    // Pull-up eligibility
    // =========================================================================

    /**
     * Check if a named expression can be pulled up above TopN.
     * Eligible: Alias with non-trivial child, not in order keys, no NoneMovableFunction.
     */
    static boolean canPullUp(NamedExpression ne, Set<ExprId> orderKeyExprIds) {
        if (!(ne instanceof Alias)) {
            return false;
        }
        Expression child = ((Alias) ne).child();
        if (child instanceof Slot || child instanceof Literal) {
            return false;
        }
        if (orderKeyExprIds.contains(ne.getExprId())) {
            return false;
        }
        if (ne.anyMatch(e -> e instanceof NoneMovableFunction)) {
            return false;
        }
        return true;
    }

    private static Set<ExprId> buildOrderKeyExprIds(LogicalTopN topN) {
        Set<ExprId> orderKeyExprIds = new HashSet<>();
        for (Object obj : topN.getOrderKeys()) {
            OrderKey ok = (OrderKey) obj;
            Expression keyExpr = ok.getExpr();
            if (keyExpr instanceof NamedExpression) {
                orderKeyExprIds.add(((NamedExpression) keyExpr).getExprId());
            }
            for (Slot slot : keyExpr.getInputSlots()) {
                orderKeyExprIds.add(slot.getExprId());
            }
        }
        return orderKeyExprIds;
    }

    // =========================================================================
    // Pass 2: Replacer (bottom-up)
    // =========================================================================

    static class Replacer extends DefaultPlanRewriter<CollectorContext> {

        @Override
        public Plan visitLogicalProject(LogicalProject<? extends Plan> project, CollectorContext context) {
            LogicalProject<? extends Plan> rewritten = (LogicalProject<? extends Plan>) visit(project, context);
            PullUpInfo info = context.getPullUpInfoForProject(rewritten);
            if (info == null && rewritten != project
                    && rewritten.getProjects().equals(project.getProjects())) {
                info = context.getPullUpInfoForProject(project);
            }
            if (info == null) {
                return rewritten;
            }
            return simplifyProject(rewritten, info, project);
        }

        @Override
        public Plan visitLogicalTopN(LogicalTopN topN, CollectorContext context) {
            LogicalTopN rewritten = (LogicalTopN) visit(topN, context);
            if (!context.hasPullUpInfo(rewritten)) {
                return rewritten;
            }
            PullUpInfo info = context.getPullUpInfo(rewritten);
            if (info.allPulledUpExprs.isEmpty()) {
                return rewritten;
            }
            return addUpperProject(rewritten, info);
        }
    }

    /** Remove pulled-up expressions from project and add their base input slots. */
    private static LogicalProject<? extends Plan> simplifyProject(
            LogicalProject<? extends Plan> project, PullUpInfo info, LogicalProject<? extends Plan> original) {
        List<NamedExpression> pulledUpExprs = info.projectToPulledUpExprs.get(original);
        if (pulledUpExprs == null) {
            pulledUpExprs = info.projectToPulledUpExprs.get(project);
        }
        if (pulledUpExprs == null || pulledUpExprs.isEmpty()) {
            return project;
        }

        Set<ExprId> pulledUpExprIds = new HashSet<>();
        for (NamedExpression ne : pulledUpExprs) {
            pulledUpExprIds.add(ne.getExprId());
        }

        List<NamedExpression> simplified = new ArrayList<>();
        Set<ExprId> existingExprIds = new HashSet<>();
        for (NamedExpression ne : project.getProjects()) {
            if (!pulledUpExprIds.contains(ne.getExprId())) {
                simplified.add(ne);
                existingExprIds.add(ne.getExprId());
            }
        }

        for (NamedExpression pulledUpExpr : pulledUpExprs) {
            List<Slot> baseSlots = info.baseSlotsByExpr.get(pulledUpExpr.getExprId());
            if (baseSlots != null) {
                for (Slot baseSlot : baseSlots) {
                    if (!existingExprIds.contains(baseSlot.getExprId())) {
                        simplified.add(baseSlot);
                        existingExprIds.add(baseSlot.getExprId());
                    }
                }
            }
        }

        if (simplified.equals(project.getProjects())) {
            return project;
        }
        return (LogicalProject<? extends Plan>) project.withProjects(simplified);
    }

    /** Create a new Project above the TopN that restores pulled-up expressions. */
    private static LogicalProject<Plan> addUpperProject(LogicalTopN topN, PullUpInfo info) {
        Map<ExprId, NamedExpression> pulledUpBySlotExprId = new HashMap<>();
        for (NamedExpression e : info.allPulledUpExprs) {
            pulledUpBySlotExprId.put(e.toSlot().getExprId(), e);
        }

        List<NamedExpression> upperOutput = new ArrayList<>();
        for (Slot slot : info.originalTopNOutput) {
            NamedExpression pulledUpExpr = pulledUpBySlotExprId.get(slot.getExprId());
            if (pulledUpExpr != null) {
                upperOutput.add(pulledUpExpr);
            } else {
                upperOutput.add(slot);
            }
        }

        return new LogicalProject<>(ImmutableList.copyOf(upperOutput), topN);
    }
}