AggScalarSubQueryToWindowFunction.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * change the plan:
 * logicalFilter(logicalApply(any(), logicalAggregate()))
 * to
 * logicalProject(logicalFilter(logicalWindow(logicalFilter(any()))))
 * <p>
 * refer paper: WinMagic - Subquery Elimination Using Window Aggregation
 * <p>
 * TODO: use materialized view pattern match to do outer and inner tree match.
 */

public class AggScalarSubQueryToWindowFunction extends DefaultPlanRewriter<JobContext> implements CustomRewriter {

    private static final Set<Class<? extends LogicalPlan>> OUTER_SUPPORTED_PLAN = ImmutableSet.of(
            LogicalJoin.class,
            LogicalProject.class,
            LogicalRelation.class
    );

    private static final Set<Class<? extends LogicalPlan>> INNER_SUPPORTED_PLAN = ImmutableSet.of(
            LogicalAggregate.class,
            LogicalFilter.class,
            LogicalJoin.class,
            LogicalProject.class,
            LogicalRelation.class
    );

    private final List<LogicalPlan> outerPlans = Lists.newArrayList();
    private final List<LogicalPlan> innerPlans = Lists.newArrayList();
    private final List<AggregateFunction> functions = Lists.newArrayList();
    private final Map<Expression, Expression> innerOuterSlotMap = Maps.newHashMap();

    /**
     * the entrance of this rule. we only override one visitor: visitLogicalFilter
     * because we need to process the filter of outer plan. It is on the top of Apply.
     */
    @Override
    public Plan rewriteRoot(Plan plan, JobContext context) {
        return plan.accept(this, context);
    }

    /**
     * we need to process Filter and Apply, but sometimes there are project between Filter and Apply.
     * According to {@link org.apache.doris.nereids.rules.analysis.SubqueryToApply} rule. The project
     * is used to project apply output to original output, it is not affect this rule at all. so we ignore it.
     */
    @Override
    public Plan visitLogicalFilter(LogicalFilter<? extends Plan> plan, JobContext context) {
        LogicalFilter<? extends Plan> filter = visitChildren(this, plan, context);
        return findApply(filter)
                .filter(a -> check(filter, a))
                .map(a -> rewrite(filter, a))
                .orElse(filter);
    }

    private Optional<LogicalApply<Plan, Plan>> findApply(LogicalFilter<? extends Plan> filter) {
        return Optional.of(filter.child())
                .map(p -> p instanceof LogicalProject ? p.child(0) : p)
                .filter(LogicalApply.class::isInstance)
                .map(p -> (LogicalApply<Plan, Plan>) p);
    }

    private boolean check(LogicalFilter<? extends Plan> outerFilter, LogicalApply<Plan, Plan> apply) {
        outerPlans.addAll(apply.child(0).collect(LogicalPlan.class::isInstance));
        innerPlans.addAll(apply.child(1).collect(LogicalPlan.class::isInstance));

        return checkPlanType()
                && checkApply(apply)
                && checkAggregate()
                && checkJoin()
                && checkProject()
                && checkRelation(apply.getCorrelationSlot())
                && checkFilter(outerFilter);
    }

    // check children's nodes because query process will be changed
    private boolean checkPlanType() {
        return outerPlans.stream().allMatch(p -> OUTER_SUPPORTED_PLAN.stream().anyMatch(c -> c.isInstance(p)))
                && innerPlans.stream().allMatch(p -> INNER_SUPPORTED_PLAN.stream().anyMatch(c -> c.isInstance(p)));
    }

    /**
     * Apply should be
     *   1. scalar
     *   2. is not mark join
     *   3. is correlated
     *   4. correlated conjunct should be {@link ComparisonPredicate}
     *   5. the top plan of Apply inner should be {@link LogicalAggregate}
     */
    private boolean checkApply(LogicalApply<Plan, Plan> apply) {
        return apply.isScalar()
                && !apply.isMarkJoin()
                && apply.right() instanceof LogicalAggregate
                && apply.isCorrelated();
    }

