ExpressionUtils.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.util;

import org.apache.doris.catalog.TableIf.TableType;
import org.apache.doris.common.Config;
import org.apache.doris.common.MaterializedViewException;
import org.apache.doris.common.NereidsException;
import org.apache.doris.common.Pair;
import org.apache.doris.common.UserException;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
import org.apache.doris.nereids.rules.expression.rules.ReplaceVariableByLiteral;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.coercion.NumericType;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * Expression rewrite helper class.
 */
public class ExpressionUtils {

    public static final List<Expression> EMPTY_CONDITION = ImmutableList.of();

    public static List<Expression> extractConjunction(Expression expr) {
        return extract(And.class, expr);
    }

    public static Set<Expression> extractConjunctionToSet(Expression expr) {
        Set<Expression> exprSet = Sets.newLinkedHashSet();
        extract(And.class, expr, exprSet);
        return exprSet;
    }

    public static List<Expression> extractDisjunction(Expression expr) {
        return extract(Or.class, expr);
    }

    /**
     * Split predicates with `And/Or` form recursively.
     * Some examples for `And`:
     * <p>
     * a and b -> a, b
     * (a and b) and c -> a, b, c
     * (a or b) and (c and d) -> (a or b), c , d
     * <p>
     * Stop recursion when meeting `Or`, so this func will ignore `And` inside `Or`.
     * Warning examples:
     * (a and b) or c -> (a and b) or c
     */
    public static List<Expression> extract(CompoundPredicate expr) {
        return extract(expr.getClass(), expr);
    }

    private static List<Expression> extract(Class<? extends Expression> type, Expression expr) {
        List<Expression> result = Lists.newArrayList();
        Deque<Expression> stack = new ArrayDeque<>();
        stack.push(expr);
        while (!stack.isEmpty()) {
            Expression current = stack.pop();
            if (type.isInstance(current)) {
                for (Expression child : current.children()) {
                    stack.push(child);
                }
            } else {
                result.add(current);
            }
        }
        result = Lists.reverse(result);
        return result;
    }

    private static void extract(Class<? extends Expression> type, Expression expr, Collection<Expression> result) {
        result.addAll(extract(type, expr));
    }

    public static Optional<Pair<Slot, Slot>> extractEqualSlot(Expression expr) {
        if (expr instanceof EqualTo && expr.child(0).isSlot() && expr.child(1).isSlot()) {
            return Optional.of(Pair.of((Slot) expr.child(0), (Slot) expr.child(1)));
        }
        return Optional.empty();
    }

    public static Optional<Expression> optionalAnd(List<Expression> expressions) {
        if (expressions.isEmpty()) {
            return Optional.empty();
        } else {
            return Optional.of(ExpressionUtils.and(expressions));
        }
    }

    /**
     * And two list.
     */
    public static Optional<Expression> optionalAnd(List<Expression> left, List<Expression> right) {
        if (left.isEmpty() && right.isEmpty()) {
            return Optional.empty();
        } else if (left.isEmpty()) {
            return optionalAnd(right);
        } else if (right.isEmpty()) {
            return optionalAnd(left);
        } else {
            return Optional.of(new And(optionalAnd(left).get(), optionalAnd(right).get()));
        }
    }

    public static Optional<Expression> optionalAnd(Expression... expressions) {
        return optionalAnd(Lists.newArrayList(expressions));
    }

    public static Optional<Expression> optionalAnd(Collection<Expression> collection) {
        return optionalAnd(ImmutableList.copyOf(collection));
    }

    /**
     *  AND / OR expression, also remove duplicate expression, boolean literal
     */
    public static Expression compound(boolean isAnd, Collection<Expression> expressions) {
        return isAnd ? and(expressions) : or(expressions);
    }

    /**
     *  AND expression, also remove duplicate expression, boolean literal
     */
    public static Expression and(Collection<Expression> expressions) {
        if (expressions.size() == 1) {
            return expressions.iterator().next();
        }
        Set<Expression> distinctExpressions = Sets.newLinkedHashSetWithExpectedSize(expressions.size());
        for (Expression expression : expressions) {
            if (expression.equals(BooleanLiteral.FALSE)) {
                return BooleanLiteral.FALSE;
            } else if (!expression.equals(BooleanLiteral.TRUE)) {
                distinctExpressions.add(expression);
            }
        }

        List<Expression> exprList = Lists.newArrayList(distinctExpressions);
        if (exprList.isEmpty()) {
            return BooleanLiteral.TRUE;
        } else if (exprList.size() == 1) {
            return exprList.get(0);
        } else {
            return new And(exprList);
        }
    }

