FindHashConditionForJoin.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Streams;

import java.util.List;

/**
 * this rule aims to find a conjunct list from on clause expression, which could
 * be used to build hash-table.
 * <p>
 * For example:
 * A join B on A.x=B.x and A.y>1 and A.x+1=B.x+B.y and A.z=B.z+A.x and (A.z=B.z or A.x=B.x)
 * {A.x=B.x, A.x+1=B.x+B.y} could be used to build hash table,
 * but {A.y>1, A.z=B.z+A.z, (A.z=B.z or A.x=B.x)} are not.
 * <p>
 * CAUTION:
 * This rule must be applied after BindSlotReference
 */
@DependsRules({
    PushFilterInsideJoin.class
})
public class FindHashConditionForJoin extends OneRewriteRuleFactory {
    @Override
    public Rule build() {
        return logicalJoin().then(join -> {
            List<Slot> leftSlots = join.left().getOutput();
            List<Slot> rightSlots = join.right().getOutput();
            Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(leftSlots,
                    rightSlots, join.getOtherJoinConjuncts());

            List<Expression> extractedHashJoinConjuncts = pair.first;
            List<Expression> remainedNonHashJoinConjuncts = pair.second;
            if (extractedHashJoinConjuncts.isEmpty()) {
                return join;
            }

            List<Expression> combinedHashJoinConjuncts = Streams
                    .concat(join.getHashJoinConjuncts().stream(),
                            extractedHashJoinConjuncts.stream())
                    .distinct()
                    .collect(ImmutableList.toImmutableList());
            JoinType joinType = join.getJoinType();
            if (joinType == JoinType.CROSS_JOIN && !combinedHashJoinConjuncts.isEmpty()) {
                joinType = JoinType.INNER_JOIN;
            }
            return new LogicalJoin<>(joinType,
                    combinedHashJoinConjuncts,
                    remainedNonHashJoinConjuncts,
                    join.getMarkJoinConjuncts(),
                    join.getDistributeHint(),
                    join.getMarkJoinSlotReference(),
                    join.children(), join.getJoinReorderContext());
        }).toRule(RuleType.FIND_HASH_CONDITION_FOR_JOIN);
    }
}