AggScalarSubQueryToWindowFunction.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.algebra.Filter;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * change the plan:
 * logicalFilter(logicalApply(any(), logicalAggregate()))
 * to
 * logicalProject(logicalFilter(logicalWindow(logicalFilter(any()))))
 * <p>
 * refer paper: WinMagic - Subquery Elimination Using Window Aggregation
 * <p>
 * TODO: use materialized view pattern match to do outer and inner tree match.
 */

public class AggScalarSubQueryToWindowFunction extends DefaultPlanRewriter<JobContext> implements CustomRewriter {

    private static final Set<Class<? extends LogicalPlan>> OUTER_SUPPORTED_PLAN = ImmutableSet.of(
            LogicalJoin.class,
            LogicalProject.class,
            LogicalRelation.class
    );

    private static final Set<Class<? extends LogicalPlan>> INNER_SUPPORTED_PLAN = ImmutableSet.of(
            LogicalAggregate.class,
            LogicalFilter.class,
            LogicalJoin.class,
            LogicalProject.class,
            LogicalRelation.class
    );

    private final List<LogicalPlan> outerPlans = Lists.newArrayList();
    private final List<LogicalPlan> innerPlans = Lists.newArrayList();
    private final List<AggregateFunction> functions = Lists.newArrayList();
    private final Map<Expression, Expression> innerOuterSlotMap = Maps.newHashMap();

    /**
     * the entrance of this rule. we only override one visitor: visitLogicalFilter
     * because we need to process the filter of outer plan. It is on the top of Apply.
     */
    @Override
    public Plan rewriteRoot(Plan plan, JobContext context) {
        if (!plan.containsType(Filter.class, LogicalApply.class, LogicalAggregate.class)) {
            return plan;
        }
        return plan.accept(this, context);
    }

    /**
     * we need to process Filter and Apply, but sometimes there are project between Filter and Apply.
     * According to {@link org.apache.doris.nereids.rules.analysis.SubqueryToApply} rule. The project
     * is used to project apply output to original output, it is not affect this rule at all. so we ignore it.
     */
    @Override
    public Plan visitLogicalFilter(LogicalFilter<? extends Plan> plan, JobContext context) {
        LogicalFilter<? extends Plan> filter = visitChildren(this, plan, context);
        return findApply(filter)
                .filter(a -> check(filter, a))
                .map(a -> rewrite(filter, a))
                .orElse(filter);
    }

    private Optional<LogicalApply<Plan, Plan>> findApply(LogicalFilter<? extends Plan> filter) {
        return Optional.of(filter.child())
                .map(p -> p instanceof LogicalProject ? p.child(0) : p)
                .filter(LogicalApply.class::isInstance)
                .map(p -> (LogicalApply<Plan, Plan>) p);
    }

    private boolean check(LogicalFilter<? extends Plan> outerFilter, LogicalApply<Plan, Plan> apply) {
        outerPlans.addAll(apply.child(0).collect(LogicalPlan.class::isInstance));
        innerPlans.addAll(apply.child(1).collect(LogicalPlan.class::isInstance));

        return checkPlanType()
                && checkApply(apply)
                && checkAggregate()
                && checkJoin()
                && checkProject()
                && checkRelation(apply.getCorrelationSlot())
                && checkFilter(outerFilter)
                && checkUniqueCorrelatedTable(apply.getCorrelationSlot());
    }

    // check children's nodes because query process will be changed
    private boolean checkPlanType() {
        return outerPlans.stream().allMatch(p -> OUTER_SUPPORTED_PLAN.stream().anyMatch(c -> c.isInstance(p)))
                && innerPlans.stream().allMatch(p -> INNER_SUPPORTED_PLAN.stream().anyMatch(c -> c.isInstance(p)));
    }

    /**
     * Apply should be
     *   1. scalar
     *   2. is not mark join
     *   3. is correlated
     *   4. correlated conjunct should be {@link ComparisonPredicate}
     *   5. the top plan of Apply inner should be {@link LogicalAggregate}
     */
    private boolean checkApply(LogicalApply<Plan, Plan> apply) {
        return apply.isScalar()
                && !apply.isMarkJoin()
                && apply.right() instanceof LogicalAggregate
                && apply.isCorrelated();
    }