    /**
     *  AND expression, also remove duplicate expression, boolean literal
     */
    public static Expression and(Expression... expressions) {
        return and(Lists.newArrayList(expressions));
    }

    public static Optional<Expression> optionalOr(List<Expression> expressions) {
        if (expressions.isEmpty()) {
            return Optional.empty();
        } else {
            return Optional.of(ExpressionUtils.or(expressions));
        }
    }

    /**
     *  OR expression, also remove duplicate expression, boolean literal
     */
    public static Expression or(Expression... expressions) {
        return or(Lists.newArrayList(expressions));
    }

    /**
     *  OR expression, also remove duplicate expression, boolean literal
     */
    public static Expression or(Collection<Expression> expressions) {
        if (expressions.size() == 1) {
            return expressions.iterator().next();
        }
        Set<Expression> distinctExpressions = Sets.newLinkedHashSetWithExpectedSize(expressions.size());
        for (Expression expression : expressions) {
            if (expression.equals(BooleanLiteral.TRUE)) {
                return BooleanLiteral.TRUE;
            } else if (!expression.equals(BooleanLiteral.FALSE)) {
                distinctExpressions.add(expression);
            }
        }

        List<Expression> exprList = Lists.newArrayList(distinctExpressions);
        if (exprList.isEmpty()) {
            return BooleanLiteral.FALSE;
        } else if (exprList.size() == 1) {
            return exprList.get(0);
        } else {
            return new Or(exprList);
        }
    }

