CreateFunctionStmt.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.analysis;

import org.apache.doris.catalog.AggregateFunction;
import org.apache.doris.catalog.AliasFunction;
import org.apache.doris.catalog.ArrayType;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.Function.NullableMode;
import org.apache.doris.catalog.FunctionUtil;
import org.apache.doris.catalog.MapType;
import org.apache.doris.catalog.ScalarFunction;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.StructType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.Config;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.ErrorReport;
import org.apache.doris.common.FeConstants;
import org.apache.doris.common.UserException;
import org.apache.doris.common.util.URI;
import org.apache.doris.common.util.Util;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.proto.FunctionService;
import org.apache.doris.proto.PFunctionServiceGrpc;
import org.apache.doris.proto.Types;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TFunctionBinaryType;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSortedMap;
import io.grpc.ManagedChannel;
import io.grpc.netty.NettyChannelBuilder;
import org.apache.commons.codec.binary.Hex;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

// create a user define function
public class CreateFunctionStmt extends DdlStmt implements NotFallbackInParser {
    @Deprecated
    public static final String OBJECT_FILE_KEY = "object_file";
    public static final String FILE_KEY = "file";
    public static final String SYMBOL_KEY = "symbol";
    public static final String PREPARE_SYMBOL_KEY = "prepare_fn";
    public static final String CLOSE_SYMBOL_KEY = "close_fn";
    public static final String MD5_CHECKSUM = "md5";
    public static final String INIT_KEY = "init_fn";
    public static final String UPDATE_KEY = "update_fn";
    public static final String MERGE_KEY = "merge_fn";
    public static final String SERIALIZE_KEY = "serialize_fn";
    public static final String FINALIZE_KEY = "finalize_fn";
    public static final String GET_VALUE_KEY = "get_value_fn";
    public static final String REMOVE_KEY = "remove_fn";
    public static final String BINARY_TYPE = "type";
    public static final String EVAL_METHOD_KEY = "evaluate";
    public static final String CREATE_METHOD_NAME = "create";
    public static final String DESTROY_METHOD_NAME = "destroy";
    public static final String ADD_METHOD_NAME = "add";
    public static final String SERIALIZE_METHOD_NAME = "serialize";
    public static final String MERGE_METHOD_NAME = "merge";
    public static final String GETVALUE_METHOD_NAME = "getValue";
    public static final String STATE_CLASS_NAME = "State";
    // add for java udf check return type nullable mode, always_nullable or always_not_nullable
    public static final String IS_RETURN_NULL = "always_nullable";
    // iff is static load, BE will be cache the udf class load, so only need load once
    public static final String IS_STATIC_LOAD = "static_load";
    public static final String EXPIRATION_TIME = "expiration_time";
    private static final Logger LOG = LogManager.getLogger(CreateFunctionStmt.class);

    private SetType type = SetType.DEFAULT;
    private final boolean ifNotExists;
    private final FunctionName functionName;
    private final boolean isAggregate;
    private final boolean isAlias;
    private boolean isTableFunction;
    private final FunctionArgsDef argsDef;
    private final TypeDef returnType;
    private TypeDef intermediateType;
    private final Map<String, String> properties;
    private final List<String> parameters;
    private final Expr originFunction;
    TFunctionBinaryType binaryType = TFunctionBinaryType.JAVA_UDF;

    // needed item set after analyzed
    private String userFile;
    private Function function;
    private String checksum = "";
    private boolean isStaticLoad = false;
    private long expirationTime = 360; // default 6 hours = 360 minutes
    // now set udf default NullableMode is ALWAYS_NULLABLE
    // if not, will core dump when input is not null column, but need return null
    // like https://github.com/apache/doris/pull/14002/files
    private NullableMode returnNullMode = NullableMode.ALWAYS_NULLABLE;

    // timeout for both connection and read. 10 seconds is long enough.
    private static final int HTTP_TIMEOUT_MS = 10000;

    public CreateFunctionStmt(SetType type, boolean ifNotExists, boolean isAggregate, FunctionName functionName,
                              FunctionArgsDef argsDef,
                              TypeDef returnType, TypeDef intermediateType, Map<String, String> properties) {
        this.type = type;
        this.ifNotExists = ifNotExists;
        this.functionName = functionName;
        this.isAggregate = isAggregate;
        this.argsDef = argsDef;
        this.returnType = returnType;
        this.intermediateType = intermediateType;
        if (properties == null) {
            this.properties = ImmutableSortedMap.of();
        } else {
            this.properties = ImmutableSortedMap.copyOf(properties, String.CASE_INSENSITIVE_ORDER);
        }
        this.isAlias = false;
        this.isTableFunction = false;
        this.parameters = ImmutableList.of();
        this.originFunction = null;
    }