    /**
     * check aggregation of inner scope, it should be only one Aggregate and only one AggregateFunction in it
     */
    private boolean checkAggregate() {
        List<LogicalAggregate<Plan>> aggSet = innerPlans.stream().filter(LogicalAggregate.class::isInstance)
                .map(p -> (LogicalAggregate<Plan>) p)
                .collect(Collectors.toList());
        if (aggSet.size() != 1) {
            // window functions don't support nesting.
            return false;
        }
        LogicalAggregate<Plan> aggOp = aggSet.get(0);
        functions.addAll(ExpressionUtils.collectAll(
                aggOp.getOutputExpressions(), AggregateFunction.class::isInstance));
        if (functions.size() != 1) {
            return false;
        }
        return functions.stream().allMatch(f -> f instanceof SupportWindowAnalytic && !f.isDistinct());
    }

    /**
     * check inner scope only have one filter. and inner filter is a sub collection of outer filter
     */
    private boolean checkFilter(LogicalFilter<? extends Plan> outerFilter) {
        List<LogicalFilter<Plan>> innerFilters = innerPlans.stream()
                .filter(LogicalFilter.class::isInstance)
                .map(p -> (LogicalFilter<Plan>) p).collect(Collectors.toList());
        if (innerFilters.size() != 1) {
            return false;
        }
        Set<Expression> outerConjunctSet = Sets.newHashSet(outerFilter.getConjuncts());
        Set<Expression> innerConjunctSet = innerFilters.get(0).getConjuncts().stream()
                .map(e -> ExpressionUtils.replace(e, innerOuterSlotMap))
                .collect(Collectors.toSet());
        Iterator<Expression> innerIterator = innerConjunctSet.iterator();
        // inner predicate should be the sub-set of outer predicate.
        while (innerIterator.hasNext()) {
            Expression innerExpr = innerIterator.next();
            Iterator<Expression> outerIterator = outerConjunctSet.iterator();
            while (outerIterator.hasNext()) {
                Expression outerExpr = outerIterator.next();
                if (ExpressionIdenticalChecker.INSTANCE.check(innerExpr, outerExpr)) {
                    innerIterator.remove();
                    outerIterator.remove();
                }
            }
        }
        // now the expressions are all like 'expr op literal' or flipped, and whose expr is not correlated.
        return innerConjunctSet.isEmpty();
    }

    /**
     * check join to ensure no condition on it.
     * this is because we cannot do accurate pattern match between outer scope and inner scope
     * so, we currently forbid join with condition here.
     */
    private boolean checkJoin() {
        return outerPlans.stream()
                .filter(LogicalJoin.class::isInstance)
                .map(p -> (LogicalJoin<Plan, Plan>) p)
                .noneMatch(j -> j.getOnClauseCondition().isPresent())
                && innerPlans.stream()
                .filter(LogicalJoin.class::isInstance)
                .map(p -> (LogicalJoin<Plan, Plan>) p)
                .noneMatch(j -> j.getOnClauseCondition().isPresent());
    }

    /**
     * check inner and outer project to ensure no project except column pruning
     */
    private boolean checkProject() {
        return outerPlans.stream()
                .filter(LogicalProject.class::isInstance)
                .map(p -> (LogicalProject<Plan>) p)
                .allMatch(p -> p.getExpressions().stream().allMatch(SlotReference.class::isInstance))
                && innerPlans.stream()
                .filter(LogicalProject.class::isInstance)
                .map(p -> (LogicalProject<Plan>) p)
                .allMatch(p -> p.getExpressions().stream().allMatch(SlotReference.class::isInstance));
    }

    /**
     * check inner and outer relation
     * 1. outer table size - inner table size must equal to 1
     * 2. outer table list - inner table list should only remain 1 table
     * 3. the remaining table in step 2 should be correlated table for inner plan
     */
    private boolean checkRelation(List<Slot> correlatedSlots) {
        List<CatalogRelation> outerTables = outerPlans.stream().filter(CatalogRelation.class::isInstance)
                .map(CatalogRelation.class::cast)
                .collect(Collectors.toList());
        List<CatalogRelation> innerTables = innerPlans.stream().filter(CatalogRelation.class::isInstance)
                .map(CatalogRelation.class::cast)
                .collect(Collectors.toList());

        List<Long> outerIds = outerTables.stream().map(node -> node.getTable().getId()).collect(Collectors.toList());
        List<Long> innerIds = innerTables.stream().map(node -> node.getTable().getId()).collect(Collectors.toList());
        if (Sets.newHashSet(outerIds).size() != outerIds.size()
                || Sets.newHashSet(innerIds).size() != innerIds.size()) {
            return false;
        }
        if (outerIds.size() - innerIds.size() != 1) {
            return false;
        }
        innerIds.forEach(outerIds::remove);
        if (outerIds.size() != 1) {
            return false;
        }

        createSlotMapping(outerTables, innerTables);

        Set<ExprId> correlatedRelationOutput = outerTables.stream()
                .filter(node -> outerIds.contains(node.getTable().getId()))
                .map(LogicalRelation.class::cast)
                .map(LogicalRelation::getOutputExprIdSet).flatMap(Collection::stream).collect(Collectors.toSet());
        return correlatedSlots.stream().allMatch(e -> correlatedRelationOutput.contains(e.getExprId()));
    }

