InferSetOperatorDistinct.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.stats.ExpressionEstimation;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.List;

/**
 * Infer Distinct from SetOperator;
 * Example:
 * <pre>
 *                    Intersect
 *   Intersect ->        |
 *                  Agg for Distinct
 * </pre>
 */
public class InferSetOperatorDistinct extends OneRewriteRuleFactory {
    private static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000;
    private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000;
    private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100;

    private final StatsDerive derive = new StatsDerive(false);

    @Override
    public Rule build() {
        return logicalSetOperation()
                .when(operation -> operation.getQualifier() == Qualifier.DISTINCT)
                .then(setOperation -> {
                    ImmutableList.Builder<Plan> newChildren =
                            ImmutableList.builderWithExpectedSize(setOperation.arity());
                    boolean hasNewChildren = false;
                    for (Plan child : setOperation.children()) {
                        if (shouldInferDistinct(child)) {
                            newChildren.add(new LogicalAggregate<>(
                                    ImmutableList.copyOf(child.getOutput()), true, child));
                            hasNewChildren = true;
                        } else {
                            newChildren.add(child);
                        }
                    }
                    if (!hasNewChildren) {
                        return null;
                    }
                    return setOperation.withChildren(newChildren.build());
                }).toRule(RuleType.INFER_SET_OPERATOR_DISTINCT);
    }

    private boolean shouldInferDistinct(Plan child) {
        return !isAgg(child) && rejectNLJ(child)
                && shouldGenerateAggregateByNdv(child, child.getOutput());
    }

    private boolean isAgg(Plan plan) {
        return plan instanceof LogicalAggregate || (plan instanceof LogicalProject && plan.child(
                0) instanceof LogicalAggregate);
    }

    // if children exist NLJ, we can't infer distinct
    // because NLJ could generate bitmap runtime filter. and it will execute failed when we do infer distinct.
    private boolean rejectNLJ(Plan plan) {
        if (plan instanceof LogicalProject) {
            plan = plan.child(0);
        }
        if (plan instanceof LogicalJoin) {
            LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) plan;
            return join.getOtherJoinConjuncts().isEmpty();
        }
        return true;
    }

    private boolean shouldGenerateAggregateByNdv(Plan plan, List<? extends NamedExpression> groupKeys) {
        Statistics stats = plan.getStats();
        if (stats == null) {
            stats = plan.accept(derive, new StatsDerive.DeriveContext());
        }
        if (stats.getRowCount() <= 0) {
            return false;
        }

        List<ColumnStatistic> lower = new ArrayList<>();
        List<ColumnStatistic> medium = new ArrayList<>();
        List<ColumnStatistic> high = new ArrayList<>();

        List<ColumnStatistic>[] cards = new List[] { lower, medium, high };

        for (NamedExpression key : groupKeys) {
            ColumnStatistic colStats = ExpressionEstimation.INSTANCE.estimate(key, stats);
            if (colStats.isUnKnown) {
                return false;
            }
            if (stats.getRowCount() * 0.9 <= colStats.ndv) {
                return false;
            }
            cards[groupByCardinality(colStats, stats.getRowCount())].add(colStats);
        }

        double lowerCartesian = 1.0;
        for (ColumnStatistic colStats : lower) {
            lowerCartesian = lowerCartesian * colStats.ndv;
        }

        // Same NDV heuristic as EagerAggRewriter#checkStats, but kept local because set-op
        // local distinct and eager aggregation have different optimization boundaries.
        double lowerUpper = Math.max(stats.getRowCount() / 20, 1);
        lowerUpper = Math.pow(lowerUpper, Math.max(lower.size() / 2, 1));

        if (high.isEmpty() && (lower.size() + medium.size()) <= 2) {
            return true;
        }

        if (high.isEmpty() && medium.isEmpty()) {
            if (lower.size() == 1 && lowerCartesian * 20 <= stats.getRowCount()) {
                return true;
            } else if (lower.size() == 2 && lowerCartesian * 7 <= stats.getRowCount()) {
                return true;
            } else if (lower.size() <= 3 && lowerCartesian * 20 <= stats.getRowCount()
                    && lowerCartesian < lowerUpper) {
                return true;
            } else {
                return false;
            }
        }

        if (high.size() >= 2 || medium.size() > 2 || (high.size() == 1 && !medium.isEmpty())) {
            return false;
        }

        double lowerCartesianLowerBound = stats.getRowCount() / LOWER_AGGREGATE_EFFECT_COEFFICIENT;
        if (high.size() + medium.size() == 1 && lower.size() <= 2
                && lowerCartesian <= lowerCartesianLowerBound) {
            return true;
        }

        return false;
    }

    private int groupByCardinality(ColumnStatistic colStats, double rowCount) {
        if (rowCount == 0 || colStats.ndv * MEDIUM_AGGREGATE_EFFECT_COEFFICIENT > rowCount) {
            return 2;
        } else if (colStats.ndv * MEDIUM_AGGREGATE_EFFECT_COEFFICIENT <= rowCount
                && colStats.ndv * LOW_AGGREGATE_EFFECT_COEFFICIENT > rowCount) {
            return 1;
        } else if (colStats.ndv * LOW_AGGREGATE_EFFECT_COEFFICIENT <= rowCount) {
            return 0;
        }
        return 2;
    }
}