SelectMaterializedIndexWithAggregate.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite.mv;

import org.apache.doris.analysis.CreateMaterializedViewStmt;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.MaterializedIndex;
import org.apache.doris.catalog.MaterializedIndexMeta;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmapWithCheck;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.algebra.Project;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.planner.PlanNode;

import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * Select materialized index, i.e., both for rollup and materialized view when aggregate is present.
 * TODO: optimize queries with aggregate not on top of scan directly, e.g., aggregate -> join -> scan
 *   to use materialized index.
 */
@Developing
public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterializedIndexRule
        implements RewriteRuleFactory {
    ///////////////////////////////////////////////////////////////////////////
    // All the patterns
    ///////////////////////////////////////////////////////////////////////////
    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
                // only agg above scan
                // Aggregate(Scan)
                logicalAggregate(logicalOlapScan().when(this::shouldSelectIndexWithAgg)).thenApplyNoThrow(ctx -> {
                    if (ctx.connectContext.getSessionVariable().isEnableSyncMvCostBasedRewrite()) {
                        return ctx.root;
                    }
                    LogicalAggregate<LogicalOlapScan> agg = ctx.root;
                    LogicalOlapScan scan = agg.child();
                    SelectResult result = select(
                            scan,
                            agg.getInputSlots(),
                            ImmutableSet.of(),
                            extractAggFunctionAndReplaceSlot(agg, Optional.empty()),
                            agg.getGroupByExpressions(),
                            new HashSet<>(agg.getExpressions()));

                    LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
                    SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan);

                    return new LogicalProject<>(
                        generateProjectsAlias(agg.getOutputs(), slotContext),
                            new ReplaceExpressions(slotContext).replace(
                                new LogicalAggregate<>(
                                    agg.getGroupByExpressions(),
                                    replaceAggOutput(
                                        agg, Optional.empty(), Optional.empty(), result.exprRewriteMap),
                                    agg.isNormalized(),
                                    agg.getSourceRepeat(),
                                    mvPlan
                                ), mvPlan));
                }).toRule(RuleType.MATERIALIZED_INDEX_AGG_SCAN),

                // filter could push down scan.
                // Aggregate(Filter(Scan))
                logicalAggregate(logicalFilter(logicalOlapScan().when(this::shouldSelectIndexWithAgg)))
                        .thenApplyNoThrow(ctx -> {
                            if (ctx.connectContext.getSessionVariable().isEnableSyncMvCostBasedRewrite()) {
                                return ctx.root;
                            }
                            LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
                            LogicalFilter<LogicalOlapScan> filter = agg.child();
                            LogicalOlapScan scan = filter.child();
                            ImmutableSet<Slot> requiredSlots = ImmutableSet.<Slot>builder()
                                    .addAll(agg.getInputSlots())
                                    .addAll(filter.getInputSlots())
                                    .build();
                            ImmutableSet<Expression> requiredExpr = ImmutableSet.<Expression>builder()
                                    .addAll(agg.getExpressions())
                                    .addAll(filter.getExpressions())
                                    .build();

                            SelectResult result = select(
                                    scan,
                                    requiredSlots,
                                    filter.getConjuncts(),
                                    extractAggFunctionAndReplaceSlot(agg, Optional.empty()),
                                    agg.getGroupByExpressions(),
                                    requiredExpr
                            );

                            LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
                            SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
                                    .map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
                                    filter.getConjuncts());

                            return new LogicalProject<>(
                                generateProjectsAlias(agg.getOutputs(), slotContext),
                                    new ReplaceExpressions(slotContext).replace(
                                        new LogicalAggregate<>(
                                            agg.getGroupByExpressions(),
                                            replaceAggOutput(agg, Optional.empty(), Optional.empty(),
                                                    result.exprRewriteMap),
                                            agg.isNormalized(),
                                            agg.getSourceRepeat(),
                                            // Note that no need to replace slots in the filter,
                                            // because the slots to
                                            // replace are value columns, which shouldn't appear in filters.
                                            filter.withChildren(mvPlan)
                                        ), mvPlan));
                        }).toRule(RuleType.MATERIALIZED_INDEX_AGG_FILTER_SCAN),

                // column pruning or other projections such as alias, etc.
                // Aggregate(Project(Scan))
                logicalAggregate(logicalProject(logicalOlapScan().when(this::shouldSelectIndexWithAgg)))
                        .thenApplyNoThrow(ctx -> {
                            if (ctx.connectContext.getSessionVariable().isEnableSyncMvCostBasedRewrite()) {
                                return ctx.root;
                            }
                            LogicalAggregate<LogicalProject<LogicalOlapScan>> agg = ctx.root;
                            LogicalProject<LogicalOlapScan> project = agg.child();
                            LogicalOlapScan scan = project.child();
                            SelectResult result = select(
                                    scan,
                                    project.getInputSlots(),
                                    ImmutableSet.of(),
                                    extractAggFunctionAndReplaceSlot(agg,
                                            Optional.of(project)),
                                    ExpressionUtils.replace(agg.getGroupByExpressions(),
                                            project.getAliasToProducer()),
                                    collectRequireExprWithAggAndProject(agg.getExpressions(), Optional.of(project))
                            );

                            LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
                            SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan);

                            List<NamedExpression> newProjectList = replaceOutput(project.getProjects(),
                                    result.exprRewriteMap.projectExprMap);
                            LogicalProject<LogicalOlapScan> newProject = new LogicalProject<>(
                                    generateNewOutputsWithMvOutputs(mvPlan, newProjectList),
                                    scan.withMaterializedIndexSelected(result.indexId));
                            return new LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext),
                                    new ReplaceExpressions(slotContext)
                                            .replace(
                                                    new LogicalAggregate<>(agg.getGroupByExpressions(),
                                                            replaceAggOutput(agg, Optional.of(project),
                                                                    Optional.of(newProject), result.exprRewriteMap),
                                                            agg.isNormalized(), agg.getSourceRepeat(), newProject),
                                                    mvPlan));
                        }).toRule(RuleType.MATERIALIZED_INDEX_AGG_PROJECT_SCAN),

                // filter could push down and project.
                // Aggregate(Project(Filter(Scan)))
                logicalAggregate(logicalProject(logicalFilter(logicalOlapScan()
                        .when(this::shouldSelectIndexWithAgg)))).thenApplyNoThrow(ctx -> {
                            if (ctx.connectContext.getSessionVariable().isEnableSyncMvCostBasedRewrite()) {
                                return ctx.root;
                            }
                            LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
                            LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
                            LogicalFilter<LogicalOlapScan> filter = project.child();
                            LogicalOlapScan scan = filter.child();
                            Set<Slot> requiredSlots = Stream.concat(
                                    project.getInputSlots().stream(), filter.getInputSlots().stream())
                                    .collect(Collectors.toSet());
                            ImmutableSet<Expression> requiredExpr = ImmutableSet.<Expression>builder()
                                    .addAll(collectRequireExprWithAggAndProject(
                                            agg.getExpressions(), Optional.of(project)))
                                    .addAll(filter.getExpressions())
                                    .build();
                            SelectResult result = select(
                                    scan,
                                    requiredSlots,
                                    filter.getConjuncts(),
                                    extractAggFunctionAndReplaceSlot(agg, Optional.of(project)),
                                    ExpressionUtils.replace(agg.getGroupByExpressions(),
                                            project.getAliasToProducer()),
                                    requiredExpr
                            );

                            LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
                            SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
                                    .map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
                                    filter.getConjuncts());
                            if (result.indexId == scan.getTable().getBaseIndexId()) {
                                LogicalOlapScan mvPlanWithoutAgg = SelectMaterializedIndexWithoutAggregate.select(scan,
                                        project::getInputSlots, filter::getConjuncts,
                                        Suppliers.memoize(() -> Utils.concatToSet(
                                                filter.getExpressions(), project.getExpressions()
                                        ))
                                );
                                SlotContext slotContextWithoutAgg = generateBaseScanExprToMvExpr(mvPlanWithoutAgg);

                                return agg.withChildren(new LogicalProject(
                                        generateProjectsAlias(project.getOutput(), slotContextWithoutAgg),
                                        new ReplaceExpressions(slotContextWithoutAgg).replace(
                                                project.withChildren(filter.withChildren(mvPlanWithoutAgg)),
                                                mvPlanWithoutAgg)));
                            }

                            List<NamedExpression> newProjectList = replaceOutput(project.getProjects(),
                                    result.exprRewriteMap.projectExprMap);
                            LogicalProject<Plan> newProject = new LogicalProject<>(
                                    generateNewOutputsWithMvOutputs(mvPlan, newProjectList),
                                    filter.withChildren(mvPlan));

                            return new LogicalProject<>(
                                generateProjectsAlias(agg.getOutputs(), slotContext),
                                    new ReplaceExpressions(slotContext).replace(
                                        new LogicalAggregate<>(
                                            agg.getGroupByExpressions(),
                                            replaceAggOutput(agg, Optional.of(project), Optional.of(newProject),
                                                    result.exprRewriteMap),
                                            agg.isNormalized(),
                                            agg.getSourceRepeat(),
                                            newProject
                                        ), mvPlan));
                        }).toRule(RuleType.MATERIALIZED_INDEX_AGG_PROJECT_FILTER_SCAN),

                // filter can't push down
                // Aggregate(Filter(Project(Scan)))
                logicalAggregate(logicalFilter(logicalProject(logicalOlapScan()
                        .when(this::shouldSelectIndexWithAgg)))).thenApplyNoThrow(ctx -> {
                            if (ctx.connectContext.getSessionVariable().isEnableSyncMvCostBasedRewrite()) {
                                return ctx.root;
                            }
                            LogicalAggregate<LogicalFilter<LogicalProject<LogicalOlapScan>>> agg = ctx.root;
                            LogicalFilter<LogicalProject<LogicalOlapScan>> filter = agg.child();
                            LogicalProject<LogicalOlapScan> project = filter.child();
                            LogicalOlapScan scan = project.child();
                            ImmutableSet<Expression> requiredExpr = ImmutableSet.<Expression>builder()
                                    .addAll(collectRequireExprWithAggAndProject(
                                            agg.getExpressions(), Optional.of(project)))
                                    .addAll(collectRequireExprWithAggAndProject(
                                            filter.getExpressions(), Optional.of(project)))
                                    .build();
                            SelectResult result = select(
                                    scan,
                                    project.getInputSlots(),
                                    filter.getConjuncts(),
                                    extractAggFunctionAndReplaceSlot(agg, Optional.of(project)),
                                    ExpressionUtils.replace(agg.getGroupByExpressions(),
                                            project.getAliasToProducer()),
                                    requiredExpr
                            );

                            LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
                            SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
                                    .map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
                                    filter.getConjuncts());

                            List<NamedExpression> newProjectList = replaceOutput(project.getProjects(),
                                    result.exprRewriteMap.projectExprMap);
                            LogicalProject<Plan> newProject = new LogicalProject<>(
                                    generateNewOutputsWithMvOutputs(mvPlan, newProjectList), mvPlan);

                            return new LogicalProject<>(
                                generateProjectsAlias(agg.getOutputs(), slotContext),
                                    new ReplaceExpressions(slotContext).replace(
                                        new LogicalAggregate<>(
                                            agg.getGroupByExpressions(),
                                            replaceAggOutput(agg, Optional.of(project), Optional.of(newProject),
                                                    result.exprRewriteMap),
                                            agg.isNormalized(),
                                            agg.getSourceRepeat(),
                                            filter.withChildren(newProject)
                                        ), mvPlan));
                        }).toRule(RuleType.MATERIALIZED_INDEX_AGG_FILTER_PROJECT_SCAN),

                // only agg above scan
                // Aggregate(Repeat(Scan))
                logicalAggregate(logicalRepeat(logicalOlapScan().when(this::shouldSelectIndexWithAgg)))
                        .thenApplyNoThrow(ctx -> {
                            if (ctx.connectContext.getSessionVariable().isEnableSyncMvCostBasedRewrite()) {
                                return ctx.root;
                            }
                            LogicalAggregate<LogicalRepeat<LogicalOlapScan>> agg = ctx.root;
                            LogicalRepeat<LogicalOlapScan> repeat = agg.child();
                            LogicalOlapScan scan = repeat.child();
                            SelectResult result = select(scan, agg.getInputSlots(), ImmutableSet.of(),
                                    extractAggFunctionAndReplaceSlot(agg, Optional.empty()),
                                    nonVirtualGroupByExprs(agg), new HashSet<>(agg.getExpressions()));

                            LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
                            SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan);

                            return new LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext),
                                    new ReplaceExpressions(slotContext)
                                            .replace(
                                                    new LogicalAggregate<>(agg.getGroupByExpressions(),
                                                            replaceAggOutput(agg, Optional.empty(), Optional.empty(),
                                                                    result.exprRewriteMap),
                                                            agg.isNormalized(), agg.getSourceRepeat(),
                                                            repeat.withAggOutputAndChild(
                                                                    replaceOutput(repeat.getOutputs(),
                                                                            result.exprRewriteMap.projectExprMap),
                                                                    mvPlan)),
                                                    mvPlan));
                        }).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_SCAN),

                // filter could push down scan.
                // Aggregate(Repeat(Filter(Scan)))
                logicalAggregate(logicalRepeat(logicalFilter(logicalOlapScan().when(this::shouldSelectIndexWithAgg))))
                        .thenApplyNoThrow(ctx -> {
                            if (ctx.connectContext.getSessionVariable().isEnableSyncMvCostBasedRewrite()) {
                                return ctx.root;
                            }
                            LogicalAggregate<LogicalRepeat<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
                            LogicalRepeat<LogicalFilter<LogicalOlapScan>> repeat = agg.child();
                            LogicalFilter<LogicalOlapScan> filter = repeat.child();
                            LogicalOlapScan scan = filter.child();
                            ImmutableSet<Slot> requiredSlots = ImmutableSet.<Slot>builder()
                                    .addAll(agg.getInputSlots())
                                    .addAll(filter.getInputSlots())
                                    .build();
                            ImmutableSet<Expression> requiredExpr = ImmutableSet.<Expression>builder()
                                    .addAll(agg.getExpressions())
                                    .addAll(filter.getExpressions())
                                    .build();

                            SelectResult result = select(
                                    scan,
                                    requiredSlots,
                                    filter.getConjuncts(),
                                    extractAggFunctionAndReplaceSlot(agg, Optional.empty()),
                                    nonVirtualGroupByExprs(agg),
                                    requiredExpr
                            );

                            LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
                            SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
                                    .map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
                                    filter.getConjuncts());

                            return new LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext),
                                    new ReplaceExpressions(slotContext).replace(new LogicalAggregate<>(
                                            agg.getGroupByExpressions(),
                                            replaceAggOutput(agg, Optional.empty(), Optional.empty(),
                                                    result.exprRewriteMap),
                                            agg.isNormalized(), agg.getSourceRepeat(),
                                            // Not that no need to replace slots in the filter,
                                            // because the slots to replace
                                            // are value columns, which shouldn't appear in filters.
                                            repeat.withAggOutputAndChild(
                                                    replaceOutput(repeat.getOutputs(),
                                                            result.exprRewriteMap.projectExprMap),
                                                    filter.withChildren(mvPlan))),
                                            mvPlan));
                        }).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_FILTER_SCAN),

                // column pruning or other projections such as alias, etc.
                // Aggregate(Repeat(Project(Scan)))
                logicalAggregate(logicalRepeat(logicalProject(logicalOlapScan().when(this::shouldSelectIndexWithAgg))))
                        .thenApplyNoThrow(ctx -> {
                            if (ctx.connectContext.getSessionVariable().isEnableSyncMvCostBasedRewrite()) {
                                return ctx.root;
                            }
                            LogicalAggregate<LogicalRepeat<LogicalProject<LogicalOlapScan>>> agg = ctx.root;
                            LogicalRepeat<LogicalProject<LogicalOlapScan>> repeat = agg.child();
                            LogicalProject<LogicalOlapScan> project = repeat.child();
                            LogicalOlapScan scan = project.child();
                            SelectResult result = select(
                                    scan,
                                    project.getInputSlots(),
                                    ImmutableSet.of(),
                                    extractAggFunctionAndReplaceSlot(agg,
                                            Optional.of(project)),
                                    ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
                                            project.getAliasToProducer()),
                                    collectRequireExprWithAggAndProject(agg.getExpressions(), Optional.of(project))
                            );

                            LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
                            SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan);

                            List<NamedExpression> newProjectList = replaceOutput(project.getProjects(),
                                    result.exprRewriteMap.projectExprMap);
                            LogicalProject<LogicalOlapScan> newProject = new LogicalProject<>(
                                    generateNewOutputsWithMvOutputs(mvPlan, newProjectList), mvPlan);

                            return new LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext),
                                    new ReplaceExpressions(slotContext).replace(
                                            new LogicalAggregate<>(agg.getGroupByExpressions(),
                                                    replaceAggOutput(agg, Optional.of(project), Optional.of(newProject),
                                                            result.exprRewriteMap),
                                                    agg.isNormalized(), agg.getSourceRepeat(),
                                                    repeat.withAggOutputAndChild(replaceOutput(repeat.getOutputs(),
                                                            result.exprRewriteMap.projectExprMap), newProject)),
                                            mvPlan));
                        }).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_SCAN),

                // filter could push down and project.
                // Aggregate(Repeat(Project(Filter(Scan))))
                logicalAggregate(logicalRepeat(logicalProject(logicalFilter(logicalOlapScan()
                        .when(this::shouldSelectIndexWithAgg))))).thenApplyNoThrow(ctx -> {
                            if (ctx.connectContext.getSessionVariable().isEnableSyncMvCostBasedRewrite()) {
                                return ctx.root;
                            }
                            LogicalAggregate<LogicalRepeat<LogicalProject
                                    <LogicalFilter<LogicalOlapScan>>>> agg = ctx.root;
                            LogicalRepeat<LogicalProject<LogicalFilter<LogicalOlapScan>>> repeat = agg.child();
                            LogicalProject<LogicalFilter<LogicalOlapScan>> project = repeat.child();
                            LogicalFilter<LogicalOlapScan> filter = project.child();
                            LogicalOlapScan scan = filter.child();
                            Set<Slot> requiredSlots = Stream.concat(
                                    project.getInputSlots().stream(), filter.getInputSlots().stream())
                                    .collect(Collectors.toSet());
                            ImmutableSet<Expression> requiredExpr = ImmutableSet.<Expression>builder()
                                    .addAll(collectRequireExprWithAggAndProject(
                                            agg.getExpressions(), Optional.of(project)))
                                    .addAll(filter.getExpressions())
                                    .build();
                            SelectResult result = select(
                                    scan,
                                    requiredSlots,
                                    filter.getConjuncts(),
                                    extractAggFunctionAndReplaceSlot(agg, Optional.of(project)),
                                    ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
                                            project.getAliasToProducer()),
                                    requiredExpr
                            );

                            LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
                            SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
                                    .map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
                                    filter.getConjuncts());

                            List<NamedExpression> newProjectList = replaceOutput(project.getProjects(),
                                    result.exprRewriteMap.projectExprMap);
                            LogicalProject<Plan> newProject = new LogicalProject<>(
                                    generateNewOutputsWithMvOutputs(mvPlan, newProjectList),
                                    filter.withChildren(mvPlan));

                            return new LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext),
                                    new ReplaceExpressions(slotContext).replace(
                                            new LogicalAggregate<>(agg.getGroupByExpressions(),
                                                    replaceAggOutput(agg, Optional.of(project), Optional.of(newProject),
                                                            result.exprRewriteMap),
                                                    agg.isNormalized(), agg.getSourceRepeat(),
                                                    repeat.withAggOutputAndChild(replaceOutput(repeat.getOutputs(),
                                                            result.exprRewriteMap.projectExprMap), newProject)),
                                            mvPlan));
                        }).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_FILTER_SCAN),

                // filter can't push down
                // Aggregate(Repeat(Filter(Project(Scan))))
                logicalAggregate(logicalRepeat(logicalFilter(logicalProject(logicalOlapScan()
                        .when(this::shouldSelectIndexWithAgg))))).thenApplyNoThrow(ctx -> {
                            if (ctx.connectContext.getSessionVariable().isEnableSyncMvCostBasedRewrite()) {
                                return ctx.root;
                            }
                            LogicalAggregate<LogicalRepeat<LogicalFilter
                                    <LogicalProject<LogicalOlapScan>>>> agg = ctx.root;
                            LogicalRepeat<LogicalFilter<LogicalProject<LogicalOlapScan>>> repeat = agg.child();
                            LogicalFilter<LogicalProject<LogicalOlapScan>> filter = repeat.child();
                            LogicalProject<LogicalOlapScan> project = filter.child();
                            LogicalOlapScan scan = project.child();
                            ImmutableSet<Expression> requiredExpr = ImmutableSet.<Expression>builder()
                                    .addAll(collectRequireExprWithAggAndProject(
                                            agg.getExpressions(), Optional.of(project)))
                                    .addAll(collectRequireExprWithAggAndProject(
                                            filter.getExpressions(), Optional.of(project)))
                                    .build();
                            SelectResult result = select(
                                    scan,
                                    project.getInputSlots(),
                                    filter.getConjuncts(),
                                    extractAggFunctionAndReplaceSlot(agg, Optional.of(project)),
                                    ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
                                            project.getAliasToProducer()),
                                    requiredExpr
                            );

                            LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
                            SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
                                    .map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
                                    filter.getConjuncts());

                            List<NamedExpression> newProjectList = replaceOutput(project.getProjects(),
                                    result.exprRewriteMap.projectExprMap);
                            LogicalProject<Plan> newProject = new LogicalProject<>(
                                    generateNewOutputsWithMvOutputs(mvPlan, newProjectList),
                                    scan.withMaterializedIndexSelected(result.indexId));

                            return new LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext),
                                    new ReplaceExpressions(slotContext).replace(new LogicalAggregate<>(
                                            agg.getGroupByExpressions(), replaceAggOutput(agg, Optional.of(project),
                                                    Optional.of(newProject), result.exprRewriteMap),
                                            agg.isNormalized(), agg.getSourceRepeat(),
                                            repeat.withAggOutputAndChild(
                                                    replaceOutput(repeat.getOutputs(),
                                                            result.exprRewriteMap.projectExprMap),
                                                    filter.withChildren(newProject))),
                                            mvPlan));
                        }).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_FILTER_PROJECT_SCAN)
        );
    }

    private static LogicalOlapScan createLogicalOlapScan(LogicalOlapScan scan, SelectResult result) {
        return scan.withMaterializedIndexSelected(result.indexId);
    }

    ///////////////////////////////////////////////////////////////////////////
    // Main entrance of select materialized index.
    ///////////////////////////////////////////////////////////////////////////

    /**
     * Select materialized index ids.
     * <p>
     * 1. find candidate indexes by pre-agg status:
     * checking input aggregate functions and group by expressions and pushdown predicates.
     * 2. filter indexes that have all the required columns.
     * 3. select best index from all the candidate indexes that could use.
     */
    private SelectResult select(LogicalOlapScan scan, Set<Slot> requiredScanOutput, Set<Expression> predicates,
            List<AggregateFunction> aggregateFunctions, List<Expression> groupingExprs,
            Set<? extends Expression> requiredExpr) {
        // remove virtual slot for grouping sets.
        Set<Slot> nonVirtualRequiredScanOutput = requiredScanOutput.stream()
                .filter(slot -> !(slot instanceof VirtualSlotReference))
                .collect(ImmutableSet.toImmutableSet());

        // use if condition to skip String.format() and speed up
        if (!scan.getOutputSet().containsAll(nonVirtualRequiredScanOutput)) {
            throw new AnalysisException(
                    String.format("Scan's output (%s) should contains all the input required scan output (%s).",
                            scan.getOutput(), nonVirtualRequiredScanOutput));
        }

        OlapTable table = scan.getTable();

        Map<Boolean, List<MaterializedIndex>> indexesGroupByIsBaseOrNot = table.getVisibleIndex()
                .stream()
                .collect(Collectors.groupingBy(index -> index.getId() == table.getBaseIndexId()));

        // try to rewrite bitmap, hll by materialized index columns.
        Set<AggRewriteResult> candidatesWithRewriting = indexesGroupByIsBaseOrNot
                .getOrDefault(false, ImmutableList.of()).stream()
                .map(index -> rewriteAgg(index, scan, nonVirtualRequiredScanOutput, predicates, aggregateFunctions,
                        groupingExprs))
                .filter(aggRewriteResult -> checkPreAggStatus(scan, aggRewriteResult.index.getId(), predicates,
                        // check pre-agg status of aggregate function that couldn't rewrite.
                        aggFuncsDiff(aggregateFunctions, aggRewriteResult), groupingExprs).isOn())
                .collect(Collectors.toSet());

        Set<MaterializedIndex> candidatesWithRewritingIndexes = candidatesWithRewriting.stream()
                .map(result -> result.index)
                .collect(Collectors.toSet());

        Set<MaterializedIndex> candidatesWithoutRewriting = indexesGroupByIsBaseOrNot
                .getOrDefault(false, ImmutableList.of()).stream()
                .filter(index -> !candidatesWithRewritingIndexes.contains(index))
                .filter(index -> preAggEnabledByHint(scan)
                        || checkPreAggStatus(scan, index.getId(), predicates, aggregateFunctions, groupingExprs).isOn())
                .collect(Collectors.toSet());

        List<MaterializedIndex> haveAllRequiredColumns = Streams.concat(
                candidatesWithoutRewriting.stream()
                        .filter(index -> containAllRequiredColumns(index, scan, nonVirtualRequiredScanOutput,
                                requiredExpr, predicates)),
                candidatesWithRewriting.stream()
                        .filter(aggRewriteResult -> containAllRequiredColumns(aggRewriteResult.index, scan,
                                aggRewriteResult.requiredScanOutput,
                                requiredExpr.stream().map(e -> aggRewriteResult.exprRewriteMap.replaceAgg(e))
                                        .collect(Collectors.toSet()),
                                predicates))
                        .map(aggRewriteResult -> aggRewriteResult.index))
                .collect(Collectors.toList());

        long selectIndexId = selectBestIndex(haveAllRequiredColumns, scan, predicates, requiredExpr);
        // Pre-aggregation is set to `on` by default for duplicate-keys table.
        // In other cases where mv is not hit, preagg may turn off from on.
        if ((new CheckContext(scan, selectIndexId)).isBaseIndex()) {
            PreAggStatus preagg = scan.getPreAggStatus();
            if (preagg.isOn()) {
                preagg = checkPreAggStatus(scan, scan.getTable().getBaseIndexId(), predicates, aggregateFunctions,
                        groupingExprs);
            }
            return new SelectResult(preagg, selectIndexId, new ExprRewriteMap());
        }

        Optional<AggRewriteResult> rewriteResultOpt = candidatesWithRewriting.stream()
                .filter(aggRewriteResult -> aggRewriteResult.index.getId() == selectIndexId).findAny();
        return new SelectResult(PreAggStatus.on(), selectIndexId,
                rewriteResultOpt.map(r -> r.exprRewriteMap).orElse(new ExprRewriteMap()));
    }

    private List<AggregateFunction> aggFuncsDiff(List<AggregateFunction> aggregateFunctions,
            AggRewriteResult aggRewriteResult) {
        return ImmutableList.copyOf(Sets.difference(ImmutableSet.copyOf(aggregateFunctions),
                aggRewriteResult.exprRewriteMap.aggFuncMap.keySet()));
    }

    private static class SelectResult {
        public final PreAggStatus preAggStatus;
        public final long indexId;
        public ExprRewriteMap exprRewriteMap;

        public SelectResult(PreAggStatus preAggStatus, long indexId, ExprRewriteMap exprRewriteMap) {
            this.preAggStatus = preAggStatus;
            this.indexId = indexId;
            this.exprRewriteMap = exprRewriteMap;
        }
    }

    /**
     * Do aggregate function extraction and replace aggregate function's input slots by underlying project.
     * <p>
     * 1. extract aggregate functions in aggregate plan.
     * <p>
     * 2. replace aggregate function's input slot by underlying project expression if project is present.
     * <p>
     * For example:
     * <pre>
     * input arguments:
     * agg: Aggregate(sum(v) as sum_value)
     * underlying project: Project(a + b as v)
     *
     * output:
     * sum(a + b)
     * </pre>
     */
    private List<AggregateFunction> extractAggFunctionAndReplaceSlot(
            LogicalAggregate<?> agg,
            Optional<LogicalProject<?>> project) {
        Optional<Map<Slot, Expression>> slotToProducerOpt = project.map(Project::getAliasToProducer);
        return agg.getOutputExpressions().stream()
                // extract aggregate functions.
                .flatMap(e -> e.<AggregateFunction>collect(AggregateFunction.class::isInstance).stream())
                // replace aggregate function's input slot by its producing expression.
                .map(expr -> slotToProducerOpt.map(slotToExpressions
                                -> (AggregateFunction) ExpressionUtils.replace(expr, slotToExpressions))
                        .orElse(expr)
                )
                .collect(Collectors.toList());
    }

    private static AggregateFunction replaceAggFuncInput(AggregateFunction aggFunc,
            Optional<Map<Slot, Expression>> slotToProducerOpt) {
        return slotToProducerOpt.map(
                        slotToExpressions -> (AggregateFunction) ExpressionUtils.replace(aggFunc, slotToExpressions))
                .orElse(aggFunc);
    }

    ///////////////////////////////////////////////////////////////////////////
    // Set pre-aggregation status.
    ///////////////////////////////////////////////////////////////////////////
    private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan, long indexId, Set<Expression> predicates,
            List<AggregateFunction> aggregateFuncs, List<Expression> groupingExprs) {
        CheckContext checkContext = new CheckContext(olapScan, indexId);
        if (checkContext.isDupKeysOrMergeOnWrite) {
            return PreAggStatus.on();
        }
        return checkAggregateFunctions(aggregateFuncs, checkContext)
                .offOrElse(() -> checkGroupingExprs(groupingExprs, checkContext))
                .offOrElse(() -> checkPredicates(ImmutableList.copyOf(predicates), checkContext));
    }

    /**
     * Check pre agg status according to aggregate functions.
     */
    private PreAggStatus checkAggregateFunctions(
            List<AggregateFunction> aggregateFuncs,
            CheckContext checkContext) {
        return aggregateFuncs.stream()
                .map(f -> AggregateFunctionChecker.INSTANCE.check(f, checkContext))
                .filter(PreAggStatus::isOff)
                .findAny()
                .orElse(PreAggStatus.on());
    }

    // TODO: support all the aggregate function types in storage engine.
    private static class AggregateFunctionChecker extends ExpressionVisitor<PreAggStatus, CheckContext> {

        public static final AggregateFunctionChecker INSTANCE = new AggregateFunctionChecker();

        public PreAggStatus check(AggregateFunction aggFun, CheckContext ctx) {
            return aggFun.accept(INSTANCE, ctx);
        }

        @Override
        public PreAggStatus visit(Expression expr, CheckContext context) {
            return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql()));
        }

        @Override
        public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, CheckContext context) {
            return checkAggFunc(aggregateFunction, AggregateType.NONE, context, false);
        }

        @Override
        public PreAggStatus visitMax(Max max, CheckContext context) {
            return checkAggFunc(max, AggregateType.MAX, context, true);
        }

        @Override
        public PreAggStatus visitMin(Min min, CheckContext context) {
            return checkAggFunc(min, AggregateType.MIN, context, true);
        }

        @Override
        public PreAggStatus visitSum(Sum sum, CheckContext context) {
            return checkAggFunc(sum, AggregateType.SUM, context, false);
        }

        @Override
        public PreAggStatus visitCount(Count count, CheckContext context) {
            if (count.isDistinct() && count.arity() == 1) {
                Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(count.child(0));
                if (slotOpt.isPresent() && (context.isDupKeysOrMergeOnWrite
                        || context.keyNameToColumn.containsKey(normalizeName(slotOpt.get().toSql())))) {
                    return PreAggStatus.on();
                }
                if (count.child(0).arity() != 0) {
                    return checkSubExpressions(count, null, context);
                }
            }
            return PreAggStatus.off(String.format(
                    "Count distinct is only valid for key columns, but meet %s.", count.toSql()));
        }

        @Override
        public PreAggStatus visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount, CheckContext context) {
            Expression expr = bitmapUnionCount.child();
            if (expr instanceof ToBitmap) {
                expr = expr.child(0);
            }
            if (context.valueNameToColumn.containsKey(normalizeName(expr.toSql()))) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus.off("invalid bitmap_union_count: " + bitmapUnionCount.toSql());
            }
        }

        @Override
        public PreAggStatus visitBitmapUnion(BitmapUnion bitmapUnion, CheckContext context) {
            Expression expr = bitmapUnion.child();
            if (expr instanceof ToBitmap) {
                expr = expr.child(0);
            }
            Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(expr);
            if (slotOpt.isPresent() && context.valueNameToColumn.containsKey(normalizeName(slotOpt.get().toSql()))) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus.off("invalid bitmap_union: " + bitmapUnion.toSql());
            }
        }

        @Override
        public PreAggStatus visitHllUnionAgg(HllUnionAgg hllUnionAgg, CheckContext context) {
            Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllUnionAgg.child());
            if (slotOpt.isPresent() && context.valueNameToColumn.containsKey(normalizeName(slotOpt.get().toSql()))) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus.off("invalid hll_union_agg: " + hllUnionAgg.toSql());
            }
        }

        @Override
        public PreAggStatus visitHllUnion(HllUnion hllUnion, CheckContext context) {
            Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllUnion.child());
            if (slotOpt.isPresent() && context.valueNameToColumn.containsKey(normalizeName(slotOpt.get().toSql()))) {
                return PreAggStatus.on();
            } else {
                return PreAggStatus.off("invalid hll_union: " + hllUnion.toSql());
            }
        }

        private PreAggStatus checkAggFunc(
                AggregateFunction aggFunc,
                AggregateType matchingAggType,
                CheckContext ctx,
                boolean canUseKeyColumn) {
            String childNameWithFuncName = ctx.isBaseIndex()
                    ? normalizeName(aggFunc.child(0).toSql())
                    : normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
                        matchingAggType, normalizeName(aggFunc.child(0).toSql())));

            boolean contains = containsAllColumn(aggFunc.child(0), ctx.keyNameToColumn.keySet());
            if ((contains || ctx.keyNameToColumn.containsKey(childNameWithFuncName))
                    && checkWhenUseKey(aggFunc, matchingAggType)) {
                if (canUseKeyColumn || ctx.isDupKeysOrMergeOnWrite || (!ctx.isBaseIndex() && contains)) {
                    return PreAggStatus.on();
                } else {
                    Column column = ctx.keyNameToColumn.get(childNameWithFuncName);
                    return PreAggStatus.off(String.format("Aggregate function %s contains key column %s.",
                            aggFunc.toSql(), column == null ? "empty column" : column.getName()));
                }
            } else if (ctx.valueNameToColumn.containsKey(childNameWithFuncName)) {
                AggregateType aggType = ctx.valueNameToColumn.get(childNameWithFuncName).getAggregationType();
                if (aggType == matchingAggType) {
                    if (aggFunc.isDistinct()) {
                        return PreAggStatus.off(
                                String.format("Aggregate function %s is distinct aggregation", aggFunc.toSql()));
                    }
                    return PreAggStatus.on();
                } else {
                    return PreAggStatus.off(String.format("Aggregate operator don't match, aggregate function: %s"
                            + ", column aggregate type: %s", aggFunc.toSql(), aggType));
                }
            } else if (!aggFunc.child(0).children().isEmpty()) {
                return checkSubExpressions(aggFunc, matchingAggType, ctx);
            } else {
                return PreAggStatus.off(String.format("Slot(%s) in %s is neither key column nor value column.",
                        childNameWithFuncName, aggFunc.toSql()));
            }
        }

        // check sub expressions in AggregateFunction.
        private PreAggStatus checkSubExpressions(AggregateFunction aggFunc, AggregateType matchingAggType,
                                                 CheckContext ctx) {
            Expression child = aggFunc.child(0);
            List<Expression> conditionExps = new ArrayList<>();
            List<Expression> returnExps = new ArrayList<>();

            // ignore cast
            while (child instanceof Cast) {
                if (!((Cast) child).getDataType().isNumericType()) {
                    return PreAggStatus.off(String.format("[%s] is not numeric CAST.", child.toSql()));
                }
                child = child.child(0);
            }
            // step 1: extract all condition exprs and return exprs
            if (child instanceof If) {
                conditionExps.add(child.child(0));
                returnExps.add(child.child(1));
                returnExps.add(child.child(2));
            } else if (child instanceof CaseWhen) {
                CaseWhen caseWhen = (CaseWhen) child;
                // WHEN THEN
                for (WhenClause whenClause : caseWhen.getWhenClauses()) {
                    conditionExps.add(whenClause.getOperand());
                    returnExps.add(whenClause.getResult());
                }
                // ELSE
                returnExps.add(caseWhen.getDefaultValue().orElse(new NullLiteral()));
            } else {
                // currently, only IF and CASE WHEN are supported
                returnExps.add(child);
            }

            // step 2: check condition expressions
            for (Expression conditionExp : conditionExps) {
                if (!containsAllColumn(conditionExp, ctx.keyNameToColumn.keySet())) {
                    return PreAggStatus.off(String.format("some columns in condition [%s] is not key.",
                            conditionExp.toSql()));
                }
            }

            // step 3: check return expressions
            // NOTE: now we just support SUM, MIN, MAX and COUNT DISTINCT
            int returnExprValidateNum = 0;
            for (Expression returnExp : returnExps) {
                // ignore cast in return expr
                while (returnExp instanceof Cast) {
                    returnExp = returnExp.child(0);
                }
                // now we only check simple return expressions
                String exprName = returnExp.getExpressionName();
                if (!returnExp.children().isEmpty()) {
                    return PreAggStatus.off(String.format("do not support compound expression [%s] in %s.",
                            returnExp.toSql(), matchingAggType));
                }
                if (ctx.keyNameToColumn.containsKey(exprName)) {
                    if (!checkWhenUseKey(aggFunc, matchingAggType)) {
                        return PreAggStatus.off("agg on key column should be MAX, MIN or COUNT DISTINCT.");
                    }
                }

                if (matchingAggType == AggregateType.SUM) {
                    if ((ctx.valueNameToColumn.containsKey(exprName)
                            && ctx.valueNameToColumn.get(exprName).getAggregationType() == matchingAggType)
                            || returnExp.isZeroLiteral() || returnExp.isNullLiteral()) {
                        returnExprValidateNum++;
                    } else {
                        return PreAggStatus.off(String.format("SUM cant preagg for [%s].", aggFunc.toSql()));
                    }
                } else if (matchingAggType == AggregateType.MAX || matchingAggType == AggregateType.MIN) {
                    if (ctx.keyNameToColumn.containsKey(exprName) || returnExp.isNullLiteral()
                            || (ctx.valueNameToColumn.containsKey(exprName)
                            && ctx.valueNameToColumn.get(exprName).getAggregationType() == matchingAggType)) {
                        returnExprValidateNum++;
                    } else {
                        return PreAggStatus.off(String.format("MAX/MIN cant preagg for [%s].", aggFunc.toSql()));
                    }
                } else if (aggFunc.getName().equalsIgnoreCase("COUNT") && aggFunc.isDistinct()) {
                    if (ctx.keyNameToColumn.containsKey(exprName)
                            || returnExp.isZeroLiteral() || returnExp.isNullLiteral()) {
                        returnExprValidateNum++;
                    } else {
                        return PreAggStatus.off(String.format("COUNT DISTINCT cant preagg for [%s].", aggFunc.toSql()));
                    }
                }
            }
            if (returnExprValidateNum == returnExps.size()) {
                return PreAggStatus.on();
            }
            return PreAggStatus.off(String.format("cant preagg for [%s].", aggFunc.toSql()));
        }
    }

    // agg on key column should be MAX, MIN, COUNT DISTINCT, SUM DISTINCT, AVG DISTINCT. return true if valid
    private static boolean checkWhenUseKey(AggregateFunction aggFunc, AggregateType matchingAggType) {
        return matchingAggType == AggregateType.MAX
                || matchingAggType == AggregateType.MIN
                || (aggFunc instanceof Sum && aggFunc.isDistinct())
                || (aggFunc instanceof Count && aggFunc.isDistinct())
                || (aggFunc instanceof Avg && aggFunc.isDistinct());
    }

    private static class CheckContext {

        public final LogicalOlapScan scan;
        public final long index;
        public final Map<String, Column> keyNameToColumn;
        public final Map<String, Column> valueNameToColumn;
        public final boolean isDupKeysOrMergeOnWrite;

        public CheckContext(LogicalOlapScan scan, long indexId) {
            this.scan = scan;
            boolean isBaseIndex = indexId == scan.getTable().getBaseIndexId();

            Supplier<Map<String, Column>> supplier = () -> Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER);

            // map<is_key, map<column_name, column>>
            Map<Boolean, Map<String, Column>> baseNameToColumnGroupingByIsKey = scan.getTable()
                    .getSchemaByIndexId(indexId).stream()
                    .collect(Collectors.groupingBy(Column::isKey,
                            Collectors.toMap(
                                    c -> isBaseIndex ? c.getName()
                                            : normalizeName(parseMvColumnToSql(c.getName())),
                                    Function.identity(), (v1, v2) -> v1, supplier)));
            Map<Boolean, Map<String, Column>> mvNameToColumnGroupingByIsKey = scan.getTable()
                    .getSchemaByIndexId(indexId).stream()
                    .collect(Collectors.groupingBy(Column::isKey,
                            Collectors.toMap(
                                    c -> isBaseIndex ? c.getName()
                                            : normalizeName(parseMvColumnToMvName(
                                                    c.getNameWithoutMvPrefix(),
                                                    c.isAggregated()
                                                            ? Optional.of(c.getAggregationType().name())
                                                            : Optional.empty())),
                                    Function.identity(), (v1, v2) -> v1, supplier)));

            this.keyNameToColumn = mvNameToColumnGroupingByIsKey.getOrDefault(true,
                    Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER));
            for (String name : baseNameToColumnGroupingByIsKey
                    .getOrDefault(true, Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER)).keySet()) {
                this.keyNameToColumn.putIfAbsent(name, baseNameToColumnGroupingByIsKey.get(true).get(name));
            }
            this.valueNameToColumn = mvNameToColumnGroupingByIsKey.getOrDefault(false,
                    Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER));
            for (String key : baseNameToColumnGroupingByIsKey.getOrDefault(false, ImmutableMap.of()).keySet()) {
                this.valueNameToColumn.putIfAbsent(key, baseNameToColumnGroupingByIsKey.get(false).get(key));
            }
            this.index = indexId;
            this.isDupKeysOrMergeOnWrite = getMeta().getKeysType() == KeysType.DUP_KEYS
                    || scan.getTable().getEnableUniqueKeyMergeOnWrite()
                            && getMeta().getKeysType() == KeysType.UNIQUE_KEYS;
        }

        public boolean isBaseIndex() {
            return index == scan.getTable().getBaseIndexId();
        }

        public MaterializedIndexMeta getMeta() {
            return scan.getTable().getIndexMetaByIndexId(index);
        }

        public Column getColumn(String name) {
            return getMeta().getColumnByDefineName(name);
        }
    }

    /**
     * Grouping expressions should not have value type columns.
     */
    private PreAggStatus checkGroupingExprs(
            List<Expression> groupingExprs,
            CheckContext checkContext) {
        return disablePreAggIfContainsAnyValueColumn(groupingExprs, checkContext,
                "Grouping expression %s contains value column %s");
    }

    /**
     * Predicates should not have value type columns.
     */
    private PreAggStatus checkPredicates(List<Expression> predicates, CheckContext checkContext) {
        Set<String> indexConjuncts = PlanNode
                .splitAndCompoundPredicateToConjuncts(checkContext.getMeta().getWhereClause()).stream()
                .map(e -> new NereidsParser().parseExpression(e.toSql()).toSql()).collect(Collectors.toSet());
        return disablePreAggIfContainsAnyValueColumn(
                predicates.stream().filter(e -> !indexConjuncts.contains(e.toSql())).collect(Collectors.toList()),
                checkContext, "Predicate %s contains value column %s");
    }

    /**
     * Check the input expressions have no referenced slot to underlying value type column.
     */
    private PreAggStatus disablePreAggIfContainsAnyValueColumn(List<Expression> exprs, CheckContext ctx,
            String errorMsg) {
        return exprs.stream()
                .map(expr -> expr.getInputSlots()
                        .stream()
                        .filter(slot -> ctx.valueNameToColumn.containsKey(normalizeName(slot.toSql())))
                        .findAny()
                        .map(slot -> Pair.of(expr, ctx.valueNameToColumn.get(normalizeName(slot.toSql()))))
                )
                .filter(Optional::isPresent)
                .findAny()
                .orElse(Optional.empty())
                .map(exprToColumn -> PreAggStatus.off(String.format(errorMsg,
                        exprToColumn.key().toSql(), exprToColumn.value().getName())))
                .orElse(PreAggStatus.on());
    }

    /**
     * rewrite for bitmap and hll
     */
    private AggRewriteResult rewriteAgg(MaterializedIndex index,
            LogicalOlapScan scan,
            Set<Slot> requiredScanOutput,
            Set<Expression> predicates,
            List<AggregateFunction> aggregateFunctions,
            List<Expression> groupingExprs) {
        ExprRewriteMap exprRewriteMap = new ExprRewriteMap();
        RewriteContext context = new RewriteContext(new CheckContext(scan, index.getId()), exprRewriteMap);
        aggregateFunctions.forEach(aggFun -> AggFuncRewriter.rewrite(aggFun, context));
        return new AggRewriteResult(index, requiredScanOutput, exprRewriteMap);
    }

    private static class ExprRewriteMap {
        /**
         * Replace map for expressions in project. For example: the query have avg(v),
         * stddev_samp(v) projectExprMap will contain v -> [mva_GENERIC__avg_state(`v`),
         * mva_GENERIC__stddev_samp_state(CAST(`v` AS DOUBLE))] then some LogicalPlan
         * will output [mva_GENERIC__avg_state(`v`),
         * mva_GENERIC__stddev_samp_state(CAST(`v` AS DOUBLE))] to replace column v
         */
        public final Map<Expression, List<Expression>> projectExprMap;
        /**
         * Replace map for aggregate functions.
         */
        public final Map<AggregateFunction, AggregateFunction> aggFuncMap;

        private Map<String, AggregateFunction> aggFuncStrMap;

        public ExprRewriteMap() {
            this.projectExprMap = Maps.newHashMap();
            this.aggFuncMap = Maps.newHashMap();
        }

        public boolean isEmpty() {
            return aggFuncMap.isEmpty();
        }

        private void buildStrMap() {
            if (aggFuncStrMap != null) {
                return;
            }
            this.aggFuncStrMap = Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER);
            for (AggregateFunction e : aggFuncMap.keySet()) {
                this.aggFuncStrMap.put(e.toSql(), aggFuncMap.get(e));
            }
        }

        public Expression replaceAgg(Expression e) {
            while (e instanceof Alias) {
                e = e.child(0);
            }
            if (!(e instanceof AggregateFunction)) {
                return e;
            }
            buildStrMap();
            return aggFuncStrMap.getOrDefault(e.toSql(), (AggregateFunction) e);
        }

        public void putIntoProjectExprMap(Expression key, Expression value) {
            if (!projectExprMap.containsKey(key)) {
                projectExprMap.put(key, Lists.newArrayList());
            }
            projectExprMap.get(key).add(value);
        }
    }

    private static class AggRewriteResult {
        public final MaterializedIndex index;
        public final Set<Slot> requiredScanOutput;
        public ExprRewriteMap exprRewriteMap;

        public AggRewriteResult(MaterializedIndex index,
                Set<Slot> requiredScanOutput,
                ExprRewriteMap exprRewriteMap) {
            this.index = index;
            this.requiredScanOutput = requiredScanOutput;
            this.exprRewriteMap = exprRewriteMap;
        }
    }

    private boolean isInputSlotsContainsNone(List<Expression> expressions, Set<Slot> slotsToCheck) {
        Set<Slot> inputSlotSet = ExpressionUtils.getInputSlotSet(expressions);
        return Sets.intersection(inputSlotSet, slotsToCheck).isEmpty();
    }

    private static class RewriteContext {
        public final CheckContext checkContext;
        public final ExprRewriteMap exprRewriteMap;

        public RewriteContext(CheckContext context, ExprRewriteMap exprRewriteMap) {
            this.checkContext = context;
            this.exprRewriteMap = exprRewriteMap;
        }
    }

    private static Expression castIfNeed(Expression expr, DataType targetType) {
        if (expr.getDataType().equals(targetType)) {
            return expr;
        }
        return new Cast(expr, targetType);
    }

    private static class AggFuncRewriter extends DefaultExpressionRewriter<RewriteContext> {
        public static final AggFuncRewriter INSTANCE = new AggFuncRewriter();

        private static void rewrite(Expression expr, RewriteContext context) {
            expr.accept(INSTANCE, context);
        }

        /**
         * count(distinct col) -> bitmap_union_count(mva_BITMAP_UNION__to_bitmap__with_check(col))
         * count(col) -> sum(mva_SUM__CASE WHEN col IS NULL THEN 0 ELSE 1 END)
         */
        @Override
        public Expression visitCount(Count count, RewriteContext context) {
            Expression result = visitAggregateFunction(count, context);
            if (result != count) {
                return result;
            }
            if (count.isDistinct() && count.arity() == 1) {
                // count(distinct col) -> bitmap_union_count(mv_bitmap_union_col)
                Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(count.child(0));

                Expression expr = new ToBitmapWithCheck(castIfNeed(count.child(0), BigIntType.INSTANCE));
                // count distinct a value column.
                if (slotOpt.isPresent()) {
                    String bitmapUnionColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
                            AggregateType.BITMAP_UNION, CreateMaterializedViewStmt.mvColumnBuilder(expr.toSql())));

                    Column mvColumn = context.checkContext.getColumn(bitmapUnionColumn);
                    // has bitmap_union column
                    if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {
                        Slot bitmapUnionSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index)
                                .stream()
                                .filter(s -> bitmapUnionColumn.equalsIgnoreCase(normalizeName(s.getName())))
                                .findFirst()
                                .orElseThrow(() -> new AnalysisException(
                                        "cannot find bitmap union slot when select mv"));

                        context.exprRewriteMap.putIntoProjectExprMap(slotOpt.get(), bitmapUnionSlot);
                        BitmapUnionCount bitmapUnionCount = new BitmapUnionCount(bitmapUnionSlot);
                        context.exprRewriteMap.aggFuncMap.put(count, bitmapUnionCount);
                        return bitmapUnionCount;
                    }
                }
            }

            Expression child = null;
            if (!count.isDistinct() && count.arity() == 1) {
                // count(col) -> sum(mva_SUM__CASE WHEN col IS NULL THEN 0 ELSE 1 END)
                Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(count.child(0));
                if (slotOpt.isPresent()) {
                    child = slotOpt.get();
                }
            } else if (count.arity() == 0) {
                // count(*) / count(1) -> sum(mva_SUM__CASE WHEN 1 IS NULL THEN 0 ELSE 1 END)
                child = new TinyIntLiteral((byte) 1);
            }

            if (child != null) {
                String countColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.SUM,
                        CreateMaterializedViewStmt.mvColumnBuilder(slotToCaseWhen(child).toSql())));

                Column mvColumn = context.checkContext.getColumn(countColumn);
                if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {
                    Slot countSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index).stream()
                            .filter(s -> countColumn.equalsIgnoreCase(normalizeName(s.getName()))).findFirst()
                            .orElseThrow(() -> new AnalysisException("cannot find count slot when select mv"));

                    context.exprRewriteMap.putIntoProjectExprMap(child, countSlot);
                    Sum sum = new Sum(countSlot);
                    context.exprRewriteMap.aggFuncMap.put(count, sum);
                    return sum;
                }
            }
            return count;
        }

        /**
         * bitmap_union(to_bitmap(col)) ->
         * bitmap_union(mva_BITMAP_UNION__to_bitmap_with_check(col))
         */
        @Override
        public Expression visitBitmapUnion(BitmapUnion bitmapUnion, RewriteContext context) {
            Expression result = visitAggregateFunction(bitmapUnion, context);
            if (result != bitmapUnion) {
                return result;
            }
            if (bitmapUnion.child() instanceof ToBitmap) {
                ToBitmap toBitmap = (ToBitmap) bitmapUnion.child();
                Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(toBitmap.child());
                if (slotOpt.isPresent()) {
                    String bitmapUnionColumn = normalizeName(CreateMaterializedViewStmt
                            .mvColumnBuilder(AggregateType.BITMAP_UNION, CreateMaterializedViewStmt
                                    .mvColumnBuilder(new ToBitmapWithCheck(toBitmap.child()).toSql())));

                    Column mvColumn = context.checkContext.getColumn(bitmapUnionColumn);
                    // has bitmap_union column
                    if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {

                        Slot bitmapUnionSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index)
                                .stream().filter(s -> bitmapUnionColumn.equalsIgnoreCase(normalizeName(s.getName())))
                                .findFirst().orElseThrow(
                                        () -> new AnalysisException("cannot find bitmap union slot when select mv"));

                        context.exprRewriteMap.putIntoProjectExprMap(toBitmap, bitmapUnionSlot);
                        BitmapUnion newBitmapUnion = new BitmapUnion(bitmapUnionSlot);
                        context.exprRewriteMap.aggFuncMap.put(bitmapUnion, newBitmapUnion);
                        return newBitmapUnion;
                    }
                }
            } else {
                Expression child = bitmapUnion.child();
                String bitmapUnionColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
                        AggregateType.BITMAP_UNION, CreateMaterializedViewStmt.mvColumnBuilder(child.toSql())));

                Column mvColumn = context.checkContext.getColumn(bitmapUnionColumn);
                // has bitmap_union column
                if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {

                    Slot bitmapUnionSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index)
                            .stream().filter(s -> bitmapUnionColumn.equalsIgnoreCase(normalizeName(s.getName())))
                            .findFirst()
                            .orElseThrow(() -> new AnalysisException("cannot find bitmap union slot when select mv"));
                    context.exprRewriteMap.putIntoProjectExprMap(child, bitmapUnionSlot);
                    BitmapUnion newBitmapUnion = new BitmapUnion(bitmapUnionSlot);
                    context.exprRewriteMap.aggFuncMap.put(bitmapUnion, newBitmapUnion);
                    return newBitmapUnion;
                }
            }

            return bitmapUnion;
        }

        /**
         * bitmap_union_count(to_bitmap(col)) -> bitmap_union_count(mva_BITMAP_UNION__to_bitmap_with_check(col))
         */
        @Override
        public Expression visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount, RewriteContext context) {
            Expression result = visitAggregateFunction(bitmapUnionCount, context);
            if (result != bitmapUnionCount) {
                return result;
            }
            if (bitmapUnionCount.child() instanceof ToBitmap) {
                ToBitmap toBitmap = (ToBitmap) bitmapUnionCount.child();
                Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(toBitmap.child());
                if (slotOpt.isPresent()) {
                    String bitmapUnionCountColumn = normalizeName(CreateMaterializedViewStmt
                            .mvColumnBuilder(AggregateType.BITMAP_UNION, CreateMaterializedViewStmt
                                    .mvColumnBuilder(new ToBitmapWithCheck(toBitmap.child()).toSql())));

                    Column mvColumn = context.checkContext.getColumn(bitmapUnionCountColumn);
                    // has bitmap_union_count column
                    if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {

                        Slot bitmapUnionCountSlot = context.checkContext.scan
                                .getOutputByIndex(context.checkContext.index)
                                .stream()
                                .filter(s -> bitmapUnionCountColumn.equalsIgnoreCase(normalizeName(s.getName())))
                                .findFirst()
                                .orElseThrow(() -> new AnalysisException(
                                        "cannot find bitmap union count slot when select mv"));

                        context.exprRewriteMap.putIntoProjectExprMap(toBitmap, bitmapUnionCountSlot);
                        BitmapUnionCount newBitmapUnionCount = new BitmapUnionCount(bitmapUnionCountSlot);
                        context.exprRewriteMap.aggFuncMap.put(bitmapUnionCount, newBitmapUnionCount);
                        return newBitmapUnionCount;
                    }
                }
            } else {
                Expression child = bitmapUnionCount.child();
                String bitmapUnionCountColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
                        AggregateType.BITMAP_UNION, CreateMaterializedViewStmt.mvColumnBuilder(child.toSql())));

                Column mvColumn = context.checkContext.getColumn(bitmapUnionCountColumn);
                // has bitmap_union_count column
                if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {

                    Slot bitmapUnionCountSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index)
                            .stream().filter(s -> bitmapUnionCountColumn.equalsIgnoreCase(normalizeName(s.getName())))
                            .findFirst().orElseThrow(
                                    () -> new AnalysisException("cannot find bitmap union count slot when select mv"));
                    context.exprRewriteMap.putIntoProjectExprMap(child, bitmapUnionCountSlot);
                    BitmapUnionCount newBitmapUnionCount = new BitmapUnionCount(bitmapUnionCountSlot);
                    context.exprRewriteMap.aggFuncMap.put(bitmapUnionCount, newBitmapUnionCount);
                    return newBitmapUnionCount;
                }
            }

            return bitmapUnionCount;
        }

        /**
         * hll_union(hll_hash(col)) to hll_union(mva_HLL_UNION__hll_hash_(col))
         */
        @Override
        public Expression visitHllUnion(HllUnion hllUnion, RewriteContext context) {
            Expression result = visitAggregateFunction(hllUnion, context);
            if (result != hllUnion) {
                return result;
            }
            if (hllUnion.child() instanceof HllHash) {
                HllHash hllHash = (HllHash) hllUnion.child();
                Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllHash.child());
                if (slotOpt.isPresent()) {
                    String hllUnionColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
                            AggregateType.HLL_UNION, CreateMaterializedViewStmt.mvColumnBuilder(hllHash.toSql())));

                    Column mvColumn = context.checkContext.getColumn(hllUnionColumn);
                    // has hll_union column
                    if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {
                        Slot hllUnionSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index)
                                .stream()
                                .filter(s -> hllUnionColumn.equalsIgnoreCase(normalizeName(s.getName())))
                                .findFirst()
                                .orElseThrow(() -> new AnalysisException(
                                        "cannot find hll union slot when select mv"));

                        context.exprRewriteMap.putIntoProjectExprMap(hllHash, hllUnionSlot);
                        HllUnion newHllUnion = new HllUnion(hllUnionSlot);
                        context.exprRewriteMap.aggFuncMap.put(hllUnion, newHllUnion);
                        return newHllUnion;
                    }
                }
            }

            return hllUnion;
        }

        /**
         * hll_union_agg(hll_hash(col)) -> hll_union_agg(mva_HLL_UNION__hll_hash_(col))
         */
        @Override
        public Expression visitHllUnionAgg(HllUnionAgg hllUnionAgg, RewriteContext context) {
            Expression result = visitAggregateFunction(hllUnionAgg, context);
            if (result != hllUnionAgg) {
                return result;
            }
            if (hllUnionAgg.child() instanceof HllHash) {
                HllHash hllHash = (HllHash) hllUnionAgg.child();
                Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllHash.child());
                if (slotOpt.isPresent()) {
                    String hllUnionColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
                            AggregateType.HLL_UNION, CreateMaterializedViewStmt.mvColumnBuilder(hllHash.toSql())));

                    Column mvColumn = context.checkContext.getColumn(hllUnionColumn);
                    // has hll_union column
                    if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {
                        Slot hllUnionSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index)
                                .stream()
                                .filter(s -> hllUnionColumn.equalsIgnoreCase(normalizeName(s.getName())))
                                .findFirst()
                                .orElseThrow(() -> new AnalysisException(
                                        "cannot find hll union slot when select mv"));

                        context.exprRewriteMap.putIntoProjectExprMap(hllHash, hllUnionSlot);
                        HllUnionAgg newHllUnionAgg = new HllUnionAgg(hllUnionSlot);
                        context.exprRewriteMap.aggFuncMap.put(hllUnionAgg, newHllUnionAgg);
                        return newHllUnionAgg;
                    }
                }
            }

            return hllUnionAgg;
        }

        /**
         * ndv(col) -> hll_union_agg(mva_HLL_UNION__hll_hash_(col))
         */
        @Override
        public Expression visitNdv(Ndv ndv, RewriteContext context) {
            Expression result = visitAggregateFunction(ndv, context);
            if (result != ndv) {
                return result;
            }
            Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(ndv.child(0));
            // ndv on a value column.
            if (slotOpt.isPresent()) {
                Expression expr = castIfNeed(ndv.child(), VarcharType.SYSTEM_DEFAULT);
                String hllUnionColumn = normalizeName(
                        CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.HLL_UNION,
                                CreateMaterializedViewStmt.mvColumnBuilder(new HllHash(expr).toSql())));

                Column mvColumn = context.checkContext.getColumn(hllUnionColumn);
                // has hll_union column
                if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {
                    Slot hllUnionSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index)
                            .stream()
                            .filter(s -> hllUnionColumn.equalsIgnoreCase(normalizeName(s.getName())))
                            .findFirst()
                            .orElseThrow(() -> new AnalysisException(
                                    "cannot find hll union slot when select mv"));

                    context.exprRewriteMap.putIntoProjectExprMap(slotOpt.get(), hllUnionSlot);
                    HllUnionAgg hllUnionAgg = new HllUnionAgg(hllUnionSlot);
                    context.exprRewriteMap.aggFuncMap.put(ndv, hllUnionAgg);
                    return hllUnionAgg;
                }
            }
            return ndv;
        }

        @Override
        public Expression visitSum(Sum sum, RewriteContext context) {
            Expression result = visitAggregateFunction(sum, context);
            if (result != sum) {
                return result;
            }
            if (!sum.isDistinct()) {
                Expression expr = castIfNeed(sum.child(), BigIntType.INSTANCE);
                String sumColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.SUM,
                        CreateMaterializedViewStmt.mvColumnBuilder(expr.toSql())));
                Column mvColumn = context.checkContext.getColumn(sumColumn);
                if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {
                    Slot sumSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index).stream()
                            .filter(s -> sumColumn.equalsIgnoreCase(normalizeName(s.getName()))).findFirst()
                            .orElseThrow(() -> new AnalysisException("cannot find sum slot when select mv"));
                    context.exprRewriteMap.putIntoProjectExprMap(sum.child(), sumSlot);
                    Sum newSum = new Sum(sumSlot);
                    context.exprRewriteMap.aggFuncMap.put(sum, newSum);
                    return newSum;
                }
            }
            return sum;
        }

        /**
         * agg(col) -> agg_merge(mva_generic_aggregation__agg_state(col)) eg: max_by(k2,
         * k3) -> max_by_merge(mva_generic_aggregation__max_by_state(k2, k3))
         */
        @Override
        public Expression visitAggregateFunction(AggregateFunction aggregateFunction, RewriteContext context) {
            String aggStateName = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
                    AggregateType.GENERIC, StateCombinator.create(aggregateFunction).toSql()));

            Column mvColumn = context.checkContext.getColumn(aggStateName);
            if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {
                Slot aggStateSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index).stream()
                        .filter(s -> aggStateName.equalsIgnoreCase(normalizeName(s.getName()))).findFirst()
                        .orElseThrow(() -> new AnalysisException("cannot find agg state slot when select mv"));

                for (Expression child : aggregateFunction.children()) {
                    context.exprRewriteMap.putIntoProjectExprMap(child, aggStateSlot);
                }

                MergeCombinator mergeCombinator = new MergeCombinator(Arrays.asList(aggStateSlot), aggregateFunction);
                context.exprRewriteMap.aggFuncMap.put(aggregateFunction, mergeCombinator);
                return mergeCombinator;
            }
            return aggregateFunction;
        }
    }

    private List<NamedExpression> replaceAggOutput(
            LogicalAggregate<? extends Plan> agg,
            Optional<Project> oldProjectOpt,
            Optional<Project> newProjectOpt,
            ExprRewriteMap exprRewriteMap) {
        ResultAggFuncRewriteCtx ctx = new ResultAggFuncRewriteCtx(oldProjectOpt, newProjectOpt, exprRewriteMap);
        return agg.getOutputExpressions()
                .stream()
                .map(expr -> (NamedExpression) ResultAggFuncRewriter.rewrite(expr, ctx))
                .collect(ImmutableList.toImmutableList());
    }

    private static class ResultAggFuncRewriteCtx {
        public final Optional<Map<Slot, Expression>> oldProjectSlotToProducerOpt;
        public final Optional<Map<Expression, Slot>> newProjectExprMapOpt;
        public final ExprRewriteMap exprRewriteMap;

        public ResultAggFuncRewriteCtx(
                Optional<Project> oldProject,
                Optional<Project> newProject,
                ExprRewriteMap exprRewriteMap) {
            this.oldProjectSlotToProducerOpt = oldProject.map(Project::getAliasToProducer);
            this.newProjectExprMapOpt = newProject.map(project -> project.getProjects()
                    .stream()
                    .filter(Alias.class::isInstance)
                    .collect(
                            Collectors.toMap(
                                    // Avoid cast to alias, retrieving the first child expression.
                                    alias -> alias.child(0),
                                    NamedExpression::toSlot
                            )
                    ));
            this.exprRewriteMap = exprRewriteMap;
        }
    }

    private static class ResultAggFuncRewriter extends DefaultExpressionRewriter<ResultAggFuncRewriteCtx> {
        public static final ResultAggFuncRewriter INSTANCE = new ResultAggFuncRewriter();

        public static Expression rewrite(Expression expr, ResultAggFuncRewriteCtx ctx) {
            return expr.accept(INSTANCE, ctx);
        }

        @Override
        public Expression visitAggregateFunction(AggregateFunction aggregateFunction,
                ResultAggFuncRewriteCtx ctx) {
            // normalize aggregate function to match the agg func replace map.
            AggregateFunction aggFunc = replaceAggFuncInput(aggregateFunction, ctx.oldProjectSlotToProducerOpt);
            Map<AggregateFunction, AggregateFunction> aggFuncMap = ctx.exprRewriteMap.aggFuncMap;
            if (aggFuncMap.containsKey(aggFunc)) {
                AggregateFunction replacedAggFunc = aggFuncMap.get(aggFunc);
                // replace the input slot by new project expr mapping.
                return ctx.newProjectExprMapOpt.map(map -> ExpressionUtils.replace(replacedAggFunc, map))
                        .orElse(replacedAggFunc);
            } else {
                return aggregateFunction;
            }
        }
    }

    private List<NamedExpression> replaceOutput(List<NamedExpression> outputs,
            Map<Expression, List<Expression>> projectMap) {
        Map<String, List<Expression>> strToExprs = Maps.newHashMap();
        for (Expression expr : projectMap.keySet()) {
            strToExprs.put(expr.toSql(), projectMap.get(expr));
        }

        List<NamedExpression> results = Lists.newArrayList();
        for (NamedExpression expr : outputs) {
            results.add(expr);

            if (!strToExprs.containsKey(expr.toSql())) {
                continue;
            }
            for (Expression newExpr : strToExprs.get(expr.toSql())) {
                if (newExpr instanceof NamedExpression) {
                    results.add((NamedExpression) newExpr);
                } else {
                    results.add(new Alias(expr.getExprId(), newExpr, expr.getName()));
                }
            }
        }
        return results;
    }

    private List<Expression> nonVirtualGroupByExprs(LogicalAggregate<? extends Plan> agg) {
        return agg.getGroupByExpressions().stream()
                .filter(expr -> !(expr instanceof VirtualSlotReference))
                .collect(ImmutableList.toImmutableList());
    }

    /**
     *  Put all the slots provided by mv into the project,
     *  and cannot simply replace them in the subsequent ReplaceExpressionWithMvColumn rule,
     *  because one base column in mv may correspond to multiple mv columns: eg. k2 -> max_k2/min_k2/sum_k2
     *  TODO: Do not add redundant columns
     */
    private List<NamedExpression> generateNewOutputsWithMvOutputs(
            LogicalOlapScan mvPlan, List<NamedExpression> outputs) {
        if (mvPlan.getSelectedIndexId() == mvPlan.getTable().getBaseIndexId()) {
            return outputs;
        }
        return ImmutableList.<NamedExpression>builder()
                .addAll(mvPlan.getOutputByIndex(mvPlan.getSelectedIndexId()))
                .addAll(outputs.stream()
                        .filter(s -> !(s instanceof Slot))
                        .collect(ImmutableList.toImmutableList()))
                .addAll(outputs.stream()
                        .filter(SlotNotFromChildren.class::isInstance)
                        .collect(ImmutableList.toImmutableList()))
                .build();
    }

    /**
     * eg: select abs(k1)+1 t,sum(abs(k2+1)) from single_slot group by t order by t;
     *  +--LogicalAggregate[88] ( groupByExpr=[t#4], outputExpr=[t#4, sum(abs((k2#1 + 1))) AS `sum(abs(k2 + 1))`#5])
     *      +--LogicalProject[87] ( distinct=false, projects=[(abs(k1#0) + 1) AS `t`#4, k2#1])
     *          +--LogicalOlapScan()
     * t -> abs(k1#0) + 1
     */
    private Set<Expression> collectRequireExprWithAggAndProject(List<? extends Expression> aggExpressions,
            Optional<LogicalProject<?>> project) {
        List<NamedExpression> projectExpressions = project.isPresent() ? project.get().getProjects() : null;
        if (projectExpressions == null) {
            return aggExpressions.stream().collect(ImmutableSet.toImmutableSet());
        }
        Optional<Map<Slot, Expression>> slotToProducerOpt = project.map(Project::getAliasToProducer);
        Map<ExprId, Expression> exprIdToExpression = projectExpressions.stream()
                .collect(Collectors.toMap(NamedExpression::getExprId, e -> {
                    if (e instanceof Alias) {
                        return ((Alias) e).child();
                    }
                    return e;
                }));
        return aggExpressions.stream().map(e -> {
            if ((e instanceof NamedExpression) && exprIdToExpression.containsKey(((NamedExpression) e).getExprId())) {
                return exprIdToExpression.get(((NamedExpression) e).getExprId());
            }
            return e;
        }).map(e -> {
            return slotToProducerOpt.map(slotToExpressions -> ExpressionUtils.replace(e, slotToExpressions)).orElse(e);
        }).collect(ImmutableSet.toImmutableSet());
    }
}