    /**
     * The correlated columns of the outer-only table must form a unique key
     * for the WinMagic window-function rewrite to be correct. Without uniqueness,
     * the window function may aggregate over duplicated outer rows, producing
     * wrong results for aggregates like SUM and COUNT.
     * <p>
     * Uses {@link DataTrait#isUnique(Set)} which covers both OLAP key metadata
     * (PRIMARY_KEYS / UNIQUE_KEYS) and declared constraints (PRIMARY KEY / UNIQUE).
     */
    private boolean checkUniqueCorrelatedTable(List<Slot> correlatedSlots) {
        List<CatalogRelation> outerTables = outerPlans.stream()
                .filter(CatalogRelation.class::isInstance)
                .map(CatalogRelation.class::cast)
                .collect(Collectors.toList());
        List<CatalogRelation> innerTables = innerPlans.stream()
                .filter(CatalogRelation.class::isInstance)
                .map(CatalogRelation.class::cast)
                .collect(Collectors.toList());

        List<Long> outerIds = outerTables.stream().map(node -> node.getTable().getId()).collect(Collectors.toList());
        List<Long> innerIds = innerTables.stream().map(node -> node.getTable().getId()).collect(Collectors.toList());

        innerIds.forEach(outerIds::remove);
        if (outerIds.size() != 1) {
            return true;
        }

        CatalogRelation outerOnlyTable = outerTables.stream()
                .filter(node -> outerIds.contains(node.getTable().getId()))
                .findFirst().orElse(null);
        if (outerOnlyTable == null) {
            return true;
        }

        // Check uniqueness via DataTrait on the correlated (outer-only) table.
        DataTrait dataTrait = outerOnlyTable.getLogicalProperties().getTrait();
        return dataTrait.isUnique(Sets.newHashSet(correlatedSlots));
    }

    private void createSlotMapping(List<CatalogRelation> outerTables, List<CatalogRelation> innerTables) {
        for (CatalogRelation outerTable : outerTables) {
            for (CatalogRelation innerTable : innerTables) {
                if (innerTable.getTable().getId() == outerTable.getTable().getId()) {
                    for (Slot innerSlot : innerTable.getOutput()) {
                        for (Slot outerSlot : outerTable.getOutput()) {
                            if (innerSlot.getName().equals(outerSlot.getName())) {
                                innerOuterSlotMap.put(innerSlot, outerSlot);
                                break;
                            }
                        }
                    }
                    break;
                }
            }
        }
    }

