SaltJoin.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.hint.DistributeHint;
import org.apache.doris.nereids.hint.Hint.HintStatus;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.generator.ExplodeNumbers;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Random;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.DistributeType;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.jetbrains.annotations.Nullable;

import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Current capabilities and limitations of SaltJoin rewrite handling:
 * - Supports single-side skew in INNER JOIN, NOT support double-side (both tables) skew
 * - Supports left table skew and NOT support right table skew in LEFT JOIN
 * - Supports right table skew and Not support left table skew in RIGHT JOIN
 *
 * INNER JOIN and LEFT JOIN use case:
 * Applicable when left table is skewed and right table is too large for broadcast
 *
 * Here are some examples in rewrite:
 * case1:
 * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null,1,2)))
 *   +--LogicalOlapScan(t1)
 *   +--LogicalOlapScan(t2)
 * ->
 * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
 *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), DEFAULT_SALT_VALUE) AS r1))
 *   |  +--LogicalFilter(t1.a is not null)
 *   |    +--LogicalOlapScan(t1)
 *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, DEFAULT_SALT_VALUE, explodeNumber) as r2)
 *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
 *       |--LogicalGenerate(generators=[explode_numbers(1000)], generatorOutput=[explodeNumber])
 *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
 *       +--LogicalFilter(t2.a is not null)
 *         +--LogicalOlapScan(t2)
 *
 * case2:
 * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(1,2)))
 *   +--LogicalOlapScan(t1)
 *   +--LogicalOlapScan(t2)
 * ->
 * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
 *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), DEFAULT_SALT_VALUE) AS r1))
 *   | +--LogicalOlapScan(t1)
 *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, DEFAULT_SALT_VALUE, explodeNumber) as r2)
 *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
 *       |--LogicalGenerate(generators=[explode_numbers(1000)], generatorOutput=[explodeNumber])
 *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
 *       +--LogicalOlapScan(t2)
 *
 * case3: not optimize, because rows will not be output in join when join key is null
 * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null)))
 *   |--LogicalOlapScan(t1)
 *   +--LogicalOlapScan(t2)
 * ->
 * LogicalJoin(type:inner, t1.a=t2.a)
 *   |--LogicalFilter(t1.a is not null)
 *   |  +--LogicalOlapScan(t1)
 *   +--LogicalFilter(t2.a is not null)
 *     +--LogicalOlapScan(t2)
 * */
public class SaltJoin extends OneRewriteRuleFactory {
    private static final String RANDOM_COLUMN_NAME_LEFT = "r1";
    private static final String RANDOM_COLUMN_NAME_RIGHT = "r2";
    private static final String SKEW_VALUE_COLUMN_NAME = "skewValue";
    private static final String EXPLODE_NUMBER_COLUMN_NAME = "explodeColumn";
    private static final int SALT_FACTOR = 4;
    private static final int DEFAULT_SALT_VALUE = 0;

    @Override
    public Rule build() {
        return logicalJoin()
                .when(join -> join.getJoinType().isOneSideOuterJoin() || join.getJoinType().isInnerJoin())
                .when(join -> join.getDistributeHint() != null && join.getDistributeHint().getSkewInfo() != null)
                .whenNot(LogicalJoin::isMarkJoin)
                .whenNot(join -> join.getDistributeHint().isSuccessInSkewRewrite())
                .thenApply(SaltJoin::transform).toRule(RuleType.SALT_JOIN);
    }

