PushDownTopNDistinctThroughJoin.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Push down TopN-Distinct through Outer Join into left child .....
 */
public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory {

    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
                // topN -> join
                logicalTopN(logicalAggregate(logicalJoin()).when(LogicalAggregate::isDistinct))
                        // TODO: complex order by
                        .when(topn ->
                                ConnectContext.get() != null
                                        && ConnectContext.get().getSessionVariable().topnOptLimitThreshold
                                        >= topn.getLimit() + topn.getOffset())
                        .when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
                                .allMatch(Slot.class::isInstance))
                        .then(topN -> {
                            LogicalAggregate<LogicalJoin<Plan, Plan>> distinct = topN.child();
                            LogicalJoin<Plan, Plan> join = distinct.child();
                            Plan newJoin = pushTopNThroughJoin(topN, join);
                            if (newJoin == null || join.children().equals(newJoin.children())) {
                                return null;
                            }
                            return topN.withChildren(distinct.withChildren(newJoin));
                        })
                        .toRule(RuleType.PUSH_DOWN_TOP_N_DISTINCT_THROUGH_JOIN),

                // topN -> project -> join
                logicalTopN(logicalAggregate(logicalProject(logicalJoin()).when(LogicalProject::isAllSlots))
                        .when(LogicalAggregate::isDistinct))
                        .when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
                                .allMatch(Slot.class::isInstance))
                        .then(topN -> {
                            LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> distinct = topN.child();
                            LogicalProject<LogicalJoin<Plan, Plan>> project = distinct.child();
                            LogicalJoin<Plan, Plan> join = project.child();

                            // If orderby exprs aren't all in the output of the project, we can't push down.
                            // topN(order by: slot(a+1))
                            // - project(a+1, b)
                            // TODO: in the future, we also can push down it.
                            Set<Slot> outputSet = project.child().getOutputSet();
                            if (!topN.getOrderKeys().stream().map(OrderKey::getExpr)
                                    .flatMap(e -> e.getInputSlots().stream())
                                    .allMatch(outputSet::contains)) {
                                return null;
                            }

                            Plan newJoin = pushTopNThroughJoin(topN, join);
                            if (newJoin == null || join.children().equals(newJoin.children())) {
                                return null;
                            }
                            return topN.withChildren(project.withChildren(distinct.withChildren(newJoin)));
                        }).toRule(RuleType.PUSH_DOWN_TOP_N_DISTINCT_THROUGH_PROJECT_JOIN)
        );
    }

    private Plan pushTopNThroughJoin(LogicalTopN<? extends Plan> topN, LogicalJoin<Plan, Plan> join) {
        Set<Slot> groupBySlots = ((LogicalAggregate<?>) topN.child()).getGroupByExpressions().stream()
                .flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toSet());
        switch (join.getJoinType()) {
            case LEFT_OUTER_JOIN: {
                List<OrderKey> pushedOrderKeys = getPushedOrderKeys(groupBySlots,
                        join.left().getOutputSet(), topN.getOrderKeys());
                if (!pushedOrderKeys.isEmpty()) {
                    LogicalTopN<Plan> left = topN.withLimitOrderKeyAndChild(
                            topN.getLimit() + topN.getOffset(), 0, pushedOrderKeys,
                            PlanUtils.distinct(join.left()));
                    return join.withChildren(left, join.right());
                }
                return null;
            }
            case RIGHT_OUTER_JOIN: {
                List<OrderKey> pushedOrderKeys = getPushedOrderKeys(groupBySlots,
                        join.right().getOutputSet(), topN.getOrderKeys());
                if (!pushedOrderKeys.isEmpty()) {
                    LogicalTopN<Plan> right = topN.withLimitOrderKeyAndChild(
                            topN.getLimit() + topN.getOffset(), 0, pushedOrderKeys,
                            PlanUtils.distinct(join.right()));
                    return join.withChildren(join.left(), right);
                }
                return null;
            }
            case CROSS_JOIN: {
                Plan leftChild = join.left();
                Plan rightChild = join.right();
                List<OrderKey> leftPushedOrderKeys = getPushedOrderKeys(groupBySlots,
                        join.left().getOutputSet(), topN.getOrderKeys());
                if (!leftPushedOrderKeys.isEmpty()) {
                    leftChild = topN.withLimitOrderKeyAndChild(
                            topN.getLimit() + topN.getOffset(), 0, leftPushedOrderKeys,
                            PlanUtils.distinct(join.left()));
                }
                List<OrderKey> rightPushedOrderKeys = getPushedOrderKeys(groupBySlots,
                        join.right().getOutputSet(), topN.getOrderKeys());
                if (!rightPushedOrderKeys.isEmpty()) {
                    rightChild = topN.withLimitOrderKeyAndChild(
                            topN.getLimit() + topN.getOffset(), 0, rightPushedOrderKeys,
                            PlanUtils.distinct(join.right()));
                }
                if (leftChild == join.left() && rightChild == join.right()) {
                    return null;
                } else {
                    return join.withChildren(leftChild, rightChild);
                }
            }
            default:
                // don't push limit.
                return null;
        }
    }

    /**
     * return pushed order-keys. If top-n distinct cannot be pushed, return empty list.
     */
    private List<OrderKey> getPushedOrderKeys(Set<Slot> groupBySlots, Set<Slot> joinChildSlot,
            List<OrderKey> orderKeys) {
        // NOTICE: Currently, we have implemented strict restrictions to ensure that the distinct columns is
        //   a superset of the output from the corresponding child of the join operator. In the future, we can relax
        //   this restriction and only require that there is overlap between the output of the corresponding child of
        //   the join operator and the distinct columns.
        //   However, this would require changes to the optimized plan, converting the pushed-down aggregation distinct
        //   to the window function "row number". Partition by distinct columns, and a filtering condition of
        //   "row number = 1" would be added.
        if (!groupBySlots.containsAll(joinChildSlot)) {
            return ImmutableList.of();
        }
        // we must check the order of order keys. the slot of non-join-output should not appear before join's output
        // other-wise, we will get wrong result, if we push top-n under join.
        ImmutableList.Builder<OrderKey> pushedOrderKeys = ImmutableList.builder();
        boolean notFound = false;
        for (OrderKey orderKey : orderKeys) {
            if (joinChildSlot.contains(orderKey.getExpr())) {
                if (notFound) {
                    return ImmutableList.of();
                } else {
                    pushedOrderKeys.add(orderKey);
                }
            } else {
                notFound = true;
            }
        }
        return pushedOrderKeys.build();
    }
}