ReorderJoin.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.hint.DistributeHint;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.DistributeType;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.Utils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Try to eliminate cross join via finding join conditions in filters and change the join orders.
 * <p>
 * <pre>
 * For example:
 *
 * input:
 * SELECT * FROM t1, t2, t3 WHERE t1.id=t3.id AND t2.id=t3.id
 *
 * output:
 * SELECT * FROM t1 JOIN t3 ON t1.id=t3.id JOIN t2 ON t2.id=t3.id
 * </pre>
 * </p>
 * Using the {@link MultiJoin} to complete this task.
 * {Join cluster}: contain multiple join with filter inside.
 * <ul>
 * <li> {Join cluster} to MultiJoin</li>
 * <li> MultiJoin to {Join cluster}</li>
 * </ul>
 */
@DependsRules({
    MergeFilters.class
})
public class ReorderJoin extends OneRewriteRuleFactory {
    @Override
    public Rule build() {
        return logicalFilter(subTree(LogicalJoin.class, LogicalFilter.class))
            .whenNot(filter -> filter.child() instanceof LogicalJoin
                    && ((LogicalJoin<?, ?>) filter.child()).isMarkJoin())
            .thenApply(ctx -> {
                if (ctx.statementContext.getConnectContext().getSessionVariable().isDisableJoinReorder()
                        || ctx.cascadesContext.isLeadingDisableJoinReorder()
                        || ((LogicalJoin<?, ?>) ctx.root.child()).isLeadingJoin()) {
                    return null;
                }
                LogicalFilter<Plan> filter = ctx.root;

                Map<Plan, DistributeHint> planToHintType = Maps.newHashMap();
                Plan plan = joinToMultiJoin(filter, planToHintType);
                Preconditions.checkState(plan instanceof MultiJoin, "join to multi join should return MultiJoin,"
                        + " but return plan is " + plan.getType());
                MultiJoin multiJoin = (MultiJoin) plan;
                ctx.statementContext.addJoinFilters(multiJoin.getJoinFilter());
                ctx.statementContext.setMaxNAryInnerJoin(multiJoin.children().size());
                Plan after = multiJoinToJoin(multiJoin, planToHintType);
                return after;
            }).toRule(RuleType.REORDER_JOIN);
    }

    /**
     * Recursively convert to
     * {@link LogicalJoin} or {@link LogicalFilter}--{@link LogicalJoin}
     * --> {@link MultiJoin}
     */
    public Plan joinToMultiJoin(Plan plan, Map<Plan, DistributeHint> planToHintType) {
        // subtree can't specify the end of Pattern. so end can be GroupPlan or Filter
        if (nonJoinAndNonFilter(plan)
                || (plan instanceof LogicalFilter && nonJoinAndNonFilter(plan.child(0)))) {
            return plan;
        }

        List<Plan> inputs = Lists.newArrayList();
        List<Expression> joinFilter = Lists.newArrayList();
        List<Expression> notInnerJoinConditions = Lists.newArrayList();

        LogicalJoin<?, ?> join;
        // Implicit rely on {rule: MergeFilters}, so don't exist filter--filter--join.
        if (plan instanceof LogicalFilter) {
            LogicalFilter<?> filter = (LogicalFilter<?>) plan;
            joinFilter.addAll(filter.getConjuncts());
            join = (LogicalJoin<?, ?>) filter.child();
        } else {
            join = (LogicalJoin<?, ?>) plan;
        }

        if (join.getJoinType().isInnerOrCrossJoin()) {
            joinFilter.addAll(join.getHashJoinConjuncts());
            joinFilter.addAll(join.getOtherJoinConjuncts());
        } else {
            notInnerJoinConditions.addAll(join.getHashJoinConjuncts());
            notInnerJoinConditions.addAll(join.getOtherJoinConjuncts());
        }

        // recursively convert children.
        planToHintType.put(join.left(), new DistributeHint(DistributeType.NONE));
        Plan left = joinToMultiJoin(join.left(), planToHintType);
        planToHintType.put(join.right(), join.getDistributeHint());
        Plan right = joinToMultiJoin(join.right(), planToHintType);

        boolean changeLeft = join.getJoinType().isRightJoin()
                || join.getJoinType().isFullOuterJoin();
        if (canCombine(left, changeLeft)) {
            MultiJoin leftMultiJoin = (MultiJoin) left;
            inputs.addAll(leftMultiJoin.children());
            joinFilter.addAll(leftMultiJoin.getJoinFilter());
        } else {
            inputs.add(left);
        }

        boolean changeRight = join.getJoinType().isLeftJoin()
                || join.getJoinType().isFullOuterJoin();
        if (canCombine(right, changeRight)) {
            MultiJoin rightMultiJoin = (MultiJoin) right;
            inputs.addAll(rightMultiJoin.children());
            joinFilter.addAll(rightMultiJoin.getJoinFilter());
        } else {
            inputs.add(right);
        }

        return new MultiJoin(
                inputs,
                joinFilter,
                join.getJoinType(),
                notInnerJoinConditions);
    }