    public CreateFunctionStmt(SetType type, boolean ifNotExists, FunctionName functionName,
            FunctionArgsDef argsDef,
            TypeDef returnType, TypeDef intermediateType, Map<String, String> properties) {
        this(type, ifNotExists, false, functionName, argsDef, returnType, intermediateType, properties);
        this.isTableFunction = true;
    }

    public CreateFunctionStmt(SetType type, boolean ifNotExists, FunctionName functionName, FunctionArgsDef argsDef,
            List<String> parameters, Expr originFunction) {
        this.type = type;
        this.ifNotExists = ifNotExists;
        this.functionName = functionName;
        this.isAlias = true;
        this.argsDef = argsDef;
        if (parameters == null) {
            this.parameters = ImmutableList.of();
        } else {
            this.parameters = ImmutableList.copyOf(parameters);
        }
        this.originFunction = originFunction;
        this.isAggregate = false;
        this.isTableFunction = false;
        this.returnType = new TypeDef(Type.VARCHAR);
        this.properties = ImmutableSortedMap.of();
    }

    public SetType getType() {
        return type;
    }

    public boolean isIfNotExists() {
        return ifNotExists;
    }

    public FunctionName getFunctionName() {
        return functionName;
    }

    public Function getFunction() {
        return function;
    }

    public Expr getOriginFunction() {
        return originFunction;
    }

    @Override
    public void analyze(Analyzer analyzer) throws UserException {
        super.analyze(analyzer);

        // https://github.com/apache/doris/issues/17810
        // this error report in P0 test, so we suspect that it is related to concurrency
        // add this change to test it.
        if (Config.use_fuzzy_session_variable) {
            synchronized (CreateFunctionStmt.class) {
                analyzeCommon(analyzer);
                // check
                if (isAggregate) {
                    analyzeUda();
                } else if (isAlias) {
                    analyzeAliasFunction();
                } else if (isTableFunction) {
                    analyzeTableFunction();
                } else {
                    analyzeUdf();
                }
            }
        } else {
            analyzeCommon(analyzer);
            // check
            if (isAggregate) {
                analyzeUda();
            } else if (isAlias) {
                analyzeAliasFunction();
            } else if (isTableFunction) {
                analyzeTableFunction();
            } else {
                analyzeUdf();
            }
        }
    }

    private void analyzeCommon(Analyzer analyzer) throws AnalysisException {
        // check function name
        functionName.analyze(analyzer, this.type);

        // check operation privilege
        if (!Env.getCurrentEnv().getAccessManager().checkGlobalPriv(ConnectContext.get(), PrivPredicate.ADMIN)) {
            ErrorReport.reportAnalysisException(ErrorCode.ERR_SPECIFIC_ACCESS_DENIED_ERROR, "ADMIN");
        }
        // check argument
        argsDef.analyze(analyzer);

        // alias function does not need analyze following params
        if (isAlias) {
            return;
        }

        returnType.analyze(analyzer);
        if (intermediateType != null) {
            intermediateType.analyze(analyzer);
        } else {
            intermediateType = returnType;
        }

        String type = properties.getOrDefault(BINARY_TYPE, "JAVA_UDF");
        binaryType = getFunctionBinaryType(type);
        if (binaryType == null) {
            throw new AnalysisException("unknown function type");
        }
        if (type.equals("NATIVE")) {
            throw new AnalysisException("do not support 'NATIVE' udf type after doris version 1.2.0,"
                                    + "please use JAVA_UDF or RPC instead");
        }

        userFile = properties.getOrDefault(FILE_KEY, properties.get(OBJECT_FILE_KEY));
        if (!Strings.isNullOrEmpty(userFile) && binaryType != TFunctionBinaryType.RPC) {
            try {
                computeObjectChecksum();
            } catch (IOException | NoSuchAlgorithmException e) {
                throw new AnalysisException("cannot to compute object's checksum. err: " + e.getMessage());
            }
            String md5sum = properties.get(MD5_CHECKSUM);
            if (md5sum != null && !md5sum.equalsIgnoreCase(checksum)) {
                throw new AnalysisException("library's checksum is not equal with input, checksum=" + checksum);
            }
        }
        if (binaryType == TFunctionBinaryType.JAVA_UDF) {
            FunctionUtil.checkEnableJavaUdf();

            // always_nullable the default value is true, equal null means true
            Boolean isReturnNull = parseBooleanFromProperties(IS_RETURN_NULL);
            if (isReturnNull != null && !isReturnNull) {
                returnNullMode = NullableMode.ALWAYS_NOT_NULLABLE;
            }
            // static_load the default value is false, equal null means false
            Boolean staticLoad = parseBooleanFromProperties(IS_STATIC_LOAD);
            if (staticLoad != null && staticLoad) {
                isStaticLoad = true;
            }
            String expirationTimeString = properties.get(EXPIRATION_TIME);
            if (expirationTimeString != null) {
                long timeMinutes = Long.parseLong(expirationTimeString);
                if (timeMinutes <= 0) {
                    throw new AnalysisException("expirationTime should greater than zero: ");
                }
                this.expirationTime = timeMinutes;
            }
        }
    }

