FoldConstantRuleOnFE.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.catalog.EncryptKey;
import org.apache.doris.catalog.Env;
import org.apache.doris.cluster.ClusterNamespace;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.datasource.InternalCatalog;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.analyzer.UnboundVariable;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionListenerMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.rules.expression.ExpressionTraverseListener;
import org.apache.doris.nereids.rules.expression.ExpressionTraverseListenerFactory;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.ExpressionEvaluator;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.Variable;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentUser;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Database;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Date;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncryptKeyRef;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.LastQueryId;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Password;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SessionUser;
import org.apache.doris.nereids.trees.expressions.functions.scalar.User;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Version;
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.trees.expressions.literal.format.DateTimeChecker;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.GlobalVariable;
import org.apache.doris.thrift.TUniqueId;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import org.apache.commons.codec.digest.DigestUtils;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Predicate;

/**
 * evaluate an expression on fe.
 */
public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule
        implements ExpressionPatternRuleFactory, ExpressionTraverseListenerFactory {

    public static final FoldConstantRuleOnFE VISITOR_INSTANCE = new FoldConstantRuleOnFE(true);
    public static final FoldConstantRuleOnFE PATTERN_MATCH_INSTANCE = new FoldConstantRuleOnFE(false);

    // record whether current expression is in an aggregate function with distinct,
    // if is, we will skip to fold constant
    private static final ListenAggDistinct LISTEN_AGG_DISTINCT = new ListenAggDistinct();
    private static final CheckWhetherUnderAggDistinct NOT_UNDER_AGG_DISTINCT = new CheckWhetherUnderAggDistinct();

    private final boolean deepRewrite;

    public FoldConstantRuleOnFE(boolean deepRewrite) {
        this.deepRewrite = deepRewrite;
    }

    public static Expression evaluate(Expression expression, ExpressionRewriteContext expressionRewriteContext) {
        return VISITOR_INSTANCE.rewrite(expression, expressionRewriteContext);
    }

    @Override
    public List<ExpressionListenerMatcher<? extends Expression>> buildListeners() {
        return ImmutableList.of(
                listenerType(AggregateFunction.class)
                        .when(AggregateFunction::isDistinct)
                        .then(LISTEN_AGG_DISTINCT.as()),

                listenerType(AggregateExpression.class)
                        .when(AggregateExpression::isDistinct)
                        .then(LISTEN_AGG_DISTINCT.as())
        );
    }

    @Override
    public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
        return ImmutableList.of(
                matches(EncryptKeyRef.class, this::visitEncryptKeyRef),
                matches(EqualTo.class, this::visitEqualTo),
                matches(GreaterThan.class, this::visitGreaterThan),
                matches(GreaterThanEqual.class, this::visitGreaterThanEqual),
                matches(LessThan.class, this::visitLessThan),
                matches(LessThanEqual.class, this::visitLessThanEqual),
                matches(NullSafeEqual.class, this::visitNullSafeEqual),
                matches(Not.class, this::visitNot),
                matches(Database.class, this::visitDatabase),
                matches(CurrentUser.class, this::visitCurrentUser),
                matches(CurrentCatalog.class, this::visitCurrentCatalog),
                matches(User.class, this::visitUser),
                matches(ConnectionId.class, this::visitConnectionId),
                matches(And.class, this::visitAnd),
                matches(Or.class, this::visitOr),
                matches(Cast.class, this::visitCast),
                matches(BoundFunction.class, this::visitBoundFunction),
                matches(BinaryArithmetic.class, this::visitBinaryArithmetic),
                matches(CaseWhen.class, this::visitCaseWhen),
                matches(If.class, this::visitIf),
                matches(InPredicate.class, this::visitInPredicate),
                matches(IsNull.class, this::visitIsNull),
                matches(TimestampArithmetic.class, this::visitTimestampArithmetic),
                matches(Password.class, this::visitPassword),
                matches(Array.class, this::visitArray),
                matches(Date.class, this::visitDate),
                matches(Version.class, this::visitVersion),
                matches(SessionUser.class, this::visitSessionUser),
                matches(LastQueryId.class, this::visitLastQueryId),
                matches(Nvl.class, this::visitNvl)
        );
    }

    @Override
    public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
        if (expr instanceof AggregateFunction && ((AggregateFunction) expr).isDistinct()) {
            return expr;
        } else if (expr instanceof AggregateExpression && ((AggregateExpression) expr).getFunction().isDistinct()) {
            return expr;
        }
        // ATTN: we must return original expr, because OrToIn is implemented with MutableState,
        //   newExpr will lose these states leading to dead loop by OrToIn -> SimplifyRange -> FoldConstantByFE
        Expression newExpr = expr.accept(this, ctx);
        if (newExpr.equals(expr)) {
            return expr;
        }
        return newExpr;
    }

    /**
     * process constant expression.
     */
    @Override
    public Expression visitSlot(Slot slot, ExpressionRewriteContext context) {
        return slot;
    }

    @Override
    public Expression visitLiteral(Literal literal, ExpressionRewriteContext context) {
        return literal;
    }

    @Override
    public Expression visitMatch(Match match, ExpressionRewriteContext context) {
        match = rewriteChildren(match, context);
        Optional<Expression> checkedExpr = preProcess(match);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        return super.visitMatch(match, context);
    }

    @Override
    public Expression visitUnboundVariable(UnboundVariable unboundVariable, ExpressionRewriteContext context) {
        Variable variable = ExpressionAnalyzer.resolveUnboundVariable(unboundVariable);
        return variable.getRealExpression();
    }

    @Override
    public Expression visitEncryptKeyRef(EncryptKeyRef encryptKeyRef, ExpressionRewriteContext context) {
        String dbName = encryptKeyRef.getDbName();
        ConnectContext connectContext = context.cascadesContext.getConnectContext();
        if (Strings.isNullOrEmpty(dbName)) {
            dbName = connectContext.getDatabase();
        }
        if ("".equals(dbName)) {
            throw new AnalysisException("DB " + dbName + "not found");
        }
        if (!Env.getCurrentEnv().getAccessManager()
                .checkDbPriv(ConnectContext.get(), InternalCatalog.INTERNAL_CATALOG_NAME,
                        dbName, PrivPredicate.SHOW)) {
            String message = ErrorCode.ERR_DB_ACCESS_DENIED_ERROR.formatErrorMsg(
                    PrivPredicate.SHOW.getPrivs().toString(), dbName);
            throw new AnalysisException(message);
        }
        org.apache.doris.catalog.Database database =
                Env.getCurrentEnv().getInternalCatalog().getDbNullable(dbName);
        if (database == null) {
            throw new AnalysisException("DB " + dbName + "not found");
        }
        EncryptKey encryptKey = database.getEncryptKey(encryptKeyRef.getEncryptKeyName());
        if (encryptKey == null) {
            throw new AnalysisException("Can not found encryptKey" + encryptKeyRef.getEncryptKeyName());
        }
        return new StringLiteral(encryptKey.getKeyString());
    }

    @Override
    public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) {
        equalTo = rewriteChildren(equalTo, context);
        Optional<Expression> checkedExpr = preProcess(equalTo);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        if (equalTo.left() instanceof ComparableLiteral && equalTo.right() instanceof ComparableLiteral) {
            return BooleanLiteral.of(((ComparableLiteral) equalTo.left())
                    .compareTo((ComparableLiteral) equalTo.right()) == 0);
        } else {
            return BooleanLiteral.of(equalTo.left().equals(equalTo.right()));
        }
    }

    @Override
    public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) {
        greaterThan = rewriteChildren(greaterThan, context);
        Optional<Expression> checkedExpr = preProcess(greaterThan);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        return BooleanLiteral.of(((ComparableLiteral) greaterThan.left())
                .compareTo((ComparableLiteral) greaterThan.right()) > 0);
    }

    @Override
    public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) {
        greaterThanEqual = rewriteChildren(greaterThanEqual, context);
        Optional<Expression> checkedExpr = preProcess(greaterThanEqual);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        return BooleanLiteral.of(((ComparableLiteral) greaterThanEqual.left())
                .compareTo((ComparableLiteral) greaterThanEqual.right()) >= 0);
    }

    @Override
    public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext context) {
        lessThan = rewriteChildren(lessThan, context);
        Optional<Expression> checkedExpr = preProcess(lessThan);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        return BooleanLiteral.of(((ComparableLiteral) lessThan.left())
                .compareTo((ComparableLiteral) lessThan.right()) < 0);
    }

    @Override
    public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) {
        lessThanEqual = rewriteChildren(lessThanEqual, context);
        Optional<Expression> checkedExpr = preProcess(lessThanEqual);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        return BooleanLiteral.of(((ComparableLiteral) lessThanEqual.left())
                .compareTo((ComparableLiteral) lessThanEqual.right()) <= 0);
    }

    @Override
    public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewriteContext context) {
        nullSafeEqual = rewriteChildren(nullSafeEqual, context);
        Optional<Expression> checkedExpr = preProcess(nullSafeEqual);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        Literal l = (Literal) nullSafeEqual.left();
        Literal r = (Literal) nullSafeEqual.right();
        if (l.isNullLiteral() && r.isNullLiteral()) {
            return BooleanLiteral.TRUE;
        } else if (!l.isNullLiteral() && !r.isNullLiteral()) {
            if (nullSafeEqual.left() instanceof ComparableLiteral
                    && nullSafeEqual.right() instanceof ComparableLiteral) {
                return BooleanLiteral.of(((ComparableLiteral) nullSafeEqual.left())
                        .compareTo((ComparableLiteral) nullSafeEqual.right()) == 0);
            } else {
                return BooleanLiteral.of(l.equals(r));
            }
        } else {
            return BooleanLiteral.FALSE;
        }
    }

    @Override
    public Expression visitNot(Not not, ExpressionRewriteContext context) {
        not = rewriteChildren(not, context);
        Optional<Expression> checkedExpr = preProcess(not);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        return BooleanLiteral.of(!((BooleanLiteral) not.child()).getValue());
    }

    @Override
    public Expression visitDatabase(Database database, ExpressionRewriteContext context) {
        String res = ClusterNamespace.getNameFromFullName(context.cascadesContext.getConnectContext().getDatabase());
        return new VarcharLiteral(res);
    }

    @Override
    public Expression visitCurrentUser(CurrentUser currentUser, ExpressionRewriteContext context) {
        String res = context.cascadesContext.getConnectContext().getCurrentUserIdentity().toString();
        return new VarcharLiteral(res);
    }

    @Override
    public Expression visitCurrentCatalog(CurrentCatalog currentCatalog, ExpressionRewriteContext context) {
        String res = context.cascadesContext.getConnectContext().getDefaultCatalog();
        return new VarcharLiteral(res);
    }

    @Override
    public Expression visitUser(User user, ExpressionRewriteContext context) {
        String res = context.cascadesContext.getConnectContext().getUserWithLoginRemoteIpString();
        return new VarcharLiteral(res);
    }

    @Override
    public Expression visitSessionUser(SessionUser user, ExpressionRewriteContext context) {
        String res = context.cascadesContext.getConnectContext().getUserWithLoginRemoteIpString();
        return new VarcharLiteral(res);
    }

    @Override
    public Expression visitLastQueryId(LastQueryId queryId, ExpressionRewriteContext context) {
        String res = "Not Available";
        TUniqueId id = context.cascadesContext.getConnectContext().getLastQueryId();
        if (id != null) {
            res = DebugUtil.printId(id);
        }
        return new VarcharLiteral(res);
    }

    @Override
    public Expression visitConnectionId(ConnectionId connectionId, ExpressionRewriteContext context) {
        return new BigIntLiteral(context.cascadesContext.getConnectContext().getConnectionId());
    }

    @Override
    public Expression visitAnd(And and, ExpressionRewriteContext context) {
        List<Expression> nonTrueLiteral = Lists.newArrayList();
        int nullCount = 0;
        for (Expression e : and.children()) {
            e = deepRewrite ? e.accept(this, context) : e;
            if (BooleanLiteral.FALSE.equals(e)) {
                return BooleanLiteral.FALSE;
            } else if (e instanceof NullLiteral) {
                nullCount++;
                nonTrueLiteral.add(e);
            } else if (!BooleanLiteral.TRUE.equals(e)) {
                nonTrueLiteral.add(e);
            }
        }

        if (nullCount == 0) {
            switch (nonTrueLiteral.size()) {
                case 0:
                    // true and true
                    return BooleanLiteral.TRUE;
                case 1:
                    // true and x
                    return nonTrueLiteral.get(0);
                default:
                    // x and y
                    return and.withChildren(nonTrueLiteral);
            }
        } else if (nullCount < and.children().size()) {
            if (nonTrueLiteral.size() == 1) {
                return nonTrueLiteral.get(0);
            } else {
                // null and x
                return and.withChildren(nonTrueLiteral);
            }
        } else {
            // null and null and null and ...
            return new NullLiteral(BooleanType.INSTANCE);
        }
    }

    @Override
    public Expression visitOr(Or or, ExpressionRewriteContext context) {
        List<Expression> nonFalseLiteral = Lists.newArrayList();
        int nullCount = 0;
        for (Expression e : or.children()) {
            e = deepRewrite ? e.accept(this, context) : e;
            if (BooleanLiteral.TRUE.equals(e)) {
                return BooleanLiteral.TRUE;
            } else if (e instanceof NullLiteral) {
                nullCount++;
                nonFalseLiteral.add(e);
            } else if (!BooleanLiteral.FALSE.equals(e)) {
                nonFalseLiteral.add(e);
            }
        }

        if (nullCount == 0) {
            switch (nonFalseLiteral.size()) {
                case 0:
                    // false or false
                    return BooleanLiteral.FALSE;
                case 1:
                    // false or x
                    return nonFalseLiteral.get(0);
                default:
                    // x or y
                    return or.withChildren(nonFalseLiteral);
            }
        } else if (nullCount < nonFalseLiteral.size()) {
            if (nonFalseLiteral.size() == 1) {
                // null or false
                return nonFalseLiteral.get(0);
            }
            // null or x
            return or.withChildren(nonFalseLiteral);
        } else {
            // null or null
            return new NullLiteral(BooleanType.INSTANCE);
        }
    }

    @Override
    public Expression visitCast(Cast cast, ExpressionRewriteContext context) {
        cast = rewriteChildren(cast, context);
        Optional<Expression> checkedExpr = preProcess(cast);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        Expression child = cast.child();
        DataType dataType = cast.getDataType();
        if (!safeToCast(cast)) {
            return cast;
        }
        // todo: process other null case
        if (child.isNullLiteral()) {
            return new NullLiteral(dataType);
        } else if (child instanceof StringLikeLiteral && dataType instanceof DateLikeType) {
            String dateStr = ((StringLikeLiteral) child).getStringValue();
            if (!DateTimeChecker.isValidDateTime(dateStr)) {
                return cast;
            }
            try {
                return ((DateLikeType) dataType).fromString(dateStr);
            } catch (Exception t) {
                return cast;
            }
        }
        try {
            Expression castResult = child.checkedCastTo(dataType);
            if (!Objects.equals(castResult, cast) && !Objects.equals(castResult, child)) {
                castResult = rewrite(castResult, context);
            }
            return castResult;
        } catch (Throwable t) {
            return cast;
        }
    }

    // Check if the given literal value is safe to cast to the targetType.
    // We need to guarantee FE cast result is identical with BE cast result.
    // Otherwise, it's not safe.
    protected boolean safeToCast(Cast cast) {
        if (cast == null || cast.child() == null || cast.getDataType() == null) {
            return true;
        }
        // Check double type.
        if (cast.child() instanceof DoubleLiteral && cast.getDataType().isStringLikeType()) {
            Double value = ((DoubleLiteral) cast.child()).getValue();
            if (value.isInfinite() || value.isNaN()) {
                return true;
            }
            return -1E16 < value && value < 1E16;
        }
        // Check other types if needed.
        return true;
    }

    @Override
    public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewriteContext context) {
        if (!boundFunction.foldable()) {
            return boundFunction;
        }
        boundFunction = rewriteChildren(boundFunction, context);
        Optional<Expression> checkedExpr = preProcess(boundFunction);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        return ExpressionEvaluator.INSTANCE.eval(boundFunction);
    }

    @Override
    public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, ExpressionRewriteContext context) {
        binaryArithmetic = rewriteChildren(binaryArithmetic, context);
        Optional<Expression> checkedExpr = preProcess(binaryArithmetic);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        return ExpressionEvaluator.INSTANCE.eval(binaryArithmetic);
    }

    @Override
    public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) {
        CaseWhen originCaseWhen = caseWhen;
        caseWhen = rewriteChildren(caseWhen, context);
        Expression newDefault = null;
        boolean foundNewDefault = false;

        List<WhenClause> whenClauses = new ArrayList<>();
        for (WhenClause whenClause : caseWhen.getWhenClauses()) {
            Expression whenOperand = whenClause.getOperand();

            if (!(whenOperand.isLiteral())) {
                whenClauses.add(new WhenClause(whenOperand, whenClause.getResult()));
            } else if (BooleanLiteral.TRUE.equals(whenOperand)) {
                foundNewDefault = true;
                newDefault = whenClause.getResult();
                break;
            }
        }

        Expression defaultResult = null;
        if (caseWhen.getDefaultValue().isPresent()) {
            defaultResult = caseWhen.getDefaultValue().get();
        }
        if (foundNewDefault) {
            defaultResult = newDefault;
        }
        if (whenClauses.isEmpty()) {
            return TypeCoercionUtils.ensureSameResultType(
                    originCaseWhen, defaultResult == null ? new NullLiteral(caseWhen.getDataType()) : defaultResult,
                    context
            );
        }
        if (defaultResult == null) {
            if (caseWhen.getDataType().isNullType()) {
                // if caseWhen's type is NULL_TYPE, means all possible return values are nulls
                // it's safe to return null literal here
                return new NullLiteral();
            } else {
                return TypeCoercionUtils.ensureSameResultType(originCaseWhen, new CaseWhen(whenClauses), context);
            }
        }
        return TypeCoercionUtils.ensureSameResultType(
                originCaseWhen, new CaseWhen(whenClauses, defaultResult), context
        );
    }

    @Override
    public Expression visitIf(If ifExpr, ExpressionRewriteContext context) {
        If originIf = ifExpr;
        ifExpr = rewriteChildren(ifExpr, context);
        if (ifExpr.child(0) instanceof NullLiteral || ifExpr.child(0).equals(BooleanLiteral.FALSE)) {
            return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.child(2), context);
        } else if (ifExpr.child(0).equals(BooleanLiteral.TRUE)) {
            return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.child(1), context);
        }
        return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr, context);
    }

    @Override
    public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
        inPredicate = rewriteChildren(inPredicate, context);
        Optional<Expression> checkedExpr = preProcess(inPredicate);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        // now the inPredicate contains literal only.
        Expression value = inPredicate.child(0);
        if (value.isNullLiteral()) {
            return new NullLiteral(BooleanType.INSTANCE);
        }
        boolean isOptionContainsNull = false;
        for (Expression item : inPredicate.getOptions()) {
            if (value.equals(item)) {
                return BooleanLiteral.TRUE;
            } else if (item.isNullLiteral()) {
                isOptionContainsNull = true;
            }
        }
        return isOptionContainsNull
                ? new NullLiteral(BooleanType.INSTANCE)
                : BooleanLiteral.FALSE;
    }

    @Override
    public Expression visitIsNull(IsNull isNull, ExpressionRewriteContext context) {
        isNull = rewriteChildren(isNull, context);
        Optional<Expression> checkedExpr = preProcess(isNull);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        return Literal.of(isNull.child().nullable());
    }

    @Override
    public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, ExpressionRewriteContext context) {
        arithmetic = rewriteChildren(arithmetic, context);
        Optional<Expression> checkedExpr = preProcess(arithmetic);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        return ExpressionEvaluator.INSTANCE.eval(arithmetic);
    }

    @Override
    public Expression visitPassword(Password password, ExpressionRewriteContext context) {
        Preconditions.checkArgument(password.child(0) instanceof StringLikeLiteral,
                "argument of password must be string literal");
        String s = ((StringLikeLiteral) password.child()).value;
        return new StringLiteral("*" + DigestUtils.sha1Hex(
                DigestUtils.sha1(s.getBytes())).toUpperCase());
    }

    @Override
    public Expression visitArray(Array array, ExpressionRewriteContext context) {
        array = rewriteChildren(array, context);
        Optional<Expression> checkedExpr = preProcess(array);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        List<Literal> arguments = (List) array.getArguments();
        // we should pass dataType to constructor because arguments maybe empty
        return new ArrayLiteral(arguments, array.getDataType());
    }

    @Override
    public Expression visitDate(Date date, ExpressionRewriteContext context) {
        date = rewriteChildren(date, context);
        Optional<Expression> checkedExpr = preProcess(date);
        if (checkedExpr.isPresent()) {
            return checkedExpr.get();
        }
        Literal child = (Literal) date.child();
        if (child instanceof NullLiteral) {
            return new NullLiteral(date.getDataType());
        }
        DataType dataType = child.getDataType();
        if (dataType.isDateTimeType()) {
            DateTimeLiteral dateTimeLiteral = (DateTimeLiteral) child;
            return new DateLiteral(dateTimeLiteral.getYear(), dateTimeLiteral.getMonth(), dateTimeLiteral.getDay());
        } else if (dataType.isDateTimeV2Type()) {
            DateTimeV2Literal dateTimeLiteral = (DateTimeV2Literal) child;
            return new DateV2Literal(dateTimeLiteral.getYear(), dateTimeLiteral.getMonth(), dateTimeLiteral.getDay());
        }
        return date;
    }

    @Override
    public Expression visitVersion(Version version, ExpressionRewriteContext context) {
        return new StringLiteral(GlobalVariable.version);
    }

    @Override
    public Expression visitNvl(Nvl nvl, ExpressionRewriteContext context) {
        Nvl originNvl = nvl;
        nvl = rewriteChildren(nvl, context);

        for (Expression expr : nvl.children()) {
            if (expr.isLiteral()) {
                if (!expr.isNullLiteral()) {
                    return TypeCoercionUtils.ensureSameResultType(originNvl, expr, context);
                }
            } else {
                return TypeCoercionUtils.ensureSameResultType(originNvl, nvl, context);
            }
        }
        // all nulls
        return TypeCoercionUtils.ensureSameResultType(originNvl, nvl.child(0), context);
    }

    private <E extends Expression> E rewriteChildren(E expr, ExpressionRewriteContext context) {
        if (!deepRewrite) {
            return expr;
        }
        switch (expr.arity()) {
            case 1: {
                Expression originChild = expr.child(0);
                Expression newChild = originChild.accept(this, context);
                return (originChild != newChild) ? (E) expr.withChildren(ImmutableList.of(newChild)) : expr;
            }
            case 2: {
                Expression originLeft = expr.child(0);
                Expression newLeft = originLeft.accept(this, context);
                Expression originRight = expr.child(1);
                Expression newRight = originRight.accept(this, context);
                return (originLeft != newLeft || originRight != newRight)
                        ? (E) expr.withChildren(ImmutableList.of(newLeft, newRight))
                        : expr;
            }
            case 0: {
                return expr;
            }
            default: {
                boolean hasNewChildren = false;
                Builder<Expression> newChildren = ImmutableList.builderWithExpectedSize(expr.arity());
                for (Expression child : expr.children()) {
                    Expression newChild = child.accept(this, context);
                    if (newChild != child) {
                        hasNewChildren = true;
                    }
                    newChildren.add(newChild);
                }
                return hasNewChildren ? (E) expr.withChildren(newChildren.build()) : expr;
            }
        }
    }

    private Optional<Expression> preProcess(Expression expression) {
        if (expression instanceof AggregateFunction || expression instanceof TableGeneratingFunction) {
            return Optional.of(expression);
        }
        if (ExpressionUtils.hasNullLiteral(expression.getArguments())
                && (expression instanceof PropagateNullLiteral || expression instanceof PropagateNullable)) {
            return Optional.of(new NullLiteral(expression.getDataType()));
        }
        if (!ExpressionUtils.isAllLiteral(expression.getArguments())) {
            return Optional.of(expression);
        }
        return Optional.empty();
    }

    private static class ListenAggDistinct implements ExpressionTraverseListener<Expression> {
        @Override
        public void onEnter(ExpressionMatchingContext<Expression> context) {
            context.cascadesContext.incrementDistinctAggLevel();
        }

        @Override
        public void onExit(ExpressionMatchingContext<Expression> context, Expression rewritten) {
            context.cascadesContext.decrementDistinctAggLevel();
        }
    }

    private static class CheckWhetherUnderAggDistinct implements Predicate<ExpressionMatchingContext<Expression>> {
        @Override
        public boolean test(ExpressionMatchingContext<Expression> context) {
            return context.cascadesContext.getDistinctAggLevel() == 0;
        }

        public <E extends Expression> Predicate<ExpressionMatchingContext<E>> as() {
            return (Predicate) this;
        }
    }

    private <E extends Expression> ExpressionPatternMatcher<? extends Expression> matches(
            Class<E> clazz, BiFunction<E, ExpressionRewriteContext, Expression> visitMethod) {
        return matchesType(clazz)
                .whenCtx(ctx -> !ctx.cascadesContext.getConnectContext().getSessionVariable()
                        .isDebugSkipFoldConstant())
                .whenCtx(NOT_UNDER_AGG_DISTINCT.as())
                .thenApply(ctx -> visitMethod.apply(ctx.expr, ctx.rewriteContext))
                .toRule(ExpressionRuleType.FOLD_CONSTANT_ON_FE);
    }
}