EliminateOuterJoin.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.TypeUtils;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSet.Builder;
import com.google.common.collect.Sets;

import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;

/**
 * Eliminate outer join.
 */
public class EliminateOuterJoin extends OneRewriteRuleFactory {

    @Override
    public Rule build() {
        return logicalFilter(
                logicalJoin().when(join -> join.getJoinType().isOuterJoin())
        ).then(filter -> {
            LogicalJoin<Plan, Plan> join = filter.child();

            Builder<Expression> conjunctsBuilder = ImmutableSet.builder();
            Set<Slot> notNullSlots = new HashSet<>();
            for (Expression predicate : filter.getConjuncts()) {
                Optional<Slot> notNullSlot = TypeUtils.isNotNull(predicate);
                if (notNullSlot.isPresent()) {
                    notNullSlots.add(notNullSlot.get());
                } else {
                    conjunctsBuilder.add(predicate);
                }
            }
            boolean canFilterLeftNull = Utils.isIntersecting(join.left().getOutputSet(), notNullSlots);
            boolean canFilterRightNull = Utils.isIntersecting(join.right().getOutputSet(), notNullSlots);
            if (!canFilterRightNull && !canFilterLeftNull) {
                return null;
            }

            JoinType newJoinType = tryEliminateOuterJoin(join.getJoinType(), canFilterLeftNull, canFilterRightNull);
            Set<Expression> conjuncts = Sets.newHashSet();
            conjuncts.addAll(filter.getConjuncts());
            boolean conjunctsChanged = false;
            if (!notNullSlots.isEmpty()) {
                for (Slot slot : notNullSlots) {
                    Not isNotNull = new Not(new IsNull(slot), true);
                    conjunctsChanged |= conjuncts.add(isNotNull);
                }
            }
            if (newJoinType.isInnerJoin()) {
                /*
                 * for example: (A left join B on A.a=B.b) join C on B.x=C.x
                 * inner join condition B.x=C.x implies 'B.x is not null',
                 * by which the left outer join could be eliminated. Finally, the join transformed to
                 * (A join B on A.a=B.b) join C on B.x=C.x.
                 * This elimination can be processed recursively.
                 *
                 * TODO: is_not_null can also be inferred from A < B and so on
                 */
                conjunctsChanged |= join.getEqualToConjuncts().stream()
                        .map(EqualTo.class::cast)
                        .map(equalTo -> JoinUtils.swapEqualToForChildrenOrder(equalTo, join.left().getOutputSet()))
                        .anyMatch(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts));

                JoinUtils.JoinSlotCoverageChecker checker = new JoinUtils.JoinSlotCoverageChecker(
                        join.left().getOutput(),
                        join.right().getOutput());
                conjunctsChanged |= join.getOtherJoinConjuncts().stream()
                        .filter(EqualTo.class::isInstance)
                        .filter(equalTo -> checker.isHashJoinCondition((EqualPredicate) equalTo))
                        .map(equalTo -> JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) equalTo,
                                join.left().getOutputSet()))
                        .anyMatch(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts));
            }
            if (conjunctsChanged) {
                return filter.withConjuncts(conjuncts.stream().collect(ImmutableSet.toImmutableSet()))
                        .withChildren(join.withJoinTypeAndContext(newJoinType, join.getJoinReorderContext()));
            }
            return filter.withChildren(join.withJoinTypeAndContext(newJoinType, join.getJoinReorderContext()));
        }).toRule(RuleType.ELIMINATE_OUTER_JOIN);
    }

    private JoinType tryEliminateOuterJoin(JoinType joinType, boolean canFilterLeftNull, boolean canFilterRightNull) {
        if (joinType.isRightOuterJoin() && canFilterLeftNull) {
            return JoinType.INNER_JOIN;
        }
        if (joinType.isLeftOuterJoin() && canFilterRightNull) {
            return JoinType.INNER_JOIN;
        }
        if (joinType.isFullOuterJoin() && canFilterLeftNull && canFilterRightNull) {
            return JoinType.INNER_JOIN;
        }
        if (joinType.isFullOuterJoin() && canFilterLeftNull) {
            return JoinType.LEFT_OUTER_JOIN;
        }
        if (joinType.isFullOuterJoin() && canFilterRightNull) {
            return JoinType.RIGHT_OUTER_JOIN;
        }
        return joinType;
    }

    private boolean createIsNotNullIfNecessary(EqualPredicate swapedEqualTo, Collection<Expression> container) {
        boolean containerChanged = false;
        if (swapedEqualTo.left().nullable()) {
            Not not = new Not(new IsNull(swapedEqualTo.left()), true);
            containerChanged |= container.add(not);
        }
        if (swapedEqualTo.right().nullable()) {
            Not not = new Not(new IsNull(swapedEqualTo.right()), true);
            containerChanged |= container.add(not);
        }
        return containerChanged;
    }
}