    /**
     * Recursively convert to
     * {@link MultiJoin}
     * -->
     * {@link LogicalJoin} or
     * {@link LogicalFilter}--{@link LogicalJoin}
     * <p>
     * When all input is CROSS/Inner Join, all join will be flattened.
     * Otherwise, we will split {join cluster} into multiple {@link MultiJoin}.
     * <p>
     * Here are examples of the {@link MultiJoin}s constructed after this rules has been applied.
     * <p>
     * simple example:
     * <ul>
     * <li>A JOIN B --> MJ(A, B)
     * <li>A JOIN B JOIN C JOIN D --> MJ(A, B, C, D)
     * <li>A LEFT (OUTER/SEMI/ANTI) JOIN B --> MJ([LOJ/LSJ/LAJ]A, B)
     * <li>A LEFT (OUTER/SEMI/ANTI) JOIN B --> MJ([ROJ/RSJ/RAJ]A, B)
     * <li>A FULL JOIN B --> MJ[FOJ](A, B)
     * </ul>
     * </p>
     * <p>
     * complex example:
     * <ul>
     * <li>A LEFT OUTER JOIN (B JOIN C) --> MJ([LOJ]A, MJ(B, C)))
     * <li>(A JOIN B) LEFT JOIN C --> MJ(A, B, C)
     * <li>(A LEFT OUTER JOIN B) JOIN C --> MJ(MJ(A, B), C)
     * <li>A LEFT JOIN (B FULL JOIN C) --> MJ(A, MJ[full](B, C))
     * <li>(A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) --> MJ[full](MJ(A, B), MJ(C, D))
     * </ul>
     * </p>
     * more complex example:
     * <ul>
     * <li> A JOIN B JOIN C LEFT JOIN D --> MJ([LOJ]A, B, C, D)
     * <li> A JOIN B JOIN C LEFT JOIN (D JOIN F) --> MJ([LOJ]A, B, C, MJ(D, F))
     * <li> A RIGHT JOIN (B JOIN C JOIN D)--> MJ([ROJ]A, B, C, D)
     * <li> A JOIN B RIGHT JOIN (C JOIN D) --> MJ(A, B, MJ([ROJ]C, D))
     * </ul>
     * </p>
     * Graphic presentation:
     * <pre>
     * A JOIN B JOIN C LEFT JOIN D JOIN F
     *      left                  left│
     * A  B  C  D  F   ──►   A  B  C  │ D  F   ──►  MJ(LOJ A,B,C,MJ(DF)
     *
     * A JOIN B RIGHT JOIN C JOIN D JOIN F
     *     right                  │right
     * A  B  C  D  F   ──►   A  B │  C  D  F   ──►  MJ(A,B,MJ(ROJ C,D,F)
     *
     * (A JOIN B JOIN C) FULL JOIN (D JOIN F)
     *       full                    │
     * A  B  C  D  F   ──►   A  B  C │ D  F    ──►  MJ(FOJ MJ(A,B,C) MJ(D,F))
     * </pre>
     */
    public Plan multiJoinToJoin(MultiJoin multiJoin, Map<Plan, DistributeHint> planToHintType) {
        if (multiJoin.arity() == 1) {
            return PlanUtils.filterOrSelf(ImmutableSet.copyOf(multiJoin.getJoinFilter()), multiJoin.child(0));
        }

        Builder<Plan> builder = ImmutableList.builder();
        // recursively handle multiJoin children.
        for (Plan child : multiJoin.children()) {
            if (child instanceof MultiJoin) {
                MultiJoin childMultiJoin = (MultiJoin) child;
                builder.add(multiJoinToJoin(childMultiJoin, planToHintType));
            } else {
                builder.add(child);
            }
        }
        MultiJoin multiJoinHandleChildren = multiJoin.withChildren(builder.build());

        if (!multiJoinHandleChildren.getJoinType().isInnerOrCrossJoin()) {
            List<Expression> remainingFilter;

            Plan left;
            Plan right;
            if (multiJoinHandleChildren.getJoinType().isLeftJoin()) {
                right = multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1);
                Set<ExprId> rightOutputExprIdSet = right.getOutputExprIdSet();
                Map<Boolean, List<Expression>> split = multiJoin.getJoinFilter().stream()
                        .collect(Collectors.partitioningBy(expr ->
                                Utils.isIntersecting(rightOutputExprIdSet, expr.getInputSlotExprIds())
                        ));
                remainingFilter = split.get(true);
                List<Expression> pushedFilter = split.get(false);
                left = multiJoinToJoin(new MultiJoin(
                        multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
                        pushedFilter,
                        JoinType.INNER_JOIN,
                        ExpressionUtils.EMPTY_CONDITION), planToHintType);
            } else if (multiJoinHandleChildren.getJoinType().isRightJoin()) {
                left = multiJoinHandleChildren.child(0);
                Set<ExprId> leftOutputExprIdSet = left.getOutputExprIdSet();
                Map<Boolean, List<Expression>> split = multiJoin.getJoinFilter().stream()
                        .collect(Collectors.partitioningBy(expr ->
                                Utils.isIntersecting(leftOutputExprIdSet, expr.getInputSlotExprIds())
                        ));
                remainingFilter = split.get(true);
                List<Expression> pushedFilter = split.get(false);
                right = multiJoinToJoin(new MultiJoin(
                        multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
                        pushedFilter,
                        JoinType.INNER_JOIN,
                        ExpressionUtils.EMPTY_CONDITION), planToHintType);
            } else {
                remainingFilter = multiJoin.getJoinFilter();
                Preconditions.checkState(multiJoinHandleChildren.arity() == 2);
                List<Plan> children = multiJoinHandleChildren.children().stream().map(child -> {
                    if (child instanceof MultiJoin) {
                        return multiJoinToJoin((MultiJoin) child, planToHintType);
                    } else {
                        return child;
                    }
                }).collect(Collectors.toList());
                left = children.get(0);
                right = children.get(1);
            }

            return PlanUtils.filterOrSelf(ImmutableSet.copyOf(remainingFilter), new LogicalJoin<>(
                    multiJoinHandleChildren.getJoinType(),
                    ExpressionUtils.EMPTY_CONDITION, multiJoinHandleChildren.getNotInnerJoinConditions(),
                    planToHintType.getOrDefault(right, new DistributeHint(DistributeType.NONE)),
                    Optional.empty(),
                    left, right, null));
        }

