InferSetOperatorDistinct.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.stats.ExpressionEstimation;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.List;

/**
 * Infer Distinct from SetOperator;
 * Example:
 * <pre>
 *                    Intersect
 *   Intersect ->        |
 *                  Agg for Distinct
 * </pre>
 */
public class InferSetOperatorDistinct extends OneRewriteRuleFactory {
    private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 10000;
    private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 1000;
    private static final double HIGH_AGGREGATE_EFFECT_COEFFICIENT = 100;

    private final StatsDerive derive = new StatsDerive(false);

    @Override
    public Rule build() {
        return logicalSetOperation()
                .when(operation -> operation.getQualifier() == Qualifier.DISTINCT)
                .then(setOperation -> {
                    ImmutableList.Builder<Plan> newChildren =
                            ImmutableList.builderWithExpectedSize(setOperation.arity());
                    boolean hasNewChildren = false;
                    for (Plan child : setOperation.children()) {
                        if (shouldInferDistinct(child)) {
                            newChildren.add(new LogicalAggregate<>(
                                    ImmutableList.copyOf(child.getOutput()), true, child));
                            hasNewChildren = true;
                        } else {
                            newChildren.add(child);
                        }
                    }
                    if (!hasNewChildren) {
                        return null;
                    }
                    return setOperation.withChildren(newChildren.build());
                }).toRule(RuleType.INFER_SET_OPERATOR_DISTINCT);
    }

    private boolean shouldInferDistinct(Plan child) {
        return !isAgg(child) && shouldGenerateAggregateByNdv(child, child.getOutput());
    }

    private boolean isAgg(Plan plan) {
        return plan instanceof LogicalAggregate || (plan instanceof LogicalProject && plan.child(
                0) instanceof LogicalAggregate);
    }

    /**
     * Decide whether a local aggregate can reduce the input enough to justify its cost, using the
     * estimated NDV of the grouping keys. This is the same empirical NDV heuristic used by
     * {@code EagerAggRewriter#checkStats} for aggregation pushdown: it categorizes each key by
     * its row-count-to-NDV ratio and bounds the Cartesian product of low-cardinality keys.
     *
     * <p>The thresholds in this method (100, 1,000, 10,000, 20, 7, and 0.9) are intentionally
     * heuristic rather than a cardinality formula. They were copied from the eager-aggregation
     * rules and calibrated against TPC-DS; in particular, this rule changes the plans of TPC-DS
     * queries 8, 14, and 75. The heuristic is deliberately copied rather than shared with eager
     * aggregation: local distinct and aggregation pushdown have different optimization boundaries,
     * so their thresholds can evolve independently without coupling the two rules.
     */
    private boolean shouldGenerateAggregateByNdv(Plan plan, List<? extends NamedExpression> groupKeys) {
        Statistics stats = plan.getStats();
        if (stats == null) {
            stats = plan.accept(derive, new StatsDerive.DeriveContext());
            if (stats == null) {
                return false;
            }
        }
        if (stats.getRowCount() <= 0) {
            return false;
        }

        List<ColumnStatistic> low = new ArrayList<>();
        List<ColumnStatistic> medium = new ArrayList<>();
        List<ColumnStatistic> high = new ArrayList<>();

        List<ColumnStatistic>[] cards = new List[] { low, medium, high };

        for (NamedExpression key : groupKeys) {
            ColumnStatistic colStats = ExpressionEstimation.INSTANCE.estimate(key, stats);
            if (colStats.isUnKnown) {
                return false;
            }
            if (stats.getRowCount() * 0.9 <= colStats.ndv) {
                return false;
            }
            cards[groupByCardinality(colStats, stats.getRowCount())].add(colStats);
        }

        double lowerCartesian = 1.0;
        for (ColumnStatistic colStats : low) {
            lowerCartesian = lowerCartesian * colStats.ndv;
        }

        // Same NDV heuristic as EagerAggRewriter#checkStats, but kept local because set-op
        // local distinct and eager aggregation have different optimization boundaries.
        double lowerUpper = Math.max(stats.getRowCount() / 20, 1);
        lowerUpper = Math.pow(lowerUpper, Math.max(low.size() / 2, 1));

        if (high.isEmpty() && (low.size() + medium.size()) <= 2) {
            return true;
        }

        if (high.isEmpty() && medium.isEmpty()) {
            if (low.size() == 1 && lowerCartesian * 20 <= stats.getRowCount()) {
                return true;
            } else if (low.size() == 2 && lowerCartesian * 7 <= stats.getRowCount()) {
                return true;
            } else if (low.size() <= 3 && lowerCartesian * 20 <= stats.getRowCount()
                    && lowerCartesian < lowerUpper) {
                return true;
            } else {
                return false;
            }
        }

        if (high.size() >= 2 || medium.size() > 2 || (high.size() == 1 && !medium.isEmpty())) {
            return false;
        }

        double lowerCartesianLowerBound = stats.getRowCount() / LOW_AGGREGATE_EFFECT_COEFFICIENT;
        if (high.size() + medium.size() == 1 && low.size() <= 2
                && lowerCartesian <= lowerCartesianLowerBound) {
            return true;
        }

        return false;
    }

    private int groupByCardinality(ColumnStatistic colStats, double rowCount) {
        if (rowCount == 0 || colStats.ndv * HIGH_AGGREGATE_EFFECT_COEFFICIENT > rowCount) {
            return 2;
        } else if (colStats.ndv * HIGH_AGGREGATE_EFFECT_COEFFICIENT <= rowCount
                && colStats.ndv * MEDIUM_AGGREGATE_EFFECT_COEFFICIENT > rowCount) {
            return 1;
        } else if (colStats.ndv * MEDIUM_AGGREGATE_EFFECT_COEFFICIENT <= rowCount) {
            return 0;
        }
        return 2;
    }
}