    public static Expression falseOrNull(Expression expression) {
        if (expression.nullable()) {
            return new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE));
        } else {
            return BooleanLiteral.FALSE;
        }
    }

    public static Expression trueOrNull(Expression expression) {
        if (expression.nullable()) {
            return new Or(new Not(new IsNull(expression)), new NullLiteral(BooleanType.INSTANCE));
        } else {
            return BooleanLiteral.TRUE;
        }
    }

    public static Expression toInPredicateOrEqualTo(Expression reference, Collection<? extends Expression> values) {
        if (values.size() < 2) {
            return or(values.stream().map(value -> new EqualTo(reference, value)).collect(Collectors.toList()));
        } else {
            return new InPredicate(reference, ImmutableList.copyOf(values));
        }
    }

    public static Expression shuttleExpressionWithLineage(Expression expression, Plan plan, BitSet tableBitSet) {
        return shuttleExpressionWithLineage(Lists.newArrayList(expression),
                plan, ImmutableSet.of(), ImmutableSet.of(), tableBitSet).get(0);
    }

    public static List<? extends Expression> shuttleExpressionWithLineage(List<? extends Expression> expressions,
            Plan plan, BitSet tableBitSet) {
        return shuttleExpressionWithLineage(expressions, plan, ImmutableSet.of(), ImmutableSet.of(), tableBitSet);
    }

    /**
     * Replace the slot in expressions with the lineage identifier from specifiedbaseTable sets or target table types
     * example as following:
     * select a + 10 as a1, d from (
     * select b - 5 as a, d from table
     * );
     * op expression before is: a + 10 as a1, d. after is: b - 5 + 10, d
     * todo to get from plan struct info
     */
    public static List<? extends Expression> shuttleExpressionWithLineage(List<? extends Expression> expressions,
            Plan plan,
            Set<TableType> targetTypes,
            Set<String> tableIdentifiers,
            BitSet tableBitSet) {
        if (expressions.isEmpty()) {
            return ImmutableList.of();
        }
        ExpressionLineageReplacer.ExpressionReplaceContext replaceContext =
                new ExpressionLineageReplacer.ExpressionReplaceContext(
                        expressions.stream().map(Expression.class::cast).collect(Collectors.toList()),
                        targetTypes,
                        tableIdentifiers,
                        tableBitSet);

        plan.accept(ExpressionLineageReplacer.INSTANCE, replaceContext);
        // Replace expressions by expression map
        List<Expression> replacedExpressions = replaceContext.getReplacedExpressions();
        if (expressions.size() != replacedExpressions.size()) {
            throw new NereidsException("shuttle expression fail",
                    new MaterializedViewException("shuttle expression fail"));
        }
        return replacedExpressions;
    }

    /**
     * Choose the minimum slot from input parameter.
     */
    public static <S extends NamedExpression> S selectMinimumColumn(Collection<S> slots) {
        Preconditions.checkArgument(!slots.isEmpty());
        S minSlot = null;
        for (S slot : slots) {
            if (minSlot == null) {
                minSlot = slot;
            } else {
                int slotDataTypeWidth = slot.getDataType().width();
                if (slotDataTypeWidth < 0) {
                    continue;
                }
                minSlot = slotDataTypeWidth < minSlot.getDataType().width()
                        || minSlot.getDataType().width() <= 0 ? slot : minSlot;
            }
        }
        return minSlot;
    }

    /**
     * Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot}
     * or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot}
     * <p>
     * for example:
     * - SlotReference to a column:
     * col
     * - Cast on SlotReference:
     * cast(int_col as string)
     * cast(cast(int_col as long) as string)
     *
     * @param expr input expression
     * @return Return Optional[ExprId] of underlying slot reference if input expression is a slot or cast on slot.
     *         Otherwise, return empty optional result.
     */
    public static Optional<ExprId> isSlotOrCastOnSlot(Expression expr) {
        return extractSlotOrCastOnSlot(expr).map(Slot::getExprId);
    }

    /**
     * Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot}
     * or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot}
     */
    public static Optional<Slot> extractSlotOrCastOnSlot(Expression expr) {
        while (expr instanceof Cast) {
            expr = expr.child(0);
        }

        if (expr instanceof SlotReference) {
            return Optional.of((Slot) expr);
        } else {
            return Optional.empty();
        }
    }

    /**
     * Generate replaceMap Slot -> Expression from NamedExpression[Expression as name]
     */
    public static Map<Slot, Expression> generateReplaceMap(List<? extends NamedExpression> namedExpressions) {
        Map<Slot, Expression> replaceMap = Maps.newLinkedHashMapWithExpectedSize(namedExpressions.size());
        for (NamedExpression namedExpression : namedExpressions) {
            if (namedExpression instanceof Alias) {
                // Avoid cast to alias, retrieving the first child expression.
                Slot slot = namedExpression.toSlot();
                replaceMap.putIfAbsent(slot, namedExpression.child(0));
            }
        }
        return replaceMap;
    }

    /**
     * replace NameExpression.
     */
    public static NamedExpression replaceNameExpression(NamedExpression expr,
            Map<? extends Expression, ? extends Expression> replaceMap) {
        Expression newExpr = replace(expr, replaceMap);
        if (newExpr instanceof NamedExpression) {
            return (NamedExpression) newExpr;
        } else {
            return new Alias(expr.getExprId(), newExpr, expr.getName());
        }
    }

    /**
     * Replace expression node in the expression tree by `replaceMap` in top-down manner.
     * For example.
     * <pre>
     * input expression: a > 1
     * replaceMap: a -> b + c
     *
     * output:
     * b + c > 1
     * </pre>
     */
    public static Expression replace(Expression expr, Map<? extends Expression, ? extends Expression> replaceMap) {
        return expr.rewriteDownShortCircuit(e -> {
            Expression replacedExpr = replaceMap.get(e);
            return replacedExpr == null ? e : replacedExpr;
        });
    }

    /**
     * Replace expression node in the expression tree by `replaceMap` in top-down manner.
     * For example.
     * <pre>
     * input expression: a > 1
     * replaceMap: d -> b + c, transferMap: a -> d
     * firstly try to get mapping expression from replaceMap by a, if can not then
     * get mapping d from transferMap by a
     * and get mapping b + c from replaceMap by d
     * output:
     * b + c > 1
     * </pre>
     */
    public static Expression replace(Expression expr, Map<? extends Expression, ? extends Expression> replaceMap,
            Map<? extends Expression, ? extends Expression> transferMap) {
        return expr.rewriteDownShortCircuit(e -> {
            Expression replacedExpr = replaceMap.get(e);
            if (replacedExpr != null) {
                return replacedExpr;
            }
            replacedExpr = replaceMap.get(transferMap.get(e));
            return replacedExpr == null ? e : replacedExpr;
        });
    }

    public static List<Expression> replace(List<Expression> exprs,
            Map<? extends Expression, ? extends Expression> replaceMap) {
        ImmutableList.Builder<Expression> result = ImmutableList.builderWithExpectedSize(exprs.size());
        for (Expression expr : exprs) {
            result.add(replace(expr, replaceMap));
        }
        return result.build();
    }

    public static Set<Expression> replace(Set<Expression> exprs,
            Map<? extends Expression, ? extends Expression> replaceMap) {
        ImmutableSet.Builder<Expression> result = ImmutableSet.builderWithExpectedSize(exprs.size());
        for (Expression expr : exprs) {
            result.add(replace(expr, replaceMap));
        }
        return result.build();
    }

    /**
     * Replace expression node in the expression tree by `replaceMap` in top-down manner.
     */
    public static List<NamedExpression> replaceNamedExpressions(List<? extends NamedExpression> namedExpressions,
            Map<? extends Expression, ? extends Expression> replaceMap) {
        Builder<NamedExpression> replaceExprs = ImmutableList.builderWithExpectedSize(namedExpressions.size());
        for (NamedExpression namedExpression : namedExpressions) {
            NamedExpression newExpr = replaceNameExpression(namedExpression, replaceMap);
            if (newExpr.getExprId().equals(namedExpression.getExprId())) {
                replaceExprs.add(newExpr);
            } else {
                replaceExprs.add(new Alias(namedExpression.getExprId(), newExpr, namedExpression.getName()));
            }
        }
        return replaceExprs.build();
    }

    public static <E extends Expression> List<E> rewriteDownShortCircuit(
            Collection<E> exprs, Function<Expression, Expression> rewriteFunction) {
        return exprs.stream()
                .map(expr -> (E) expr.rewriteDownShortCircuit(rewriteFunction))
                .collect(ImmutableList.toImmutableList());
    }

    private static class ExpressionReplacer
            extends DefaultExpressionRewriter<Map<? extends Expression, ? extends Expression>> {
        public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();

        private ExpressionReplacer() {
        }

        @Override
        public Expression visit(Expression expr, Map<? extends Expression, ? extends Expression> replaceMap) {
            if (replaceMap.containsKey(expr)) {
                return replaceMap.get(expr);
            }
            return super.visit(expr, replaceMap);
        }
    }

    /**
     * merge arguments into an expression array
     *
     * @param arguments instance of Expression or Expression Array
     * @return Expression Array
     */
    public static List<Expression> mergeArguments(Object... arguments) {
        Builder<Expression> builder = ImmutableList.builder();
        for (Object argument : arguments) {
            if (argument instanceof Expression[]) {
                builder.addAll(Arrays.asList((Expression[]) argument));
            } else {
                builder.add((Expression) argument);
            }
        }
        return builder.build();
    }

    /** isAllLiteral */
    public static boolean isAllLiteral(List<Expression> children) {
        for (Expression child : children) {
            if (!(child instanceof Literal)) {
                return false;
            }
        }
        return true;
    }

    /**
     * return true if all children are literal but not null literal.
     */
    public static boolean isAllNonNullComparableLiteral(List<Expression> children) {
        for (Expression child : children) {
            if ((!(child instanceof ComparableLiteral)) || (child instanceof NullLiteral)) {
                return false;
            }
        }
        return true;
    }

    /** matchNumericType */
    public static boolean matchNumericType(List<Expression> children) {
        for (Expression child : children) {
            if (!child.getDataType().isNumericType()) {
                return false;
            }
        }
        return true;
    }

    /** matchDateLikeType */
    public static boolean matchDateLikeType(List<Expression> children) {
        for (Expression child : children) {
            if (!child.getDataType().isDateLikeType()) {
                return false;
            }
        }
        return true;
    }

    /** hasNullLiteral */
    public static boolean hasNullLiteral(List<Expression> children) {
        for (Expression child : children) {
            if (child instanceof NullLiteral) {
                return true;
            }
        }
        return false;
    }

    /** hasOnlyMetricType */
    public static boolean hasOnlyMetricType(List<Expression> children) {
        for (Expression child : children) {
            if (child.getDataType().isOnlyMetricType()) {
                return true;
            }
        }
        return false;
    }

    /**
     * canInferNotNullForMarkSlot
     */
    public static boolean canInferNotNullForMarkSlot(Expression predicate, ExpressionRewriteContext ctx) {
        /*
         * assume predicate is from LogicalFilter
         * the idea is replacing each mark join slot with null and false literal then run FoldConstant rule
         * if the evaluate result are:
         * 1. all true
         * 2. all null and false (in logicalFilter, we discard both null and false values)
         * the mark slot can be non-nullable boolean
         * and in semi join, we can safely change the mark conjunct to hash conjunct
         */
        ImmutableList<Literal> literals =
                ImmutableList.of(new NullLiteral(BooleanType.INSTANCE), BooleanLiteral.FALSE);
        List<MarkJoinSlotReference> markJoinSlotReferenceList =
                new ArrayList<>((predicate.collect(MarkJoinSlotReference.class::isInstance)));
        int markSlotSize = markJoinSlotReferenceList.size();
        int maxMarkSlotCount = 4;
        // if the conjunct has mark slot, and maximum 4 mark slots(for performance)
        if (markSlotSize > 0 && markSlotSize <= maxMarkSlotCount) {
            Map<Expression, Expression> replaceMap = Maps.newHashMap();
            boolean meetTrue = false;
            boolean meetNullOrFalse = false;
            /*
             * markSlotSize = 1 -> loopCount = 2  ---- 0, 1
             * markSlotSize = 2 -> loopCount = 4  ---- 00, 01, 10, 11
             * markSlotSize = 3 -> loopCount = 8  ---- 000, 001, 010, 011, 100, 101, 110, 111
             * markSlotSize = 4 -> loopCount = 16 ---- 0000, 0001, ... 1111
             */
            int loopCount = 1 << markSlotSize;
            for (int i = 0; i < loopCount; ++i) {
                replaceMap.clear();
                /*
                 * replace each mark slot with null or false
                 * literals.get(0) -> NullLiteral(BooleanType.INSTANCE)
                 * literals.get(1) -> BooleanLiteral.FALSE
                 */
                for (int j = 0; j < markSlotSize; ++j) {
                    replaceMap.put(markJoinSlotReferenceList.get(j), literals.get((i >> j) & 1));
                }
                Expression evalResult = FoldConstantRule.evaluate(
                        ExpressionUtils.replace(predicate, replaceMap),
                        ctx
                );

                if (evalResult.equals(BooleanLiteral.TRUE)) {
                    if (meetNullOrFalse) {
                        return false;
                    } else {
                        meetTrue = true;
                    }
                } else if ((isNullOrFalse(evalResult))) {
                    if (meetTrue) {
                        return false;
                    } else {
                        meetNullOrFalse = true;
                    }
                } else {
                    return false;
                }
            }
            return true;
        }
        return false;
    }

    private static boolean isNullOrFalse(Expression expression) {
        return expression.isNullLiteral() || expression.equals(BooleanLiteral.FALSE);
    }

    /**
     * infer notNulls slot from predicate
     */
    public static Set<Slot> inferNotNullSlots(Set<Expression> predicates, CascadesContext cascadesContext) {
        ImmutableSet.Builder<Slot> notNullSlots = ImmutableSet.builderWithExpectedSize(predicates.size());
        for (Expression predicate : predicates) {
            for (Slot slot : predicate.getInputSlots()) {
                Map<Expression, Expression> replaceMap = new HashMap<>();
                Literal nullLiteral = new NullLiteral(slot.getDataType());
                replaceMap.put(slot, nullLiteral);
                Expression evalExpr = FoldConstantRule.evaluate(
                        ExpressionUtils.replace(predicate, replaceMap),
                        new ExpressionRewriteContext(cascadesContext)
                );
                if (evalExpr.isNullLiteral() || BooleanLiteral.FALSE.equals(evalExpr)) {
                    notNullSlots.add(slot);
                }
            }
        }
        return notNullSlots.build();
    }

    /**
     * infer notNulls slot from predicate
     */
    public static Set<Expression> inferNotNull(Set<Expression> predicates, CascadesContext cascadesContext) {
        ImmutableSet.Builder<Expression> newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size());
        for (Slot slot : inferNotNullSlots(predicates, cascadesContext)) {
            newPredicates.add(new Not(new IsNull(slot), false));
        }
        return newPredicates.build();
    }

    /**
     * infer notNulls slot from predicate but these slots must be in the given slots.
     */
    public static Set<Expression> inferNotNull(Set<Expression> predicates, Set<Slot> slots,
            CascadesContext cascadesContext) {
        ImmutableSet.Builder<Expression> newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size());
        for (Slot slot : inferNotNullSlots(predicates, cascadesContext)) {
            if (slots.contains(slot)) {
                newPredicates.add(new Not(new IsNull(slot), true));
            }
        }
        return newPredicates.build();
    }

    /** flatExpressions */
    public static <E extends Expression> List<E> flatExpressions(List<List<E>> expressionLists) {
        int num = 0;
        for (List<E> expressionList : expressionLists) {
            num += expressionList.size();
        }

        ImmutableList.Builder<E> flatten = ImmutableList.builderWithExpectedSize(num);
        for (List<E> expressionList : expressionLists) {
            flatten.addAll(expressionList);
        }
        return flatten.build();
    }

    /** containsType */
    public static boolean containsType(Collection<? extends Expression> expressions, Class type) {
        for (Expression expression : expressions) {
            if (expression.anyMatch(expr -> expr.anyMatch(type::isInstance))) {
                return true;
            }
        }
        return false;
    }

    /** allMatch */
    public static boolean allMatch(
            Collection<? extends Expression> expressions, Predicate<Expression> predicate) {
        for (Expression expression : expressions) {
            if (!predicate.test(expression)) {
                return false;
            }
        }
        return true;
    }

    /** anyMatch */
    public static boolean anyMatch(
            Collection<? extends Expression> expressions, Predicate<Expression> predicate) {
        for (Expression expression : expressions) {
            if (predicate.test(expression)) {
                return true;
            }
        }
        return false;
    }

    /** deapAnyMatch */
    public static boolean deapAnyMatch(
            Collection<? extends Expression> expressions, Predicate<TreeNode<Expression>> predicate) {
        for (Expression expression : expressions) {
            if (expression.anyMatch(expr -> expr.anyMatch(predicate))) {
                return true;
            }
        }
        return false;
    }

    /** deapNoneMatch */
    public static boolean deapNoneMatch(
            Collection<? extends Expression> expressions, Predicate<TreeNode<Expression>> predicate) {
        for (Expression expression : expressions) {
            if (expression.anyMatch(expr -> expr.anyMatch(predicate))) {
                return false;
            }
        }
        return true;
    }

    public static <E> Set<E> collect(Collection<? extends Expression> expressions,
            Predicate<TreeNode<Expression>> predicate) {
        ImmutableSet.Builder<E> set = ImmutableSet.builder();
        for (Expression expr : expressions) {
            set.addAll(expr.collectToList(predicate));
        }
        return set.build();
    }

    public static <E> List<E> collectToList(Collection<? extends Expression> expressions,
            Predicate<TreeNode<Expression>> predicate) {
        ImmutableList.Builder<E> list = ImmutableList.builder();
        for (Expression expr : expressions) {
            list.addAll(expr.collectToList(predicate));
        }
        return list.build();
    }

    /**
     * extract uniform slot for the given predicate, such as a = 1 and b = 2
     */
    public static ImmutableMap<Slot, Expression> extractUniformSlot(Expression expression) {
        ImmutableMap.Builder<Slot, Expression> builder = new ImmutableMap.Builder<>();
        if (expression instanceof And) {
            expression.children().forEach(child -> builder.putAll(extractUniformSlot(child)));
        }
        if (expression instanceof EqualTo) {
            if (isInjective(expression.child(0)) && expression.child(1).isConstant()) {
                builder.put((Slot) expression.child(0), expression.child(1));
            }
        }
        return builder.build();
    }

    // TODO: Add more injective functions
    public static boolean isInjective(Expression expression) {
        return expression instanceof Slot;
    }

    // if the input is unique,  the output of agg is unique, too
    public static boolean isInjectiveAgg(Expression agg) {
        return agg instanceof Sum || agg instanceof Avg || agg instanceof Max || agg instanceof Min;
    }

    public static <E> Set<E> mutableCollect(List<? extends Expression> expressions,
            Predicate<TreeNode<Expression>> predicate) {
        Set<E> set = new HashSet<>();
        for (Expression expr : expressions) {
            set.addAll(expr.collect(predicate));
        }
        return set;
    }

    /** collectAll */
    public static <E> List<E> collectAll(Collection<? extends Expression> expressions,
            Predicate<TreeNode<Expression>> predicate) {
        switch (expressions.size()) {
            case 0: return ImmutableList.of();
            default: {
                ImmutableList.Builder<E> result = ImmutableList.builder();
                for (Expression expr : expressions) {
                    result.addAll((Set) expr.collect(predicate));
                }
                return result.build();
            }
        }
    }

    public static List<List<Expression>> rollupToGroupingSets(List<Expression> rollupExpressions) {
        List<List<Expression>> groupingSets = Lists.newArrayList();
        for (int end = rollupExpressions.size(); end >= 0; --end) {
            groupingSets.add(rollupExpressions.subList(0, end));
        }
        return groupingSets;
    }

    /**
     * check and maybe commute for predications except not pred.
     */
    public static Optional<Expression> checkAndMaybeCommute(Expression expression) {
        if (expression instanceof Not) {
            return Optional.empty();
        }
        if (expression instanceof InPredicate) {
            InPredicate predicate = ((InPredicate) expression);
            if (!predicate.getCompareExpr().isSlot()
                    || predicate.getOptions().size() > Config.max_distribution_pruner_recursion_depth) {
                return Optional.empty();
            }
            return Optional.ofNullable(predicate.optionsAreLiterals() ? expression : null);
        } else if (expression instanceof ComparisonPredicate) {
            ComparisonPredicate predicate = ((ComparisonPredicate) expression);
            if (predicate.left() instanceof Literal) {
                predicate = predicate.commute();
            }
            return Optional.ofNullable(predicate.left().isSlot() && predicate.right().isLiteral() ? predicate : null);
        } else if (expression instanceof IsNull) {
            return Optional.ofNullable(((IsNull) expression).child().isSlot() ? expression : null);
        }
        return Optional.empty();
    }

    public static List<List<Expression>> cubeToGroupingSets(List<Expression> cubeExpressions) {
        List<List<Expression>> groupingSets = Lists.newArrayList();
        cubeToGroupingSets(cubeExpressions, 0, Lists.newArrayList(), groupingSets);
        return groupingSets;
    }

    private static void cubeToGroupingSets(List<Expression> cubeExpressions, int activeIndex,
            List<Expression> currentGroupingSet, List<List<Expression>> groupingSets) {
        if (activeIndex == cubeExpressions.size()) {
            groupingSets.add(currentGroupingSet);
            return;
        }

        // use current expression
        List<Expression> newCurrentGroupingSet = Lists.newArrayList(currentGroupingSet);
        newCurrentGroupingSet.add(cubeExpressions.get(activeIndex));
        cubeToGroupingSets(cubeExpressions, activeIndex + 1, newCurrentGroupingSet, groupingSets);

        // skip current expression
        cubeToGroupingSets(cubeExpressions, activeIndex + 1, currentGroupingSet, groupingSets);
    }

    /**
     * Get input slot set from list of expressions.
     */
    public static Set<Slot> getInputSlotSet(Collection<? extends Expression> exprs) {
        Set<Slot> set = new HashSet<>();
        for (Expression expr : exprs) {
            set.addAll(expr.getInputSlots());
        }
        return set;
    }

    public static Expression getExpressionCoveredByCast(Expression expression) {
        while (expression instanceof Cast) {
            expression = ((Cast) expression).child();
        }
        return expression;
    }

    /**
     * the expressions can be used as runtime filter targets
     */
    public static Expression getSingleNumericSlotOrExpressionCoveredByCast(Expression expression) {
        if (expression.getInputSlots().size() == 1) {
            Slot slot = expression.getInputSlots().iterator().next();
            if (slot.getDataType() instanceof NumericType) {
                return expression.getInputSlots().iterator().next();
            }
        }
        // for other datatype, only support cast.
        // example: T1 join T2 on subStr(T1.a, 1,4) = subStr(T2.a, 1,4)
        // the cost of subStr is too high, and hence we do not generate RF subStr(T2.a, 1,4)->subStr(T1.a, 1,4)
        while (expression instanceof Cast) {
            expression = ((Cast) expression).child();
        }
        return expression;
    }

    /**
     * To check whether a slot is constant after passing through a filter
     */
    public static boolean checkSlotConstant(Slot slot, Set<Expression> predicates) {
        return predicates.stream().anyMatch(predicate -> {
                    if (predicate instanceof EqualTo) {
                        EqualTo equalTo = (EqualTo) predicate;
                        return (equalTo.left() instanceof Literal && equalTo.right().equals(slot))
                                || (equalTo.right() instanceof Literal && equalTo.left().equals(slot));
                    }
                    return false;
                }
        );
    }

    /**
     * Check the expression is inferred or not, if inferred return true, nor return false
     */
    public static boolean isInferred(Expression expression) {
        return expression.accept(new DefaultExpressionVisitor<Boolean, Void>() {

            @Override
            public Boolean visit(Expression expr, Void context) {
                boolean inferred = expr.isInferred();
                if (expr.isInferred() || expr.children().isEmpty()) {
                    return inferred;
                }
                inferred = true;
                for (Expression child : expr.children()) {
                    inferred = inferred && child.accept(this, context);
                }
                return inferred;
            }
        }, null);
    }

    /** distinctSlotByName */
    public static List<Slot> distinctSlotByName(List<Slot> slots) {
        Set<String> existSlotNames = new HashSet<>(slots.size() * 2);
        Builder<Slot> distinctSlots = ImmutableList.builderWithExpectedSize(slots.size());
        for (Slot slot : slots) {
            String name = slot.getName();
            if (existSlotNames.add(name)) {
                distinctSlots.add(slot);
            }
        }
        return distinctSlots.build();
    }

    /** containsWindowExpression */
    public static boolean containsWindowExpression(List<NamedExpression> expressions) {
        for (NamedExpression expression : expressions) {
            if (expression.anyMatch(WindowExpression.class::isInstance)) {
                return true;
            }
        }
        return false;
    }

    /** filter */
    public static <E extends Expression> List<E> filter(List<? extends Expression> expressions, Class<E> clazz) {
        ImmutableList.Builder<E> result = ImmutableList.builderWithExpectedSize(expressions.size());
        for (Expression expression : expressions) {
            if (clazz.isInstance(expression)) {
                result.add((E) expression);
            }
        }
        return result.build();
    }

    /** test whether unionConstExprs satisfy conjuncts */
    public static boolean unionConstExprsSatisfyConjuncts(LogicalUnion union, Set<Expression> conjuncts) {
        CascadesContext tempCascadeContext = CascadesContext.initContext(
                ConnectContext.get().getStatementContext(), union, PhysicalProperties.ANY);
        ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(tempCascadeContext);
        for (List<NamedExpression> constOutput : union.getConstantExprsList()) {
            Map<Expression, Expression> replaceMap = new HashMap<>();
            for (int i = 0; i < constOutput.size(); i++) {
                Expression output = constOutput.get(i);
                if (output instanceof Alias) {
                    replaceMap.put(union.getOutput().get(i), ((Alias) output).child());
                } else {
                    replaceMap.put(union.getOutput().get(i), output);
                }
            }
            for (Expression conjunct : conjuncts) {
                Expression res = FoldConstantRule.evaluate(ExpressionUtils.replace(conjunct, replaceMap),
                        rewriteContext);
                if (!res.equals(BooleanLiteral.TRUE)) {
                    return false;
                }
            }
        }
        return true;
    }

    /** analyze the unbound expression and fold it to literal */
    public static Literal analyzeAndFoldToLiteral(ConnectContext ctx, Expression expression) throws UserException {
        Scope scope = new Scope(new ArrayList<>());
        LogicalEmptyRelation plan = new LogicalEmptyRelation(
                ConnectContext.get().getStatementContext().getNextRelationId(),
                new ArrayList<>());
        CascadesContext cascadesContext = CascadesContext.initContext(ctx.getStatementContext(), plan,
                PhysicalProperties.ANY);
        ExpressionAnalyzer analyzer = new ExpressionAnalyzer(null, scope, cascadesContext, false, false);
        Expression boundExpr = UnboundSlotRewriter.INSTANCE.rewrite(expression, null);
        Expression analyzedExpr;
        try {
            analyzedExpr = analyzer.analyze(boundExpr, new ExpressionRewriteContext(cascadesContext));
        } catch (AnalysisException e) {
            throw new UserException(expression + " must be constant value");
        }
        ExpressionRewriteContext context = new ExpressionRewriteContext(cascadesContext);
        ExpressionRuleExecutor executor = new ExpressionRuleExecutor(ImmutableList.of(
                ExpressionRewrite.bottomUp(ReplaceVariableByLiteral.INSTANCE)
        ));
        Expression rewrittenExpression = executor.rewrite(analyzedExpr, context);
        Expression foldExpression = FoldConstantRule.evaluate(rewrittenExpression, context);
        if (foldExpression instanceof Literal) {
            return (Literal) foldExpression;
        } else {
            throw new UserException(expression + " must be constant value");
        }
    }

    /**
     * mergeList
     */
    public static List<Expression> mergeList(List<Expression> list1, List<Expression> list2) {
        ImmutableList.Builder<Expression> builder = ImmutableList.builder();
        for (Expression expression : list1) {
            if (expression != null) {
                builder.add(expression);
            }
        }
        for (Expression expression : list2) {
            if (expression != null) {
                builder.add(expression);
            }
        }
        return builder.build();
    }

    private static class UnboundSlotRewriter extends DefaultExpressionRewriter<Void> {
        public static final UnboundSlotRewriter INSTANCE = new UnboundSlotRewriter();

        public Expression rewrite(Expression e, Void ctx) {
            return e.accept(this, ctx);
        }

        @Override
        public Expression visitUnboundSlot(UnboundSlot unboundSlot, Void ctx) {
            // set exec_mem_limit=21G, '21G' will be parsed as unbound slot
            // we need to rewrite it to String Literal '21G'
            return new StringLiteral(unboundSlot.getName());
        }
    }

    /**
     * format a list of slots
     */
    public static String slotListShapeInfo(List<Slot> materializedSlots) {
        StringBuilder shapeBuilder = new StringBuilder();
        shapeBuilder.append("(");
        boolean isFirst = true;
        for (Slot slot : materializedSlots) {
            if (isFirst) {
                shapeBuilder.append(slot.shapeInfo());
                isFirst = false;
            } else {
                shapeBuilder.append(",").append(slot.shapeInfo());
            }
        }
        shapeBuilder.append(")");
        return shapeBuilder.toString();
    }
}