CollectJoinConstraint.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.hint.Hint;
import org.apache.doris.nereids.hint.JoinConstraint;
import org.apache.doris.nereids.hint.LeadingHint;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Set;

/**
 * CollectJoinConstraint
 */
public class CollectJoinConstraint implements RewriteRuleFactory {

    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
            logicalJoin().thenApply(ctx -> {
                if (!ctx.cascadesContext.isLeadingJoin()) {
                    return ctx.root;
                }
                LeadingHint leading = (LeadingHint) ctx.cascadesContext
                            .getHintMap().get("Leading");
                LogicalJoin join = ctx.root;
                if (join.getJoinType().isNullAwareLeftAntiJoin()) {
                    leading.setStatus(Hint.HintStatus.UNUSED);
                    leading.setErrorMessage("condition does not matched joinType");
                }
                Long leftHand = LongBitmap.computeTableBitmap(join.left().getInputRelations());
                Long rightHand = LongBitmap.computeTableBitmap(join.right().getInputRelations());
                join.setBitmap(LongBitmap.or(leftHand, rightHand));
                List<Expression> expressions = join.getHashJoinConjuncts();
                Long totalFilterBitMap = 0L;
                Long nonNullableSlotBitMap = 0L;
                for (Expression expression : expressions) {
                    Long nonNullable = calSlotsTableBitMap(leading, expression.getInputSlots(), true);
                    nonNullableSlotBitMap = LongBitmap.or(nonNullableSlotBitMap, nonNullable);
                    Long filterBitMap = calSlotsTableBitMap(leading, expression.getInputSlots(), false);
                    totalFilterBitMap = LongBitmap.or(totalFilterBitMap, filterBitMap);
                    if (join.getJoinType().isLeftJoin()) {
                        filterBitMap = LongBitmap.or(filterBitMap, rightHand);
                    }
                    leading.getFilters().add(Pair.of(filterBitMap, expression));
                    leading.putConditionJoinType(expression, join.getJoinType());
                }
                expressions = join.getOtherJoinConjuncts();
                for (Expression expression : expressions) {
                    Long nonNullable = calSlotsTableBitMap(leading, expression.getInputSlots(), true);
                    nonNullableSlotBitMap = LongBitmap.or(nonNullableSlotBitMap, nonNullable);
                    Long filterBitMap = calSlotsTableBitMap(leading, expression.getInputSlots(), false);
                    totalFilterBitMap = LongBitmap.or(totalFilterBitMap, filterBitMap);
                    if (join.getJoinType().isLeftJoin()) {
                        filterBitMap = LongBitmap.or(filterBitMap, rightHand);
                    }
                    leading.getFilters().add(Pair.of(filterBitMap, expression));
                    leading.putConditionJoinType(expression, join.getJoinType());
                }
                collectJoinConstraintList(leading, leftHand, rightHand, join, totalFilterBitMap, nonNullableSlotBitMap);

                return ctx.root;
            }).toRule(RuleType.COLLECT_JOIN_CONSTRAINT),