    /**
     * check aggregation of inner scope, it should be only one Aggregate and only one AggregateFunction in it
     */
    private boolean checkAggregate() {
        List<LogicalAggregate<Plan>> aggSet = innerPlans.stream().filter(LogicalAggregate.class::isInstance)
                .map(p -> (LogicalAggregate<Plan>) p)
                .collect(Collectors.toList());
        if (aggSet.size() != 1) {
            // window functions don't support nesting.
            return false;
        }
        LogicalAggregate<Plan> aggOp = aggSet.get(0);
        functions.addAll(ExpressionUtils.collectAll(
                aggOp.getOutputExpressions(), AggregateFunction.class::isInstance));
        if (functions.size() != 1) {
            return false;
        }
        return functions.stream().allMatch(f -> f instanceof SupportWindowAnalytic && !f.isDistinct());
    }

    /**
     * check inner scope only have one filter. and inner filter is a sub collection of outer filter
     */
    private boolean checkFilter(LogicalFilter<? extends Plan> outerFilter) {
        List<LogicalFilter<Plan>> innerFilters = innerPlans.stream()
                .filter(LogicalFilter.class::isInstance)
                .map(p -> (LogicalFilter<Plan>) p).collect(Collectors.toList());
        if (innerFilters.size() != 1) {
            return false;
        }
        Set<Expression> outerConjunctSet = Sets.newHashSet(outerFilter.getConjuncts());
        Set<Expression> innerConjunctSet = innerFilters.get(0).getConjuncts().stream()
                .map(e -> ExpressionUtils.replace(e, innerOuterSlotMap))
                .collect(Collectors.toSet());
        Iterator<Expression> innerIterator = innerConjunctSet.iterator();
        // inner predicate should be the sub-set of outer predicate.
        while (innerIterator.hasNext()) {
            Expression innerExpr = innerIterator.next();
            Iterator<Expression> outerIterator = outerConjunctSet.iterator();
            while (outerIterator.hasNext()) {
                Expression outerExpr = outerIterator.next();
                if (ExpressionIdenticalChecker.INSTANCE.check(innerExpr, outerExpr)) {
                    innerIterator.remove();
                    outerIterator.remove();
                }
            }
        }
        // now the expressions are all like 'expr op literal' or flipped, and whose expr is not correlated.
        return innerConjunctSet.isEmpty();
    }

    /**
     * check join to ensure no condition on it.
     * this is because we cannot do accurate pattern match between outer scope and inner scope
     * so, we currently forbid join with condition here.
     */
    private boolean checkJoin() {
        return outerPlans.stream()
                .filter(LogicalJoin.class::isInstance)
                .map(p -> (LogicalJoin<Plan, Plan>) p)
                .noneMatch(j -> j.getOnClauseCondition().isPresent())
                && innerPlans.stream()
                .filter(LogicalJoin.class::isInstance)
                .map(p -> (LogicalJoin<Plan, Plan>) p)
                .noneMatch(j -> j.getOnClauseCondition().isPresent());
    }

    /**
     * check inner and outer project to ensure no project except column pruning
     */
    private boolean checkProject() {
        return outerPlans.stream()
                .filter(LogicalProject.class::isInstance)
                .map(p -> (LogicalProject<Plan>) p)
                .allMatch(p -> p.getExpressions().stream().allMatch(SlotReference.class::isInstance))
                && innerPlans.stream()
                .filter(LogicalProject.class::isInstance)
                .map(p -> (LogicalProject<Plan>) p)
                .allMatch(p -> p.getExpressions().stream().allMatch(SlotReference.class::isInstance));
    }