    private Boolean parseBooleanFromProperties(String propertyString) throws AnalysisException {
        String valueOfString = properties.get(propertyString);
        if (valueOfString == null) {
            return null;
        }
        if (!valueOfString.equalsIgnoreCase("false") && !valueOfString.equalsIgnoreCase("true")) {
            throw new AnalysisException(propertyString + " in properties, you should set it false or true");
        }
        return Boolean.parseBoolean(valueOfString);
    }

    private void computeObjectChecksum() throws IOException, NoSuchAlgorithmException {
        if (FeConstants.runningUnitTest) {
            // skip checking checksum when running ut
            return;
        }

        try (InputStream inputStream = Util.getInputStreamFromUrl(userFile, null, HTTP_TIMEOUT_MS, HTTP_TIMEOUT_MS)) {
            MessageDigest digest = MessageDigest.getInstance("MD5");
            byte[] buf = new byte[4096];
            int bytesRead = 0;
            do {
                bytesRead = inputStream.read(buf);
                if (bytesRead < 0) {
                    break;
                }
                digest.update(buf, 0, bytesRead);
            } while (true);

            checksum = Hex.encodeHexString(digest.digest());
        }
    }

    private void analyzeTableFunction() throws AnalysisException {
        String symbol = properties.get(SYMBOL_KEY);
        if (Strings.isNullOrEmpty(symbol)) {
            throw new AnalysisException("No 'symbol' in properties");
        }
        if (!returnType.getType().isArrayType()) {
            throw new AnalysisException("JAVA_UDF OF UDTF return type must be array type");
        }
        analyzeJavaUdf(symbol);
        URI location;
        if (!Strings.isNullOrEmpty(userFile)) {
            location = URI.create(userFile);
        } else {
            location = null;
        }
        function = ScalarFunction.createUdf(binaryType,
                functionName, argsDef.getArgTypes(),
                ((ArrayType) (returnType.getType())).getItemType(), argsDef.isVariadic(),
                location, symbol, null, null);
        function.setChecksum(checksum);
        function.setNullableMode(returnNullMode);
        function.setStaticLoad(isStaticLoad);
        function.setExpirationTime(expirationTime);
        function.setUDTFunction(true);
        // Todo: maybe in create tables function, need register two function, one is
        // normal and one is outer as those have different result when result is NULL.
    }

