FoldConstantsRule.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/apache/impala/blob/branch-2.9.0/fe/src/main/java/org/apache/impala/FoldConstantsRule.java
// and modified by Doris

package org.apache.doris.rewrite;


import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.ArithmeticExpr;
import org.apache.doris.analysis.BetweenPredicate;
import org.apache.doris.analysis.CaseExpr;
import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.InformationFunction;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.analysis.NullLiteral;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.TimestampArithmeticExpr;
import org.apache.doris.analysis.VariableExpr;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.LoadException;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.common.util.TimeUtils;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.Types.PScalarType;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.VariableMgr;
import org.apache.doris.rpc.BackendServiceProxy;
import org.apache.doris.system.Backend;
import org.apache.doris.thrift.TExpr;
import org.apache.doris.thrift.TFoldConstantParams;
import org.apache.doris.thrift.TNetworkAddress;
import org.apache.doris.thrift.TPrimitiveType;
import org.apache.doris.thrift.TQueryGlobals;
import org.apache.doris.thrift.TQueryOptions;

import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.time.LocalDateTime;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

/**
 * This rule replaces a constant Expr with its equivalent LiteralExpr by evaluating the
 * Expr in the BE. Exprs that are already LiteralExprs are not changed.
 *
 * TODO: Expressions fed into this rule are currently not required to be analyzed
 * in order to support constant folding in expressions that contain unresolved
 * references to select-list aliases (such expressions cannot be analyzed).
 * The cross-dependencies between rule transformations and analysis are vague at the
 * moment and make rule application overly complicated.
 *
 * Examples:
 * 1 + 1 + 1 --> 3
 * toupper('abc') --> 'ABC'
 * cast('2016-11-09' as timestamp) --> TIMESTAMP '2016-11-09 00:00:00'
 */
public class FoldConstantsRule implements ExprRewriteRule {
    private static final Logger LOG = LogManager.getLogger(FoldConstantsRule.class);

    public static ExprRewriteRule INSTANCE = new FoldConstantsRule();

    @Override
    public Expr apply(Expr expr, Analyzer analyzer, ExprRewriter.ClauseType clauseType) throws AnalysisException {
        // evaluate `case when expr` when possible
        if (expr instanceof CaseExpr) {
            return CaseExpr.computeCaseExpr((CaseExpr) expr);
        }

        // Avoid calling Expr.isConstant() because that would lead to repeated traversals
        // of the Expr tree. Assumes the bottom-up application of this rule. Constant
        // children should have been folded at this point.
        for (Expr child : expr.getChildren()) {
            if (!child.isLiteral() && !(child instanceof CastExpr) && !((child instanceof FunctionCallExpr
                    || child instanceof ArithmeticExpr || child instanceof TimestampArithmeticExpr
                    || child instanceof VariableExpr))) {
                return expr;
            }
        }

        if (expr.isLiteral() || !expr.isConstant()) {
            return expr;
        }

        // Do not constant fold cast(null as dataType) because we cannot preserve the
        // cast-to-types and that can lead to query failures, e.g., CTAS
        if (expr instanceof CastExpr) {
            CastExpr castExpr = (CastExpr) expr;
            if (castExpr.isNotFold()) {
                return castExpr;
            }
            if (castExpr.getChild(0) instanceof NullLiteral) {
                return castExpr.getChild(0);
            }
        }

        // Analyze constant exprs, if necessary. Note that the 'expr' may become non-constant
        // after analysis (e.g., aggregate functions).
        if (!expr.isAnalyzed()) {
            expr.analyze(analyzer);
            if (!expr.isConstant()) {
                return expr;
            }
        }
        // it may be wrong to fold constant value in inline view
        // so pass the info to getResultValue method to let predicate itself
        // to decide if it can fold constant value safely
        return expr.getResultValue(expr instanceof SlotRef ? false : analyzer.isInlineViewAnalyzer());
    }