        // following this multiJoin just contain INNER/CROSS.
        Set<Expression> joinFilter = new LinkedHashSet<>(multiJoinHandleChildren.getJoinFilter());

        Plan left = multiJoinHandleChildren.child(0);
        Set<Integer> usedPlansIndex = new LinkedHashSet<>();
        usedPlansIndex.add(0);

        while (usedPlansIndex.size() != multiJoinHandleChildren.children().size()) {
            LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, multiJoinHandleChildren.children(),
                    joinFilter, usedPlansIndex, planToHintType);
            join.getHashJoinConjuncts().forEach(joinFilter::remove);
            join.getOtherJoinConjuncts().forEach(joinFilter::remove);

            left = join;
        }

        return PlanUtils.filterOrSelf(joinFilter, left);
    }

    /**
     * Returns whether an input can be merged without changing semantics.
     *
     * @param input input into a MultiJoin or (GroupPlan|LogicalFilter)
     * @param changeChildren generate nullable or one side child not exist.
     * @return true if the input can be combined into a parent MultiJoin
     */
    private static boolean canCombine(Plan input, boolean changeChildren) {
        return input instanceof MultiJoin
                && ((MultiJoin) input).getJoinType().isInnerOrCrossJoin()
                && !changeChildren;
    }

    /**
     * Find hash condition from joinFilter
     * Get InnerJoin from left, right from [candidates].
     *
     * @return InnerJoin or CrossJoin{left, last of [candidates]}
     */
    private LogicalJoin<? extends Plan, ? extends Plan> findInnerJoin(Plan left, List<Plan> candidates,
            Set<Expression> joinFilter, Set<Integer> usedPlansIndex, Map<Plan, DistributeHint> planToHintType) {
        List<Expression> firstOtherJoinConditions = ExpressionUtils.EMPTY_CONDITION;
        int firstCandidate = -1;
        Set<ExprId> leftOutputExprIdSet = left.getOutputExprIdSet();
        for (int i = 0; i < candidates.size(); i++) {
            if (usedPlansIndex.contains(i)) {
                continue;
            }

            Plan candidate = candidates.get(i);
            Set<ExprId> rightOutputExprIdSet = candidate.getOutputExprIdSet();

            Set<ExprId> joinOutputExprIdSet = JoinUtils.getJoinOutputExprIdSet(left, candidate);

            List<Expression> currentJoinFilter = joinFilter.stream()
                    .filter(expr -> {
                        Set<ExprId> exprInputSlotExprIds = expr.getInputSlotExprIds();
                        return !leftOutputExprIdSet.containsAll(exprInputSlotExprIds)
                                && !rightOutputExprIdSet.containsAll(exprInputSlotExprIds)
                                && joinOutputExprIdSet.containsAll(exprInputSlotExprIds);
                    }).collect(Collectors.toList());

            Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
                    left.getOutput(), candidate.getOutput(), currentJoinFilter);
            List<Expression> hashJoinConditions = pair.first;
            if (!hashJoinConditions.isEmpty()) {
                usedPlansIndex.add(i);
                return new LogicalJoin<>(JoinType.INNER_JOIN,
                        hashJoinConditions, pair.second,
                        planToHintType.getOrDefault(candidate, new DistributeHint(DistributeType.NONE)),
                        Optional.empty(),
                        left, candidate, null);
            } else {
                if (firstCandidate == -1) {
                    firstCandidate = i;
                    firstOtherJoinConditions = pair.second;
                }
            }
        }
        // All { left -> one in [candidates] } is CrossJoin
        // Generate a CrossJoin
        // NOTICE: we must traverse for head to tail to ensure result is stable.
        usedPlansIndex.add(firstCandidate);
        Plan right = candidates.get(firstCandidate);
        return new LogicalJoin<>(JoinType.CROSS_JOIN,
                ExpressionUtils.EMPTY_CONDITION,
                firstOtherJoinConditions,
                planToHintType.getOrDefault(right, new DistributeHint(DistributeType.NONE)),
                Optional.empty(),
                left, right, null);
    }

    private boolean nonJoinAndNonFilter(Plan plan) {
        return !(plan instanceof LogicalJoin) && !(plan instanceof LogicalFilter);
    }
}