    private void analyzeUda() throws AnalysisException {
        AggregateFunction.AggregateFunctionBuilder builder
                = AggregateFunction.AggregateFunctionBuilder.createUdfBuilder();
        URI location;
        if (!Strings.isNullOrEmpty(userFile)) {
            location = URI.create(userFile);
        } else {
            location = null;
        }
        builder.name(functionName).argsType(argsDef.getArgTypes()).retType(returnType.getType())
                .hasVarArgs(argsDef.isVariadic()).intermediateType(intermediateType.getType())
                .location(location);
        String initFnSymbol = properties.get(INIT_KEY);
        if (initFnSymbol == null && !(binaryType == TFunctionBinaryType.JAVA_UDF
                || binaryType == TFunctionBinaryType.RPC)) {
            throw new AnalysisException("No 'init_fn' in properties");
        }
        String updateFnSymbol = properties.get(UPDATE_KEY);
        if (updateFnSymbol == null && !(binaryType == TFunctionBinaryType.JAVA_UDF)) {
            throw new AnalysisException("No 'update_fn' in properties");
        }
        String mergeFnSymbol = properties.get(MERGE_KEY);
        if (mergeFnSymbol == null && !(binaryType == TFunctionBinaryType.JAVA_UDF)) {
            throw new AnalysisException("No 'merge_fn' in properties");
        }
        String serializeFnSymbol = properties.get(SERIALIZE_KEY);
        String finalizeFnSymbol = properties.get(FINALIZE_KEY);
        String getValueFnSymbol = properties.get(GET_VALUE_KEY);
        String removeFnSymbol = properties.get(REMOVE_KEY);
        String symbol = properties.get(SYMBOL_KEY);
        if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) {
            if (initFnSymbol != null) {
                checkRPCUdf(initFnSymbol);
            }
            checkRPCUdf(updateFnSymbol);
            checkRPCUdf(mergeFnSymbol);
            if (serializeFnSymbol != null) {
                checkRPCUdf(serializeFnSymbol);
            }
            if (finalizeFnSymbol != null) {
                checkRPCUdf(finalizeFnSymbol);
            }
            if (getValueFnSymbol != null) {
                checkRPCUdf(getValueFnSymbol);
            }
            if (removeFnSymbol != null) {
                checkRPCUdf(removeFnSymbol);
            }
        } else if (binaryType == TFunctionBinaryType.JAVA_UDF) {
            if (Strings.isNullOrEmpty(symbol)) {
                throw new AnalysisException("No 'symbol' in properties of java-udaf");
            }
            analyzeJavaUdaf(symbol);
        }
        function = builder.initFnSymbol(initFnSymbol).updateFnSymbol(updateFnSymbol).mergeFnSymbol(mergeFnSymbol)
                .serializeFnSymbol(serializeFnSymbol).finalizeFnSymbol(finalizeFnSymbol)
                .getValueFnSymbol(getValueFnSymbol).removeFnSymbol(removeFnSymbol).symbolName(symbol).build();
        function.setLocation(location);
        function.setBinaryType(binaryType);
        function.setChecksum(checksum);
        function.setNullableMode(returnNullMode);
        function.setStaticLoad(isStaticLoad);
        function.setExpirationTime(expirationTime);
    }

    private void analyzeUdf() throws AnalysisException {
        String symbol = properties.get(SYMBOL_KEY);
        if (Strings.isNullOrEmpty(symbol)) {
            throw new AnalysisException("No 'symbol' in properties");
        }
        String prepareFnSymbol = properties.get(PREPARE_SYMBOL_KEY);
        String closeFnSymbol = properties.get(CLOSE_SYMBOL_KEY);
        // TODO(yangzhg) support check function in FE when function service behind load balancer
        // the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster
        if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) {
            if (StringUtils.isNotBlank(prepareFnSymbol) || StringUtils.isNotBlank(closeFnSymbol)) {
                throw new AnalysisException("prepare and close in RPC UDF are not supported.");
            }
            checkRPCUdf(symbol);
        } else if (binaryType == TFunctionBinaryType.JAVA_UDF) {
            analyzeJavaUdf(symbol);
        }
        URI location;
        if (!Strings.isNullOrEmpty(userFile)) {
            location = URI.create(userFile);
        } else {
            location = null;
        }
        function = ScalarFunction.createUdf(binaryType,
                functionName, argsDef.getArgTypes(),
                returnType.getType(), argsDef.isVariadic(),
                location, symbol, prepareFnSymbol, closeFnSymbol);
        function.setChecksum(checksum);
        function.setNullableMode(returnNullMode);
        function.setStaticLoad(isStaticLoad);
        function.setExpirationTime(expirationTime);
    }

    private void analyzeJavaUdaf(String clazz) throws AnalysisException {
        HashMap<String, Method> allMethods = new HashMap<>();

        try {
            if (Strings.isNullOrEmpty(userFile)) {
                try {
                    ClassLoader cl = this.getClass().getClassLoader();
                    checkUdafClass(clazz, cl, allMethods);
                    return;
                } catch (ClassNotFoundException e) {
                    throw new AnalysisException("Class [" + clazz + "] not found in classpath");
                }
            }
            URL[] urls = {new URL("jar:" + userFile + "!/")};
            try (URLClassLoader cl = URLClassLoader.newInstance(urls)) {
                checkUdafClass(clazz, cl, allMethods);
            } catch (ClassNotFoundException e) {
                throw new AnalysisException(
                        "Class [" + clazz + "] or inner class [State] not found in file :" + userFile);
            } catch (IOException e) {
                throw new AnalysisException("Failed to load file: " + userFile);
            }
        } catch (MalformedURLException e) {
            throw new AnalysisException("Failed to load file: " + userFile);
        }
    }

    private void checkUdafClass(String clazz, ClassLoader cl, HashMap<String, Method> allMethods)
            throws ClassNotFoundException, AnalysisException {
        Class udfClass = cl.loadClass(clazz);
        String udfClassName = udfClass.getCanonicalName();
        String stateClassName = udfClassName + "$" + STATE_CLASS_NAME;
        Class stateClass = cl.loadClass(stateClassName);

        for (Method m : udfClass.getMethods()) {
            if (!m.getDeclaringClass().equals(udfClass)) {
                continue;
            }
            String name = m.getName();
            if (allMethods.containsKey(name)) {
                throw new AnalysisException(
                        String.format("UDF class '%s' has multiple methods with name '%s' ", udfClassName,
                                name));
            }
            allMethods.put(name, m);
        }

        if (allMethods.get(CREATE_METHOD_NAME) == null) {
            throw new AnalysisException(
                    String.format("No method '%s' in class '%s'!", CREATE_METHOD_NAME, udfClassName));
        } else {
            checkMethodNonStaticAndPublic(CREATE_METHOD_NAME, allMethods.get(CREATE_METHOD_NAME), udfClassName);
            checkArgumentCount(allMethods.get(CREATE_METHOD_NAME), 0, udfClassName);
            checkReturnJavaType(udfClassName, allMethods.get(CREATE_METHOD_NAME), stateClass);
        }

        if (allMethods.get(DESTROY_METHOD_NAME) == null) {
            throw new AnalysisException(
                    String.format("No method '%s' in class '%s'!", DESTROY_METHOD_NAME, udfClassName));
        } else {
            checkMethodNonStaticAndPublic(DESTROY_METHOD_NAME, allMethods.get(DESTROY_METHOD_NAME),
                    udfClassName);
            checkArgumentCount(allMethods.get(DESTROY_METHOD_NAME), 1, udfClassName);
            checkReturnJavaType(udfClassName, allMethods.get(DESTROY_METHOD_NAME), void.class);
        }

        if (allMethods.get(ADD_METHOD_NAME) == null) {
            throw new AnalysisException(
                    String.format("No method '%s' in class '%s'!", ADD_METHOD_NAME, udfClassName));
        } else {
            checkMethodNonStaticAndPublic(ADD_METHOD_NAME, allMethods.get(ADD_METHOD_NAME), udfClassName);
            checkArgumentCount(allMethods.get(ADD_METHOD_NAME), argsDef.getArgTypes().length + 1, udfClassName);
            checkReturnJavaType(udfClassName, allMethods.get(ADD_METHOD_NAME), void.class);
            for (int i = 0; i < argsDef.getArgTypes().length; i++) {
                Parameter p = allMethods.get(ADD_METHOD_NAME).getParameters()[i + 1];
                checkUdfType(udfClass, allMethods.get(ADD_METHOD_NAME), argsDef.getArgTypes()[i], p.getType(),
                        p.getName());
            }
        }

        if (allMethods.get(SERIALIZE_METHOD_NAME) == null) {
            throw new AnalysisException(
                    String.format("No method '%s' in class '%s'!", SERIALIZE_METHOD_NAME, udfClassName));
        } else {
            checkMethodNonStaticAndPublic(SERIALIZE_METHOD_NAME, allMethods.get(SERIALIZE_METHOD_NAME),
                    udfClassName);
            checkArgumentCount(allMethods.get(SERIALIZE_METHOD_NAME), 2, udfClassName);
            checkReturnJavaType(udfClassName, allMethods.get(SERIALIZE_METHOD_NAME), void.class);
        }

        if (allMethods.get(MERGE_METHOD_NAME) == null) {
            throw new AnalysisException(
                    String.format("No method '%s' in class '%s'!", MERGE_METHOD_NAME, udfClassName));
        } else {
            checkMethodNonStaticAndPublic(MERGE_METHOD_NAME, allMethods.get(MERGE_METHOD_NAME), udfClassName);
            checkArgumentCount(allMethods.get(MERGE_METHOD_NAME), 2, udfClassName);
            checkReturnJavaType(udfClassName, allMethods.get(MERGE_METHOD_NAME), void.class);
        }

        if (allMethods.get(GETVALUE_METHOD_NAME) == null) {
            throw new AnalysisException(
                    String.format("No method '%s' in class '%s'!", GETVALUE_METHOD_NAME, udfClassName));
        } else {
            checkMethodNonStaticAndPublic(GETVALUE_METHOD_NAME, allMethods.get(GETVALUE_METHOD_NAME),
                    udfClassName);
            checkArgumentCount(allMethods.get(GETVALUE_METHOD_NAME), 1, udfClassName);
            checkReturnUdfType(udfClass, allMethods.get(GETVALUE_METHOD_NAME), returnType.getType());
        }

        if (!Modifier.isPublic(stateClass.getModifiers()) || !Modifier.isStatic(stateClass.getModifiers())) {
            throw new AnalysisException(
                    String.format(
                            "UDAF '%s' should have one public & static 'State' class to Construction data ",
                            udfClassName));
        }
    }

    private void checkMethodNonStaticAndPublic(String methoName, Method method, String udfClassName)
            throws AnalysisException {
        if (Modifier.isStatic(method.getModifiers())) {
            throw new AnalysisException(
                    String.format("Method '%s' in class '%s' should be non-static", methoName, udfClassName));
        }
        if (!Modifier.isPublic(method.getModifiers())) {
            throw new AnalysisException(
                    String.format("Method '%s' in class '%s' should be public", methoName, udfClassName));
        }
    }

    private void checkArgumentCount(Method method, int argumentCount, String udfClassName) throws AnalysisException {
        if (method.getParameters().length != argumentCount) {
            throw new AnalysisException(
                    String.format("The number of parameters for method '%s' in class '%s' should be %d",
                            method.getName(), udfClassName, argumentCount));
        }
    }

    private void checkReturnJavaType(String udfClassName, Method method, Class expType) throws AnalysisException {
        checkJavaType(udfClassName, method, expType, method.getReturnType(), "return");
    }

    private void checkJavaType(String udfClassName, Method method, Class expType, Class ptype, String pname)
            throws AnalysisException {
        if (!expType.equals(ptype)) {
            throw new AnalysisException(
                    String.format("UDF class '%s' method '%s' parameter %s[%s] expect type %s", udfClassName,
                            method.getName(), pname, ptype.getCanonicalName(), expType.getCanonicalName()));
        }
    }

    private void checkReturnUdfType(Class clazz, Method method, Type expType) throws AnalysisException {
        checkUdfType(clazz, method, expType, method.getReturnType(), "return");
    }

    private void analyzeJavaUdf(String clazz) throws AnalysisException {
        try {
            if (Strings.isNullOrEmpty(userFile)) {
                try {
                    ClassLoader cl = this.getClass().getClassLoader();
                    checkUdfClass(clazz, cl);
                    return;
                } catch (ClassNotFoundException e) {
                    throw new AnalysisException("Class [" + clazz + "] not found in classpath");
                }
            }
            URL[] urls = {new URL("jar:" + userFile + "!/")};
            try (URLClassLoader cl = URLClassLoader.newInstance(urls)) {
                checkUdfClass(clazz, cl);
            } catch (ClassNotFoundException e) {
                throw new AnalysisException("Class [" + clazz + "] not found in file :" + userFile);
            } catch (IOException e) {
                throw new AnalysisException("Failed to load file: " + userFile);
            }
        } catch (MalformedURLException e) {
            throw new AnalysisException("Failed to load file: " + userFile);
        }
    }

    private void checkUdfClass(String clazz, ClassLoader cl) throws ClassNotFoundException, AnalysisException {
        Class udfClass = cl.loadClass(clazz);
        List<Method> evalList = Arrays.stream(udfClass.getMethods())
                .filter(m -> m.getDeclaringClass().equals(udfClass) && EVAL_METHOD_KEY.equals(m.getName()))
                .collect(Collectors.toList());
        if (evalList.size() == 0) {
            throw new AnalysisException(String.format(
                "No method '%s' in class '%s'!", EVAL_METHOD_KEY, udfClass.getCanonicalName()));
        }
        List<Method> evalNonStaticAndPublicList = evalList.stream()
                .filter(m -> !Modifier.isStatic(m.getModifiers()) && Modifier.isPublic(m.getModifiers()))
                .collect(Collectors.toList());
        if (evalNonStaticAndPublicList.size() == 0) {
            throw new AnalysisException(
                String.format("Method '%s' in class '%s' should be non-static and public", EVAL_METHOD_KEY,
                    udfClass.getCanonicalName()));
        }
        List<Method> evalArgLengthMatchList = evalNonStaticAndPublicList.stream().filter(
                m -> m.getParameters().length == argsDef.getArgTypes().length).collect(Collectors.toList());
        if (evalArgLengthMatchList.size() == 0) {
            throw new AnalysisException(
                String.format("The number of parameters for method '%s' in class '%s' should be %d",
                    EVAL_METHOD_KEY, udfClass.getCanonicalName(), argsDef.getArgTypes().length));
        } else if (evalArgLengthMatchList.size() == 1) {
            Method method = evalArgLengthMatchList.get(0);
            checkUdfType(udfClass, method, returnType.getType(), method.getReturnType(), "return");
            for (int i = 0; i < method.getParameters().length; i++) {
                Parameter p = method.getParameters()[i];
                checkUdfType(udfClass, method, argsDef.getArgTypes()[i], p.getType(), p.getName());
            }
        } else {
            // If multiple methods have the same parameters,
            // the error message returned cannot be as specific as a single method
            boolean hasError = false;
            for (Method method : evalArgLengthMatchList) {
                try {
                    checkUdfType(udfClass, method, returnType.getType(), method.getReturnType(), "return");
                    for (int i = 0; i < method.getParameters().length; i++) {
                        Parameter p = method.getParameters()[i];
                        checkUdfType(udfClass, method, argsDef.getArgTypes()[i], p.getType(), p.getName());
                    }
                    hasError = false;
                    break;
                } catch (AnalysisException e) {
                    hasError = true;
                }
            }
            if (hasError) {
                throw new AnalysisException(String.format(
                    "Multi methods '%s' in class '%s' and no one passed parameter matching verification",
                    EVAL_METHOD_KEY, udfClass.getCanonicalName()));
            }
        }
    }

    private void checkUdfType(Class clazz, Method method, Type expType, Class pType, String pname)
            throws AnalysisException {
        Set<Class> javaTypes;
        if (expType instanceof ScalarType) {
            ScalarType scalarType = (ScalarType) expType;
            javaTypes = Type.PrimitiveTypeToJavaClassType.get(scalarType.getPrimitiveType());
        } else if (expType instanceof ArrayType) {
            ArrayType arrayType = (ArrayType) expType;
            javaTypes = Type.PrimitiveTypeToJavaClassType.get(arrayType.getPrimitiveType());
        } else if (expType instanceof MapType) {
            MapType mapType = (MapType) expType;
            javaTypes = Type.PrimitiveTypeToJavaClassType.get(mapType.getPrimitiveType());
        } else if (expType instanceof StructType) {
            StructType structType = (StructType) expType;
            javaTypes = Type.PrimitiveTypeToJavaClassType.get(structType.getPrimitiveType());
        } else {
            throw new AnalysisException(
                    String.format("Method '%s' in class '%s' does not support type '%s'",
                            method.getName(), clazz.getCanonicalName(), expType));
        }

        if (javaTypes == null) {
            throw new AnalysisException(
                    String.format("Method '%s' in class '%s' does not support type '%s'",
                            method.getName(), clazz.getCanonicalName(), expType.toString()));
        }
        if (!javaTypes.contains(pType)) {
            throw new AnalysisException(
                    String.format("UDF class '%s' method '%s' %s[%s] type is not supported!",
                            clazz.getCanonicalName(), method.getName(), pname, pType.getCanonicalName()));
        }
    }

    private void checkRPCUdf(String symbol) throws AnalysisException {
        // TODO(yangzhg) support check function in FE when function service behind load balancer
        // the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster
        String[] url = userFile.split(":");
        if (url.length != 2) {
            throw new AnalysisException("function server address invalid.");
        }
        String host = url[0];
        int port = Integer.valueOf(url[1]);
        ManagedChannel channel = NettyChannelBuilder.forAddress(host, port)
                .flowControlWindow(Config.grpc_max_message_size_bytes)
                .maxInboundMessageSize(Config.grpc_max_message_size_bytes)
                .enableRetry().maxRetryAttempts(3)
                .usePlaintext().build();
        PFunctionServiceGrpc.PFunctionServiceBlockingStub stub = PFunctionServiceGrpc.newBlockingStub(channel);
        FunctionService.PCheckFunctionRequest.Builder builder = FunctionService.PCheckFunctionRequest.newBuilder();
        builder.getFunctionBuilder().setFunctionName(symbol);
        for (Type arg : argsDef.getArgTypes()) {
            builder.getFunctionBuilder().addInputs(convertToPParameterType(arg));
        }
        builder.getFunctionBuilder().setOutput(convertToPParameterType(returnType.getType()));
        FunctionService.PCheckFunctionResponse response = stub.checkFn(builder.build());
        if (response == null || !response.hasStatus()) {
            throw new AnalysisException("cannot access function server");
        }
        if (response.getStatus().getStatusCode() != 0) {
            throw new AnalysisException("check function [" + symbol + "] failed: " + response.getStatus());
        }
    }

    private Types.PGenericType convertToPParameterType(Type arg) throws AnalysisException {
        Types.PGenericType.Builder typeBuilder = Types.PGenericType.newBuilder();
        switch (arg.getPrimitiveType()) {
            case INVALID_TYPE:
                typeBuilder.setId(Types.PGenericType.TypeId.UNKNOWN);
                break;
            case BOOLEAN:
                typeBuilder.setId(Types.PGenericType.TypeId.BOOLEAN);
                break;
            case SMALLINT:
                typeBuilder.setId(Types.PGenericType.TypeId.INT16);
                break;
            case TINYINT:
                typeBuilder.setId(Types.PGenericType.TypeId.INT8);
                break;
            case INT:
                typeBuilder.setId(Types.PGenericType.TypeId.INT32);
                break;
            case BIGINT:
                typeBuilder.setId(Types.PGenericType.TypeId.INT64);
                break;
            case FLOAT:
                typeBuilder.setId(Types.PGenericType.TypeId.FLOAT);
                break;
            case DOUBLE:
                typeBuilder.setId(Types.PGenericType.TypeId.DOUBLE);
                break;
            case CHAR:
            case VARCHAR:
                typeBuilder.setId(Types.PGenericType.TypeId.STRING);
                break;
            case HLL:
                typeBuilder.setId(Types.PGenericType.TypeId.HLL);
                break;
            case BITMAP:
                typeBuilder.setId(Types.PGenericType.TypeId.BITMAP);
                break;
            case QUANTILE_STATE:
                typeBuilder.setId(Types.PGenericType.TypeId.QUANTILE_STATE);
                break;
            case AGG_STATE:
                typeBuilder.setId(Types.PGenericType.TypeId.AGG_STATE);
                break;
            case DATE:
                typeBuilder.setId(Types.PGenericType.TypeId.DATE);
                break;
            case DATEV2:
                typeBuilder.setId(Types.PGenericType.TypeId.DATEV2);
                break;
            case DATETIME:
            case TIME:
                typeBuilder.setId(Types.PGenericType.TypeId.DATETIME);
                break;
            case DATETIMEV2:
            case TIMEV2:
                typeBuilder.setId(Types.PGenericType.TypeId.DATETIMEV2);
                break;
            case DECIMALV2:
            case DECIMAL128:
                typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL128)
                        .getDecimalTypeBuilder()
                        .setPrecision(((ScalarType) arg).getScalarPrecision())
                        .setScale(((ScalarType) arg).getScalarScale());
                break;
            case DECIMAL32:
                typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL32)
                        .getDecimalTypeBuilder()
                        .setPrecision(((ScalarType) arg).getScalarPrecision())
                        .setScale(((ScalarType) arg).getScalarScale());
                break;
            case DECIMAL64:
                typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL64)
                        .getDecimalTypeBuilder()
                        .setPrecision(((ScalarType) arg).getScalarPrecision())
                        .setScale(((ScalarType) arg).getScalarScale());
                break;
            case LARGEINT:
                typeBuilder.setId(Types.PGenericType.TypeId.INT128);
                break;
            default:
                throw new AnalysisException("type " + arg.getPrimitiveType().toString() + " is not supported");
        }
        return typeBuilder.build();
    }

    private TFunctionBinaryType getFunctionBinaryType(String type) {
        TFunctionBinaryType binaryType = null;
        try {
            binaryType = TFunctionBinaryType.valueOf(type);
        } catch (IllegalArgumentException e) {
            // ignore enum Exception
        }
        return binaryType;
    }

    private void analyzeAliasFunction() throws AnalysisException {
        function = AliasFunction.createFunction(functionName, argsDef.getArgTypes(),
                Type.VARCHAR, argsDef.isVariadic(), parameters, originFunction);
        ((AliasFunction) function).analyze();
    }

    @Override
    public String toSql() {
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("CREATE ");
        if (isAggregate) {
            stringBuilder.append("AGGREGATE ");
        } else if (isAlias) {
            stringBuilder.append("ALIAS ");
        }

        stringBuilder.append("FUNCTION ");
        stringBuilder.append(functionName.toString());
        stringBuilder.append(argsDef.toSql());
        if (isAlias) {
            stringBuilder.append(" WITH PARAMETER (")
                    .append(parameters.toString())
                    .append(") AS ")
                    .append(originFunction.toSql());
        } else {
            stringBuilder.append(" RETURNS ");
            stringBuilder.append(returnType.toString());
        }
        if (properties.size() > 0) {
            stringBuilder.append(" PROPERTIES (");
            int i = 0;
            for (Map.Entry<String, String> entry : properties.entrySet()) {
                if (i != 0) {
                    stringBuilder.append(", ");
                }
                stringBuilder.append('"').append(entry.getKey()).append('"');
                stringBuilder.append("=");
                stringBuilder.append('"').append(entry.getValue()).append('"');
                i++;
            }
            stringBuilder.append(")");

        }
        return stringBuilder.toString();
    }

    @Override
    public RedirectStatus getRedirectStatus() {
        return RedirectStatus.FORWARD_WITH_SYNC;
    }

    @Override
    public StmtType stmtType() {
        return StmtType.CREATE;
    }
}