    /**
     * check inner and outer relation
     * 1. outer table size - inner table size must equal to 1
     * 2. outer table list - inner table list should only remain 1 table
     * 3. the remaining table in step 2 should be correlated table for inner plan
     */
    private boolean checkRelation(List<Expression> correlatedSlots) {
        List<CatalogRelation> outerTables = outerPlans.stream().filter(CatalogRelation.class::isInstance)
                .map(CatalogRelation.class::cast)
                .collect(Collectors.toList());
        List<CatalogRelation> innerTables = innerPlans.stream().filter(CatalogRelation.class::isInstance)
                .map(CatalogRelation.class::cast)
                .collect(Collectors.toList());

        List<Long> outerIds = outerTables.stream().map(node -> node.getTable().getId()).collect(Collectors.toList());
        List<Long> innerIds = innerTables.stream().map(node -> node.getTable().getId()).collect(Collectors.toList());
        if (Sets.newHashSet(outerIds).size() != outerIds.size()
                || Sets.newHashSet(innerIds).size() != innerIds.size()) {
            return false;
        }
        if (outerIds.size() - innerIds.size() != 1) {
            return false;
        }
        innerIds.forEach(outerIds::remove);
        if (outerIds.size() != 1) {
            return false;
        }

        createSlotMapping(outerTables, innerTables);

        Set<ExprId> correlatedRelationOutput = outerTables.stream()
                .filter(node -> outerIds.contains(node.getTable().getId()))
                .map(LogicalRelation.class::cast)
                .map(LogicalRelation::getOutputExprIdSet).flatMap(Collection::stream).collect(Collectors.toSet());
        return ExpressionUtils.collect(correlatedSlots, NamedExpression.class::isInstance).stream()
                .map(NamedExpression.class::cast)
                .allMatch(e -> correlatedRelationOutput.contains(e.getExprId()));
    }

    private void createSlotMapping(List<CatalogRelation> outerTables, List<CatalogRelation> innerTables) {
        for (CatalogRelation outerTable : outerTables) {
            for (CatalogRelation innerTable : innerTables) {
                if (innerTable.getTable().getId() == outerTable.getTable().getId()) {
                    for (Slot innerSlot : innerTable.getOutput()) {
                        for (Slot outerSlot : outerTable.getOutput()) {
                            if (innerSlot.getName().equals(outerSlot.getName())) {
                                innerOuterSlotMap.put(innerSlot, outerSlot);
                                break;
                            }
                        }
                    }
                    break;
                }
            }
        }
    }