    /**
     * Rewrite a correlated scalar subquery into a Window function.
     *
     * <h3>Input Plan Shape</h3>
     * <pre>
     * Filter(pred_shared + pred_outer-only + pred_correlated)
     *   Apply(correlation: outer-only.col)
     *     Join / CrossJoin
     *       Scan(shared_tbl)   -- appears in both outer & inner
     *       Scan(outer-only_tbl)  -- only in outer, correlated table
     *     Aggregate(agg_func)
     *       Filter(inner_correlated_pred)
     *         Scan(shared_tbl)
     * </pre>
     *
     * <h3>Output Plan Shape</h3>
     * <pre>
     * Filter(pred_shared + window_comparison)  -- shared-only preds stay ABOVE window
     *   Window(agg_func OVER (PARTITION BY outer-only.cols))
     *     Filter(pred_outer-only + join_cond)  -- outer-only preds go BELOW window
     *       Join / CrossJoin
     *         Scan(shared_tbl)
     *         Scan(outer-only_tbl)
     * </pre>
     *
     * <h3>Key Correctness Rule</h3>
     * Predicates that reference ONLY shared-table columns (tables appearing in both
     * outer and inner plans) MUST stay above the Window. Otherwise the window
     * function would see fewer rows than the original scalar subquery.
     *
     * <p>Example: Given fact(f) as shared table and dim(d) as outer-only table,
     * with the query:
     * <pre>
     *   SELECT ... FROM fact f, dim d
     *   WHERE f.k = d.k AND f.v > 6
     *     AND f.v * 2 > (SELECT SUM(f2.v) FROM fact f2 WHERE f2.k = d.k)
     * </pre>
     *
     * The predicate {@code f.v > 6} references only shared-table columns, so it
     * must stay above the window. The window computes SUM over ALL fact rows
     * per d.k, matching the original scalar subquery semantics.
     */
    private Plan rewrite(LogicalFilter<? extends Plan> filter, LogicalApply<Plan, Plan> apply) {
        Preconditions.checkArgument(apply.right() instanceof LogicalAggregate,
                "right child of Apply should be LogicalAggregate");
        LogicalAggregate<Plan> agg = (LogicalAggregate<Plan>) apply.right();

        // transform algorithm
        // first: find the slot in outer scope corresponding to the slot in aggregate function in inner scope.
        // second: find the aggregation function in inner scope, and replace it to window function, and the aggregate
        // slot is the slot in outer scope in the first step.
        // third: the expression containing aggregation function in inner scope will be the child of an alias,
        // so in the predicate between outer and inner, we change the alias to expression which is the alias's child,
        // and change the aggregation function to the alias of window function.

        // for example, in tpc-h Q17
        // window filter conjuncts is
        // cast(l_quantity#id1 as decimal(27, 9)) < `0.2 * avg(l_quantity)`#id2
        // and
        // 0.2 * avg(l_quantity#id3) as `0.2 * l_quantity`#id2
        // is aggregate's output expression
        // we change it to
        // cast(l_quantity#id1 as decimal(27, 9)) < 0.2 * `avg(l_quantity#id1) over(window)`#id4
        // and
        // avg(l_quantity#id1) over(window) as `avg(l_quantity#id1) over(window)`#id4

        // it's a simple case, but we may meet some complex cases in ut.
        // TODO: support compound predicate and multi apply node.

        Map<Boolean, Set<Expression>> conjuncts = filter.getConjuncts().stream()
                .collect(Collectors.groupingBy(conjunct -> Sets
                        .intersection(conjunct.getInputSlotExprIds(), agg.getOutputExprIdSet())
                        .isEmpty(), Collectors.toSet()));
        Set<Expression> correlatedConjuncts = conjuncts.get(false);
        if (correlatedConjuncts.isEmpty() || correlatedConjuncts.size() > 1
                || !(correlatedConjuncts.iterator().next() instanceof ComparisonPredicate)) {
            //TODO: only support simple comparison predicate now
            return filter;
        }
        Expression windowFilterConjunct = correlatedConjuncts.iterator().next();
        windowFilterConjunct = PlanUtils.maybeCommuteComparisonPredicate(
                (ComparisonPredicate) windowFilterConjunct, apply.left());

        AggregateFunction function = functions.get(0);
        if (function instanceof NullableAggregateFunction) {
            // adjust agg function's nullable.
            function = ((NullableAggregateFunction) function).withAlwaysNullable(false);
        }

        WindowExpression windowFunction = createWindowFunction(apply.getCorrelationSlot(),
                (AggregateFunction) ExpressionUtils.replace(function, innerOuterSlotMap));
        NamedExpression windowFunctionAlias = new Alias(windowFunction);

        // build filter conjunct, get the alias of the agg output and extract its child.
        // then replace the agg to window function, then build conjunct
        // we ensure aggOut is Alias.
        NamedExpression aggOut = agg.getOutputExpressions().get(0);
        Expression aggOutExpr = aggOut.child(0);
        // change the agg function to window function alias.
        aggOutExpr = ExpressionUtils.replace(aggOutExpr, ImmutableMap
                .of(functions.get(0), windowFunctionAlias.toSlot()));

        windowFilterConjunct = ExpressionUtils.replace(windowFilterConjunct,
                ImmutableMap.of(aggOut.toSlot(), aggOutExpr));

        // Split uncorrelated conjuncts: predicates that reference ONLY shared
        // relation slots (tables appearing in both outer and inner plans) must
        // stay ABOVE the window. Otherwise the window function would see a
        // different set of rows than the original scalar subquery.
        //
        // For example, with fact(f) as shared table and dim(d) as outer-only:
        //   f.v > 6        → shared-only → must stay above the window
        //   d.tag > 0      → outer-only  → safe below the window
        //   f.k = d.k      → join cond   → needed below the window
        //
        // We find shared tables by comparing table IDs that appear in both
        // outer and inner plans, then collect ALL output slots of those
        // tables (not just columns referenced in the inner query).
        List<CatalogRelation> outerRels = outerPlans.stream()
                .filter(CatalogRelation.class::isInstance)
                .map(CatalogRelation.class::cast)
                .collect(Collectors.toList());
        List<CatalogRelation> innerRels = innerPlans.stream()
                .filter(CatalogRelation.class::isInstance)
                .map(CatalogRelation.class::cast)
                .collect(Collectors.toList());
        Set<Long> innerTableIds = innerRels.stream()
                .map(r -> r.getTable().getId())
                .collect(Collectors.toSet());
        Set<ExprId> sharedOuterExprIds = outerRels.stream()
                .filter(r -> innerTableIds.contains(r.getTable().getId()))
                .flatMap(r -> r.getOutput().stream())
                .map(Slot::getExprId)
                .collect(Collectors.toSet());
        Set<Expression> uncorrelatedConjuncts = conjuncts.get(true);
        Set<Expression> belowWindowConjuncts = Sets.newHashSet();
        Set<Expression> aboveWindowConjuncts = Sets.newHashSet();
        if (uncorrelatedConjuncts != null) {
            for (Expression conj : uncorrelatedConjuncts) {
                boolean hasShared = false;
                boolean hasNonShared = false;
                for (ExprId id : conj.getInputSlotExprIds()) {
                    if (sharedOuterExprIds.contains(id)) {
                        hasShared = true;
                    } else {
                        hasNonShared = true;
                    }
                }
                if (hasShared && !hasNonShared) {
                    // References ONLY shared table slots → must stay above
                    aboveWindowConjuncts.add(conj);
                } else {
                    // References outer-only slots (possibly also shared) → below
                    belowWindowConjuncts.add(conj);
                }
            }
        }

        LogicalFilter<Plan> newFilter = filter.withConjunctsAndChild(belowWindowConjuncts, apply.left());
        LogicalWindow<Plan> newWindow = new LogicalWindow<>(ImmutableList.of(windowFunctionAlias), newFilter);

        // Combine shared-table predicates with the window comparison predicate above the window
        Set<Expression> topConjuncts = Sets.newHashSet(windowFilterConjunct);
        topConjuncts.addAll(aboveWindowConjuncts);
        LogicalFilter<Plan> windowFilter = new LogicalFilter<>(ImmutableSet.copyOf(topConjuncts), newWindow);
        return windowFilter;
    }

