FunctionRegistry.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.catalog;
import org.apache.doris.common.Config;
import org.apache.doris.datasource.InternalCatalog;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.BuiltinFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdafBuilder;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdfBuilder;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdtfBuilder;
import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.qe.ConnectContext;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.lang.reflect.Constructor;
import java.lang.reflect.Modifier;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import javax.annotation.concurrent.ThreadSafe;
/**
* New function registry for nereids.
*
* this class is developing for more functions.
*/
@Developing
@ThreadSafe
public class FunctionRegistry {
// to record the global alias function and other udf.
private static final String GLOBAL_FUNCTION = "__GLOBAL_FUNCTION__";
private final Map<String, List<FunctionBuilder>> name2BuiltinBuilders;
private final Map<String, Map<String, List<FunctionBuilder>>> name2UdfBuilders;
public FunctionRegistry() {
name2BuiltinBuilders = new ConcurrentHashMap<>();
name2UdfBuilders = new ConcurrentHashMap<>();
registerBuiltinFunctions(name2BuiltinBuilders);
afterRegisterBuiltinFunctions(name2BuiltinBuilders);
}
public Map<String, List<FunctionBuilder>> getName2BuiltinBuilders() {
return name2BuiltinBuilders;
}
public String getGlobalFunctionDbName() {
return GLOBAL_FUNCTION;
}
public Map<String, Map<String, List<FunctionBuilder>>> getName2UdfBuilders() {
return name2UdfBuilders;
}
// this function is used to test.
// for example, you can create child class of FunctionRegistry and clear builtin functions or add more functions
// in this method
@VisibleForTesting
protected void afterRegisterBuiltinFunctions(Map<String, List<FunctionBuilder>> name2Builders) {}
public FunctionBuilder findFunctionBuilder(String name, List<?> arguments) {
return findFunctionBuilder(null, name, arguments);
}
public FunctionBuilder findFunctionBuilder(String name, Object argument) {
return findFunctionBuilder(null, name, ImmutableList.of(argument));
}
public Optional<List<FunctionBuilder>> tryGetBuiltinBuilders(String name) {
List<FunctionBuilder> builders = name2BuiltinBuilders.get(name);
return name2BuiltinBuilders.get(name) == null
? Optional.empty()
: Optional.of(ImmutableList.copyOf(builders));
}
public boolean isAggregateFunction(String dbName, String name) {
name = name.toLowerCase();
Class<?> aggClass = org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction.class;
if (StringUtils.isEmpty(dbName)) {
List<FunctionBuilder> functionBuilders = name2BuiltinBuilders.get(name);
if (functionBuilders != null) {
for (FunctionBuilder functionBuilder : functionBuilders) {
if (aggClass.isAssignableFrom(functionBuilder.functionClass())) {
return true;
}
}
}
}
List<FunctionBuilder> udfBuilders = findUdfBuilder(dbName, name);
for (FunctionBuilder udfBuilder : udfBuilders) {
if (aggClass.isAssignableFrom(udfBuilder.functionClass())) {
return true;
}
}
return false;
}
// currently we only find function by name and arity and args' types.
public FunctionBuilder findFunctionBuilder(String dbName, String name, List<?> arguments) {
List<FunctionBuilder> functionBuilders = null;
int arity = arguments.size();
String qualifiedName = StringUtils.isEmpty(dbName) ? name : dbName + "." + name;
boolean preferUdfOverBuiltin = ConnectContext.get() == null ? false
: ConnectContext.get().getSessionVariable().preferUdfOverBuiltin;
if (preferUdfOverBuiltin) {
// find udf first, then find builtin function
functionBuilders = findUdfBuilder(dbName, name);
if (CollectionUtils.isEmpty(functionBuilders) && StringUtils.isEmpty(dbName)) {
// if dbName is not empty, we should search builtin functions first
functionBuilders = findBuiltinFunctionBuilder(name, arguments);
}
} else {
// find builtin function first, then find udf
if (StringUtils.isEmpty(dbName)) {
functionBuilders = findBuiltinFunctionBuilder(name, arguments);
}
if (CollectionUtils.isEmpty(functionBuilders)) {
functionBuilders = findUdfBuilder(dbName, name);
}
}
if (functionBuilders == null || functionBuilders.isEmpty()) {
throw new AnalysisException("Can not found function '" + qualifiedName + "'");
}
// check the arity and type
List<FunctionBuilder> candidateBuilders = Lists.newArrayListWithCapacity(arguments.size());
for (FunctionBuilder functionBuilder : functionBuilders) {
if (functionBuilder.canApply(arguments)) {
candidateBuilders.add(functionBuilder);
}
}
if (candidateBuilders.isEmpty()) {
String candidateHints = getCandidateHint(name, functionBuilders);
throw new AnalysisException("Can not found function '" + qualifiedName
+ "' which has " + arity + " arity. Candidate functions are: " + candidateHints);
}
if (!Config.enable_java_udf) {
candidateBuilders = candidateBuilders.stream()
.filter(fb -> !(fb instanceof JavaUdfBuilder || fb instanceof JavaUdafBuilder
|| fb instanceof JavaUdtfBuilder))
.collect(Collectors.toList());
if (candidateBuilders.isEmpty()) {
throw new AnalysisException("java_udf has been disabled.");
}
}
if (candidateBuilders.size() > 1) {
boolean needChooseOne = true;
List<FunctionSignature> signatures = Lists.newArrayListWithCapacity(candidateBuilders.size());
for (FunctionBuilder functionBuilder : candidateBuilders) {
if (functionBuilder instanceof UdfBuilder) {
signatures.addAll(((UdfBuilder) functionBuilder).getSignatures());
} else {
needChooseOne = false;
break;
}
}
for (Object argument : arguments) {
if (!(argument instanceof Expression)) {
needChooseOne = false;
break;
}
}
if (needChooseOne) {
FunctionSignature signature = new UdfSignatureSearcher(signatures, (List) arguments).getSignature();
for (int i = 0; i < signatures.size(); i++) {
if (signatures.get(i).equals(signature)) {
return candidateBuilders.get(i);
}
}
}
String candidateHints = getCandidateHint(name, candidateBuilders);
throw new AnalysisException("Function '" + qualifiedName + "' is ambiguous: " + candidateHints);
}
return candidateBuilders.get(0);
}
private List<FunctionBuilder> findBuiltinFunctionBuilder(String name, List<?> arguments) {
List<FunctionBuilder> functionBuilders;
// search internal function only if dbName is empty
functionBuilders = name2BuiltinBuilders.get(name.toLowerCase());
if (CollectionUtils.isEmpty(functionBuilders) && AggCombinerFunctionBuilder.isAggStateCombinator(name)) {
String nestedName = AggCombinerFunctionBuilder.getNestedName(name);
String combinatorSuffix = AggCombinerFunctionBuilder.getCombinatorSuffix(name);
functionBuilders = name2BuiltinBuilders.get(nestedName.toLowerCase());
if (functionBuilders != null) {
List<FunctionBuilder> candidateBuilders = Lists.newArrayListWithCapacity(functionBuilders.size());
for (FunctionBuilder functionBuilder : functionBuilders) {
AggCombinerFunctionBuilder combinerBuilder
= new AggCombinerFunctionBuilder(combinatorSuffix, functionBuilder);
if (combinerBuilder.canApply(arguments)) {
candidateBuilders.add(combinerBuilder);
}
}
functionBuilders = candidateBuilders;
}
}
return functionBuilders;
}
/**
* public for test.
*/
public List<FunctionBuilder> findUdfBuilder(String dbName, String name) {
List<String> scopes = ImmutableList.of(GLOBAL_FUNCTION);
if (ConnectContext.get() != null) {
dbName = dbName == null ? ConnectContext.get().getDatabase() : dbName;
if (dbName == null || !Env.getCurrentEnv().getAccessManager()
.checkDbPriv(ConnectContext.get(), InternalCatalog.INTERNAL_CATALOG_NAME, dbName,
PrivPredicate.SELECT)) {
scopes = ImmutableList.of(GLOBAL_FUNCTION);
} else {
scopes = ImmutableList.of(dbName, GLOBAL_FUNCTION);
}
}
synchronized (name2UdfBuilders) {
for (String scope : scopes) {
List<FunctionBuilder> candidate = name2UdfBuilders.getOrDefault(scope, ImmutableMap.of())
.get(name.toLowerCase());
if (candidate != null && !candidate.isEmpty()) {
return candidate;
}
}
}
return ImmutableList.of();
}
private void registerBuiltinFunctions(Map<String, List<FunctionBuilder>> name2Builders) {
FunctionHelper.addFunctions(name2Builders, BuiltinScalarFunctions.INSTANCE.scalarFunctions);
FunctionHelper.addFunctions(name2Builders, BuiltinAggregateFunctions.INSTANCE.aggregateFunctions);
FunctionHelper.addFunctions(name2Builders, BuiltinTableValuedFunctions.INSTANCE.tableValuedFunctions);
FunctionHelper.addFunctions(name2Builders, BuiltinTableGeneratingFunctions.INSTANCE.tableGeneratingFunctions);
FunctionHelper.addFunctions(name2Builders, BuiltinWindowFunctions.INSTANCE.windowFunctions);
}
public String getCandidateHint(String name, List<FunctionBuilder> candidateBuilders) {
return candidateBuilders.stream()
.filter(builder -> {
if (builder instanceof BuiltinFunctionBuilder) {
Constructor<BoundFunction> builderMethod
= ((BuiltinFunctionBuilder) builder).getBuilderMethod();
if (Modifier.isAbstract(builderMethod.getModifiers())
|| !Modifier.isPublic(builderMethod.getModifiers())) {
return false;
}
for (Class<?> parameterType : builderMethod.getParameterTypes()) {
if (!Expression.class.isAssignableFrom(parameterType)
&& !(parameterType.isArray()
&& Expression.class.isAssignableFrom(parameterType.getComponentType()))) {
return false;
}
}
}
return true;
})
.map(builder -> name + builder.parameterDisplayString())
.collect(Collectors.joining(", ", "[", "]"));
}
public void addUdf(String dbName, String name, UdfBuilder builder) {
if (dbName == null) {
dbName = GLOBAL_FUNCTION;
}
synchronized (name2UdfBuilders) {
Map<String, List<FunctionBuilder>> builders = name2UdfBuilders
.computeIfAbsent(dbName, k -> Maps.newHashMap());
builders.computeIfAbsent(name, k -> Lists.newArrayList()).add(builder);
}
}
public void dropUdf(String dbName, String name, List<DataType> argTypes) {
if (dbName == null) {
dbName = GLOBAL_FUNCTION;
}
synchronized (name2UdfBuilders) {
Map<String, List<FunctionBuilder>> builders = name2UdfBuilders.getOrDefault(dbName, ImmutableMap.of());
builders.getOrDefault(name, Lists.newArrayList())
.removeIf(builder -> ((UdfBuilder) builder).getArgTypes().equals(argTypes));
// the name will be used when show functions, so remove the name when it's dropped
if (builders.getOrDefault(name, Lists.newArrayList()).isEmpty()) {
builders.remove(name);
}
}
}
/**
* use for search appropriate signature for UDFs if candidate more than one.
*/
static class UdfSignatureSearcher implements ExplicitlyCastableSignature {
private final List<FunctionSignature> signatures;
private final List<Expression> arguments;
public UdfSignatureSearcher(List<FunctionSignature> signatures, List<Expression> arguments) {
this.signatures = signatures;
this.arguments = arguments;
}
@Override
public List<FunctionSignature> getSignatures() {
return signatures;
}
@Override
public FunctionSignature getSignature() {
return searchSignature(signatures);
}
@Override
public boolean nullable() {
throw new AnalysisException("could not call nullable on UdfSignatureSearcher");
}
@Override
public List<Expression> children() {
return arguments;
}
@Override
public Expression child(int index) {
return arguments.get(index);
}
@Override
public int arity() {
return arguments.size();
}
@Override
public <T> Optional<T> getMutableState(String key) {
return Optional.empty();
}
@Override
public void setMutableState(String key, Object value) {
}
@Override
public Expression withChildren(List<Expression> children) {
throw new AnalysisException("could not call withChildren on UdfSignatureSearcher");
}
}
}