JoinEstimation.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.stats;

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.ColumnStatisticBuilder;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsBuilder;

import com.google.common.collect.Lists;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

/**
 * Estimate hash join stats.
 * TODO: Update other props in the ColumnStats properly.
 */
public class JoinEstimation {
    private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3;
    private static double UNKNOWN_COL_STATS_FILTER_SEL_LOWER_BOUND = 0.5;
    private static double TRUSTABLE_CONDITION_SELECTIVITY_POW_FACTOR = 2.0;
    private static double UNTRUSTABLE_CONDITION_SELECTIVITY_LINEAR_FACTOR = 0.9;
    private static double TRUSTABLE_UNIQ_THRESHOLD = 0.9;

    private static EqualPredicate normalizeEqualPredJoinCondition(EqualPredicate equal, Statistics rightStats) {
        boolean changeOrder = equal.left().getInputSlots().stream()
                .anyMatch(slot -> rightStats.findColumnStatistics(slot) != null);
        if (changeOrder) {
            return equal.commute();
        } else {
            return equal;
        }
    }

    private static boolean joinConditionContainsUnknownColumnStats(Statistics leftStats,
            Statistics rightStats, Join join) {
        for (Expression expr : join.getEqualPredicates()) {
            for (Slot slot : expr.getInputSlots()) {
                ColumnStatistic colStats = leftStats.findColumnStatistics(slot);
                if (colStats == null) {
                    colStats = rightStats.findColumnStatistics(slot);
                }
                if (colStats == null || colStats.isUnKnown) {
                    return true;
                }
            }
        }
        return false;
    }

