CountLiteralRewrite.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * count(1) ==> count(*)
 * count(null) ==> 0
 */
public class CountLiteralRewrite extends OneRewriteRuleFactory {
    @Override
    public Rule build() {
        return logicalAggregate().then(
                agg -> {
                    List<NamedExpression> newExprs = Lists.newArrayListWithCapacity(agg.getOutputExpressions().size());
                    if (!rewriteCountLiteral(agg.getOutputExpressions(), newExprs)) {
                        // no need to rewrite
                        return agg;
                    }

                    List<NamedExpression> projectFuncs = Lists.newArrayListWithCapacity(newExprs.size());
                    Builder<NamedExpression> aggFuncsBuilder
                            = ImmutableList.builderWithExpectedSize(newExprs.size());
                    for (NamedExpression newExpr : newExprs) {
                        if (newExpr.isConstant()) {
                            projectFuncs.add(newExpr);
                        } else {
                            aggFuncsBuilder.add(newExpr);
                        }
                    }

                    List<NamedExpression> aggFuncs = aggFuncsBuilder.build();
                    if (aggFuncs.isEmpty()) {
                        // if there is no group by keys and other agg func, don't rewrite
                        return null;
                    } else {
                        // if there is group by keys, put count(null) in projects, such as
                        // project(0 as count(null))
                        // --Aggregate(k1, group by k1)
                        Plan plan = agg.withAggOutput(aggFuncs);
                        if (!projectFuncs.isEmpty()) {
                            for (NamedExpression aggFunc : aggFuncs) {
                                projectFuncs.add(aggFunc.toSlot());
                            }
                            plan = new LogicalProject<>(projectFuncs, plan);
                        }
                        return plan;
                    }
                }
        ).toRule(RuleType.COUNT_LITERAL_REWRITE);
    }

    private boolean rewriteCountLiteral(List<NamedExpression> oldExprs, List<NamedExpression> newExprs) {
        boolean changed = false;
        for (Expression expr : oldExprs) {
            Map<Expression, Expression> replaced = new HashMap<>();
            Set<AggregateFunction> oldAggFuncSet = expr.collect(AggregateFunction.class::isInstance);
            for (AggregateFunction aggFun : oldAggFuncSet) {
                if (isCountLiteral(aggFun)) {
                    replaced.put(aggFun, rewrite((Count) aggFun));
                }
            }
            expr = expr.rewriteUp(s -> replaced.getOrDefault(s, s));
            changed |= !replaced.isEmpty();
            newExprs.add((NamedExpression) expr);
        }
        return changed;
    }

    private boolean isCountLiteral(AggregateFunction aggFunc) {
        return !aggFunc.isDistinct()
                && aggFunc instanceof Count
                && aggFunc.children().size() == 1
                && aggFunc.child(0).isLiteral();
    }

    private Expression rewrite(Count count) {
        if (count.child(0).isNullLiteral()) {
            return new BigIntLiteral(0);
        }
        return new Count();
    }
}