JoinUtils.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.catalog.ColocateTableIndex;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.rewrite.ForeignKeyContext;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContains;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Utils for join
*/
public class JoinUtils {
/**
* couldShuffle
*/
public static boolean couldShuffle(Join join) {
// Cross-join and Null-Aware-Left-Anti-Join only can be broadcast join.
// standalone mark join would consider null value from both build and probe side, so must use broadcast join.
// mark join with hash conjuncts can shuffle by hash conjuncts
// TODO actually standalone mark join can use shuffle, but need do nullaware shuffle to broadcast null value
// to all instances
return !(join.getJoinType().isCrossJoin() || join.getJoinType().isNullAwareLeftAntiJoin()
|| (!join.getMarkJoinConjuncts().isEmpty() && join.getHashJoinConjuncts().isEmpty()));
}
public static boolean couldBroadcast(Join join) {
return !(join.getJoinType().isRightJoin() || join.getJoinType().isFullOuterJoin());
}
/**
* check if the row count of the left child in the broadcast join is less than a threshold value.
*/
public static boolean checkBroadcastJoinStats(PhysicalHashJoin<? extends Plan, ? extends Plan> join) {
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
double memLimit = sessionVariable.getMaxExecMemByte();
double rowsLimit = sessionVariable.getBroadcastRowCountLimit();
double brMemlimit = sessionVariable.getBroadcastHashtableMemLimitPercentage();
double datasize = join.getGroupExpression().get().child(1)
.getStatistics().computeSize(join.right().getOutput());
double rowCount = join.getGroupExpression().get().child(1).getStatistics().getRowCount();
return rowCount <= rowsLimit && datasize <= memLimit * brMemlimit;
}
/**
* for a given equation, judge if it can be used as hash join condition
*/
public static final class JoinSlotCoverageChecker {
Set<ExprId> leftExprIds;
Set<ExprId> rightExprIds;
public JoinSlotCoverageChecker(List<Slot> left, List<Slot> right) {
leftExprIds = left.stream().map(Slot::getExprId).collect(ImmutableSet.toImmutableSet());
rightExprIds = right.stream().map(Slot::getExprId).collect(ImmutableSet.toImmutableSet());
}
/**
* consider following cases:
* 1# A=1 => not for hash table
* 2# t1.a=t2.a + t2.b => hash table
* 3# t1.a=t1.a + t2.b => not for hash table
* 4# t1.a=t2.a or t1.b=t2.b not for hash table
* 5# t1.a > 1 not for hash table
*
* @param equal a conjunct in on clause condition
* @return true if the equal can be used as hash join condition
*/
public boolean isHashJoinCondition(EqualPredicate equal) {
Set<ExprId> equalLeftExprIds = equal.left().getInputSlotExprIds();
if (equalLeftExprIds.isEmpty()) {
return false;
}
Set<ExprId> equalRightExprIds = equal.right().getInputSlotExprIds();
if (equalRightExprIds.isEmpty()) {
return false;
}
return leftExprIds.containsAll(equalLeftExprIds) && rightExprIds.containsAll(equalRightExprIds)
|| leftExprIds.containsAll(equalRightExprIds) && rightExprIds.containsAll(equalLeftExprIds);
}
}
/**
* extract expression
*
* @param leftSlots left child output slots
* @param rightSlots right child output slots
* @param onConditions conditions to be split
* @return pair of hashCondition and otherCondition
*/
public static Pair<List<Expression>, List<Expression>> extractExpressionForHashTable(List<Slot> leftSlots,
List<Slot> rightSlots, List<Expression> onConditions) {
JoinSlotCoverageChecker checker = new JoinSlotCoverageChecker(leftSlots, rightSlots);
ImmutableList.Builder<Expression> hashConditions = ImmutableList.builderWithExpectedSize(onConditions.size());
ImmutableList.Builder<Expression> otherConditions = ImmutableList.builderWithExpectedSize(onConditions.size());
for (Expression expr : onConditions) {
if (expr instanceof EqualPredicate && checker.isHashJoinCondition((EqualPredicate) expr)) {
hashConditions.add(expr);
} else {
otherConditions.add(expr);
}
}
return Pair.of(
hashConditions.build(),
otherConditions.build()
);
}
/**
* This is used for bitmap runtime filter only.
* Extract bitmap_contains conjunct:
* like: bitmap_contains(a, b) and ..., Not(bitmap_contains(a, b)) and ...,
* where `a` and `b` are from right child and left child, respectively.
*
* @return condition for bitmap runtime filter: bitmap_contains
*/
public static List<Expression> extractBitmapRuntimeFilterConditions(List<Slot> leftSlots,
List<Slot> rightSlots, List<Expression> onConditions) {
List<Expression> result = Lists.newArrayList();
for (Expression expr : onConditions) {
BitmapContains bitmapContains = null;
if (expr instanceof Not) {
List<Expression> notChildren = ExpressionUtils.extractConjunction(expr.child(0));
if (notChildren.size() == 1 && notChildren.get(0) instanceof BitmapContains) {
bitmapContains = (BitmapContains) notChildren.get(0);
}
} else if (expr instanceof BitmapContains) {
bitmapContains = (BitmapContains) expr;
}
if (bitmapContains == null) {
continue;
}
//first child in right, second child in left
if (leftSlots.containsAll(bitmapContains.child(1).collect(Slot.class::isInstance))
&& rightSlots.containsAll(bitmapContains.child(0).collect(Slot.class::isInstance))) {
result.add(expr);
}
}
return result;
}
public static boolean shouldNestedLoopJoin(Join join) {
// currently, mark join conjuncts only has one conjunct, so we always get the first element here
return join.getHashJoinConjuncts().isEmpty() && (join.getMarkJoinConjuncts().isEmpty()
|| !(join.getMarkJoinConjuncts().get(0) instanceof EqualPredicate));
}
public static boolean shouldNestedLoopJoin(JoinType joinType, List<Expression> hashConjuncts) {
// this function is only called by hyper graph, which reject mark join
// so mark join is not processed here
return hashConjuncts.isEmpty();
}
/**
* The left and right child of origin predicates need to swap sometimes.
* Case A:
* select * from t1 join t2 on t2.id=t1.id
* The left plan node is t1 and the right plan node is t2.
* The left child of origin predicate is t2.id and the right child of origin predicate is t1.id.
* In this situation, the children of predicate need to be swap => t1.id=t2.id.
*/
public static EqualPredicate swapEqualToForChildrenOrder(EqualPredicate equalTo, Set<Slot> leftOutput) {
if (leftOutput.containsAll(equalTo.left().getInputSlots())) {
return equalTo;
} else {
return equalTo.commute();
}
}
/**
* return true if we should do bucket shuffle join when translate plan.
*/
public static boolean shouldBucketShuffleJoin(AbstractPhysicalJoin<PhysicalPlan, PhysicalPlan> join) {
if (isStorageBucketed(join.right().getPhysicalProperties())) {
return true;
} else if (SessionVariable.canUseNereidsDistributePlanner()
&& isStorageBucketed(join.left().getPhysicalProperties())) {
return true;
}
return false;
}
private static boolean isStorageBucketed(PhysicalProperties physicalProperties) {
DistributionSpec distributionSpec = physicalProperties.getDistributionSpec();
if (!(distributionSpec instanceof DistributionSpecHash)) {
return false;
}
DistributionSpecHash rightHash = (DistributionSpecHash) distributionSpec;
if (rightHash.getShuffleType() == ShuffleType.STORAGE_BUCKETED) {
return true;
}
return false;
}
/**
* return true if we should do broadcast join when translate plan.
*/
public static boolean shouldBroadcastJoin(AbstractPhysicalJoin<PhysicalPlan, PhysicalPlan> join) {
PhysicalPlan right = join.right();
if (right instanceof PhysicalDistribute) {
return ((PhysicalDistribute<?>) right).getDistributionSpec() instanceof DistributionSpecReplicated;
}
return false;
}
/**
* return true if we should do colocate join when translate plan.
*/
public static boolean shouldColocateJoin(AbstractPhysicalJoin<PhysicalPlan, PhysicalPlan> join) {
if (ConnectContext.get() == null
|| ConnectContext.get().getSessionVariable().isDisableColocatePlan()) {
return false;
}
DistributionSpec leftDistributionSpec = join.left().getPhysicalProperties().getDistributionSpec();
DistributionSpec rightDistributionSpec = join.right().getPhysicalProperties().getDistributionSpec();
if (!(leftDistributionSpec instanceof DistributionSpecHash)
|| !(rightDistributionSpec instanceof DistributionSpecHash)) {
return false;
}
return couldColocateJoin((DistributionSpecHash) leftDistributionSpec,
(DistributionSpecHash) rightDistributionSpec, join.getHashJoinConjuncts());
}
/**
* could do colocate join with left and right child distribution spec.
*/
public static boolean couldColocateJoin(DistributionSpecHash leftHashSpec, DistributionSpecHash rightHashSpec,
List<Expression> conjuncts) {
if (ConnectContext.get() == null
|| ConnectContext.get().getSessionVariable().isDisableColocatePlan()) {
return false;
}
if (leftHashSpec.getShuffleType() != ShuffleType.NATURAL
|| rightHashSpec.getShuffleType() != ShuffleType.NATURAL) {
return false;
}
final long leftTableId = leftHashSpec.getTableId();
final long rightTableId = rightHashSpec.getTableId();
final Set<Long> leftTablePartitions = leftHashSpec.getPartitionIds();
final Set<Long> rightTablePartitions = rightHashSpec.getPartitionIds();
// For UT or no partition is selected, getSelectedIndexId() == -1, see selectMaterializedView()
boolean hitSameIndex = (leftTableId == rightTableId)
&& (leftHashSpec.getSelectedIndexId() != -1 && rightHashSpec.getSelectedIndexId() != -1)
&& (leftHashSpec.getSelectedIndexId() == rightHashSpec.getSelectedIndexId());
boolean noNeedCheckColocateGroup = hitSameIndex && (leftTablePartitions.equals(rightTablePartitions))
&& (leftTablePartitions.size() <= 1);
ColocateTableIndex colocateIndex = Env.getCurrentColocateIndex();
if (noNeedCheckColocateGroup) {
return true;
}
if (!colocateIndex.isSameGroup(leftTableId, rightTableId)
|| colocateIndex.isGroupUnstable(colocateIndex.getGroup(leftTableId))) {
return false;
}
Set<Integer> equalIndices = new HashSet<>();
for (Expression expr : conjuncts) {
// only simple equal predicate can use colocate join
if (!(expr instanceof EqualPredicate)) {
return false;
}
Expression leftChild = ((EqualPredicate) expr).left();
Expression rightChild = ((EqualPredicate) expr).right();
if (!(leftChild instanceof SlotReference) || !(rightChild instanceof SlotReference)) {
return false;
}
SlotReference leftSlot = (SlotReference) leftChild;
SlotReference rightSlot = (SlotReference) rightChild;
Integer leftIndex = leftHashSpec.getExprIdToEquivalenceSet().get(leftSlot.getExprId());
Integer rightIndex = rightHashSpec.getExprIdToEquivalenceSet().get(rightSlot.getExprId());
if (leftIndex == null) {
leftIndex = rightHashSpec.getExprIdToEquivalenceSet().get(leftSlot.getExprId());
rightIndex = leftHashSpec.getExprIdToEquivalenceSet().get(rightSlot.getExprId());
}
if (!Objects.equals(leftIndex, rightIndex)) {
return false;
}
if (leftIndex != null) {
equalIndices.add(leftIndex);
}
}
// on conditions must contain all distributed columns
if (equalIndices.containsAll(leftHashSpec.getExprIdToEquivalenceSet().values())) {
return true;
} else {
return false;
}
}
public static Set<ExprId> getJoinOutputExprIdSet(Plan left, Plan right) {
Set<ExprId> joinOutputExprIdSet = new HashSet<>();
joinOutputExprIdSet.addAll(left.getOutputExprIdSet());
joinOutputExprIdSet.addAll(right.getOutputExprIdSet());
return joinOutputExprIdSet;
}
private static List<Slot> applyNullable(List<Slot> slots, boolean nullable) {
Builder<Slot> newSlots = ImmutableList.builderWithExpectedSize(slots.size());
for (Slot slot : slots) {
newSlots.add(slot.withNullable(nullable));
}
return newSlots.build();
}
private static Map<Slot, Slot> mapPrimaryToForeign(ImmutableEqualSet<Slot> equivalenceSet,
Set<Slot> foreignKeys) {
ImmutableMap.Builder<Slot, Slot> builder = new ImmutableMap.Builder<>();
for (Slot foreignSlot : foreignKeys) {
Set<Slot> primarySlots = equivalenceSet.calEqualSet(foreignSlot);
if (primarySlots.size() != 1) {
return ImmutableMap.of();
}
builder.put(primarySlots.iterator().next(), foreignSlot);
}
return builder.build();
}
/**
* Check whether the given join can be eliminated by pk-fk
*/
public static boolean canEliminateByFk(LogicalJoin<?, ?> join, Plan primaryPlan, Plan foreignPlan) {
if (!join.getJoinType().isInnerJoin() || !join.getOtherJoinConjuncts().isEmpty() || join.isMarkJoin()) {
return false;
}
ForeignKeyContext context = new ForeignKeyContext();
context.collectForeignKeyConstraint(primaryPlan);
context.collectForeignKeyConstraint(foreignPlan);
ImmutableEqualSet<Slot> equalSet = join.getEqualSlots();
Set<Slot> primaryKey = Sets.intersection(equalSet.getAllItemSet(), primaryPlan.getOutputSet());
Set<Slot> foreignKey = Sets.intersection(equalSet.getAllItemSet(), foreignPlan.getOutputSet());
if (!context.isForeignKey(foreignKey) || !context.isPrimaryKey(primaryKey)) {
return false;
}
Map<Slot, Slot> primaryToForeignKey = mapPrimaryToForeign(equalSet, foreignKey);
return context.satisfyConstraint(primaryToForeignKey);
}
/**
* can this join be eliminated by its left child
*/
public static boolean canEliminateByLeft(LogicalJoin<?, ?> join, DataTrait rightFuncDeps) {
if (join.getJoinType().isLeftOuterJoin()) {
Pair<Set<Slot>, Set<Slot>> njHashKeys = join.extractNullRejectHashKeys();
if (!join.getOtherJoinConjuncts().isEmpty() || njHashKeys == null) {
return false;
}
return rightFuncDeps.isUnique(njHashKeys.second);
}
return false;
}
/**
* calculate the output slot of a join operator according join type and its children
*
* @param joinType the type of join operator
* @param left left child
* @param right right child
* @return return the output slots
*/
public static List<Slot> getJoinOutput(JoinType joinType, Plan left, Plan right) {
return getJoinOutput(joinType, left, right, false);
}
/**
* calculate the output slot of a join operator according join type and its children
*
* @param joinType the type of join operator
* @param left left child
* @param right right child
* @param asteriskOutput when true, return output for asterisk
*
* @return return the output slots
*/
public static List<Slot> getJoinOutput(JoinType joinType, Plan left, Plan right, boolean asteriskOutput) {
List<Slot> leftOutput = asteriskOutput ? left.getAsteriskOutput() : left.getOutput();
List<Slot> rightOutput = asteriskOutput ? right.getAsteriskOutput() : right.getOutput();
switch (joinType) {
case LEFT_SEMI_JOIN:
case LEFT_ANTI_JOIN:
case NULL_AWARE_LEFT_ANTI_JOIN:
return ImmutableList.copyOf(leftOutput);
case RIGHT_SEMI_JOIN:
case RIGHT_ANTI_JOIN:
return ImmutableList.copyOf(rightOutput);
case LEFT_OUTER_JOIN:
return ImmutableList.<Slot>builder()
.addAll(leftOutput)
.addAll(applyNullable(rightOutput, true))
.build();
case RIGHT_OUTER_JOIN:
return ImmutableList.<Slot>builder()
.addAll(applyNullable(leftOutput, true))
.addAll(rightOutput)
.build();
case FULL_OUTER_JOIN:
return ImmutableList.<Slot>builder()
.addAll(applyNullable(leftOutput, true))
.addAll(applyNullable(rightOutput, true))
.build();
default:
return ImmutableList.<Slot>builder()
.addAll(leftOutput)
.addAll(rightOutput)
.build();
}
}
public static boolean hasMarkConjuncts(Join join) {
return !join.getMarkJoinConjuncts().isEmpty();
}
public static boolean isNullAwareMarkJoin(Join join) {
// if mark join's hash conjuncts is empty, we use mark conjuncts as hash conjuncts
// and translate join type to NULL_AWARE_LEFT_SEMI_JOIN or NULL_AWARE_LEFT_ANTI_JOIN
return join.getHashJoinConjuncts().isEmpty() && !join.getMarkJoinConjuncts().isEmpty();
}
/**
* forbid join reorder if top join's condition use mark join slot produced by bottom join
*/
public static boolean checkReorderPrecondition(LogicalJoin<?, ?> top, LogicalJoin<?, ?> bottom) {
Set<Slot> markSlots = top.getConditionSlot().stream()
.filter(MarkJoinSlotReference.class::isInstance)
.collect(Collectors.toSet());
markSlots.retainAll(bottom.getOutputSet());
return markSlots.isEmpty();
}
}