            logicalProject(logicalOlapScan()).thenApply(
                ctx -> {
                    if (!ctx.cascadesContext.isLeadingJoin()) {
                        return ctx.root;
                    }
                    LeadingHint leading = (LeadingHint) ctx.cascadesContext
                            .getHintMap().get("Leading");
                    LogicalProject<LogicalOlapScan> project = ctx.root;
                    LogicalOlapScan scan = project.child();
                    leading.getRelationIdToScanMap().put(scan.getRelationId(), project);
                    return ctx.root;
                }
            ).toRule(RuleType.COLLECT_JOIN_CONSTRAINT)
        );
    }

    private void collectJoinConstraintList(LeadingHint leading, Long leftHand, Long rightHand, LogicalJoin join,
                                            Long filterTableBitMap, Long nonNullableSlotBitMap) {
        Long totalTables = LongBitmap.or(leftHand, rightHand);
        if (join.getJoinType().isInnerOrCrossJoin()) {
            leading.setInnerJoinBitmap(LongBitmap.or(leading.getInnerJoinBitmap(), totalTables));
            return;
        }
        if (join.getJoinType().isFullOuterJoin()) {
            JoinConstraint newJoinConstraint = new JoinConstraint(leftHand, rightHand, leftHand, rightHand,
                    JoinType.FULL_OUTER_JOIN, false);
            leading.getJoinConstraintList().add(newJoinConstraint);
            return;
        }
        boolean isStrict = LongBitmap.isOverlap(nonNullableSlotBitMap, leftHand);
        Long minLeftHand = LongBitmap.newBitmapIntersect(filterTableBitMap, leftHand);
        Long innerJoinTableBitmap = LongBitmap.and(totalTables, leading.getInnerJoinBitmap());
        Long filterAndInnerBelow = LongBitmap.newBitmapUnion(filterTableBitMap, innerJoinTableBitmap);
        Long minRightHand = LongBitmap.newBitmapIntersect(filterAndInnerBelow, rightHand);
        for (JoinConstraint other : leading.getJoinConstraintList()) {
            if (other.getJoinType() == JoinType.FULL_OUTER_JOIN) {
                if (LongBitmap.isOverlap(leftHand, other.getLeftHand())
                        || LongBitmap.isOverlap(leftHand, other.getRightHand())) {
                    minLeftHand = LongBitmap.or(minLeftHand,
                        other.getLeftHand());
                    minLeftHand = LongBitmap.or(minLeftHand,
                        other.getRightHand());
                }
                if (LongBitmap.isOverlap(rightHand, other.getLeftHand())
                        || LongBitmap.isOverlap(rightHand, other.getRightHand())) {
                    minRightHand = LongBitmap.or(minRightHand,
                        other.getLeftHand());
                    minRightHand = LongBitmap.or(minRightHand,
                        other.getRightHand());
                }
                /* Needn't do anything else with the full join */
                continue;
            }

            if (LongBitmap.isOverlap(leftHand, other.getRightHand())) {
                if (LongBitmap.isOverlap(filterTableBitMap, other.getRightHand())
                        && (join.getJoinType().isSemiOrAntiJoin()
                        || !LongBitmap.isOverlap(nonNullableSlotBitMap, other.getMinRightHand()))) {
                    minLeftHand = LongBitmap.or(minLeftHand,
                        other.getLeftHand());
                    minLeftHand = LongBitmap.or(minLeftHand,
                        other.getRightHand());
                }
            }

            if (LongBitmap.isOverlap(rightHand, other.getRightHand())) {
                if (LongBitmap.isOverlap(filterTableBitMap, other.getRightHand())
                        || !LongBitmap.isOverlap(filterTableBitMap, other.getMinLeftHand())
                        || join.getJoinType().isSemiOrAntiJoin()
                        || other.getJoinType().isSemiOrAntiJoin()
                        || !other.isLhsStrict()) {
                    minRightHand = LongBitmap.or(minRightHand, other.getLeftHand());
                    minRightHand = LongBitmap.or(minRightHand, other.getRightHand());
                }
            }
        }
        if (minLeftHand == 0L) {
            minLeftHand = leftHand;
        }
        if (minRightHand == 0L) {
            minRightHand = rightHand;
        }

        JoinConstraint newJoinConstraint = new JoinConstraint(minLeftHand, minRightHand, leftHand, rightHand,
                join.getJoinType(), isStrict);
        leading.getJoinConstraintList().add(newJoinConstraint);
    }

    private long calSlotsTableBitMap(LeadingHint leading, Set<Slot> slots, boolean getNotNullable) {
        Preconditions.checkArgument(slots.size() != 0);
        long bitmap = LongBitmap.newBitmap();
        for (Slot slot : slots) {
            if (getNotNullable && slot.nullable()) {
                continue;
            }
            if (!slot.isColumnFromTable() && (slot.getQualifier() == null || slot.getQualifier().isEmpty())) {
                // we can not get info from column not from table
                continue;
            }
            String tableName = slot.getQualifier().get(slot.getQualifier().size() - 1);
            RelationId id = leading.findRelationIdAndTableName(tableName);
            if (id == null) {
                leading.setStatus(Hint.HintStatus.SYNTAX_ERROR);
                leading.setErrorMessage("can not find table: " + tableName);
                return bitmap;
            }
            long currBitmap = LongBitmap.set(bitmap, id.asInt());
            bitmap = LongBitmap.or(bitmap, currBitmap);
        }
        return bitmap;
    }
}