AbstractSelectMaterializedIndexRule.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite.mv;

import org.apache.doris.analysis.CreateMaterializedViewStmt;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.MaterializedIndex;
import org.apache.doris.catalog.MaterializedIndexMeta;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.planner.PlanNode;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * Base class for selecting materialized index rules.
 */
public abstract class AbstractSelectMaterializedIndexRule {
    protected boolean shouldSelectIndexWithAgg(LogicalOlapScan scan) {
        switch (scan.getTable().getKeysType()) {
            case AGG_KEYS:
            case UNIQUE_KEYS:
            case DUP_KEYS:
                return !scan.isIndexSelected();
            default:
                return false;
        }
    }

    protected boolean shouldSelectIndexWithoutAgg(LogicalOlapScan scan) {
        switch (scan.getTable().getKeysType()) {
            case AGG_KEYS:
            case UNIQUE_KEYS:
            case DUP_KEYS:
                return !scan.isIndexSelected();
            default:
                return false;
        }
    }

    // get the predicates that can be ignored when all aggregate functions are sum
    protected static List<Expression> getPrunedPredicatesWithAllSumAgg(List<Expression> aggExpressions,
            Set<Expression> predicateExpr) {
        List<Expression> prunedExpr = new ArrayList<>();

        Set<String> sumSlots = aggExpressions.stream().map(e -> e.child(0).toSql())
                .collect(Collectors.toCollection(() -> new TreeSet<String>(String.CASE_INSENSITIVE_ORDER)));
        for (Expression expr : predicateExpr) {
            if (expr instanceof Not && expr.child(0) instanceof IsNull) {
                Expression slot = expr.child(0).child(0);
                String countColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.SUM,
                        CreateMaterializedViewStmt.mvColumnBuilder(slotToCaseWhen(slot).toSql())));
                if (sumSlots.contains(countColumn)) {
                    prunedExpr.add(expr);
                }
            }
        }
        return prunedExpr;
    }

    // we can prune some predicates when there is no group-by column
    protected static List<Expression> getPrunedPredicates(List<Expression> aggExpressions,
            Set<Expression> predicateExpr) {
        List<Expression> prunedExpr = new ArrayList<>();

        boolean isAllSumAgg = true;
        for (Expression expr : aggExpressions) {
            if (!(expr instanceof Sum)) {
                isAllSumAgg = false;
                break;
            }
        }
        if (isAllSumAgg) {
            prunedExpr.addAll(getPrunedPredicatesWithAllSumAgg(aggExpressions, predicateExpr));
        }

        return prunedExpr;
    }

    protected static boolean containAllKeyColumns(OlapTable table, MaterializedIndex index) {
        Set<String> mvColNames = table.getKeyColumnsByIndexId(index.getId()).stream()
                .map(c -> normalizeName(parseMvColumnToSql(c.getNameWithoutMvPrefix())))
                .collect(Collectors.toCollection(() -> new TreeSet<String>(String.CASE_INSENSITIVE_ORDER)));

        Set<String> keyColNames = table.getBaseSchemaKeyColumns().stream()
                .map(c -> normalizeName(parseMvColumnToSql(c.getNameWithoutMvPrefix())))
                .collect(Collectors.toCollection(() -> new TreeSet<String>(String.CASE_INSENSITIVE_ORDER)));

        return keyColNames.containsAll(mvColNames);
    }

    protected static boolean containAllRequiredColumns(MaterializedIndex index, LogicalOlapScan scan,
            Set<Slot> requiredScanOutput, Set<? extends Expression> requiredExpr, Set<Expression> predicateExpr) {
        OlapTable table = scan.getTable();
        MaterializedIndexMeta meta = table.getIndexMetaByIndexId(index.getId());

        Set<String> predicateExprSql = predicateExpr.stream().map(ExpressionTrait::toSql).collect(Collectors.toSet());

        // Here we use toSqlWithoutTbl because the output of toSql() is slot#[0] in Nereids
        Set<String> indexConjuncts = PlanNode.splitAndCompoundPredicateToConjuncts(meta.getWhereClause()).stream()
                .map(e -> {
                    e.setDisableTableName(true);
                    return e;
                }).map(e -> new NereidsParser().parseExpression(e.toSql()).toSql()).collect(Collectors.toSet());

        for (String indexConjunct : indexConjuncts) {
            if (predicateExprSql.contains(indexConjunct)) {
                predicateExprSql.remove(indexConjunct);
            } else {
                return false;
            }
        }

        Set<String> requiredMvColumnNames = requiredScanOutput.stream()
                .map(s -> normalizeName(Column.getNameWithoutMvPrefix(s.getName())))
                .collect(Collectors.toCollection(() -> new TreeSet<String>(String.CASE_INSENSITIVE_ORDER)));

        Set<String> mvColNames = meta.getSchema().stream()
                .map(c -> normalizeName(parseMvColumnToSql(c.getNameWithoutMvPrefix())))
                .collect(Collectors.toCollection(() -> new TreeSet<String>(String.CASE_INSENSITIVE_ORDER)));
        mvColNames.addAll(indexConjuncts);

        if (mvColNames.containsAll(requiredMvColumnNames) && predicateExprSql.isEmpty()) {
            return true;
        }

        Set<Expression> remained = requiredExpr.stream().filter(e -> !containsAllColumn(e, mvColNames))
                .collect(Collectors.toSet());
        if (remained.isEmpty()) {
            return true;
        }

        if (!scan.getGroupExpression().isPresent()) {
            Set<Expression> prunedExpr = getPrunedPredicates(
                    requiredExpr.stream().filter(e -> e instanceof AggregateFunction).collect(Collectors.toList()),
                    predicateExpr).stream().collect(Collectors.toSet());
            remained = remained.stream().filter(e -> !prunedExpr.contains(e)).collect(Collectors.toSet());
        }

        return remained.isEmpty();
    }

    public static String parseMvColumnToSql(String mvName) {
        return new NereidsParser().parseExpression(
                    org.apache.doris.analysis.CreateMaterializedViewStmt.mvColumnBreaker(mvName)).toSql();
    }

    public static String parseMvColumnToMvName(String mvName, Optional<String> aggTypeName) {
        return CreateMaterializedViewStmt.mvColumnBuilder(aggTypeName,
                new NereidsParser().parseExpression(
                    org.apache.doris.analysis.CreateMaterializedViewStmt.mvColumnBreaker(mvName)).toSql());
    }

    protected static boolean containsAllColumn(Expression expression, Set<String> mvColumnNames) {
        String sql = expression.toSql();
        if (mvColumnNames.contains(sql) || mvColumnNames
                .contains(org.apache.doris.analysis.CreateMaterializedViewStmt.mvColumnBreaker(sql))) {
            return true;
        }
        if (expression.children().isEmpty()) {
            return expression instanceof VirtualSlotReference;
        }
        for (Expression child : expression.children()) {
            if (child instanceof Literal) {
                continue;
            }
            if (!containsAllColumn(child, mvColumnNames)) {
                return false;
            }
        }
        return true;
    }

    /**
     * 1. find matching key prefix most.
     * 2. sort by row count, column count and index id.
     */
    protected static long selectBestIndex(
            List<MaterializedIndex> candidates,
            LogicalOlapScan scan,
            Set<Expression> predicates,
            Set<? extends Expression> requiredExprs) {
        if (candidates.isEmpty()) {
            return scan.getTable().getBaseIndexId();
        }

        MaterializedIndex baseIndex = scan.getTable().getBaseIndex();
        candidates.add(baseIndex);

        OlapTable table = scan.getTable();
        // Scan slot exprId -> slot name
        Map<ExprId, String> exprIdToName = scan.getOutput()
                .stream()
                .collect(Collectors.toMap(NamedExpression::getExprId, NamedExpression::getName));

        // find matching key prefix most.
        candidates = matchPrefixMost(scan, candidates, predicates, exprIdToName);
        if (candidates.size() > 1) {
            Set<String> requiredExprNames = requiredExprs.stream().map(e -> {
                if (e instanceof Alias) {
                    return ((Alias) e).child().toSql().toLowerCase();
                }
                return e.toSql().toLowerCase();
            }).collect(Collectors.toSet());
            candidates = matchColumnMost(scan.getTable(), candidates, requiredExprNames);
        }

        List<Long> partitionIds = scan.getSelectedPartitionIds();
        // sort by row count, column count and index id.
        List<Long> sortedIndexIds = candidates.stream()
                .map(MaterializedIndex::getId)
                .sorted(Comparator
                        // compare by row count
                        .comparing(rid -> partitionIds.stream()
                                .mapToLong(pid -> table.getPartition(pid).getIndex((Long) rid).getRowCount())
                                .sum())
                        // compare by column count
                        .thenComparing(rid -> table.getSchemaByIndexId((Long) rid).size())
                        // prioritize using non-base index
                        .thenComparing(rid -> (Long) rid == baseIndex.getId())
                        // compare by index id
                        .thenComparing(rid -> (Long) rid))
                .collect(Collectors.toList());

        return table.getBestMvIdWithHint(sortedIndexIds);
    }

    protected static List<MaterializedIndex> matchPrefixMost(
            LogicalOlapScan scan,
            List<MaterializedIndex> candidate,
            Set<Expression> predicates,
            Map<ExprId, String> exprIdToName) {
        Map<Boolean, Set<String>> split = filterCanUsePrefixIndexAndSplitByEquality(predicates, exprIdToName);
        Set<String> equalColNames = split.getOrDefault(true, ImmutableSet.of()).stream()
                .map(String::toLowerCase).collect(Collectors.toSet());
        Set<String> nonEqualColNames = split.getOrDefault(false, ImmutableSet.of()).stream()
                .map(String::toLowerCase).collect(Collectors.toSet());

        // prioritize using index with where clause
        if (candidate.stream()
                .anyMatch(index -> scan.getTable().getIndexMetaByIndexId(index.getId()).getWhereClause() != null)) {
            candidate = candidate.stream()
                    .filter(index -> scan.getTable().getIndexMetaByIndexId(index.getId()).getWhereClause() != null)
                    .collect(Collectors.toList());
        }

        // prioritize using index with pre agg
        if (candidate.stream().anyMatch(
                index -> scan.getTable().getIndexMetaByIndexId(index.getId()).getKeysType() != KeysType.DUP_KEYS)) {
            candidate = candidate.stream().filter(
                    index -> scan.getTable().getIndexMetaByIndexId(index.getId()).getKeysType() != KeysType.DUP_KEYS)
                    .collect(Collectors.toList());
        }

        if (!(equalColNames.isEmpty() && nonEqualColNames.isEmpty())) {
            List<MaterializedIndex> matchingResult = matchKeyPrefixMost(scan.getTable(), candidate,
                    equalColNames, nonEqualColNames);
            return matchingResult.isEmpty() ? candidate : matchingResult;
        } else {
            return candidate;
        }
    }

    ///////////////////////////////////////////////////////////////////////////
    // Split conjuncts into equal-to and non-equal-to.
    ///////////////////////////////////////////////////////////////////////////

    /**
     * Filter the input conjuncts those can use prefix and split into 2 groups: is equal-to or non-equal-to predicate
     * when comparing the key column.
     */
    private static Map<Boolean, Set<String>> filterCanUsePrefixIndexAndSplitByEquality(
            Set<Expression> conjuncts, Map<ExprId, String> exprIdToColName) {
        return conjuncts.stream()
                .map(expr -> PredicateChecker.canUsePrefixIndex(expr, exprIdToColName))
                .filter(result -> !result.equals(PrefixIndexCheckResult.FAILURE))
                .collect(Collectors.groupingBy(
                        result -> result.type == ResultType.SUCCESS_EQUAL,
                        Collectors.mapping(result -> result.colName, Collectors.toSet())
                ));
    }

    private enum ResultType {
        FAILURE,
        SUCCESS_EQUAL,
        SUCCESS_NON_EQUAL,
    }

    private static class PrefixIndexCheckResult {
        public static final PrefixIndexCheckResult FAILURE = new PrefixIndexCheckResult(null, ResultType.FAILURE);
        private final String colName;
        private final ResultType type;

        private PrefixIndexCheckResult(String colName, ResultType result) {
            this.colName = colName;
            this.type = result;
        }

        public static PrefixIndexCheckResult createEqual(String name) {
            return new PrefixIndexCheckResult(name, ResultType.SUCCESS_EQUAL);
        }

        public static PrefixIndexCheckResult createNonEqual(String name) {
            return new PrefixIndexCheckResult(name, ResultType.SUCCESS_NON_EQUAL);
        }
    }

    /**
     * Check if an expression could prefix key index.
     */
    private static class PredicateChecker extends ExpressionVisitor<PrefixIndexCheckResult, Map<ExprId, String>> {
        private static final PredicateChecker INSTANCE = new PredicateChecker();

        private PredicateChecker() {
        }

        public static PrefixIndexCheckResult canUsePrefixIndex(Expression expression,
                Map<ExprId, String> exprIdToName) {
            return expression.accept(INSTANCE, exprIdToName);
        }

        @Override
        public PrefixIndexCheckResult visit(Expression expr, Map<ExprId, String> context) {
            return PrefixIndexCheckResult.FAILURE;
        }

        @Override
        public PrefixIndexCheckResult visitInPredicate(InPredicate in, Map<ExprId, String> context) {
            Optional<ExprId> slotOrCastOnSlot = ExpressionUtils.isSlotOrCastOnSlot(in.getCompareExpr());
            if (slotOrCastOnSlot.isPresent() && in.getOptions().stream().allMatch(Literal.class::isInstance)) {
                return PrefixIndexCheckResult.createEqual(context.get(slotOrCastOnSlot.get()));
            } else {
                return PrefixIndexCheckResult.FAILURE;
            }
        }

        @Override
        public PrefixIndexCheckResult visitComparisonPredicate(ComparisonPredicate cp, Map<ExprId, String> context) {
            if (cp instanceof EqualPredicate) {
                return check(cp, context, PrefixIndexCheckResult::createEqual);
            } else {
                return check(cp, context, PrefixIndexCheckResult::createNonEqual);
            }
        }

        private PrefixIndexCheckResult check(ComparisonPredicate cp, Map<ExprId, String> exprIdToColumnName,
                Function<String, PrefixIndexCheckResult> resultMapper) {
            return check(cp).map(exprId -> resultMapper.apply(exprIdToColumnName.get(exprId)))
                    .orElse(PrefixIndexCheckResult.FAILURE);
        }

        private Optional<ExprId> check(ComparisonPredicate cp) {
            Optional<ExprId> exprId = check(cp.left(), cp.right());
            return exprId.isPresent() ? exprId : check(cp.right(), cp.left());
        }

        private Optional<ExprId> check(Expression maybeSlot, Expression maybeConst) {
            Optional<ExprId> exprIdOpt = ExpressionUtils.isSlotOrCastOnSlot(maybeSlot);
            return exprIdOpt.isPresent() && maybeConst.isConstant() ? exprIdOpt : Optional.empty();
        }
    }

    ///////////////////////////////////////////////////////////////////////////
    // Matching key prefix
    ///////////////////////////////////////////////////////////////////////////
    private static List<MaterializedIndex> matchKeyPrefixMost(
            OlapTable table,
            List<MaterializedIndex> indexes,
            Set<String> equalColumns,
            Set<String> nonEqualColumns) {
        TreeMap<Integer, List<MaterializedIndex>> collect = indexes.stream()
                .collect(Collectors.toMap(
                        index -> indexKeyPrefixMatchCount(table, index, equalColumns, nonEqualColumns),
                        Lists::newArrayList,
                        (l1, l2) -> {
                            l1.addAll(l2);
                            return l1;
                        },
                        TreeMap::new)
                );
        return collect.descendingMap().firstEntry().getValue();
    }

    private static int indexKeyPrefixMatchCount(
            OlapTable table,
            MaterializedIndex index,
            Set<String> equalColNames,
            Set<String> nonEqualColNames) {
        int matchCount = 0;
        for (Column column : table.getSchemaByIndexId(index.getId())) {
            if (equalColNames.contains(normalizeName(column.getNameWithoutMvPrefix().toLowerCase()))) {
                matchCount++;
            } else if (nonEqualColNames.contains(normalizeName(column.getNameWithoutMvPrefix().toLowerCase()))) {
                // un-equivalence predicate's columns can match only first column in index.
                matchCount++;
                break;
            } else {
                break;
            }
        }
        return matchCount;
    }

    private static List<MaterializedIndex> matchColumnMost(
            OlapTable table,
            List<MaterializedIndex> indexes,
            Set<String> requiredExprs) {
        TreeMap<Integer, List<MaterializedIndex>> collect = indexes.stream()
                .collect(Collectors.toMap(
                        index -> columnMatchCount(table, index, requiredExprs),
                        Lists::newArrayList,
                        (l1, l2) -> {
                            l1.addAll(l2);
                            return l1;
                        },
                        TreeMap::new)
                );
        return collect.descendingMap().firstEntry().getValue();
    }

    private static int columnMatchCount(OlapTable table, MaterializedIndex index, Set<String> requiredColNames) {
        int matchCount = 0;
        for (Column column : table.getSchemaByIndexId(index.getId())) {
            if (requiredColNames.contains(normalizeName(column.getNameWithoutMvPrefix().toLowerCase()))) {
                matchCount++;
            }
        }
        return matchCount;
    }

    protected static boolean preAggEnabledByHint(LogicalOlapScan olapScan) {
        for (String hint : olapScan.getHints()) {
            if ("PREAGGOPEN".equalsIgnoreCase(hint)) {
                return true;
            }
        }
        return false;
    }

    public static String normalizeName(String name) {
        return name.replace("`", "").toLowerCase();
    }

    public static Expression slotToCaseWhen(Expression expression) {
        return new CaseWhen(ImmutableList.of(new WhenClause(new IsNull(expression), new TinyIntLiteral((byte) 0))),
                new TinyIntLiteral((byte) 1));
    }

    protected SlotContext generateBaseScanExprToMvExpr(LogicalOlapScan mvPlan) {
        if (mvPlan.getSelectedIndexId() == mvPlan.getTable().getBaseIndexId()) {
            return SlotContext.EMPTY;
        }
        Map<Slot, Slot> baseSlotToMvSlot = new HashMap<>();
        Map<String, Slot> mvNameToMvSlot = new HashMap<>();
        for (Slot mvSlot : mvPlan.getOutputByIndex(mvPlan.getSelectedIndexId())) {
            boolean isPushed = false;
            for (Slot baseSlot : mvPlan.getOutput()) {
                if (baseSlot.toSql().equalsIgnoreCase(
                        org.apache.doris.analysis.CreateMaterializedViewStmt.mvColumnBreaker(
                            normalizeName(mvSlot.getName())))) {
                    baseSlotToMvSlot.put(baseSlot, mvSlot);
                    isPushed = true;
                    break;
                }
            }
            if (!isPushed) {
                mvNameToMvSlot.put(normalizeName(mvSlot.getName()), mvSlot);
            }
        }
        OlapTable table = mvPlan.getTable();
        MaterializedIndexMeta meta = table.getIndexMetaByIndexId(mvPlan.getSelectedIndexId());

        return new SlotContext(baseSlotToMvSlot, mvNameToMvSlot,
                PlanNode.splitAndCompoundPredicateToConjuncts(meta.getWhereClause()).stream()
                        .map(e -> {
                            e.setDisableTableName(true);
                            return e;
                        })
                        .map(e -> new NereidsParser().parseExpression(e.toSql()))
                        .collect(Collectors.toSet()));
    }

    // Call this generateBaseScanExprToMvExpr only when we have both agg and filter
    protected SlotContext generateBaseScanExprToMvExpr(LogicalOlapScan mvPlan, Set<Expression> requiredExpr,
            Set<Expression> predicateExpr) {
        SlotContext context = generateBaseScanExprToMvExpr(mvPlan);
        if (mvPlan.getGroupExpression().isPresent()) {
            return context;
        }
        Set<Expression> pruned = getPrunedPredicates(
                requiredExpr.stream().filter(e -> e instanceof AggregateFunction).collect(Collectors.toList()),
                predicateExpr).stream().collect(Collectors.toSet());

        return new SlotContext(context.baseSlotToMvSlot, context.mvNameToMvSlot,
                Stream.concat(pruned.stream(), context.trueExprs.stream()).collect(Collectors.toSet()));
    }

    /** SlotContext */
    protected static class SlotContext {
        public static final SlotContext EMPTY
                = new SlotContext(ImmutableMap.of(), ImmutableMap.of(), ImmutableSet.of());

        // base index Slot to selected mv Slot
        public final Map<Slot, Slot> baseSlotToMvSlot;
        // selected mv Slot name to mv Slot, we must use ImmutableSortedMap because column name could be uppercase
        public final ImmutableSortedMap<String, Slot> mvNameToMvSlot;

        public final ImmutableSet<Expression> trueExprs;

        public SlotContext(Map<Slot, Slot> baseSlotToMvSlot, Map<String, Slot> mvNameToMvSlot,
                Set<Expression> trueExprs) {
            this.baseSlotToMvSlot = ImmutableMap.copyOf(baseSlotToMvSlot);
            this.mvNameToMvSlot = ImmutableSortedMap.copyOf(mvNameToMvSlot, String.CASE_INSENSITIVE_ORDER);
            this.trueExprs = ImmutableSet.copyOf(trueExprs);
        }
    }

    /**
     * ReplaceExpressions
     * Notes: For the sum type, the original column and the mv column may have inconsistent types,
     *        but the be side will not check the column type, so it can work normally
     */
    protected static class ReplaceExpressions extends DefaultPlanVisitor<Plan, Void> {
        private final SlotContext slotContext;

        public ReplaceExpressions(SlotContext slotContext) {
            this.slotContext = slotContext;
        }

        public Plan replace(Plan plan, LogicalOlapScan scan) {
            if (scan.getSelectedIndexId() == scan.getTable().getBaseIndexId()) {
                return plan;
            }
            return plan.accept(this, null);
        }

        @Override
        public LogicalAggregate<Plan> visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, Void ctx) {
            Plan child = agg.child(0).accept(this, ctx);
            List<Expression> groupByExprs = agg.getGroupByExpressions();
            List<Expression> newGroupByExprs = groupByExprs.stream()
                    .map(expr -> new ReplaceExpressionWithMvColumn(slotContext).replace(expr))
                    .collect(ImmutableList.toImmutableList());

            List<NamedExpression> outputExpressions = agg.getOutputExpressions();
            List<NamedExpression> newOutputExpressions = outputExpressions.stream()
                    .map(expr -> (NamedExpression) new ReplaceExpressionWithMvColumn(slotContext).replace(expr))
                    .collect(ImmutableList.toImmutableList());

            return agg.withNormalized(newGroupByExprs, newOutputExpressions, child);
        }

        @Override
        public LogicalRepeat<Plan> visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, Void ctx) {
            Plan child = repeat.child(0).accept(this, ctx);
            List<List<Expression>> groupingSets = repeat.getGroupingSets();
            List<List<Expression>> newGroupingExprs = Lists.newArrayList();
            for (List<Expression> expressions : groupingSets) {
                newGroupingExprs.add(
                        expressions.stream().map(expr -> new ReplaceExpressionWithMvColumn(slotContext).replace(expr))
                                .collect(ImmutableList.toImmutableList()));
            }

            List<NamedExpression> outputExpressions = repeat.getOutputExpressions();
            List<NamedExpression> newOutputExpressions = PlanUtils.adjustNullableForRepeat(newGroupingExprs,
                    outputExpressions.stream()
                            .map(expr -> (NamedExpression) new ReplaceExpressionWithMvColumn(slotContext).replace(expr))
                            .collect(ImmutableList.toImmutableList()));

            return repeat.withNormalizedExpr(newGroupingExprs, newOutputExpressions, child);
        }

        @Override
        public LogicalFilter<Plan> visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void ctx) {
            Plan child = filter.child(0).accept(this, ctx);
            Set<Expression> newConjuncts = ImmutableSet.copyOf(ExpressionUtils.extractConjunction(
                    new ReplaceExpressionWithMvColumn(slotContext).replace(filter.getPredicate())));

            return filter.withConjunctsAndChild(newConjuncts, child);
        }

        @Override
        public LogicalProject<Plan> visitLogicalProject(LogicalProject<? extends Plan> project, Void ctx) {
            Plan child = project.child(0).accept(this, ctx);
            List<NamedExpression> projects = project.getProjects();
            List<NamedExpression> newProjects = projects.stream()
                    .map(expr -> (NamedExpression) new ReplaceExpressionWithMvColumn(slotContext).replace(expr))
                    .collect(ImmutableList.toImmutableList());

            return project.withProjectsAndChild(newProjects, child);
        }

        @Override
        public LogicalOlapScan visitLogicalOlapScan(LogicalOlapScan scan, Void ctx) {
            return (LogicalOlapScan) scan.withGroupExprLogicalPropChildren(scan.getGroupExpression(), Optional.empty(),
                    ImmutableList.of());
        }
    }

    /**
     * ReplaceExpressionWithMvColumn
     */
    protected static class ReplaceExpressionWithMvColumn extends DefaultExpressionRewriter<Void> {

        // base index Slot to selected mv Slot
        private final Map<Slot, Slot> baseSlotToMvSlot;
        // selected mv Slot name to mv Slot,  we must use ImmutableSortedMap because column name could be uppercase
        private final ImmutableSortedMap<String, Slot> mvNameToMvSlot;

        private final ImmutableSet<String> trueExprs;

        public ReplaceExpressionWithMvColumn(SlotContext slotContext) {
            this.baseSlotToMvSlot = ImmutableMap.copyOf(slotContext.baseSlotToMvSlot);
            this.mvNameToMvSlot = ImmutableSortedMap.copyOf(slotContext.mvNameToMvSlot, String.CASE_INSENSITIVE_ORDER);
            this.trueExprs = slotContext.trueExprs.stream().map(e -> e.toSql()).collect(ImmutableSet.toImmutableSet());
        }

        public Expression replace(Expression expression) {
            return expression.accept(this, null);
        }

        @Override
        public Expression visit(Expression expr, Void context) {
            if (notUseMv() || org.apache.doris.analysis.CreateMaterializedViewStmt.isMVColumn(expr.toSql())) {
                return expr;
            } else if (trueExprs.contains(expr.toSql())) {
                return BooleanLiteral.TRUE;
            } else if (checkExprIsMvColumn(expr)) {
                return mvNameToMvSlot
                        .get(org.apache.doris.analysis.CreateMaterializedViewStmt.mvColumnBuilder(expr.toSql()));
            } else {
                expr = super.visit(expr, context);
                return expr;
            }
        }

        @Override
        public Expression visitSlotReference(SlotReference slotReference, Void context) {
            if (baseSlotToMvSlot.containsKey(slotReference)) {
                return baseSlotToMvSlot.get(slotReference);
            }
            if (mvNameToMvSlot.containsKey(slotReference.toSql())) {
                return mvNameToMvSlot.get(slotReference.toSql());
            }
            return slotReference;
        }

        @Override
        public Expression visitAggregateFunction(AggregateFunction aggregateFunction, Void context) {
            String childrenName = aggregateFunction.children()
                    .stream()
                    .map(Expression::toSql)
                    .collect(Collectors.joining(", "));
            String mvName = org.apache.doris.analysis.CreateMaterializedViewStmt.mvAggregateColumnBuilder(
                    aggregateFunction.getName(), childrenName);
            if (mvNameToMvSlot.containsKey(mvName)) {
                return aggregateFunction.withChildren(mvNameToMvSlot.get(mvName));
            } else if (mvNameToMvSlot.containsKey(childrenName)) {
                // aggRewrite eg: bitmap_union_count -> bitmap_union
                return aggregateFunction.withChildren(mvNameToMvSlot.get(childrenName));
            }
            return visit(aggregateFunction, context);
        }

        @Override
        public Expression visitScalarFunction(ScalarFunction scalarFunction, Void context) {
            List<Expression> newChildrenWithoutCast = scalarFunction.children().stream()
                    .map(child -> {
                        if (child instanceof Cast) {
                            return ((Cast) child).child();
                        }
                        return child;
                    }).collect(ImmutableList.toImmutableList());
            Expression newScalarFunction = scalarFunction.withChildren(newChildrenWithoutCast);
            if (checkExprIsMvColumn(newScalarFunction)) {
                return mvNameToMvSlot.get(
                    org.apache.doris.analysis.CreateMaterializedViewStmt.mvColumnBuilder(newScalarFunction.toSql()));
            }
            return visit(scalarFunction, context);
        }

        private boolean notUseMv() {
            return baseSlotToMvSlot.isEmpty() && mvNameToMvSlot.isEmpty();
        }

        private boolean checkExprIsMvColumn(Expression expr) {
            return mvNameToMvSlot.containsKey(
                org.apache.doris.analysis.CreateMaterializedViewStmt.mvColumnBuilder(expr.toSql()));
        }
    }

    protected List<NamedExpression> generateProjectsAlias(List<? extends NamedExpression> oldProjects,
            SlotContext slotContext) {
        return oldProjects.stream().map(e -> {
            Expression real = e;
            if (real instanceof Alias) {
                real = real.child(0);
            }
            if (slotContext.baseSlotToMvSlot.containsKey(e.toSlot())) {
                return new Alias(e.getExprId(), slotContext.baseSlotToMvSlot.get(e.toSlot()), e.getName());
            }
            if (slotContext.mvNameToMvSlot.containsKey(real.toSql())) {
                return new Alias(e.getExprId(), slotContext.mvNameToMvSlot.get(real.toSql()), e.getName());
            }
            return e.toSlot();
        }).collect(ImmutableList.toImmutableList());
    }
}