ConflictRulesMaker.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.jobs.joinorder.hypergraphv2;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraphv2.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraphv2.edge.Edge;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * This is a conflict rule maker to
 */
public class ConflictRulesMaker {
    private static ValEntry[][] assocTable = {
            //             inner-B       semi-B        anti-B        left-B         full-B
            /* inner-A */ {ValEntry.YES, ValEntry.YES, ValEntry.YES, ValEntry.YES, ValEntry.NO},
            /* semi-A  */ {ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO},
            /* anti-A  */ {ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO},
            /* left-A  */ {ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.BRejectRA, ValEntry.NO},
            /* full-A  */ {ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.BRejectRA, ValEntry.ABRejectRA},
    };

    private static ValEntry[][] leftAsscomTable = {
            //             inner-B       semi-B        anti-B        left-B         full-B
            /* inner-A */ {ValEntry.YES, ValEntry.YES, ValEntry.YES, ValEntry.YES, ValEntry.NO},
            /* semi-A  */ {ValEntry.YES, ValEntry.YES, ValEntry.YES, ValEntry.YES, ValEntry.NO},
            /* anti-A  */ {ValEntry.YES, ValEntry.YES, ValEntry.YES, ValEntry.YES, ValEntry.NO},
            /* left-A  */ {ValEntry.YES, ValEntry.YES, ValEntry.YES, ValEntry.YES, ValEntry.ARejectLA},
            /* full-A  */ {ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.BRejectRB, ValEntry.ABRejectLA},
    };

    private static ValEntry[][] rightAsscomTable = {
            //             inner-B       semi-B        anti-B        left-B         full-B
            /* inner-A */ {ValEntry.YES, ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO},
            /* semi-A  */ {ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO},
            /* anti-A  */ {ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO},
            /* left-A  */ {ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO},
            /* full-A  */ {ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.NO, ValEntry.ABRejectRB},
    };

    private ConflictRulesMaker() {
    }

    /**
     * Make edge's conflict rule by CD-C algorithm in
     * On the correct and complete enumeration of the core search
     */
    public static void makeConflictRules(Edge edgeB, List<Edge> joinEdges, ExpressionRewriteContext ctx) {
        // find all left and right subtree edges and ready for CD-C check
        BitSet leftSubTreeEdges = subTreeEdges(edgeB.getLeftChildEdges(), joinEdges);
        BitSet rightSubTreeEdges = subTreeEdges(edgeB.getRightChildEdges(), joinEdges);
        List<Pair<Long, Long>> conflictRules = new ArrayList<>();
        for (int i = leftSubTreeEdges.nextSetBit(0); i >= 0; i = leftSubTreeEdges.nextSetBit(i + 1)) {
            Edge childA = joinEdges.get(i);
            if (!isAssocLeftTree(childA, edgeB, ctx)) {
                generateAssocLeftTreeCR(childA, conflictRules);
            }
            if (!isLAssoc(childA, edgeB, ctx)) {
                generateLAssocCR(childA, conflictRules);
            }
        }

        for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) {
            Edge childA = joinEdges.get(i);
            if (!isAssocRightTree(edgeB, childA, ctx)) {
                generateAssocRightTreeCR(childA, conflictRules);
            }
            if (!isRAssoc(edgeB, childA, ctx)) {
                generateRAssocCR(childA, conflictRules);
            }
        }