    /**
     * fold constant expr by BE
     * SysVariableDesc and InformationFunction need handling specially
     * @param exprMap
     * @param analyzer
     * @return
     * @throws AnalysisException
     */
    public boolean apply(Map<String, Expr> exprMap, Analyzer analyzer, boolean changed, TQueryOptions tQueryOptions)
            throws AnalysisException {
        // root_expr_id_string:
        //     child_expr_id_string : texpr
        //     child_expr_id_string : texpr
        Map<String, Map<String, TExpr>> paramMap = new HashMap<>();
        Map<String, Expr> allConstMap = new HashMap<>();
        // map to collect SysVariableDesc
        Map<String, Map<String, Expr>> sysVarsMap = new HashMap<>();
        // map to collect InformationFunction
        Map<String, Map<String, Expr>> infoFnsMap = new HashMap<>();
        for (Map.Entry<String, Expr> entry : exprMap.entrySet()) {
            Map<String, TExpr> constMap = new HashMap<>();
            Map<String, Expr> oriConstMap = new HashMap<>();
            Map<String, Expr> sysVarMap = new HashMap<>();
            Map<String, Expr> infoFnMap = new HashMap<>();
            getConstExpr(entry.getValue(), constMap, oriConstMap, analyzer, sysVarMap, infoFnMap);

            if (!constMap.isEmpty()) {
                paramMap.put(entry.getKey(), constMap);
                allConstMap.putAll(oriConstMap);
            }
            if (!sysVarMap.isEmpty()) {
                sysVarsMap.put(entry.getKey(), sysVarMap);
            }
            if (!infoFnMap.isEmpty()) {
                infoFnsMap.put(entry.getKey(), infoFnMap);
            }
        }

        if (!sysVarsMap.isEmpty()) {
            putBackConstExpr(exprMap, sysVarsMap);
            changed = true;
        }

        if (!infoFnsMap.isEmpty()) {
            putBackConstExpr(exprMap, infoFnsMap);
            changed = true;
        }

        if (!paramMap.isEmpty()) {
            Map<String, Map<String, Expr>> resultMap = calcConstExpr(paramMap, allConstMap, analyzer.getContext(),
                    tQueryOptions);

            if (!resultMap.isEmpty()) {
                putBackConstExpr(exprMap, resultMap);
                changed = true;

            }

        }
        return changed;
    }

    /**
     * get all constant children expr from a expr
     * @param expr
     * @param constExprMap
     * @param analyzer
     * @throws AnalysisException
     */
    // public only for unit test
    public void getConstExpr(Expr expr, Map<String, TExpr> constExprMap, Map<String, Expr> oriConstMap,
            Analyzer analyzer, Map<String, Expr> sysVarMap, Map<String, Expr> infoFnMap)
            throws AnalysisException {
        if (expr.isConstant()) {
            // Do not constant fold cast(null as dataType) because we cannot preserve the
            // cast-to-types and that can lead to query failures, e.g., CTAS
            if (expr instanceof CastExpr) {
                CastExpr castExpr = (CastExpr) expr;
                if (castExpr.getChild(0) instanceof NullLiteral) {
                    return;
                }
            }
            // skip literal expr
            if (expr instanceof LiteralExpr) {
                return;
            }
            // skip BetweenPredicate need to be rewrite to CompoundPredicate
            if (expr instanceof BetweenPredicate) {
                return;
            }
            // collect sysVariableDesc expr
            if (expr.contains(Predicates.instanceOf(VariableExpr.class))) {
                getSysVarDescExpr(expr, sysVarMap);
                return;
            }
            // collect InformationFunction
            if (expr.contains(Predicates.instanceOf(InformationFunction.class))) {
                getInfoFnExpr(expr, infoFnMap);
                return;
            }
            constExprMap.put(expr.getId().toString(), expr.treeToThrift());
            oriConstMap.put(expr.getId().toString(), expr);
        } else {
            recursiveGetChildrenConstExpr(expr, constExprMap, oriConstMap, analyzer, sysVarMap, infoFnMap);

        }
    }

    private void recursiveGetChildrenConstExpr(Expr expr, Map<String, TExpr> constExprMap,
            Map<String, Expr> oriConstMap, Analyzer analyzer, Map<String, Expr> sysVarMap, Map<String, Expr> infoFnMap)
            throws AnalysisException {
        for (int i = 0; i < expr.getChildren().size(); i++) {
            final Expr child = expr.getChildren().get(i);
            getConstExpr(child, constExprMap, oriConstMap, analyzer, sysVarMap, infoFnMap);
        }

    }

