PushDownProjectThroughInnerOuterJoin.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.exploration.join;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.ExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Rule for pushdown project through inner/outer join
 * Just push down project inside join to avoid to push the top of Join-Cluster.
 * <pre>
 *    Project                   Join
 *      |            ──►       /    \
 *     Join               Project  Project
 *    /   \                  |       |
 *   A     B                 A       B
 * </pre>
 */
public class PushDownProjectThroughInnerOuterJoin implements ExplorationRuleFactory {
    public static final PushDownProjectThroughInnerOuterJoin INSTANCE = new PushDownProjectThroughInnerOuterJoin();

    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
                logicalJoin(logicalProject(logicalJoin()), group())
                        .when(j -> j.left().child().getJoinType().isOuterJoin()
                                || j.left().child().getJoinType().isInnerJoin())
                        // Just pushdown project with non-column expr like (t.id + 1)
                        .whenNot(j -> j.left().isAllSlots())
                        .whenNot(j -> j.left().child().hasDistributeHint())
                        .then(topJoin -> {
                            LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topJoin.left();
                            Plan newLeft = pushdownProject(project);
                            if (newLeft == null) {
                                return null;
                            }
                            return topJoin.withChildren(newLeft, topJoin.right());
                        }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_OUTER_JOIN_LEFT),
                logicalJoin(group(), logicalProject(logicalJoin()))
                        .when(j -> j.right().child().getJoinType().isOuterJoin()
                                || j.right().child().getJoinType().isInnerJoin())
                        // Just pushdown project with non-column expr like (t.id + 1)
                        .whenNot(j -> j.right().isAllSlots())
                        .whenNot(j -> j.right().child().hasDistributeHint())
                        .then(topJoin -> {
                            LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topJoin.right();
                            Plan newRight = pushdownProject(project);
                            if (newRight == null) {
                                return null;
                            }
                            return topJoin.withChildren(topJoin.left(), newRight);
                        }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_OUTER_JOIN_RIGHT)
        );
    }

    private Plan pushdownProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
        LogicalJoin<GroupPlan, GroupPlan> join = project.child();
        Set<ExprId> aOutputExprIdSet = join.left().getOutputExprIdSet();
        Set<ExprId> bOutputExprIdSet = join.right().getOutputExprIdSet();

        // reject hyper edge in Project.
        if (!project.getProjects().stream().allMatch(expr -> {
            Set<ExprId> inputSlotExprIds = expr.getInputSlotExprIds();
            return aOutputExprIdSet.containsAll(inputSlotExprIds)
                    || bOutputExprIdSet.containsAll(inputSlotExprIds);
        })) {
            return null;
        }

        List<NamedExpression> aProjects = new ArrayList<>();
        List<NamedExpression> bProjects = new ArrayList<>();
        List<NamedExpression> projects;
        if (join.getJoinType().isInnerJoin()) {
            projects = project.getProjects();
        } else {
            Map<Slot, Slot> childrenSlots = new HashMap<>();
            join.left().getOutputSet().forEach(slot -> childrenSlots.put(slot, slot));
            join.right().getOutputSet().forEach(slot -> childrenSlots.put(slot, slot));
            join.getOutputSet().forEach(slot -> {
                if (childrenSlots.containsKey(slot)) {
                    childrenSlots.put(slot, childrenSlots.get(slot));
                }
            });

            projects = project.getProjects().stream().map(expr -> expr.rewriteUp(e ->
                    e instanceof Slot ? childrenSlots.get((Slot) e) : e
            )).map(e -> (NamedExpression) e).collect(Collectors.toList());
        }
        for (NamedExpression namedExpression : projects) {
            Set<ExprId> usedExprIds = namedExpression.getInputSlotExprIds();
            if (aOutputExprIdSet.containsAll(usedExprIds)) {
                aProjects.add(namedExpression);
            } else {
                bProjects.add(namedExpression);
            }
        }

        boolean leftContains = aProjects.stream().anyMatch(e -> !(e instanceof Slot));
        boolean rightContains = bProjects.stream().anyMatch(e -> !(e instanceof Slot));
        // due to JoinCommute, we don't need to consider just right contains.
        if (!leftContains) {
            return null;
        }
        // we could not push nullable side project
        if (((join.getJoinType().isLeftOuterJoin() || join.getJoinType().isFullOuterJoin()) && rightContains)
                || ((join.getJoinType().isRightOuterJoin() || join.getJoinType().isFullOuterJoin()) && leftContains)) {
            return null;
        }

        Builder<NamedExpression> newAProject = ImmutableList.<NamedExpression>builder().addAll(aProjects);
        Set<Slot> aConditionSlots = CBOUtils.joinChildConditionSlots(join, true);
        Set<Slot> aProjectSlots = aProjects.stream().map(NamedExpression::toSlot)
                .collect(Collectors.toSet());
        aConditionSlots.stream().filter(slot -> !aProjectSlots.contains(slot)).forEach(newAProject::add);
        Plan newLeft = new LogicalProject<>(newAProject.build(), join.left());

        if (!rightContains) {
            Plan newJoin = join.withChildren(newLeft, join.right());
            return new LogicalProject<>(ImmutableList.copyOf(project.getOutput()), newJoin);
        }

        Builder<NamedExpression> newBProject = ImmutableList.<NamedExpression>builder().addAll(bProjects);
        Set<Slot> bConditionSlots = CBOUtils.joinChildConditionSlots(join, false);
        Set<Slot> bProjectSlots = bProjects.stream().map(NamedExpression::toSlot)
                .collect(Collectors.toSet());
        bConditionSlots.stream().filter(slot -> !bProjectSlots.contains(slot)).forEach(newBProject::add);
        Plan newRight = new LogicalProject<>(newBProject.build(), join.right());

        Plan newJoin = join.withChildren(newLeft, newRight);
        return new LogicalProject<>(ImmutableList.copyOf(project.getOutput()), newJoin);
    }

}