AdjustPreAggStatus.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.MaterializedIndexMeta;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.algebra.Project;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * AdjustPreAggStatus
 */
@Developing
public class AdjustPreAggStatus implements RewriteRuleFactory {
    ///////////////////////////////////////////////////////////////////////////
    // All the patterns
    ///////////////////////////////////////////////////////////////////////////
    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
                // Aggregate(Scan)
                logicalAggregate(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))
                        .thenApply(ctx -> {
                            LogicalAggregate<LogicalOlapScan> agg = ctx.root;
                            LogicalOlapScan scan = agg.child();
                            PreAggStatus preAggStatus = checkKeysType(scan);
                            if (preAggStatus == PreAggStatus.unset()) {
                                List<AggregateFunction> aggregateFunctions =
                                        extractAggFunctionAndReplaceSlot(agg, Optional.empty());
                                List<Expression> groupByExpressions = agg.getGroupByExpressions();
                                Set<Expression> predicates = ImmutableSet.of();
                                preAggStatus = checkPreAggStatus(scan, predicates,
                                        aggregateFunctions, groupByExpressions);
                            }
                            return agg.withChildren(scan.withPreAggStatus(preAggStatus));
                        }).toRule(RuleType.PREAGG_STATUS_AGG_SCAN),

                // Aggregate(Filter(Scan))
                logicalAggregate(
                        logicalFilter(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))
                                .thenApply(ctx -> {
                                    LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
                                    LogicalFilter<LogicalOlapScan> filter = agg.child();
                                    LogicalOlapScan scan = filter.child();
                                    PreAggStatus preAggStatus = checkKeysType(scan);
                                    if (preAggStatus == PreAggStatus.unset()) {
                                        List<AggregateFunction> aggregateFunctions =
                                                extractAggFunctionAndReplaceSlot(agg, Optional.empty());
                                        List<Expression> groupByExpressions =
                                                agg.getGroupByExpressions();
                                        Set<Expression> predicates = filter.getConjuncts();
                                        preAggStatus = checkPreAggStatus(scan, predicates,
                                                aggregateFunctions, groupByExpressions);
                                    }
                                    return agg.withChildren(filter
                                            .withChildren(scan.withPreAggStatus(preAggStatus)));
                                }).toRule(RuleType.PREAGG_STATUS_AGG_FILTER_SCAN),

                // Aggregate(Project(Scan))
                logicalAggregate(logicalProject(
                        logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))
                                .thenApply(ctx -> {
                                    LogicalAggregate<LogicalProject<LogicalOlapScan>> agg =
                                            ctx.root;
                                    LogicalProject<LogicalOlapScan> project = agg.child();
                                    LogicalOlapScan scan = project.child();
                                    PreAggStatus preAggStatus = checkKeysType(scan);
                                    if (preAggStatus == PreAggStatus.unset()) {
                                        List<AggregateFunction> aggregateFunctions =
                                                extractAggFunctionAndReplaceSlot(agg,
                                                        Optional.of(project));
                                        List<Expression> groupByExpressions =
                                                ExpressionUtils.replace(agg.getGroupByExpressions(),
                                                        project.getAliasToProducer());
                                        Set<Expression> predicates = ImmutableSet.of();
                                        preAggStatus = checkPreAggStatus(scan, predicates,
                                                aggregateFunctions, groupByExpressions);
                                    }
                                    return agg.withChildren(project
                                            .withChildren(scan.withPreAggStatus(preAggStatus)));
                                }).toRule(RuleType.PREAGG_STATUS_AGG_PROJECT_SCAN),

                // Aggregate(Project(Filter(Scan)))
                logicalAggregate(logicalProject(logicalFilter(
                        logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))
                                .thenApply(ctx -> {
                                    LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
                                    LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
                                    LogicalFilter<LogicalOlapScan> filter = project.child();
                                    LogicalOlapScan scan = filter.child();
                                    PreAggStatus preAggStatus = checkKeysType(scan);
                                    if (preAggStatus == PreAggStatus.unset()) {
                                        List<AggregateFunction> aggregateFunctions =
                                                extractAggFunctionAndReplaceSlot(agg, Optional.of(project));
                                        List<Expression> groupByExpressions =
                                                ExpressionUtils.replace(agg.getGroupByExpressions(),
                                                        project.getAliasToProducer());
                                        Set<Expression> predicates = filter.getConjuncts();
                                        preAggStatus = checkPreAggStatus(scan, predicates,
                                                aggregateFunctions, groupByExpressions);
                                    }
                                    return agg.withChildren(project.withChildren(filter
                                            .withChildren(scan.withPreAggStatus(preAggStatus))));
                                }).toRule(RuleType.PREAGG_STATUS_AGG_PROJECT_FILTER_SCAN),

                // Aggregate(Filter(Project(Scan)))
                logicalAggregate(logicalFilter(logicalProject(
                        logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))
                                .thenApply(ctx -> {
                                    LogicalAggregate<LogicalFilter<LogicalProject<LogicalOlapScan>>> agg = ctx.root;
                                    LogicalFilter<LogicalProject<LogicalOlapScan>> filter =
                                            agg.child();
                                    LogicalProject<LogicalOlapScan> project = filter.child();
                                    LogicalOlapScan scan = project.child();
                                    PreAggStatus preAggStatus = checkKeysType(scan);
                                    if (preAggStatus == PreAggStatus.unset()) {
                                        List<AggregateFunction> aggregateFunctions =
                                                extractAggFunctionAndReplaceSlot(agg, Optional.of(project));
                                        List<Expression> groupByExpressions =
                                                ExpressionUtils.replace(agg.getGroupByExpressions(),
                                                        project.getAliasToProducer());
                                        Set<Expression> predicates = ExpressionUtils.replace(
                                                filter.getConjuncts(), project.getAliasToProducer());
                                        preAggStatus = checkPreAggStatus(scan, predicates,
                                                aggregateFunctions, groupByExpressions);
                                    }
                                    return agg.withChildren(filter.withChildren(project
                                            .withChildren(scan.withPreAggStatus(preAggStatus))));
                                }).toRule(RuleType.PREAGG_STATUS_AGG_FILTER_PROJECT_SCAN),

                // Aggregate(Repeat(Scan))
                logicalAggregate(
                        logicalRepeat(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))
                                .thenApply(ctx -> {
                                    LogicalAggregate<LogicalRepeat<LogicalOlapScan>> agg = ctx.root;
                                    LogicalRepeat<LogicalOlapScan> repeat = agg.child();
                                    LogicalOlapScan scan = repeat.child();
                                    PreAggStatus preAggStatus = checkKeysType(scan);
                                    if (preAggStatus == PreAggStatus.unset()) {
                                        List<AggregateFunction> aggregateFunctions =
                                                extractAggFunctionAndReplaceSlot(agg, Optional.empty());
                                        List<Expression> groupByExpressions = nonVirtualGroupByExprs(agg);
                                        Set<Expression> predicates = ImmutableSet.of();
                                        preAggStatus = checkPreAggStatus(scan, predicates,
                                                aggregateFunctions, groupByExpressions);
                                    }
                                    return agg.withChildren(repeat
                                            .withChildren(scan.withPreAggStatus(preAggStatus)));
                                }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_SCAN),

                // Aggregate(Repeat(Filter(Scan)))
                logicalAggregate(logicalRepeat(logicalFilter(
                        logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))
                                .thenApply(ctx -> {
                                    LogicalAggregate<LogicalRepeat<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
                                    LogicalRepeat<LogicalFilter<LogicalOlapScan>> repeat = agg.child();
                                    LogicalFilter<LogicalOlapScan> filter = repeat.child();
                                    LogicalOlapScan scan = filter.child();
                                    PreAggStatus preAggStatus = checkKeysType(scan);
                                    if (preAggStatus == PreAggStatus.unset()) {
                                        List<AggregateFunction> aggregateFunctions =
                                                extractAggFunctionAndReplaceSlot(agg, Optional.empty());
                                        List<Expression> groupByExpressions =
                                                nonVirtualGroupByExprs(agg);
                                        Set<Expression> predicates = filter.getConjuncts();
                                        preAggStatus = checkPreAggStatus(scan, predicates,
                                                aggregateFunctions, groupByExpressions);
                                    }
                                    return agg.withChildren(repeat.withChildren(filter
                                            .withChildren(scan.withPreAggStatus(preAggStatus))));
                                }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_FILTER_SCAN),

                // Aggregate(Repeat(Project(Scan)))
                logicalAggregate(logicalRepeat(logicalProject(
                        logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))
                                .thenApply(ctx -> {
                                    LogicalAggregate<LogicalRepeat<LogicalProject<LogicalOlapScan>>> agg = ctx.root;
                                    LogicalRepeat<LogicalProject<LogicalOlapScan>> repeat = agg.child();
                                    LogicalProject<LogicalOlapScan> project = repeat.child();
                                    LogicalOlapScan scan = project.child();
                                    PreAggStatus preAggStatus = checkKeysType(scan);
                                    if (preAggStatus == PreAggStatus.unset()) {
                                        List<AggregateFunction> aggregateFunctions =
                                                extractAggFunctionAndReplaceSlot(agg, Optional.empty());
                                        List<Expression> groupByExpressions =
                                                ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
                                                        project.getAliasToProducer());
                                        Set<Expression> predicates = ImmutableSet.of();
                                        preAggStatus = checkPreAggStatus(scan, predicates,
                                                aggregateFunctions, groupByExpressions);
                                    }
                                    return agg.withChildren(repeat.withChildren(project
                                            .withChildren(scan.withPreAggStatus(preAggStatus))));
                                }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_PROJECT_SCAN),

                // Aggregate(Repeat(Project(Filter(Scan))))
                logicalAggregate(logicalRepeat(logicalProject(logicalFilter(
                        logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))))
                                .thenApply(ctx -> {
                                    LogicalAggregate<LogicalRepeat<LogicalProject<LogicalFilter<LogicalOlapScan>>>> agg
                                            = ctx.root;
                                    LogicalRepeat<LogicalProject<LogicalFilter<LogicalOlapScan>>> repeat = agg.child();
                                    LogicalProject<LogicalFilter<LogicalOlapScan>> project = repeat.child();
                                    LogicalFilter<LogicalOlapScan> filter = project.child();
                                    LogicalOlapScan scan = filter.child();
                                    PreAggStatus preAggStatus = checkKeysType(scan);
                                    if (preAggStatus == PreAggStatus.unset()) {
                                        List<AggregateFunction> aggregateFunctions =
                                                extractAggFunctionAndReplaceSlot(agg, Optional.empty());
                                        List<Expression> groupByExpressions =
                                                ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
                                                        project.getAliasToProducer());
                                        Set<Expression> predicates = filter.getConjuncts();
                                        preAggStatus = checkPreAggStatus(scan, predicates,
                                                aggregateFunctions, groupByExpressions);
                                    }
                                    return agg.withChildren(repeat
                                            .withChildren(project.withChildren(filter.withChildren(
                                                    scan.withPreAggStatus(preAggStatus)))));
                                }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_PROJECT_FILTER_SCAN),

                // Aggregate(Repeat(Filter(Project(Scan))))
                logicalAggregate(logicalRepeat(logicalFilter(logicalProject(
                        logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))))
                                .thenApply(ctx -> {
                                    LogicalAggregate<LogicalRepeat<LogicalFilter<LogicalProject<LogicalOlapScan>>>> agg
                                            = ctx.root;
                                    LogicalRepeat<LogicalFilter<LogicalProject<LogicalOlapScan>>> repeat = agg.child();
                                    LogicalFilter<LogicalProject<LogicalOlapScan>> filter = repeat.child();
                                    LogicalProject<LogicalOlapScan> project = filter.child();
                                    LogicalOlapScan scan = project.child();
                                    PreAggStatus preAggStatus = checkKeysType(scan);
                                    if (preAggStatus == PreAggStatus.unset()) {
                                        List<AggregateFunction> aggregateFunctions =
                                                extractAggFunctionAndReplaceSlot(agg, Optional.of(project));
                                        List<Expression> groupByExpressions =
                                                ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
                                                        project.getAliasToProducer());
                                        Set<Expression> predicates = ExpressionUtils.replace(
                                                filter.getConjuncts(), project.getAliasToProducer());
                                        preAggStatus = checkPreAggStatus(scan, predicates,
                                                aggregateFunctions, groupByExpressions);
                                    }
                                    return agg.withChildren(repeat
                                            .withChildren(filter.withChildren(project.withChildren(
                                                    scan.withPreAggStatus(preAggStatus)))));
                                }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_FILTER_PROJECT_SCAN),

                // Filter(Project(Scan))
                logicalFilter(logicalProject(
                        logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))
                                .thenApply(ctx -> {
                                    LogicalFilter<LogicalProject<LogicalOlapScan>> filter = ctx.root;
                                    LogicalProject<LogicalOlapScan> project = filter.child();
                                    LogicalOlapScan scan = project.child();
                                    PreAggStatus preAggStatus = checkKeysType(scan);
                                    if (preAggStatus == PreAggStatus.unset()) {
                                        List<AggregateFunction> aggregateFunctions = ImmutableList.of();
                                        List<Expression> groupByExpressions = ImmutableList.of();
                                        Set<Expression> predicates = ExpressionUtils.replace(
                                                filter.getConjuncts(), project.getAliasToProducer());
                                        preAggStatus = checkPreAggStatus(scan, predicates,
                                                aggregateFunctions, groupByExpressions);
                                    }
                                    return filter.withChildren(project
                                            .withChildren(scan.withPreAggStatus(preAggStatus)));
                                }).toRule(RuleType.PREAGG_STATUS_FILTER_PROJECT_SCAN),

                // Filter(Scan)
                logicalFilter(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))
                        .thenApply(ctx -> {
                            LogicalFilter<LogicalOlapScan> filter = ctx.root;
                            LogicalOlapScan scan = filter.child();
                            PreAggStatus preAggStatus = checkKeysType(scan);
                            if (preAggStatus == PreAggStatus.unset()) {
                                List<AggregateFunction> aggregateFunctions = ImmutableList.of();
                                List<Expression> groupByExpressions = ImmutableList.of();
                                Set<Expression> predicates = filter.getConjuncts();
                                preAggStatus = checkPreAggStatus(scan, predicates,
                                        aggregateFunctions, groupByExpressions);
                            }
                            return filter.withChildren(scan.withPreAggStatus(preAggStatus));
                        }).toRule(RuleType.PREAGG_STATUS_FILTER_SCAN),

                // only scan.
                logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)
                        .thenApply(ctx -> {
                            LogicalOlapScan scan = ctx.root;
                            PreAggStatus preAggStatus = checkKeysType(scan);
                            if (preAggStatus == PreAggStatus.unset()) {
                                List<AggregateFunction> aggregateFunctions = ImmutableList.of();
                                List<Expression> groupByExpressions = ImmutableList.of();
                                Set<Expression> predicates = ImmutableSet.of();
                                preAggStatus = checkPreAggStatus(scan, predicates,
                                        aggregateFunctions, groupByExpressions);
                            }
                            return scan.withPreAggStatus(preAggStatus);
                        }).toRule(RuleType.PREAGG_STATUS_SCAN));
    }

    ///////////////////////////////////////////////////////////////////////////
    // Set pre-aggregation status.
    ///////////////////////////////////////////////////////////////////////////

    /**
     * Do aggregate function extraction and replace aggregate function's input slots by underlying project.
     * <p>
     * 1. extract aggregate functions in aggregate plan.
     * <p>
     * 2. replace aggregate function's input slot by underlying project expression if project is present.
     * <p>
     * For example:
     * <pre>
     * input arguments:
     * agg: Aggregate(sum(v) as sum_value)
     * underlying project: Project(a + b as v)
     *
     * output:
     * sum(a + b)
     * </pre>
     */
    private List<AggregateFunction> extractAggFunctionAndReplaceSlot(LogicalAggregate<?> agg,
            Optional<LogicalProject<?>> project) {
        Optional<Map<Slot, Expression>> slotToProducerOpt =
                project.map(Project::getAliasToProducer);
        return agg.getOutputExpressions().stream()
                // extract aggregate functions.
                .flatMap(e -> e.<AggregateFunction>collect(AggregateFunction.class::isInstance)
                        .stream())
                // replace aggregate function's input slot by its producing expression.
                .map(expr -> slotToProducerOpt
                        .map(slotToExpressions -> (AggregateFunction) ExpressionUtils.replace(expr,
                                slotToExpressions))
                        .orElse(expr))
                .collect(Collectors.toList());
    }

    private PreAggStatus checkKeysType(LogicalOlapScan olapScan) {
        long selectIndexId = olapScan.getSelectedIndexId();
        MaterializedIndexMeta meta = olapScan.getTable().getIndexMetaByIndexId(selectIndexId);
        if (meta.getKeysType() == KeysType.DUP_KEYS || (meta.getKeysType() == KeysType.UNIQUE_KEYS
                && olapScan.getTable().getEnableUniqueKeyMergeOnWrite())) {
            return PreAggStatus.on();
        } else {
            return PreAggStatus.unset();
        }
    }

    private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan, Set<Expression> predicates,
            List<AggregateFunction> aggregateFuncs, List<Expression> groupingExprs) {
        Set<Slot> outputSlots = olapScan.getOutputSet();
        Pair<Set<SlotReference>, Set<SlotReference>> splittedSlots = splitSlots(outputSlots);
        Set<SlotReference> keySlots = splittedSlots.first;
        Set<SlotReference> valueSlots = splittedSlots.second;
        Preconditions.checkState(outputSlots.size() == keySlots.size() + valueSlots.size(),
                "output slots contains no key or value slots");

        Set<Slot> groupingExprsInputSlots = ExpressionUtils.getInputSlotSet(groupingExprs);
        if (groupingExprsInputSlots.retainAll(keySlots)) {
            return PreAggStatus
                    .off(String.format("Grouping expression %s contains non-key column %s",
                            groupingExprs, groupingExprsInputSlots));
        }

        Set<Slot> predicateInputSlots = ExpressionUtils.getInputSlotSet(predicates);
        if (predicateInputSlots.retainAll(keySlots)) {
            return PreAggStatus.off(String.format("Predicate %s contains non-key column %s",
                    predicates, predicateInputSlots));
        }

        return checkAggregateFunctions(aggregateFuncs, groupingExprsInputSlots);
    }

    private Pair<Set<SlotReference>, Set<SlotReference>> splitSlots(Set<Slot> slots) {
        Set<SlotReference> keySlots = Sets.newHashSetWithExpectedSize(slots.size());
        Set<SlotReference> valueSlots = Sets.newHashSetWithExpectedSize(slots.size());
        for (Slot slot : slots) {
            if (slot instanceof SlotReference && ((SlotReference) slot).getColumn().isPresent()) {
                if (((SlotReference) slot).getColumn().get().isKey()) {
                    keySlots.add((SlotReference) slot);
                } else {
                    valueSlots.add((SlotReference) slot);
                }
            }
        }
        return Pair.of(keySlots, valueSlots);
    }

    private static Expression removeCast(Expression expression) {
        while (expression instanceof Cast) {
            expression = ((Cast) expression).child();
        }
        return expression;
    }

    private PreAggStatus checkAggWithKeyAndValueSlots(AggregateFunction aggFunc,
            Set<SlotReference> keySlots, Set<SlotReference> valueSlots) {
        Expression child = aggFunc.child(0);
        List<Expression> conditionExps = new ArrayList<>();
        List<Expression> returnExps = new ArrayList<>();

        // ignore cast
        while (child instanceof Cast) {
            if (!((Cast) child).getDataType().isNumericType()) {
                return PreAggStatus.off(String.format("%s is not numeric CAST.", child.toSql()));
            }
            child = child.child(0);
        }
        // step 1: extract all condition exprs and return exprs
        if (child instanceof If) {
            conditionExps.add(child.child(0));
            returnExps.add(removeCast(child.child(1)));
            returnExps.add(removeCast(child.child(2)));
        } else if (child instanceof CaseWhen) {
            CaseWhen caseWhen = (CaseWhen) child;
            // WHEN THEN
            for (WhenClause whenClause : caseWhen.getWhenClauses()) {
                conditionExps.add(whenClause.getOperand());
                returnExps.add(removeCast(whenClause.getResult()));
            }
            // ELSE
            returnExps.add(removeCast(caseWhen.getDefaultValue().orElse(new NullLiteral())));
        } else {
            // currently, only IF and CASE WHEN are supported
            returnExps.add(removeCast(child));
        }

        // step 2: check condition expressions
        Set<Slot> inputSlots = ExpressionUtils.getInputSlotSet(conditionExps);
        inputSlots.retainAll(valueSlots);
        if (!inputSlots.isEmpty()) {
            return PreAggStatus
                    .off(String.format("some columns in condition %s is not key.", conditionExps));
        }

        return KeyAndValueSlotsAggChecker.INSTANCE.check(aggFunc, returnExps);
    }

    private PreAggStatus checkAggregateFunctions(List<AggregateFunction> aggregateFuncs,
            Set<Slot> groupingExprsInputSlots) {
        PreAggStatus preAggStatus = aggregateFuncs.isEmpty() && groupingExprsInputSlots.isEmpty()
                ? PreAggStatus.off("No aggregate on scan.")
                : PreAggStatus.on();
        for (AggregateFunction aggFunc : aggregateFuncs) {
            if (aggFunc.children().isEmpty()) {
                preAggStatus = PreAggStatus.off(
                        String.format("can't turn preAgg on for aggregate function %s", aggFunc));
            } else if (aggFunc.children().size() == 1 && aggFunc.child(0) instanceof Slot) {
                Slot aggSlot = (Slot) aggFunc.child(0);
                if (aggSlot instanceof SlotReference
                        && ((SlotReference) aggSlot).getColumn().isPresent()) {
                    if (((SlotReference) aggSlot).getColumn().get().isKey()) {
                        preAggStatus = OneKeySlotAggChecker.INSTANCE.check(aggFunc);
                    } else {
                        preAggStatus = OneValueSlotAggChecker.INSTANCE.check(aggFunc,
                                ((SlotReference) aggSlot).getColumn().get().getAggregationType());
                    }
                } else {
                    preAggStatus = PreAggStatus.off(
                            String.format("aggregate function %s use unknown slot %s from scan",
                                    aggFunc, aggSlot));
                }
            } else {
                Set<Slot> aggSlots = aggFunc.getInputSlots();
                Pair<Set<SlotReference>, Set<SlotReference>> splitSlots = splitSlots(aggSlots);
                preAggStatus =
                        checkAggWithKeyAndValueSlots(aggFunc, splitSlots.first, splitSlots.second);
            }
            if (preAggStatus.isOff()) {
                return preAggStatus;
            }
        }
        return preAggStatus;
    }

    private List<Expression> nonVirtualGroupByExprs(LogicalAggregate<? extends Plan> agg) {
        return agg.getGroupByExpressions().stream()
                .filter(expr -> !(expr instanceof VirtualSlotReference))
                .collect(ImmutableList.toImmutableList());
    }

    private static class OneValueSlotAggChecker
            extends ExpressionVisitor<PreAggStatus, AggregateType> {
        public static final OneValueSlotAggChecker INSTANCE = new OneValueSlotAggChecker();

        public PreAggStatus check(AggregateFunction aggFun, AggregateType aggregateType) {
            return aggFun.accept(INSTANCE, aggregateType);
        }

        @Override
        public PreAggStatus visit(Expression expr, AggregateType aggregateType) {
            return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql()));
        }

        @Override
        public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction,
                AggregateType aggregateType) {
            return PreAggStatus
                    .off(String.format("%s is not supported.", aggregateFunction.toSql()));
        }

        @Override
        public PreAggStatus visitMax(Max max, AggregateType aggregateType) {
            if (aggregateType == AggregateType.MAX && !max.isDistinct()) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus
                        .off(String.format("%s is not match agg mode %s or has distinct param",
                                max.toSql(), aggregateType));
            }
        }

        @Override
        public PreAggStatus visitMin(Min min, AggregateType aggregateType) {
            if (aggregateType == AggregateType.MIN && !min.isDistinct()) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus
                        .off(String.format("%s is not match agg mode %s or has distinct param",
                                min.toSql(), aggregateType));
            }
        }

        @Override
        public PreAggStatus visitSum(Sum sum, AggregateType aggregateType) {
            if (aggregateType == AggregateType.SUM && !sum.isDistinct()) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus
                        .off(String.format("%s is not match agg mode %s or has distinct param",
                                sum.toSql(), aggregateType));
            }
        }

        @Override
        public PreAggStatus visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount,
                AggregateType aggregateType) {
            if (aggregateType == AggregateType.BITMAP_UNION) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus.off("invalid bitmap_union_count: " + bitmapUnionCount.toSql());
            }
        }

        @Override
        public PreAggStatus visitBitmapUnion(BitmapUnion bitmapUnion, AggregateType aggregateType) {
            if (aggregateType == AggregateType.BITMAP_UNION) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus.off("invalid bitmapUnion: " + bitmapUnion.toSql());
            }
        }

        @Override
        public PreAggStatus visitHllUnionAgg(HllUnionAgg hllUnionAgg, AggregateType aggregateType) {
            if (aggregateType == AggregateType.HLL_UNION) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus.off("invalid hllUnionAgg: " + hllUnionAgg.toSql());
            }
        }

        @Override
        public PreAggStatus visitHllUnion(HllUnion hllUnion, AggregateType aggregateType) {
            if (aggregateType == AggregateType.HLL_UNION) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus.off("invalid hllUnion: " + hllUnion.toSql());
            }
        }
    }

    private static class OneKeySlotAggChecker extends ExpressionVisitor<PreAggStatus, Void> {
        public static final OneKeySlotAggChecker INSTANCE = new OneKeySlotAggChecker();

        public PreAggStatus check(AggregateFunction aggFun) {
            return aggFun.accept(INSTANCE, null);
        }

        @Override
        public PreAggStatus visit(Expression expr, Void context) {
            return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql()));
        }

        @Override
        public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction,
                Void context) {
            return PreAggStatus.off(String.format("Aggregate function %s contains key column %s",
                    aggregateFunction.toSql(), aggregateFunction.child(0).toSql()));
        }

        @Override
        public PreAggStatus visitMax(Max max, Void context) {
            return PreAggStatus.on();
        }

        @Override
        public PreAggStatus visitMin(Min min, Void context) {
            return PreAggStatus.on();
        }

        @Override
        public PreAggStatus visitCount(Count count, Void context) {
            if (count.isDistinct()) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus.off(String.format("%s is not distinct.", count.toSql()));
            }
        }
    }

    private static class KeyAndValueSlotsAggChecker
            extends ExpressionVisitor<PreAggStatus, List<Expression>> {
        public static final KeyAndValueSlotsAggChecker INSTANCE = new KeyAndValueSlotsAggChecker();

        public PreAggStatus check(AggregateFunction aggFun, List<Expression> returnValues) {
            return aggFun.accept(INSTANCE, returnValues);
        }

        @Override
        public PreAggStatus visit(Expression expr, List<Expression> returnValues) {
            return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql()));
        }

        @Override
        public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction,
                List<Expression> returnValues) {
            return PreAggStatus
                    .off(String.format("%s is not supported.", aggregateFunction.toSql()));
        }

        @Override
        public PreAggStatus visitSum(Sum sum, List<Expression> returnValues) {
            for (Expression value : returnValues) {
                if (!(isAggTypeMatched(value, AggregateType.SUM) || value.isZeroLiteral()
                        || value.isNullLiteral())) {
                    return PreAggStatus.off(String.format("%s is not supported.", sum.toSql()));
                }
            }
            return PreAggStatus.on();
        }

        @Override
        public PreAggStatus visitMax(Max max, List<Expression> returnValues) {
            for (Expression value : returnValues) {
                if (!(isAggTypeMatched(value, AggregateType.MAX) || isKeySlot(value)
                        || value.isNullLiteral())) {
                    return PreAggStatus.off(String.format("%s is not supported.", max.toSql()));
                }
            }
            return PreAggStatus.on();
        }

        @Override
        public PreAggStatus visitMin(Min min, List<Expression> returnValues) {
            for (Expression value : returnValues) {
                if (!(isAggTypeMatched(value, AggregateType.MIN) || isKeySlot(value)
                        || value.isNullLiteral())) {
                    return PreAggStatus.off(String.format("%s is not supported.", min.toSql()));
                }
            }
            return PreAggStatus.on();
        }

        @Override
        public PreAggStatus visitCount(Count count, List<Expression> returnValues) {
            if (count.isDistinct()) {
                for (Expression value : returnValues) {
                    if (!(isKeySlot(value) || value.isZeroLiteral() || value.isNullLiteral())) {
                        return PreAggStatus
                                .off(String.format("%s is not supported.", count.toSql()));
                    }
                }
                return PreAggStatus.on();
            } else {
                return PreAggStatus.off(String.format("%s is not supported.", count.toSql()));
            }
        }

        private boolean isKeySlot(Expression expression) {
            return expression instanceof SlotReference
                    && ((SlotReference) expression).getColumn().isPresent()
                    && ((SlotReference) expression).getColumn().get().isKey();
        }

        private boolean isAggTypeMatched(Expression expression, AggregateType aggregateType) {
            return expression instanceof SlotReference
                    && ((SlotReference) expression).getColumn().isPresent()
                    && ((SlotReference) expression).getColumn().get()
                            .getAggregationType() == aggregateType;
        }
    }
}