SkewJoin.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.hint.DistributeHint;
import org.apache.doris.nereids.hint.JoinSkewInfo;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.StatsDerive.DeriveContext;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.DistributeType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;

import java.util.ArrayList;
import java.util.List;

/**
 * When encountering a data-skewed join, there are currently two optimization methods:
 * using salt-join or using broadcast join.
 * If we detect data skew during the RBO phase and the right table is relatively large, we will automatically add salt.
 *
 * Depends on InitJoinOrder rule
 */
public class SkewJoin extends OneRewriteRuleFactory {

    @Override
    public Rule build() {
        return logicalJoin()
                .when(join -> join.getJoinType().isOneSideOuterJoin()
                        || join.getJoinType().isInnerJoin())
                .when(join -> join.getDistributeHint().distributeType == DistributeType.NONE)
                .whenNot(LogicalJoin::isMarkJoin)
                .thenApply(SkewJoin::transform).toRule(RuleType.SALT_JOIN);
    }

    private static Plan transform(MatchingContext<LogicalJoin<Plan, Plan>> ctx) {
        if (ConnectContext.get() == null) {
            return null;
        }
        StatsDerive derive = new StatsDerive(false);

        LogicalJoin<Plan, Plan> join = ctx.root;
        Expression skewExpr = null;
        List<Expression> hotValues = new ArrayList<>();
        if (join.getHashJoinConjuncts().size() != 1) {
            return null;
        }
        AbstractPlan left = (AbstractPlan) join.left();
        if (left.getStats() == null) {
            left.accept(derive, new DeriveContext());
        }
        AbstractPlan right = (AbstractPlan) join.right();
        if (right.getStats() == null) {
            right.accept(derive, new DeriveContext());
        }

        EqualPredicate equal = (EqualPredicate) join.getHashJoinConjuncts().get(0);
        if (join.left().getOutputSet().contains(equal.right())) {
            equal = equal.commute();
        }
        if (join.getJoinType().isInnerJoin() || join.getJoinType().isLeftOuterJoin()) {
            Expression leftEqHand = equal.child(0);
            if (left.getStats().findColumnStatistics(leftEqHand) != null
                    && left.getStats().findColumnStatistics(leftEqHand).getHotValues() != null) {
                skewExpr = leftEqHand;
                hotValues.addAll(left.getStats().findColumnStatistics(leftEqHand).getHotValues().keySet());
            }
        } else if (join.getJoinType().isRightOuterJoin()) {
            Expression rightEqHand = equal.child(1);
            if (right.getStats().findColumnStatistics(rightEqHand) != null
                    && right.getStats().findColumnStatistics(rightEqHand).getHotValues() != null) {
                skewExpr = rightEqHand;
                hotValues.addAll(right.getStats().findColumnStatistics(rightEqHand).getHotValues().keySet());
            }
        } else {
            return null;
        }
        if (skewExpr == null || hotValues.isEmpty()) {
            return null;
        }

        SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
        // broadcast join for small right table
        // salt join for large right table
        if (right.getStats().getRowCount() < sessionVariable.getBroadcastRowCountLimit() / 100) {
            DistributeHint hint = new DistributeHint(DistributeType.BROADCAST_RIGHT);
            join.setHint(hint);
            return join;
        } else {
            DistributeHint hint = new DistributeHint(DistributeType.SHUFFLE_RIGHT,
                    new JoinSkewInfo(skewExpr, hotValues, false));
            join.setHint(hint);
            return SaltJoin.transform(join);
        }
    }
}