    private static Plan transform(MatchingContext<LogicalJoin<Plan, Plan>> ctx) {
        LogicalJoin<Plan, Plan> join = ctx.root;
        DistributeHint hint = join.getDistributeHint();
        if (hint.distributeType != DistributeType.SHUFFLE_RIGHT) {
            return null;
        }
        Expression skewExpr = hint.getSkewExpr();
        if (!skewExpr.isSlot()) {
            return null;
        }
        if ((join.getJoinType().isLeftOuterJoin() || join.getJoinType().isInnerJoin())
                && !join.left().getOutput().contains((Slot) skewExpr)
                || join.getJoinType().isRightOuterJoin() && !join.right().getOutput().contains((Slot) skewExpr)) {
            return null;
        }
        int factor = getSaltFactor(ctx);
        Optional<Expression> literalType = TypeCoercionUtils.characterLiteralTypeCoercion(String.valueOf(factor),
                TinyIntType.INSTANCE);
        if (!literalType.isPresent()) {
            return null;
        }
        Expression leftSkewExpr = null;
        Expression rightSkewExpr = null;
        Expression skewConjunct = null;
        for (Expression conjunct : join.getHashJoinConjuncts()) {
            if (skewExpr.equals(conjunct.child(0)) || skewExpr.equals(conjunct.child(1))) {
                if (join.left().getOutputSet().contains((Slot) conjunct.child(0))
                        && join.right().getOutputSet().contains((Slot) conjunct.child(1))) {
                    skewConjunct = conjunct;
                } else if (join.left().getOutputSet().contains((Slot) conjunct.child(1))
                        && join.right().getOutputSet().contains((Slot) conjunct.child(0))) {
                    skewConjunct = ((ComparisonPredicate) conjunct).commute();
                } else {
                    return null;
                }
                leftSkewExpr = skewConjunct.child(0);
                rightSkewExpr = skewConjunct.child(1);
                break;
            }
        }
        if (leftSkewExpr == null || rightSkewExpr == null) {
            return null;
        }
        List<Expression> skewValues = join.getDistributeHint().getSkewValues();
        Set<Expression> skewValuesSet = new HashSet<>(skewValues);
        List<Expression> expandSideValues = getSaltedSkewValuesForExpandSide(skewConjunct, skewValuesSet);
        List<Expression> skewSideValues = getSaltedSkewValuesForSkewSide(skewConjunct, skewValuesSet, join);
        if (skewSideValues.isEmpty()) {
            return null;
        }
        DataType type = literalType.get().getDataType();
        LogicalProject<Plan> rightProject;
        LogicalProject<Plan> leftProject;
        if (join.getJoinType() == JoinType.INNER_JOIN || join.getJoinType() == JoinType.LEFT_OUTER_JOIN) {
            leftProject = addRandomSlot(leftSkewExpr, skewSideValues, join.left(), factor, type);
            rightProject = expandSkewValueRows(rightSkewExpr, expandSideValues, join.right(), factor, type);
        } else {
            leftProject = expandSkewValueRows(leftSkewExpr, expandSideValues, join.left(), factor, type);
            rightProject = addRandomSlot(rightSkewExpr, skewSideValues, join.right(), factor, type);
        }
        EqualTo saltEqual = new EqualTo(leftProject.getProjects().get(leftProject.getProjects().size() - 1).toSlot(),
                rightProject.getProjects().get(rightProject.getProjects().size() - 1).toSlot());
        saltEqual = (EqualTo) TypeCoercionUtils.processComparisonPredicate(saltEqual);
        ImmutableList.Builder<Expression> newHashJoinConjuncts = ImmutableList.builderWithExpectedSize(
                join.getHashJoinConjuncts().size() + 1);
        newHashJoinConjuncts.addAll(join.getHashJoinConjuncts());
        newHashJoinConjuncts.add(saltEqual);
        hint.setStatus(HintStatus.SUCCESS);
        hint.setSkewInfo(hint.getSkewInfo().withSuccessInSaltJoin(true));
        return new LogicalJoin<>(join.getJoinType(), newHashJoinConjuncts.build(), join.getOtherJoinConjuncts(),
                hint, leftProject, rightProject, JoinReorderContext.EMPTY);
    }

    // Add a project on top of originPlan, which includes all the original columns plus a case when column.
    private static LogicalProject<Plan> addRandomSlot(Expression skewExpr, List<Expression> skewValues,
            Plan originPlan, int factor, DataType type) {
        List<Expression> skewValuesExceptNull = skewValues.stream().filter(value -> !(value instanceof NullLiteral))
                .collect(Collectors.toList());
        Expression ifCondition = getIfCondition(skewExpr, skewValues, skewValuesExceptNull);
        Random random = new Random(new BigIntLiteral(0), new BigIntLiteral(factor - 1));
        Cast cast = new Cast(random, type);
        If ifExpr = new If(ifCondition, cast, Literal.convertToTypedLiteral(DEFAULT_SALT_VALUE, type));
        ImmutableList.Builder<NamedExpression> namedExpressionsBuilder = ImmutableList.builderWithExpectedSize(
                originPlan.getOutput().size() + 1);
        namedExpressionsBuilder.addAll(originPlan.getOutput());
        namedExpressionsBuilder.add(new Alias(ifExpr, RANDOM_COLUMN_NAME_LEFT));
        return new LogicalProject<>(Utils.fastToImmutableList(namedExpressionsBuilder.build()), originPlan);
    }