    private WindowExpression createWindowFunction(List<Slot> correlatedSlots, AggregateFunction function) {
        // partition by clause is set by all the correlated slots.
        return new WindowExpression(function, ImmutableList.copyOf(correlatedSlots), Collections.emptyList());
    }

    private static class ExpressionIdenticalChecker extends DefaultExpressionVisitor<Boolean, Expression> {
        public static final ExpressionIdenticalChecker INSTANCE = new ExpressionIdenticalChecker();

        public boolean check(Expression expression, Expression expression1) {
            return expression.accept(this, expression1);
        }

        private boolean isClassMatch(Object o1, Object o2) {
            return o1.getClass().equals(o2.getClass());
        }

        private boolean isSameChild(Expression expression, Expression expression1) {
            if (expression.children().size() != expression1.children().size()) {
                return false;
            }
            for (int i = 0; i < expression.children().size(); ++i) {
                if (!expression.children().get(i).accept(this, expression1.children().get(i))) {
                    return false;
                }
            }
            return true;
        }

        @Override
        public Boolean visit(Expression expression, Expression expression1) {
            return isClassMatch(expression, expression1) && isSameChild(expression, expression1);
        }

        @Override
        public Boolean visitSlotReference(SlotReference slotReference, Expression other) {
            return slotReference.equals(other);
        }

        @Override
        public Boolean visitLiteral(Literal literal, Expression other) {
            return literal.equals(other);
        }

        @Override
        public Boolean visitComparisonPredicate(ComparisonPredicate cp, Expression other) {
            return cp.equals(other) || cp.commute().equals(other);
        }
    }
}