    private Plan rewrite(LogicalFilter<? extends Plan> filter, LogicalApply<Plan, Plan> apply) {
        Preconditions.checkArgument(apply.right() instanceof LogicalAggregate,
                "right child of Apply should be LogicalAggregate");
        LogicalAggregate<Plan> agg = (LogicalAggregate<Plan>) apply.right();

        // transform algorithm
        // first: find the slot in outer scope corresponding to the slot in aggregate function in inner scope.
        // second: find the aggregation function in inner scope, and replace it to window function, and the aggregate
        // slot is the slot in outer scope in the first step.
        // third: the expression containing aggregation function in inner scope will be the child of an alias,
        // so in the predicate between outer and inner, we change the alias to expression which is the alias's child,
        // and change the aggregation function to the alias of window function.

        // for example, in tpc-h Q17
        // window filter conjuncts is
        // cast(l_quantity#id1 as decimal(27, 9)) < `0.2 * avg(l_quantity)`#id2
        // and
        // 0.2 * avg(l_quantity#id3) as `0.2 * l_quantity`#id2
        // is aggregate's output expression
        // we change it to
        // cast(l_quantity#id1 as decimal(27, 9)) < 0.2 * `avg(l_quantity#id1) over(window)`#id4
        // and
        // avg(l_quantity#id1) over(window) as `avg(l_quantity#id1) over(window)`#id4

        // it's a simple case, but we may meet some complex cases in ut.
        // TODO: support compound predicate and multi apply node.

        Map<Boolean, Set<Expression>> conjuncts = filter.getConjuncts().stream()
                .collect(Collectors.groupingBy(conjunct -> Sets
                        .intersection(conjunct.getInputSlotExprIds(), agg.getOutputExprIdSet())
                        .isEmpty(), Collectors.toSet()));
        Set<Expression> correlatedConjuncts = conjuncts.get(false);
        if (correlatedConjuncts.isEmpty() || correlatedConjuncts.size() > 1
                || !(correlatedConjuncts.iterator().next() instanceof ComparisonPredicate)) {
            //TODO: only support simple comparison predicate now
            return filter;
        }
        Expression windowFilterConjunct = correlatedConjuncts.iterator().next();
        windowFilterConjunct = PlanUtils.maybeCommuteComparisonPredicate(
                (ComparisonPredicate) windowFilterConjunct, apply.left());

        AggregateFunction function = functions.get(0);
        if (function instanceof NullableAggregateFunction) {
            // adjust agg function's nullable.
            function = ((NullableAggregateFunction) function).withAlwaysNullable(false);
        }

        WindowExpression windowFunction = createWindowFunction(apply.getCorrelationSlot(),
                (AggregateFunction) ExpressionUtils.replace(function, innerOuterSlotMap));
        NamedExpression windowFunctionAlias = new Alias(windowFunction);

        // build filter conjunct, get the alias of the agg output and extract its child.
        // then replace the agg to window function, then build conjunct
        // we ensure aggOut is Alias.
        NamedExpression aggOut = agg.getOutputExpressions().get(0);
        Expression aggOutExpr = aggOut.child(0);
        // change the agg function to window function alias.
        aggOutExpr = ExpressionUtils.replace(aggOutExpr, ImmutableMap
                .of(functions.get(0), windowFunctionAlias.toSlot()));

        windowFilterConjunct = ExpressionUtils.replace(windowFilterConjunct,
                ImmutableMap.of(aggOut.toSlot(), aggOutExpr));

        LogicalFilter<Plan> newFilter = filter.withConjunctsAndChild(conjuncts.get(true), apply.left());
        LogicalWindow<Plan> newWindow = new LogicalWindow<>(ImmutableList.of(windowFunctionAlias), newFilter);
        LogicalFilter<Plan> windowFilter = new LogicalFilter<>(ImmutableSet.of(windowFilterConjunct), newWindow);
        return windowFilter;
    }

    private WindowExpression createWindowFunction(List<Expression> correlatedSlots, AggregateFunction function) {
        // partition by clause is set by all the correlated slots.
        Preconditions.checkArgument(correlatedSlots.stream().allMatch(Slot.class::isInstance));
        return new WindowExpression(function, correlatedSlots, Collections.emptyList());
    }

    private static class ExpressionIdenticalChecker extends DefaultExpressionVisitor<Boolean, Expression> {
        public static final ExpressionIdenticalChecker INSTANCE = new ExpressionIdenticalChecker();

        public boolean check(Expression expression, Expression expression1) {
            return expression.accept(this, expression1);
        }

        private boolean isClassMatch(Object o1, Object o2) {
            return o1.getClass().equals(o2.getClass());
        }

        private boolean isSameChild(Expression expression, Expression expression1) {
            if (expression.children().size() != expression1.children().size()) {
                return false;
            }
            for (int i = 0; i < expression.children().size(); ++i) {
                if (!expression.children().get(i).accept(this, expression1.children().get(i))) {
                    return false;
                }
            }
            return true;
        }

        @Override
        public Boolean visit(Expression expression, Expression expression1) {
            return isClassMatch(expression, expression1) && isSameChild(expression, expression1);
        }

        @Override
        public Boolean visitSlotReference(SlotReference slotReference, Expression other) {
            return slotReference.equals(other);
        }

        @Override
        public Boolean visitLiteral(Literal literal, Expression other) {
            return literal.equals(other);
        }

        @Override
        public Boolean visitComparisonPredicate(ComparisonPredicate cp, Expression other) {
            return cp.equals(other) || cp.commute().equals(other);
        }
    }
}