    private static Statistics estimateInnerJoinWithEqualPredicate(Statistics leftStats,
            Statistics rightStats, Join join) {
        /*
         * When we estimate filter A=B,
         * if any side of equation, A or B, is almost unique, the confidence level of estimation is high.
         * But is both sides are not unique, the confidence level is very low.
         * The equations, whose confidence level is low, are called unTrustEquation.
         * In order to avoid error propagation, for unTrustEquations, we only use the biggest selectivity.
         */
        List<Double> unTrustEqualRatio = Lists.newArrayList();
        List<EqualPredicate> unTrustableCondition = Lists.newArrayList();
        boolean leftBigger = leftStats.getRowCount() > rightStats.getRowCount();
        double rightStatsRowCount = StatsMathUtil.nonZeroDivisor(rightStats.getRowCount());
        double leftStatsRowCount = StatsMathUtil.nonZeroDivisor(leftStats.getRowCount());
        List<EqualPredicate> trustableConditions = join.getEqualPredicates().stream()
                .map(expression -> (EqualPredicate) expression)
                .filter(
                        expression -> {
                            // since ndv is not accurate, if ndv/rowcount < TRUSTABLE_UNIQ_THRESHOLD,
                            // this column is regarded as unique.
                            EqualPredicate equal = normalizeEqualPredJoinCondition(expression, rightStats);
                            ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats);
                            ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats);
                            boolean trustable = eqRightColStats.ndv / rightStatsRowCount > TRUSTABLE_UNIQ_THRESHOLD
                                    || eqLeftColStats.ndv / leftStatsRowCount > TRUSTABLE_UNIQ_THRESHOLD;
                            if (!trustable) {
                                double rNdv = StatsMathUtil.nonZeroDivisor(eqRightColStats.ndv);
                                double lNdv = StatsMathUtil.nonZeroDivisor(eqLeftColStats.ndv);
                                if (leftBigger) {
                                    unTrustEqualRatio.add((rightStatsRowCount / rNdv)
                                            * Math.min(eqLeftColStats.ndv, eqRightColStats.ndv) / lNdv);
                                } else {
                                    unTrustEqualRatio.add((leftStatsRowCount / lNdv)
                                            * Math.min(eqLeftColStats.ndv, eqRightColStats.ndv) / rNdv);
                                }
                                unTrustableCondition.add(equal);
                            }
                            return trustable;
                        }
                ).collect(Collectors.toList());

        Statistics innerJoinStats;
        Statistics crossJoinStats = new StatisticsBuilder()
                .setRowCount(Math.max(1, leftStats.getRowCount()) * Math.max(1, rightStats.getRowCount()))
                .putColumnStatistics(leftStats.columnStatistics())
                .putColumnStatistics(rightStats.columnStatistics())
                .build();

        double outputRowCount;
        if (!trustableConditions.isEmpty()) {
            // TODO: strict pk-fk can use one-side stats instead of crossJoinStats
            // in estimateJoinConditionSel, to get more accurate estimation.
            List<Double> joinConditionSels = trustableConditions.stream()
                    .map(expression -> estimateJoinConditionSel(crossJoinStats, expression))
                    .sorted()
                    .collect(Collectors.toList());

            double sel = 1.0;
            double denominator = 1.0;
            for (Double joinConditionSel : joinConditionSels) {
                sel *= Math.pow(joinConditionSel, 1 / denominator);
                denominator *= TRUSTABLE_CONDITION_SELECTIVITY_POW_FACTOR;
            }
            outputRowCount = Math.max(1, crossJoinStats.getRowCount() * sel);
            outputRowCount = outputRowCount * Math.pow(UNTRUSTABLE_CONDITION_SELECTIVITY_LINEAR_FACTOR,
                    unTrustableCondition.size());
        } else {
            outputRowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount());
            Optional<Double> ratio = unTrustEqualRatio.stream().min(Double::compareTo);
            if (ratio.isPresent()) {
                outputRowCount = Math.max(1, outputRowCount * ratio.get());
            }
        }
        innerJoinStats = crossJoinStats.withRowCountAndEnforceValid(outputRowCount);
        return innerJoinStats;
    }

    private static Statistics estimateInnerJoinWithoutEqualPredicate(Statistics leftStats,
            Statistics rightStats, Join join) {
        if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
            double rowCount = (leftStats.getRowCount() + rightStats.getRowCount());
            // We do more like the nested loop join with one rows than inner join
            if (leftStats.getRowCount() == 1 || rightStats.getRowCount() == 1) {
                rowCount *= 0.99;
            } else {
                rowCount *= 1.01;
            }
            rowCount = Math.max(1, rowCount);
            return new StatisticsBuilder()
                    .setRowCount(rowCount)
                    .putColumnStatistics(leftStats.columnStatistics())
                    .putColumnStatistics(rightStats.columnStatistics())
                    .build();
        }
        return new StatisticsBuilder()
                .setRowCount(Math.max(1, leftStats.getRowCount() * rightStats.getRowCount()))
                .putColumnStatistics(leftStats.columnStatistics())
                .putColumnStatistics(rightStats.columnStatistics())
                .build();
    }

    private static double computeSelectivityForBuildSideWhenColStatsUnknown(Statistics buildStats, Join join) {
        double sel = 1.0;
        for (Expression cond : join.getEqualPredicates()) {
            if (cond instanceof EqualTo) {
                EqualTo equal = (EqualTo) cond;
                if (equal.left() instanceof Slot && equal.right() instanceof Slot) {
                    ColumnStatistic buildColStats = buildStats.findColumnStatistics(equal.left());
                    if (buildColStats == null) {
                        buildColStats = buildStats.findColumnStatistics(equal.right());
                    }
                    if (buildColStats != null) {
                        double buildSel = Math.min(buildStats.getRowCount() / buildColStats.count, 1.0);
                        buildSel = Math.max(buildSel, UNKNOWN_COL_STATS_FILTER_SEL_LOWER_BOUND);
                        sel = Math.min(sel, buildSel);
                    }
                }
            }
        }
        return sel;
    }

    private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rightStats, Join join) {
        if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
            double rowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount());
            rowCount = Math.max(1, rowCount);
            return new StatisticsBuilder()
                .setRowCount(rowCount)
                .putColumnStatistics(leftStats.columnStatistics())
                .putColumnStatistics(rightStats.columnStatistics())
                .build();
        }

        Statistics innerJoinStats;
        if (join.getEqualPredicates().isEmpty()) {
            innerJoinStats = estimateInnerJoinWithoutEqualPredicate(leftStats, rightStats, join);
        } else {
            innerJoinStats = estimateInnerJoinWithEqualPredicate(leftStats, rightStats, join);
        }

        if (!join.getOtherJoinConjuncts().isEmpty()) {
            FilterEstimation filterEstimation = new FilterEstimation();
            innerJoinStats = filterEstimation.estimate(
                    ExpressionUtils.and(join.getOtherJoinConjuncts()), innerJoinStats);
            if (innerJoinStats.getRowCount() <= 0) {
                innerJoinStats = new StatisticsBuilder(innerJoinStats).setRowCount(1).build();
            }
        }
        return innerJoinStats;
    }

    private static double estimateJoinConditionSel(Statistics crossJoinStats, Expression joinCond) {
        Statistics statistics = new FilterEstimation().estimate(joinCond, crossJoinStats);
        return statistics.getRowCount() / crossJoinStats.getRowCount();
    }

    private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics leftStats,
            Statistics rightStats, Join join, EqualPredicate equalTo) {
        Expression eqLeft = equalTo.left();
        Expression eqRight = equalTo.right();
        ColumnStatistic probColStats = leftStats.findColumnStatistics(eqLeft);
        ColumnStatistic buildColStats;
        if (probColStats == null) {
            probColStats = leftStats.findColumnStatistics(eqRight);
            buildColStats = rightStats.findColumnStatistics(eqLeft);
        } else {
            buildColStats = rightStats.findColumnStatistics(eqRight);
        }
        if (probColStats == null || buildColStats == null) {
            return Double.POSITIVE_INFINITY;
        }

        double rowCount;
        if (join.getJoinType().isLeftSemiOrAntiJoin()) {
            double semiRowCount = StatsMathUtil.divide(leftStats.getRowCount() * buildColStats.ndv,
                    buildColStats.getOriginalNdv());
            if (join.getJoinType().isSemiJoin()) {
                rowCount = semiRowCount;
            } else {
                rowCount = Math.max(leftStats.getRowCount() - semiRowCount,
                        leftStats.getRowCount() * DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT);
            }
        } else {
            //right semi or anti
            double semiRowCount = StatsMathUtil.divide(rightStats.getRowCount() * probColStats.ndv,
                    probColStats.getOriginalNdv());
            if (join.getJoinType().isSemiJoin()) {
                rowCount = semiRowCount;
            } else {
                rowCount = Math.max(rightStats.getRowCount() - semiRowCount,
                        rightStats.getRowCount() * DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT);
            }
        }
        return Math.max(1, rowCount);
    }

    private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics rightStats,
            Statistics innerJoinStats, Join join) {
        if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join) || join.isMarkJoin()) {
            double sel = join.isMarkJoin() ? 1.0 : computeSelectivityForBuildSideWhenColStatsUnknown(rightStats, join);
            Statistics result;
            if (join.getJoinType().isLeftSemiOrAntiJoin()) {
                result = new StatisticsBuilder().setRowCount(leftStats.getRowCount() * sel)
                        .putColumnStatistics(leftStats.columnStatistics())
                        .putColumnStatistics(rightStats.columnStatistics())
                        .build();
            } else {
                //right semi or anti
                result = new StatisticsBuilder().setRowCount(rightStats.getRowCount() * sel)
                        .putColumnStatistics(leftStats.columnStatistics())
                        .putColumnStatistics(rightStats.columnStatistics())
                        .build();
            }
            result.normalizeColumnStatistics();
            return result;
        }
        double rowCount = Double.POSITIVE_INFINITY;
        for (Expression conjunct : join.getEqualPredicates()) {
            double eqRowCount = estimateSemiOrAntiRowCountBySlotsEqual(leftStats, rightStats,
                    join, (EqualPredicate) conjunct);
            if (rowCount > eqRowCount) {
                rowCount = eqRowCount;
            }
        }
        if (Double.isInfinite(rowCount)) {
            //slotsEqual estimation failed, fall back to original algorithm
            double baseRowCount =
                    join.getJoinType().isLeftSemiOrAntiJoin() ? leftStats.getRowCount() : rightStats.getRowCount();
            rowCount = Math.min(innerJoinStats.getRowCount(), baseRowCount);
            return innerJoinStats.withRowCountAndEnforceValid(rowCount);
        } else {
            // TODO: tuning the new semi/anti estimation method
            /*double crossRowCount = Math.max(1, leftStats.getRowCount()) * Math.max(1, rightStats.getRowCount());
            double selectivity = innerJoinStats.getRowCount() / crossRowCount;
            selectivity = Statistics.getValidSelectivity(selectivity);
            double outputRowCount;
            StatisticsBuilder builder;

            if (join.getJoinType().isLeftSemiOrAntiJoin()) {
                outputRowCount = leftStats.getRowCount();
                builder = new StatisticsBuilder(leftStats);
            } else {
                outputRowCount = rightStats.getRowCount();
                builder = new StatisticsBuilder(rightStats);
            }
            if (join.getJoinType().isLeftSemiJoin() || join.getJoinType().isRightSemiJoin()) {
                outputRowCount *= selectivity;
            } else {
                outputRowCount *= 1 - selectivity;
                if (join.getJoinType().isLeftAntiJoin() && rightStats.getRowCount() < 1) {
                    outputRowCount = leftStats.getRowCount();
                } else if (join.getJoinType().isRightAntiJoin() && leftStats.getRowCount() < 1) {
                    outputRowCount = rightStats.getRowCount();
                } else {
                    outputRowCount = StatsMathUtil.normalizeRowCountOrNdv(outputRowCount);
                }
            }
            builder.setRowCount(outputRowCount);
            Statistics outputStats = builder.build();
            outputStats.normalizeColumnStatistics();
            return outputStats;*/
            StatisticsBuilder builder;
            if (join.getJoinType().isLeftSemiOrAntiJoin()) {
                builder = new StatisticsBuilder(leftStats);
                builder.setRowCount(rowCount);
            } else {
                //right semi or anti
                builder = new StatisticsBuilder(rightStats);
                builder.setRowCount(rowCount);
            }
            Statistics outputStats = builder.build();
            outputStats.normalizeColumnStatistics();
            return outputStats;
        }
    }

    /**
     * estimate join
     */
    public static Statistics estimate(Statistics leftStats, Statistics rightStats, Join join) {
        JoinType joinType = join.getJoinType();
        Statistics crossJoinStats = new StatisticsBuilder()
                .setRowCount(Math.max(1, leftStats.getRowCount()) * Math.max(1, rightStats.getRowCount()))
                .putColumnStatistics(leftStats.columnStatistics())
                .putColumnStatistics(rightStats.columnStatistics())
                .build();
        Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join);
        if (joinType.isSemiOrAntiJoin()) {
            Statistics outputStats = estimateSemiOrAnti(leftStats, rightStats, innerJoinStats, join);
            updateJoinConditionColumnStatistics(outputStats, join);
            return outputStats;
        } else if (joinType == JoinType.INNER_JOIN) {
            updateJoinConditionColumnStatistics(innerJoinStats, join);
            return innerJoinStats;
        } else if (joinType == JoinType.LEFT_OUTER_JOIN) {
            double rowCount = Math.max(leftStats.getRowCount(), innerJoinStats.getRowCount());
            updateJoinConditionColumnStatistics(crossJoinStats, join);
            return crossJoinStats.withRowCountAndEnforceValid(rowCount);
        } else if (joinType == JoinType.RIGHT_OUTER_JOIN) {
            double rowCount = Math.max(rightStats.getRowCount(), innerJoinStats.getRowCount());
            updateJoinConditionColumnStatistics(crossJoinStats, join);
            return crossJoinStats.withRowCountAndEnforceValid(rowCount);
        } else if (joinType == JoinType.FULL_OUTER_JOIN) {
            double rowCount = Math.max(leftStats.getRowCount(), innerJoinStats.getRowCount());
            rowCount = Math.max(rightStats.getRowCount(), rowCount);
            updateJoinConditionColumnStatistics(crossJoinStats, join);
            return crossJoinStats.withRowCountAndEnforceValid(rowCount);
        } else if (joinType == JoinType.CROSS_JOIN) {
            updateJoinConditionColumnStatistics(crossJoinStats, join);
            return crossJoinStats;
        }
        throw new AnalysisException("join type not supported: " + join.getJoinType());
    }

    /**
     * update join condition columns' ColumnStatistics, based on different join type.
     */
    private static void updateJoinConditionColumnStatistics(Statistics inputStats, Join join) {
        Map<Expression, ColumnStatistic> updatedCols = new HashMap<>();
        JoinType joinType = join.getJoinType();
        for (Expression expr : join.getEqualPredicates()) {
            EqualPredicate equalTo = (EqualPredicate) expr;
            ColumnStatistic leftColStats = ExpressionEstimation.estimate(equalTo.left(), inputStats);
            ColumnStatistic rightColStats = ExpressionEstimation.estimate(equalTo.right(), inputStats);
            double leftNdv = 1.0;
            double rightNdv = 1.0;
            boolean updateLeft = false;
            boolean updateRight = false;
            Expression eqLeft = equalTo.left();
            if (eqLeft instanceof Cast) {
                eqLeft = eqLeft.child(0);
            }
            Expression eqRight = equalTo.right();
            if (eqRight instanceof Cast) {
                eqRight = eqRight.child(0);
            }
            if (joinType == JoinType.INNER_JOIN) {
                leftNdv = Math.min(leftColStats.ndv, rightColStats.ndv);
                rightNdv = Math.min(leftColStats.ndv, rightColStats.ndv);
                updateLeft = true;
                updateRight = true;
            } else if (joinType == JoinType.LEFT_OUTER_JOIN) {
                leftNdv = leftColStats.ndv;
                rightNdv = Math.min(leftColStats.ndv, rightColStats.ndv);
                updateLeft = true;
                updateRight = true;
            } else if (joinType == JoinType.LEFT_SEMI_JOIN
                    || joinType == JoinType.LEFT_ANTI_JOIN
                    || joinType == JoinType.NULL_AWARE_LEFT_ANTI_JOIN) {
                leftNdv = Math.min(leftColStats.ndv, rightColStats.ndv);
                updateLeft = true;
            } else if (joinType == JoinType.RIGHT_OUTER_JOIN) {
                leftNdv = Math.min(leftColStats.ndv, rightColStats.ndv);
                rightNdv = rightColStats.ndv;
            } else if (joinType == JoinType.RIGHT_SEMI_JOIN
                    || joinType == JoinType.RIGHT_ANTI_JOIN) {
                rightNdv = Math.min(leftColStats.ndv, rightColStats.ndv);
                updateRight = true;
            } else if (joinType == JoinType.FULL_OUTER_JOIN || joinType == JoinType.CROSS_JOIN) {
                leftNdv = leftColStats.ndv;
                rightNdv = rightColStats.ndv;
                updateLeft = true;
                updateRight = true;
            }

            if (updateLeft) {
                leftColStats = new ColumnStatisticBuilder(leftColStats).setNdv(leftNdv).build();
                updatedCols.put(eqLeft, leftColStats);
            }
            if (updateRight) {
                rightColStats = new ColumnStatisticBuilder(rightColStats).setNdv(rightNdv).build();
                updatedCols.put(eqRight, rightColStats);
            }
        }
        updatedCols.entrySet().stream().forEach(
                entry -> inputStats.addColumnStats(entry.getKey(), entry.getValue())
        );
    }

}