        long tes = simplifyConflictRules(edgeB.getRequireNodes(), conflictRules);
        if (!LongBitmap.isOverlap(tes, edgeB.getLeftSubtreeNodes())) {
            tes = LongBitmap.or(tes, edgeB.getLeftSubtreeNodes());
            tes = simplifyConflictRules(tes, conflictRules);
        }
        if (!LongBitmap.isOverlap(tes, edgeB.getRightSubtreeNodes())) {
            tes = LongBitmap.or(tes, edgeB.getRightSubtreeNodes());
            tes = simplifyConflictRules(tes, conflictRules);
        }
        edgeB.setLeftExtendedNodes(LongBitmap.and(tes, edgeB.getLeftSubtreeNodes()));
        edgeB.setRightExtendedNodes(LongBitmap.and(tes, edgeB.getRightSubtreeNodes()));
        edgeB.setConflictRules(conflictRules);
    }

    private static long simplifyConflictRules(long tes, List<Pair<Long, Long>> conflictRules) {
        long oldTes;
        do {
            oldTes = tes;
            for (Pair<Long, Long> rule : conflictRules) {
                if (LongBitmap.isOverlap(rule.first, tes)) {
                    tes = LongBitmap.or(tes, rule.second);
                }
            }
            final long tempTes = tes;
            conflictRules.removeIf(rule -> LongBitmap.isSubset(rule.second, tempTes));
        } while (tes != oldTes && !conflictRules.isEmpty());
        return tes;
    }

    private static boolean isAssocLeftTree(Edge leftChildEdge, Edge currentEdge, ExpressionRewriteContext ctx) {
        int indexA = getIndexForJoinType(leftChildEdge.getJoinType());
        int indexB = getIndexForJoinType(currentEdge.getJoinType());
        if (indexA >= 0 && indexB >= 0) {
            return isValidToReorder(assocTable[indexA][indexB], leftChildEdge.getJoin(), currentEdge.getJoin(), ctx);
        } else {
            return false;
        }
    }

    private static boolean isAssocRightTree(Edge currentEdge, Edge rightChildEdge, ExpressionRewriteContext ctx) {
        int indexA = getIndexForJoinType(currentEdge.getJoinType());
        int indexB = getIndexForJoinType(rightChildEdge.getJoinType());
        if (indexA >= 0 && indexB >= 0) {
            return isValidToReorder(assocTable[indexA][indexB], currentEdge.getJoin(), rightChildEdge.getJoin(), ctx);
        } else {
            return false;
        }
    }

    private static boolean isLAssoc(Edge leftChildEdge, Edge currentEdge, ExpressionRewriteContext ctx) {
        int indexA = getIndexForJoinType(leftChildEdge.getJoinType());
        int indexB = getIndexForJoinType(currentEdge.getJoinType());
        if (indexA >= 0 && indexB >= 0) {
            return isValidToReorder(leftAsscomTable[indexA][indexB], leftChildEdge.getJoin(), currentEdge.getJoin(),
                    ctx);
        } else {
            return false;
        }
    }

    private static boolean isRAssoc(Edge currentEdge, Edge rightChildEdge, ExpressionRewriteContext ctx) {
        int indexA = getIndexForJoinType(currentEdge.getJoinType());
        int indexB = getIndexForJoinType(rightChildEdge.getJoinType());
        if (indexA >= 0 && indexB >= 0) {
            return isValidToReorder(rightAsscomTable[indexA][indexB], currentEdge.getJoin(), rightChildEdge.getJoin(),
                    ctx);
        } else {
            return false;
        }
    }

    private static BitSet subTreeEdge(Edge edge, List<Edge> joinEdges) {
        long subTreeNodes = edge.getSubTreeNodes();
        BitSet subEdges = new BitSet();
        joinEdges.stream()
                .filter(e -> LongBitmap.isSubset(e.getReferenceNodes(), subTreeNodes))
                .forEach(e -> subEdges.set(e.getIndex()));
        return subEdges;
    }

    private static BitSet subTreeEdges(BitSet edgeSet, List<Edge> joinEdges) {
        BitSet bitSet = new BitSet();
        edgeSet.stream()
                .mapToObj(i -> subTreeEdge(joinEdges.get(i), joinEdges))
                .forEach(bitSet::or);
        return bitSet;
    }

    private static int getIndexForJoinType(JoinType joinType) {
        switch (joinType) {
            case CROSS_JOIN:
            case INNER_JOIN:
                return 0;
            case LEFT_SEMI_JOIN:
                return 1;
            case LEFT_ANTI_JOIN:
            case NULL_AWARE_LEFT_ANTI_JOIN:
                return 2;
            case LEFT_OUTER_JOIN:
                return 3;
            case FULL_OUTER_JOIN:
                return 4;
            default:
                return -1;
        }
    }

    private static boolean isValidToReorder(ValEntry valEntry,
                                            LogicalJoin joinA,
                                            LogicalJoin joinB,
                                            ExpressionRewriteContext ctx) {
        switch (valEntry) {
            case YES:
                return true;
            case BRejectRA: {
                Set<Slot> outputBL = joinB.left().getOutputSet();
                for (Object expression : joinB.getExpressions()) {
                    if (isEvalToNullOrFalse(outputBL, (Expression) expression, ctx)) {
                        return true;
                    }
                }
                return false;
            }
            case ABRejectRA: {
                boolean aRejectRA = false;
                Set<Slot> outputAR = joinA.right().getOutputSet();
                for (Object expression : joinA.getExpressions()) {
                    if (isEvalToNullOrFalse(outputAR, (Expression) expression, ctx)) {
                        aRejectRA = true;
                        break;
                    }
                }
                if (aRejectRA) {
                    Set<Slot> outputBL = joinB.left().getOutputSet();
                    for (Object expression : joinB.getExpressions()) {
                        if (isEvalToNullOrFalse(outputBL, (Expression) expression, ctx)) {
                            return true;
                        }
                    }
                }
                return false;
            }
            case ARejectLA: {
                Set<Slot> outputAL = joinA.left().getOutputSet();
                for (Object expression : joinA.getExpressions()) {
                    if (isEvalToNullOrFalse(outputAL, (Expression) expression, ctx)) {
                        return true;
                    }
                }
                return false;
            }
            case BRejectRB: {
                Set<Slot> outputBR = joinB.right().getOutputSet();
                for (Object expression : joinB.getExpressions()) {
                    if (isEvalToNullOrFalse(outputBR, (Expression) expression, ctx)) {
                        return true;
                    }
                }
                return false;
            }
            case ABRejectLA: {
                boolean aRejectLA = false;
                Set<Slot> outputAL = joinA.left().getOutputSet();
                for (Object expression : joinA.getExpressions()) {
                    if (isEvalToNullOrFalse(outputAL, (Expression) expression, ctx)) {
                        aRejectLA = true;
                        break;
                    }
                }
                if (aRejectLA) {
                    Set<Slot> outputBL = joinB.left().getOutputSet();
                    for (Object expression : joinB.getExpressions()) {
                        if (isEvalToNullOrFalse(outputBL, (Expression) expression, ctx)) {
                            return true;
                        }
                    }
                }
                return false;
            }
            case ABRejectRB: {
                boolean aRejectRB = false;
                Set<Slot> outputAR = joinA.right().getOutputSet();
                for (Object expression : joinA.getExpressions()) {
                    if (isEvalToNullOrFalse(outputAR, (Expression) expression, ctx)) {
                        aRejectRB = true;
                        break;
                    }
                }
                if (aRejectRB) {
                    Set<Slot> outputBR = joinB.right().getOutputSet();
                    for (Object expression : joinB.getExpressions()) {
                        if (isEvalToNullOrFalse(outputBR, (Expression) expression, ctx)) {
                            return true;
                        }
                    }
                }
                return false;
            }
            case NO:
            default:
                return false;
        }
    }

    private static boolean isEvalToNullOrFalse(Set<Slot> slots, Expression expression, ExpressionRewriteContext ctx) {
        Map<Slot, NullLiteral> replaceMap = new HashMap<>();
        for (Slot slot : slots) {
            replaceMap.put(slot, new NullLiteral(slot.getDataType()));
        }
        Expression evalExpr = FoldConstantRule.evaluate(
                ExpressionUtils.replace(expression, replaceMap), ctx);
        return evalExpr.isNullLiteral() || BooleanLiteral.FALSE.equals(evalExpr);
    }

    private static void generateAssocLeftTreeCR(Edge leftChildEdge, List<Pair<Long, Long>> conflictRules) {
        long childReferencedNodes = leftChildEdge.getReferenceNodes();
        long childLeftSubtreeNodes = leftChildEdge.getLeftSubtreeNodes();
        long childRightSubtreeNodes = leftChildEdge.getRightSubtreeNodes();
        if (LongBitmap.isOverlap(childReferencedNodes, childLeftSubtreeNodes)) {
            conflictRules.add(Pair.of(childRightSubtreeNodes,
                    LongBitmap.newBitmapIntersect(childReferencedNodes, childLeftSubtreeNodes)));
        } else {
            conflictRules.add(Pair.of(childRightSubtreeNodes, childLeftSubtreeNodes));
        }
    }

    private static void generateAssocRightTreeCR(Edge rightChildEdge, List<Pair<Long, Long>> conflictRules) {
        long childReferencedNodes = rightChildEdge.getReferenceNodes();
        long childLeftSubtreeNodes = rightChildEdge.getLeftSubtreeNodes();
        long childRightSubtreeNodes = rightChildEdge.getRightSubtreeNodes();
        if (LongBitmap.isOverlap(childReferencedNodes, childRightSubtreeNodes)) {
            conflictRules.add(Pair.of(childLeftSubtreeNodes,
                    LongBitmap.newBitmapIntersect(childReferencedNodes, childRightSubtreeNodes)));
        } else {
            conflictRules.add(Pair.of(childLeftSubtreeNodes, childRightSubtreeNodes));
        }
    }

    private static void generateLAssocCR(Edge leftChildEdge, List<Pair<Long, Long>> conflictRules) {
        long childReferencedNodes = leftChildEdge.getReferenceNodes();
        long childLeftSubtreeNodes = leftChildEdge.getLeftSubtreeNodes();
        long childRightSubtreeNodes = leftChildEdge.getRightSubtreeNodes();
        if (LongBitmap.isOverlap(childReferencedNodes, childRightSubtreeNodes)) {
            conflictRules.add(Pair.of(childLeftSubtreeNodes,
                    LongBitmap.newBitmapIntersect(childReferencedNodes, childRightSubtreeNodes)));
        } else {
            conflictRules.add(Pair.of(childLeftSubtreeNodes, childRightSubtreeNodes));
        }
    }

    private static void generateRAssocCR(Edge rightChildEdge, List<Pair<Long, Long>> conflictRules) {
        long childReferencedNodes = rightChildEdge.getReferenceNodes();
        long childLeftSubtreeNodes = rightChildEdge.getLeftSubtreeNodes();
        long childRightSubtreeNodes = rightChildEdge.getRightSubtreeNodes();
        if (LongBitmap.isOverlap(childReferencedNodes, childLeftSubtreeNodes)) {
            conflictRules.add(Pair.of(childRightSubtreeNodes,
                    LongBitmap.newBitmapIntersect(childReferencedNodes, childLeftSubtreeNodes)));
        } else {
            conflictRules.add(Pair.of(childRightSubtreeNodes, childLeftSubtreeNodes));
        }
    }

    private enum ValEntry {
        YES,
        NO,
        BRejectRA,
        ABRejectRA,
        ARejectLA,
        BRejectRB,
        ABRejectLA,
        ABRejectRB
    }
}