OrExpansion.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.rules.rewrite.OrExpansion.OrExpandsionContext;
import org.apache.doris.nereids.trees.copier.DeepCopierContext;
import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

/**
 * https://blogs.oracle.com/optimizer/post/optimizer-transformations-or-expansion
 * NLJ (cond1 or cond2)              UnionAll
 *                      =>         /        \
 *                           HJ(cond1) HJ(cond2 and !cond1)
 */
public class OrExpansion extends DefaultPlanRewriter<OrExpandsionContext> implements CustomRewriter {
    public static final OrExpansion INSTANCE = new OrExpansion();
    public static final ImmutableSet<JoinType> supportJoinType = new ImmutableSet
            .Builder<JoinType>()
            .add(JoinType.INNER_JOIN)
            .add(JoinType.LEFT_ANTI_JOIN)
            .add(JoinType.LEFT_OUTER_JOIN)
            .add(JoinType.FULL_OUTER_JOIN)
            .build();

    @Override
    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
        OrExpandsionContext ctx = new OrExpandsionContext(
                jobContext.getCascadesContext().getStatementContext(), jobContext.getCascadesContext());
        plan = plan.accept(this, ctx);
        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
            LogicalCTEProducer<? extends Plan> producer = ctx.cteProducerList.get(i);
            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
        }
        return plan;
    }

    @Override
    public Plan visitLogicalCTEAnchor(
            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, OrExpandsionContext ctx) {
        Plan child1 = anchor.child(0).accept(this, ctx);
        // Consumer's CTE must be child of the cteAnchor in this case:
        // anchor
        // +-producer1
        // +-agg(consumer1) join agg(consumer1)
        // ------------>
        // anchor
        // +-producer1
        // +-anchor
        // +--producer2(agg2(consumer1))
        // +--producer3(agg3(consumer1))
        // +-consumer2 join consumer3
        OrExpandsionContext consumerContext =
                new OrExpandsionContext(ctx.statementContext, ctx.cascadesContext);
        Plan child2 = anchor.child(1).accept(this, consumerContext);
        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
            LogicalCTEProducer<? extends Plan> producer = consumerContext.cteProducerList.get(i);
            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, child2);
        }
        return anchor.withChildren(ImmutableList.of(child1, child2));
    }

    @Override
    public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, OrExpandsionContext ctx) {
        join = (LogicalJoin<? extends Plan, ? extends Plan>) this.visit(join, ctx);
        if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) {
            return join;
        }
        if (!supportJoinType.contains(join.getJoinType())) {
            return join;
        }
        Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(),
                "Only Expansion nest loop join without hashCond");

        //1. Try to split or conditions
        Pair<List<Expression>, List<Expression>> hashOtherConditions = splitOrCondition(join);
        if (hashOtherConditions == null || hashOtherConditions.first.size() <= 1) {
            return join;
        }

        //2. Construct CTE with the children
        LogicalPlan leftClone = LogicalPlanDeepCopier.INSTANCE
                .deepCopy((LogicalPlan) join.left(), new DeepCopierContext());
        LogicalCTEProducer<? extends Plan> leftProducer = new LogicalCTEProducer<>(
                ctx.statementContext.getNextCTEId(), leftClone);
        LogicalPlan rightClone = LogicalPlanDeepCopier.INSTANCE
                .deepCopy((LogicalPlan) join.right(), new DeepCopierContext());
        LogicalCTEProducer<? extends Plan> rightProducer = new LogicalCTEProducer<>(
                ctx.statementContext.getNextCTEId(), rightClone);
        Map<Slot, Slot> leftCloneToLeft = new HashMap<>();
        for (int i = 0; i < leftClone.getOutput().size(); i++) {
            leftCloneToLeft.put(leftClone.getOutput().get(i), (join.left()).getOutput().get(i));
        }
        Map<Slot, Slot> rightCloneToRight = new HashMap<>();
        for (int i = 0; i < rightClone.getOutput().size(); i++) {
            rightCloneToRight.put(rightClone.getOutput().get(i), (join.right()).getOutput().get(i));
        }

        // 3. Expand join to hash join with CTE
        List<Plan> joins = new ArrayList<>();
        if (join.getJoinType().isInnerJoin()) {
            joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions,
                    join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
        } else if (join.getJoinType().isOuterJoin()) {
            // left outer join = inner join union left anti join
            joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions,
                    join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
            joins.add(expandLeftAntiJoin(ctx.cascadesContext,
                    hashOtherConditions, join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
            if (join.getJoinType().equals(JoinType.FULL_OUTER_JOIN)) {
                // full outer join = inner join union left anti join union right anti join
                joins.add(expandLeftAntiJoin(ctx.cascadesContext, hashOtherConditions,
                        join, rightProducer, leftProducer, rightCloneToRight, leftCloneToLeft));
            }
        } else if (join.getJoinType().equals(JoinType.LEFT_ANTI_JOIN)) {
            joins.add(expandLeftAntiJoin(ctx.cascadesContext, hashOtherConditions,
                    join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
        } else {
            throw new RuntimeException("or-expansion is not supported for " + join);
        }
        //4. union all joins and put producers to context
        List<List<SlotReference>> childrenOutputs = joins.stream()
                .map(j -> j.getOutput().stream()
                        .map(SlotReference.class::cast)
                        .collect(ImmutableList.toImmutableList()))
                .collect(ImmutableList.toImmutableList());
        LogicalUnion union = new LogicalUnion(Qualifier.ALL, new ArrayList<>(join.getOutput()),
                childrenOutputs, ImmutableList.of(), false, joins);
        ctx.cteProducerList.add(leftProducer);
        ctx.cteProducerList.add(rightProducer);
        return union;
    }

    // try to find a condition that can be split into hash conditions
    // If that conditions exist, return split disjunctions, and other
    // conditions without the or condition, otherwise return null.
    private @Nullable Pair<List<Expression>, List<Expression>> splitOrCondition(
            LogicalJoin<? extends Plan, ? extends Plan> join) {
        List<Expression> otherConditions = new ArrayList<>(join.getOtherJoinConjuncts());
        for (Expression expr : otherConditions) {
            List<Expression> disjunctions = ExpressionUtils.extractDisjunction(expr);
            Pair<List<Expression>, List<Expression>> hashOtherCond = JoinUtils.extractExpressionForHashTable(
                    join.left().getOutput(), join.right().getOutput(), disjunctions);
            if (hashOtherCond.second.isEmpty()) {
                otherConditions.remove(expr);
                return Pair.of(disjunctions, otherConditions);
            }
        }
        return null;
    }

    private Map<Slot, Slot> constructReplaceMap(LogicalCTEConsumer leftConsumer, Map<Slot, Slot> leftCloneToLeft,
            LogicalCTEConsumer rightConsumer, Map<Slot, Slot> rightCloneToRight) {
        Map<Slot, Slot> replaced = new HashMap<>();
        for (Entry<Slot, Slot> entry : leftConsumer.getConsumerToProducerOutputMap().entrySet()) {
            replaced.put(leftCloneToLeft.get(entry.getValue()), entry.getKey());
        }
        for (Entry<Slot, Slot> entry : rightConsumer.getConsumerToProducerOutputMap().entrySet()) {
            replaced.put(rightCloneToRight.get(entry.getValue()), entry.getKey());
        }
        return replaced;
    }

    // expand Anti Join:
    // Left Anti join cond1 or cond2, other           Left Anti join cond1 and other
    // /                      \                         /                        \
    //left                   right          ===>  Anti join cond2 and other    CTERight2
    //                                                  /      \
    //                                               CTELeft CTERight1
    private Plan expandLeftAntiJoin(CascadesContext ctx,
            Pair<List<Expression>, List<Expression>> hashOtherConditions,
            LogicalJoin<? extends Plan, ? extends Plan> originJoin,
            LogicalCTEProducer<? extends Plan> leftProducer,
            LogicalCTEProducer<? extends org.apache.doris.nereids.trees.plans.Plan> rightProducer,
            Map<Slot, Slot> leftCloneToLeft, Map<Slot, Slot> rightCloneToRight) {
        LogicalCTEConsumer left = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
                leftProducer.getCteId(), "", leftProducer);
        LogicalCTEConsumer right = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
                rightProducer.getCteId(), "", rightProducer);
        ctx.putCTEIdToConsumer(left);
        ctx.putCTEIdToConsumer(right);

        Map<Slot, Slot> replaced = constructReplaceMap(left, leftCloneToLeft, right, rightCloneToRight);
        List<Expression> disjunctions = hashOtherConditions.first;
        List<Expression> otherConditions = hashOtherConditions.second;
        List<Expression> newOtherConditions = otherConditions.stream()
                .map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s)).collect(Collectors.toList());

        Expression hashCond = disjunctions.get(0);
        hashCond = hashCond.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s);
        Plan newPlan = new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(hashCond),
                newOtherConditions, originJoin.getDistributeHint(),
                originJoin.getMarkJoinSlotReference(), left, right, JoinReorderContext.EMPTY);
        if (hashCond.children().stream().anyMatch(e -> !(e instanceof Slot))) {
            Plan normalizedPlan = PushDownExpressionsInHashCondition.pushDownHashExpression(
                    (LogicalJoin<? extends Plan, ? extends Plan>) newPlan);
            newPlan = new LogicalProject<>(new ArrayList<>(newPlan.getOutput()), normalizedPlan);
        }

        for (int i = 1; i < disjunctions.size(); i++) {
            hashCond = disjunctions.get(i);
            LogicalCTEConsumer newRight = new LogicalCTEConsumer(
                    ctx.getStatementContext().getNextRelationId(), rightProducer.getCteId(), "", rightProducer);
            ctx.putCTEIdToConsumer(newRight);
            Map<Slot, Slot> newReplaced = constructReplaceMap(left, leftCloneToLeft, newRight, rightCloneToRight);
            newOtherConditions = otherConditions.stream()
                    .map(e -> e.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s))
                    .collect(Collectors.toList());
            hashCond = hashCond.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s);
            newPlan = new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(hashCond),
                    newOtherConditions, originJoin.getDistributeHint(),
                    originJoin.getMarkJoinSlotReference(), newPlan, newRight, JoinReorderContext.EMPTY);
            if (hashCond.children().stream().anyMatch(e -> !(e instanceof Slot))) {
                newPlan = PushDownExpressionsInHashCondition.pushDownHashExpression(
                        (LogicalJoin<? extends Plan, ? extends Plan>) newPlan);
            }
        }

        Plan finalNewPlan = newPlan;
        List<NamedExpression> projects = originJoin.getOutput().stream()
                .map(replaced::get)
                .map(s -> finalNewPlan.getOutputSet().contains(s) ? s :
                        new Alias(new NullLiteral(s.getDataType()), s.getName()))
                .collect(Collectors.toList());
        return new LogicalProject<>(projects, newPlan);
    }

    // expand Inner Join:
    // Inner join cond1 or cond2                   UnionAll
    // /                     \                /               \
    //left                   right   ===>  Inner join cond1   Inner join cond1 and !cond2
    //                                      /      \          /       \
    //                                  CTELeft1 CTERight1  CTELeft2  CTERight2
    private List<Plan> expandInnerJoin(CascadesContext ctx, Pair<List<Expression>,
            List<Expression>> hashOtherConditions,
            LogicalJoin<? extends Plan, ? extends Plan> join, LogicalCTEProducer<? extends Plan> leftProducer,
            LogicalCTEProducer<? extends Plan> rightProducer,
            Map<Slot, Slot> leftCloneToLeft, Map<Slot, Slot> rightCloneToRight) {
        List<Expression> disjunctions = hashOtherConditions.first;
        List<Expression> otherConditions = hashOtherConditions.second;
        // For null values, equalTo and not equalTo both return false
        // To avoid it, we always return true when there is null
        List<Expression> notExprs = disjunctions.stream()
                .map(e -> ExpressionUtils.or(new Not(e), new IsNull(e)))
                .collect(ImmutableList.toImmutableList());
        List<Plan> joins = Lists.newArrayList();

        for (int hashCondIdx = 0; hashCondIdx < disjunctions.size(); hashCondIdx++) {
            // extract hash conditions and other condition
            Pair<List<Expression>, List<Expression>> pair = extractHashAndOtherConditions(hashCondIdx, disjunctions,
                    notExprs);
            pair.second.addAll(otherConditions);

            LogicalCTEConsumer left = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
                    leftProducer.getCteId(), "", leftProducer);
            LogicalCTEConsumer right = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
                    rightProducer.getCteId(), "", rightProducer);
            ctx.putCTEIdToConsumer(left);
            ctx.putCTEIdToConsumer(right);

            //rewrite conjuncts to replace the old slots with CTE slots
            Map<Slot, Slot> replaced = constructReplaceMap(left, leftCloneToLeft, right, rightCloneToRight);
            List<Expression> hashCond = pair.first.stream()
                    .map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s))
                    .collect(Collectors.toList());
            List<Expression> otherCond = pair.second.stream()
                    .map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s))
                    .collect(Collectors.toList());

            LogicalJoin<? extends Plan, ? extends Plan> newJoin = new LogicalJoin<>(
                    JoinType.INNER_JOIN, hashCond, otherCond, join.getDistributeHint(),
                    join.getMarkJoinSlotReference(), left, right, null);
            if (newJoin.getHashJoinConjuncts().stream()
                    .anyMatch(equalTo -> equalTo.children().stream().anyMatch(e -> !(e instanceof Slot)))) {
                Plan plan = PushDownExpressionsInHashCondition.pushDownHashExpression(newJoin);
                plan = new LogicalProject<>(new ArrayList<>(newJoin.getOutput()), plan);
                joins.add(plan);
            } else {
                joins.add(newJoin);
            }
        }
        return joins;
    }

    // join(a or b or c) = join(a) union join(b) union join(c)
    //                   = join(a) union all (join b and !a) union all join(c and !b and !a)
    // return hashConditions and otherConditions
    private Pair<List<Expression>, List<Expression>> extractHashAndOtherConditions(int hashCondIdx,
            List<Expression> equal, List<Expression> not) {
        List<Expression> others = new ArrayList<>();
        for (int i = 0; i < hashCondIdx; i++) {
            others.add(not.get(i));
        }
        return Pair.of(Lists.newArrayList(equal.get(hashCondIdx)), others);
    }

    class OrExpandsionContext {
        List<LogicalCTEProducer<? extends Plan>> cteProducerList;
        StatementContext statementContext;
        CascadesContext cascadesContext;

        public OrExpandsionContext(StatementContext statementContext, CascadesContext cascadesContext) {
            this.statementContext = statementContext;
            this.cteProducerList = new ArrayList<>();
            this.cascadesContext = cascadesContext;
        }
    }
}