HashJoinStatsDerive.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.statistics;
import org.apache.doris.analysis.BinaryPredicate;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.JoinOperator;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.catalog.ColumnStats;
import org.apache.doris.common.CheckedMath;
import org.apache.doris.common.UserException;
import org.apache.doris.planner.HashJoinNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.List;
/**
* Derive HashJoinNode statistics.
*/
public class HashJoinStatsDerive extends BaseStatsDerive {
private static final Logger LOG = LogManager.getLogger(HashJoinStatsDerive.class);
private JoinOperator joinOp;
private List<BinaryPredicate> eqJoinConjuncts = Lists.newArrayList();
@Override
public void init(PlanStats node) throws UserException {
Preconditions.checkState(node instanceof HashJoinNode);
super.init(node);
joinOp = ((HashJoinNode) node).getJoinOp();
eqJoinConjuncts.addAll(((HashJoinNode) node).getEqJoinConjuncts());
}
@Override
protected long deriveRowCount() {
if (joinOp.isSemiAntiJoin()) {
rowCount = getSemiJoinrowCount();
} else if (joinOp.isInnerJoin() || joinOp.isOuterJoin()) {
rowCount = getJoinrowCount();
} else {
if (LOG.isDebugEnabled()) {
LOG.debug("joinOp:{} is not supported for HashJoinStatsDerive", joinOp);
}
}
capRowCountAtLimit();
return rowCount;
}
/**
* Returns the estimated rowCount of a semi join node.
* For a left semi join between child(0) and child(1), we look for equality join
* conditions "L.c = R.d" (with L being from child(0) and R from child(1)) and use as
* the rowCount estimate the minimum of
* |child(0)| * Min(NDV(L.c), NDV(R.d)) / NDV(L.c)
* over all suitable join conditions. The reasoning is that:
* -each row in child(0) is returned at most once
* -the probability of a row in child(0) having a match in R is
* Min(NDV(L.c), NDV(R.d)) / NDV(L.c)
*
*<p>
* For a left anti join we estimate the rowCount as the minimum of:
* |L| * Max(NDV(L.c) - NDV(R.d), NDV(L.c)) / NDV(L.c)
* over all suitable join conditions. The reasoning is that:
* - each row in child(0) is returned at most once
* - if NDV(L.c) > NDV(R.d) then the probability of row in L having a match
* in child(1) is (NDV(L.c) - NDV(R.d)) / NDV(L.c)
* - otherwise, we conservatively use |L| to avoid underestimation
*</p>
*
*<p>
* We analogously estimate the rowCount for right semi/anti joins, and treat the
* null-aware anti join like a regular anti join
*</p>
*/
private long getSemiJoinrowCount() {
Preconditions.checkState(joinOp.isSemiJoin());
// Return -1 if the rowCount of the returned side is unknown.
double rowCount;
if (joinOp == JoinOperator.RIGHT_SEMI_JOIN
|| joinOp == JoinOperator.RIGHT_ANTI_JOIN) {
if (childrenStatsResult.get(1).getRowCount() == -1) {
return -1;
}
rowCount = childrenStatsResult.get(1).getRowCount();
} else {
if (childrenStatsResult.get(0).getRowCount() == -1) {
return -1;
}
rowCount = childrenStatsResult.get(0).getRowCount();
}
double minSelectivity = 1.0;
for (Expr eqJoinPredicate : eqJoinConjuncts) {
double lhsNdv = getNdv(eqJoinPredicate.getChild(0));
lhsNdv = Math.min(lhsNdv, childrenStatsResult.get(0).getRowCount());
double rhsNdv = getNdv(eqJoinPredicate.getChild(1));
rhsNdv = Math.min(rhsNdv, childrenStatsResult.get(1).getRowCount());
// Skip conjuncts with unknown NDV on either side.
if (lhsNdv == -1 || rhsNdv == -1) {
continue;
}
double selectivity = 1.0;
switch (joinOp) {
case LEFT_SEMI_JOIN: {
selectivity = (double) Math.min(lhsNdv, rhsNdv) / (double) (lhsNdv);
break;
}
case RIGHT_SEMI_JOIN: {
selectivity = (double) Math.min(lhsNdv, rhsNdv) / (double) (rhsNdv);
break;
}
case LEFT_ANTI_JOIN:
case NULL_AWARE_LEFT_ANTI_JOIN: {
selectivity = (double) (lhsNdv > rhsNdv ? (lhsNdv - rhsNdv) : lhsNdv) / (double) lhsNdv;
break;
}
case RIGHT_ANTI_JOIN: {
selectivity = (double) (rhsNdv > lhsNdv ? (rhsNdv - lhsNdv) : rhsNdv) / (double) rhsNdv;
break;
}
default:
Preconditions.checkState(false);
}
minSelectivity = Math.min(minSelectivity, selectivity);
}
Preconditions.checkState(rowCount != -1);
return Math.round(rowCount * minSelectivity);
}
/**
* Unwraps the SlotRef in expr and returns the NDVs of it.
* Returns -1 if the NDVs are unknown or if expr is not a SlotRef.
*/
private long getNdv(Expr expr) {
SlotRef slotRef = expr.unwrapSlotRef(false);
if (slotRef == null) {
return -1;
}
SlotDescriptor slotDesc = slotRef.getDesc();
if (slotDesc == null) {
return -1;
}
ColumnStats stats = slotDesc.getStats();
if (!stats.hasNumDistinctValues()) {
return -1;
}
return stats.getNumDistinctValues();
}
private long getJoinrowCount() {
Preconditions.checkState(joinOp.isInnerJoin() || joinOp.isOuterJoin());
Preconditions.checkState(childrenStatsResult.size() == 2);
long lhsCard = (long) childrenStatsResult.get(0).getRowCount();
long rhsCard = (long) childrenStatsResult.get(1).getRowCount();
if (lhsCard == -1 || rhsCard == -1) {
return lhsCard;
}
// Collect join conjuncts that are eligible to participate in rowCount estimation.
List<HashJoinNode.EqJoinConjunctScanSlots> eqJoinConjunctSlots = new ArrayList<>();
for (Expr eqJoinConjunct : eqJoinConjuncts) {
HashJoinNode.EqJoinConjunctScanSlots slots = HashJoinNode.EqJoinConjunctScanSlots.create(eqJoinConjunct);
if (slots != null) {
eqJoinConjunctSlots.add(slots);
}
}
if (eqJoinConjunctSlots.isEmpty()) {
// There are no eligible equi-join conjuncts.
return lhsCard;
}
return getGenericJoinrowCount(eqJoinConjunctSlots, lhsCard, rhsCard);
}
/**
* Returns the estimated join rowCount of a generic N:M inner or outer join based
* on the given list of equi-join conjunct slots and the join input cardinalities.
* The returned result is >= 0.
* The list of join conjuncts must be non-empty and the cardinalities must be >= 0.
*
* <p>
* Generic estimation:
* rowCount = |child(0)| * |child(1)| / max(NDV(L.c), NDV(R.d))
* - case A: NDV(L.c) <= NDV(R.d)
* every row from child(0) joins with |child(1)| / NDV(R.d) rows
* - case B: NDV(L.c) > NDV(R.d)
* every row from child(1) joins with |child(0)| / NDV(L.c) rows
* - we adjust the NDVs from both sides to account for predicates that may
* might have reduce the rowCount and NDVs
*</p>
*/
private long getGenericJoinrowCount(List<HashJoinNode.EqJoinConjunctScanSlots> eqJoinConjunctSlots,
long lhsCard,
long rhsCard) {
Preconditions.checkState(joinOp.isInnerJoin() || joinOp.isOuterJoin());
Preconditions.checkState(!eqJoinConjunctSlots.isEmpty());
Preconditions.checkState(lhsCard >= 0 && rhsCard >= 0);
long result = -1;
for (HashJoinNode.EqJoinConjunctScanSlots slots : eqJoinConjunctSlots) {
// Adjust the NDVs on both sides to account for predicates. Intuitively, the NDVs
// should only decrease. We ignore adjustments that would lead to an increase.
double lhsAdjNdv = slots.lhsNdv();
if (slots.lhsNumRows() > lhsCard) {
lhsAdjNdv *= lhsCard / slots.lhsNumRows();
}
double rhsAdjNdv = slots.rhsNdv();
if (slots.rhsNumRows() > rhsCard) {
rhsAdjNdv *= rhsCard / slots.rhsNumRows();
}
// A lower limit of 1 on the max Adjusted Ndv ensures we don't estimate
// rowCount more than the max possible.
long tmpNdv = Double.doubleToLongBits(Math.max(1, Math.max(lhsAdjNdv, rhsAdjNdv)));
long joinCard = tmpNdv == rhsCard
? lhsCard
: CheckedMath.checkedMultiply(
Math.round((lhsCard / Math.max(1, Math.max(lhsAdjNdv, rhsAdjNdv)))), rhsCard);
if (result == -1) {
result = joinCard;
} else {
result = Math.min(result, joinCard);
}
}
Preconditions.checkState(result >= 0);
return result;
}
}