    private void getSysVarDescExpr(Expr expr, Map<String, Expr> sysVarMap) {
        if (expr instanceof VariableExpr) {
            Expr literalExpr = ((VariableExpr) expr).getLiteralExpr();
            if (literalExpr == null) {
                try {
                    VariableMgr.fillValue(ConnectContext.get().getSessionVariable(), (VariableExpr) expr);
                    literalExpr = ((VariableExpr) expr).getLiteralExpr();
                } catch (AnalysisException e) {
                    if (ConnectContext.get() != null) {
                        ConnectContext.get().getState().reset();
                    }
                    LOG.warn("failed to get session variable value: " + ((VariableExpr) expr).getName());
                }
            }
            sysVarMap.put(expr.getId().toString(), literalExpr);
        } else {
            for (Expr child : expr.getChildren()) {
                getSysVarDescExpr(child, sysVarMap);
            }
        }
    }

    private void getInfoFnExpr(Expr expr, Map<String, Expr> infoFnMap) {
        if (expr instanceof InformationFunction) {
            Type type = expr.getType();
            LiteralExpr literalExpr = null;
            try {
                String str = null;
                if (type.equals(Type.VARCHAR)) {
                    str = ((InformationFunction) expr).getStrValue();
                } else if (type.equals(Type.BIGINT)) {
                    str = ((InformationFunction) expr).getIntValue();
                }
                Preconditions.checkNotNull(str);
                literalExpr = LiteralExpr.create(str, type);
                infoFnMap.put(expr.getId().toString(), literalExpr);
            } catch (AnalysisException e) {
                if (ConnectContext.get() != null) {
                    ConnectContext.get().getState().reset();
                }
                LOG.warn("failed to get const expr value from InformationFunction: {}", e.getMessage());
            }

        } else {
            for (Expr child : expr.getChildren()) {
                getInfoFnExpr(child, infoFnMap);
            }
        }
    }

    /**
     * put all rewritten expr back to ori expr map
     * @param exprMap
     * @param resultMap
     */
    private void putBackConstExpr(Map<String, Expr> exprMap, Map<String, Map<String, Expr>> resultMap) {
        for (Map.Entry<String, Map<String, Expr>> entry : resultMap.entrySet()) {
            Expr rewrittenExpr = putBackConstExpr(exprMap.get(entry.getKey()), entry.getValue());
            exprMap.put(entry.getKey(), rewrittenExpr);
        }
    }

    private Expr putBackConstExpr(Expr expr, Map<String, Expr> resultMap) {
        for (Map.Entry<String, Expr> entry : resultMap.entrySet()) {
            if (entry.getValue() instanceof LiteralExpr) {
                expr = replaceExpr(expr, entry.getKey(), (LiteralExpr) entry.getValue());

            }
        }
        return expr;
    }

    /**
     * find and replace constant child expr of a expr by literal expr
     * @param expr
     * @param key
     * @param literalExpr
     * @return
     */
    private Expr replaceExpr(Expr expr, String key, LiteralExpr literalExpr) {
        if (expr.getId().toString().equals(key)) {
            return literalExpr;
        }
        // ATTN: make sure the child order of expr keep unchanged
        for (int i = 0; i < expr.getChildren().size(); i++) {
            Expr child = expr.getChild(i);
            if (!(child instanceof LiteralExpr) && literalExpr.equals(replaceExpr(child, key, literalExpr))) {
                literalExpr.setId(child.getId());
                expr.setChild(i, literalExpr);
                break;
            }
        }
        return expr;
    }