    // if saltedSkewValues is [1,2,null], then ifCondition is "skewExpr in [1,2] or skewExpr is null"
    // if saltedSkewValues is [1,2], then ifCondition is "skewExpr in [1,2]"
    // if saltedSkewValues is [null], then ifCondition is "skewExpr is null"
    private static @Nullable Expression getIfCondition(Expression skewExpr, List<Expression> skewValues,
            List<Expression> skewValuesExceptNull) {
        IsNull isNull = null;
        InPredicate in = null;
        if (skewValuesExceptNull.size() < skewValues.size()) {
            isNull = new IsNull(skewExpr);
        }
        if (!skewValuesExceptNull.isEmpty()) {
            in = new InPredicate(skewExpr, skewValuesExceptNull);
        }
        Expression predicate = null;
        if (isNull != null && in != null) {
            predicate = new Or(in, isNull);
        } else if (isNull != null && in == null) {
            predicate = isNull;
        } else if (isNull == null && in != null) {
            predicate = in;
        }
        return predicate;
    }

    private static LogicalProject<Plan> expandSkewValueRows(Expression skewExpr, List<Expression> saltedSkewValues,
            Plan originPlan, int factor, DataType type) {
        if (saltedSkewValues.isEmpty()) {
            ImmutableList.Builder<NamedExpression> namedExpressionsBuilder = ImmutableList.builderWithExpectedSize(
                    originPlan.getOutput().size() + 1);
            namedExpressionsBuilder.addAll(originPlan.getOutput());
            namedExpressionsBuilder.add(new Alias(Literal.convertToTypedLiteral(DEFAULT_SALT_VALUE, type),
                    RANDOM_COLUMN_NAME_RIGHT));
            return new LogicalProject<>(namedExpressionsBuilder.build(), originPlan);
        }
        // construct LogicalUnion and LogicalGenerate
        // if skew values are: 1 and null, the equal sql is:
        // select skewValue, explodeColumn from (select 1 as skewValue union all select null) as t11
        // lateral view explode_numbers(1000) tmp1 as explodeColumn
        ImmutableList.Builder<List<NamedExpression>> constantExprsList = ImmutableList.builderWithExpectedSize(
                saltedSkewValues.size());
        List<NamedExpression> outputs = ImmutableList.of(new SlotReference(SKEW_VALUE_COLUMN_NAME,
                skewExpr.getDataType(), false));
        boolean saltedSkewValuesHasNull = false;
        for (Expression skewValue : saltedSkewValues) {
            constantExprsList.add(ImmutableList.of(new Alias(skewValue, SKEW_VALUE_COLUMN_NAME)));
            if (skewValue instanceof NullLiteral) {
                saltedSkewValuesHasNull = true;
            }
        }
        LogicalUnion union = new LogicalUnion(Qualifier.ALL, outputs, ImmutableList.of(), constantExprsList.build(),
                false, ImmutableList.of());
        List<Function> generators = ImmutableList.of(new ExplodeNumbers(new IntegerLiteral(factor)));
        SlotReference generateSlot = new SlotReference(EXPLODE_NUMBER_COLUMN_NAME, IntegerType.INSTANCE, false);
        LogicalGenerate<Plan> generate = new LogicalGenerate<>(generators, ImmutableList.of(generateSlot), union);
        ImmutableList.Builder<NamedExpression> projectsBuilder = ImmutableList.builderWithExpectedSize(
                union.getOutput().size() + 1);
        projectsBuilder.addAll(union.getOutput());
        projectsBuilder.add(new Alias(new Cast(generateSlot, type)));
        List<NamedExpression> projects = projectsBuilder.build();
        LogicalProject<Plan> project = new LogicalProject<>(projects, generate);
        // construct right join
        EqualPredicate equalTo;
        if (saltedSkewValuesHasNull) {
            equalTo = new NullSafeEqual(outputs.get(0), skewExpr);
        } else {
            equalTo = new EqualTo(outputs.get(0), skewExpr);
        }
        equalTo = (EqualPredicate) TypeCoercionUtils.processComparisonPredicate(equalTo);
        JoinReorderContext joinReorderContext = new JoinReorderContext();
        joinReorderContext.setLeadingJoin(true);
        LogicalJoin<Plan, Plan> rightJoin = new LogicalJoin<>(JoinType.RIGHT_OUTER_JOIN, ImmutableList.of(equalTo),
                project, originPlan, joinReorderContext);
        // construct upper project
        ImmutableList.Builder<NamedExpression> namedExpressionsBuilder = ImmutableList.builderWithExpectedSize(
                originPlan.getOutput().size() + 1);
        namedExpressionsBuilder.addAll(originPlan.getOutput());
        Slot castGeneratedSlot = projects.get(1).toSlot();
        If ifExpr = new If(new IsNull(castGeneratedSlot), Literal.convertToTypedLiteral(DEFAULT_SALT_VALUE, type),
                castGeneratedSlot);
        namedExpressionsBuilder.add(new Alias(ifExpr, RANDOM_COLUMN_NAME_RIGHT));
        return new LogicalProject<>(namedExpressionsBuilder.build(), rightJoin);
    }

    private static int getSaltFactor(MatchingContext<LogicalJoin<Plan, Plan>> ctx) {
        int factor = ctx.connectContext.getStatementContext().getConnectContext()
                .getSessionVariable().joinSkewAddSaltExplodeFactor;
        if (factor <= 0) {
            int beNumber = Math.max(1, ctx.connectContext.getEnv().getClusterInfo().getBackendsNumber(true));
            int parallelInstance = Math.max(1, ctx.connectContext.getSessionVariable().getParallelExecInstanceNum());
            factor = (int) Math.min((long) beNumber * parallelInstance * SALT_FACTOR, Integer.MAX_VALUE);
        }
        return factor;
    }

    private static List<Expression> getSaltedSkewValuesForExpandSide(Expression skewConjunct,
            Set<Expression> skewValuesSet) {
        if (skewConjunct instanceof NullSafeEqual) {
            return Utils.fastToImmutableList(skewValuesSet);
        } else if (skewConjunct instanceof EqualTo) {
            return skewValuesSet.stream().filter(value -> !(value instanceof NullLiteral))
                    .collect(ImmutableList.toImmutableList());
        }
        return ImmutableList.of();
    }

    private static List<Expression> getSaltedSkewValuesForSkewSide(Expression skewConjunct,
            Set<Expression> skewValuesSet, LogicalJoin<Plan, Plan> join) {
        if (skewConjunct instanceof NullSafeEqual) {
            return Utils.fastToImmutableList(skewValuesSet);
        } else if (skewConjunct instanceof EqualTo) {
            if (join.getJoinType().isInnerJoin()) {
                return skewValuesSet.stream().filter(value -> !(value instanceof NullLiteral))
                        .collect(ImmutableList.toImmutableList());
            } else {
                return Utils.fastToImmutableList(skewValuesSet);
            }
        }
        return ImmutableList.of();
    }

    private static LogicalJoin<Plan, Plan> addNotNull(LogicalJoin<Plan, Plan> join, Expression skewConjunct,
            Set<Expression> skewValuesSet) {
        if (skewConjunct instanceof NullSafeEqual) {
            return join;
        }
        boolean containsNull = skewValuesSet.stream().anyMatch(value -> value instanceof NullLiteral);
        if (!containsNull) {
            return join;
        }

        LogicalFilter<Plan> leftFilter =
                new LogicalFilter<>(ImmutableSet.of(new Not(new IsNull(skewConjunct.child(0)))), join.left());
        LogicalFilter<Plan> rightFilter =
                new LogicalFilter<>(ImmutableSet.of(new Not(new IsNull(skewConjunct.child(1)))), join.right());
        DistributeHint hint = join.getDistributeHint();
        switch (join.getJoinType()) {
            case INNER_JOIN:
                hint.setStatus(HintStatus.SUCCESS);
                hint.setSkewInfo(hint.getSkewInfo().withSuccessInSaltJoin(true));
                return join.withDistributeHintChildren(hint, leftFilter, rightFilter);
            case LEFT_OUTER_JOIN:
                hint.setStatus(HintStatus.SUCCESS);
                hint.setSkewInfo(hint.getSkewInfo().withSuccessInSaltJoin(true));
                return join.withDistributeHintChildren(hint, join.left(), rightFilter);
            case RIGHT_OUTER_JOIN:
                hint.setStatus(HintStatus.SUCCESS);
                hint.setSkewInfo(hint.getSkewInfo().withSuccessInSaltJoin(true));
                return join.withDistributeHintChildren(hint, leftFilter, join.right());
            default:
                return join;
        }
    }
}