    /**
     * calc all constant exprs by BE
     * @param map
     * @param context
     * @return
     */
    private Map<String, Map<String, Expr>> calcConstExpr(Map<String, Map<String, TExpr>> map,
            Map<String, Expr> allConstMap,
            ConnectContext context, TQueryOptions tQueryOptions) {
        TNetworkAddress brpcAddress = null;
        Map<String, Map<String, Expr>> resultMap = new HashMap<>();
        try {
            List<Long> backendIds = Env.getCurrentSystemInfo().getAllBackendByCurrentCluster(true);
            if (backendIds.isEmpty()) {
                throw new LoadException("Failed to get all partitions. No alive backends");
            }
            Collections.shuffle(backendIds);
            Backend be = Env.getCurrentSystemInfo().getBackend(backendIds.get(0));
            brpcAddress = new TNetworkAddress(be.getHost(), be.getBrpcPort());

            TQueryGlobals queryGlobals = new TQueryGlobals();
            queryGlobals.setNowString(TimeUtils.getDatetimeFormatWithTimeZone().format(LocalDateTime.now()));
            queryGlobals.setTimestampMs(System.currentTimeMillis());
            queryGlobals.setNanoSeconds(LocalDateTime.now().getNano());
            queryGlobals.setTimeZone(TimeUtils.DEFAULT_TIME_ZONE);
            if (context.getSessionVariable().getTimeZone().equals("CST")) {
                queryGlobals.setTimeZone(TimeUtils.DEFAULT_TIME_ZONE);
            } else {
                queryGlobals.setTimeZone(context.getSessionVariable().getTimeZone());
            }

            TFoldConstantParams tParams = new TFoldConstantParams(map, queryGlobals);
            tParams.setVecExec(true);
            tParams.setQueryOptions(tQueryOptions);
            tParams.setQueryId(context.queryId());
            tParams.setIsNereids(false);

            Future<InternalService.PConstantExprResult> future
                    = BackendServiceProxy.getInstance().foldConstantExpr(brpcAddress, tParams);
            InternalService.PConstantExprResult result = future.get(5, TimeUnit.SECONDS);

            if (result.getStatus().getStatusCode() == 0) {
                for (Map.Entry<String, InternalService.PExprResultMap> entry
                        : result.getExprResultMapMap().entrySet()) {
                    Map<String, Expr> tmp = new HashMap<>();
                    for (Map.Entry<String, InternalService.PExprResult> entry1
                            : entry.getValue().getMapMap().entrySet()) {
                        PScalarType scalarType = entry1.getValue().getType();
                        TPrimitiveType ttype = TPrimitiveType.findByValue(scalarType.getType());
                        Expr retExpr = null;
                        if (entry1.getValue().getSuccess()) {
                            Type type = null;
                            if (ttype == TPrimitiveType.CHAR) {
                                Preconditions.checkState(scalarType.hasLen());
                                type = ScalarType.createCharType(scalarType.getLen());
                            } else if (ttype == TPrimitiveType.VARCHAR) {
                                Preconditions.checkState(scalarType.hasLen());
                                type = ScalarType.createVarcharType(scalarType.getLen());
                            } else if (ttype == TPrimitiveType.DECIMALV2) {
                                type = ScalarType.createDecimalType(scalarType.getPrecision(),
                                        scalarType.getScale());
                            } else if (ttype == TPrimitiveType.DATETIMEV2) {
                                type = ScalarType.createDatetimeV2Type(scalarType.getScale());
                            } else if (ttype == TPrimitiveType.DECIMAL32
                                    || ttype == TPrimitiveType.DECIMAL64
                                    || ttype == TPrimitiveType.DECIMAL128I
                                    || ttype == TPrimitiveType.DECIMAL256) {
                                type = ScalarType.createDecimalV3Type(scalarType.getPrecision(),
                                        scalarType.getScale());
                            } else {
                                type = ScalarType.createType(
                                        PrimitiveType.fromThrift(ttype));
                            }
                            retExpr = LiteralExpr.create(entry1.getValue().getContent(),
                                    type);
                        } else {
                            retExpr = allConstMap.get(entry1.getKey());
                        }
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("retExpr: " + retExpr.toString());
                        }
                        tmp.put(entry1.getKey(), retExpr);
                    }
                    if (!tmp.isEmpty()) {
                        resultMap.put(entry.getKey(), tmp);
                    }
                }

            } else {
                LOG.warn("failed_fold_context.queryId(): " + DebugUtil.printId(context.queryId()));
                LOG.warn("failed to get const expr value from be: {}", result.getStatus().getErrorMsgsList());
            }
        } catch (Exception e) {
            LOG.warn("failed_fold_context.queryId(): " + DebugUtil.printId(context.queryId()));
            LOG.warn("failed to get const expr value from be: {}", e.getMessage());
        }
        return resultMap;
    }
}