ScalarFunction.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.catalog;
import org.apache.doris.analysis.CreateFunctionStmt;
import org.apache.doris.analysis.FunctionName;
import org.apache.doris.common.io.Text;
import org.apache.doris.common.util.URI;
import org.apache.doris.thrift.TDictFunction;
import org.apache.doris.thrift.TFunction;
import org.apache.doris.thrift.TFunctionBinaryType;
import org.apache.doris.thrift.TScalarFunction;
import com.google.common.collect.Maps;
import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.DataInput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
// import org.apache.doris.thrift.TSymbolType;
/**
* Internal representation of a scalar function.
*/
public class ScalarFunction extends Function {
private static final Logger LOG = LogManager.getLogger(ScalarFunction.class);
// The name inside the binary at location_ that contains this particular
// function. e.g. org.example.MyUdf.class.
@SerializedName("sn")
private String symbolName;
@SerializedName("pfs")
private String prepareFnSymbol;
@SerializedName("cfs")
private String closeFnSymbol;
TDictFunction dictFunction = null;
// Only used for serialization
protected ScalarFunction() {
}
public ScalarFunction(FunctionName fnName, List<Type> argTypes, Type retType, boolean hasVarArgs,
boolean userVisible) {
this(fnName, argTypes, retType, hasVarArgs, TFunctionBinaryType.BUILTIN, userVisible, true);
}
public ScalarFunction(FunctionName fnName, List<Type> argTypes, Type retType, boolean hasVarArgs,
boolean userVisible, boolean isVec) {
this(fnName, argTypes, retType, hasVarArgs, TFunctionBinaryType.BUILTIN, userVisible, isVec);
}
public ScalarFunction(FunctionName fnName, List<Type> argTypes, Type retType, boolean hasVarArgs,
TFunctionBinaryType binaryType, boolean userVisible, boolean isVec) {
super(0, fnName, argTypes, retType, hasVarArgs, binaryType, userVisible, isVec,
NullableMode.DEPEND_ON_ARGUMENT);
}
/**
* nerieds custom scalar function
*/
public ScalarFunction(FunctionName fnName, List<Type> argTypes, Type retType, boolean hasVarArgs, String symbolName,
TFunctionBinaryType binaryType, boolean userVisible, boolean isVec, NullableMode nullableMode) {
super(0, fnName, argTypes, retType, hasVarArgs, binaryType, userVisible, isVec, nullableMode);
this.symbolName = symbolName;
}
public ScalarFunction(FunctionName fnName, List<Type> argTypes,
Type retType, URI location, String symbolName, String initFnSymbol,
String closeFnSymbol) {
super(fnName, argTypes, retType, false);
setLocation(location);
setSymbolName(symbolName);
setPrepareFnSymbol(initFnSymbol);
setCloseFnSymbol(closeFnSymbol);
}
/**
* Creates a builtin scalar function. This is a helper that wraps a few steps
* into one call.
*/
public static ScalarFunction createBuiltin(String name, Type retType,
ArrayList<Type> argTypes, boolean hasVarArgs,
String symbol, String prepareFnSymbol, String closeFnSymbol,
boolean userVisible) {
return createBuiltin(name, retType, NullableMode.DEPEND_ON_ARGUMENT, argTypes, hasVarArgs,
symbol, prepareFnSymbol, closeFnSymbol, userVisible);
}
public static ScalarFunction createBuiltin(
String name, Type retType, NullableMode nullableMode,
ArrayList<Type> argTypes, boolean hasVarArgs,
String symbol, String prepareFnSymbol, String closeFnSymbol, boolean userVisible) {
ScalarFunction fn = new ScalarFunction(
new FunctionName(name), argTypes, retType, hasVarArgs, userVisible);
fn.symbolName = symbol;
fn.prepareFnSymbol = prepareFnSymbol;
fn.closeFnSymbol = closeFnSymbol;
fn.nullableMode = nullableMode;
return fn;
}
public static ScalarFunction createBuiltinOperator(
String name, ArrayList<Type> argTypes, Type retType) {
return createBuiltinOperator(name, argTypes, retType, NullableMode.DEPEND_ON_ARGUMENT);
}
/**
* Creates a builtin scalar operator function. This is a helper that wraps a few
* steps
* into one call.
* TODO: this needs to be kept in sync with what generates the be operator
* implementations. (gen_functions.py). Is there a better way to coordinate
* this.
*/
public static ScalarFunction createBuiltinOperator(
String name, ArrayList<Type> argTypes, Type retType, NullableMode nullableMode) {
return createBuiltinOperator(name, null, argTypes, retType, nullableMode);
}
public static ScalarFunction createBuiltinOperator(
String name, String symbol, ArrayList<Type> argTypes, Type retType) {
return createBuiltinOperator(name, symbol, argTypes, retType, NullableMode.DEPEND_ON_ARGUMENT);
}
public static ScalarFunction createBuiltinOperator(
String name, String symbol, ArrayList<Type> argTypes, Type retType, NullableMode nullableMode) {
return createBuiltin(name, symbol, argTypes, false, retType, false, nullableMode);
}
public static ScalarFunction createBuiltin(
String name, String symbol, ArrayList<Type> argTypes,
boolean hasVarArgs, Type retType, boolean userVisible, NullableMode nullableMode) {
ScalarFunction fn = new ScalarFunction(
new FunctionName(name), argTypes, retType, hasVarArgs, userVisible);
fn.symbolName = symbol;
fn.nullableMode = nullableMode;
return fn;
}
public static ScalarFunction createUdf(
TFunctionBinaryType binaryType,
FunctionName name, Type[] args,
Type returnType, boolean isVariadic,
URI location, String symbol, String prepareFnSymbol, String closeFnSymbol) {
ScalarFunction fn = new ScalarFunction(name, Arrays.asList(args), returnType, isVariadic, binaryType,
true, false);
fn.symbolName = symbol;
fn.prepareFnSymbol = prepareFnSymbol;
fn.closeFnSymbol = closeFnSymbol;
fn.setLocation(location);
return fn;
}
public ScalarFunction(ScalarFunction other) {
super(other);
if (other == null) {
return;
}
symbolName = other.symbolName;
prepareFnSymbol = other.prepareFnSymbol;
closeFnSymbol = other.closeFnSymbol;
}
@Override
public Function clone() {
return new ScalarFunction(this);
}
public void setSymbolName(String s) {
symbolName = s;
}
public void setPrepareFnSymbol(String s) {
prepareFnSymbol = s;
}
public void setCloseFnSymbol(String s) {
closeFnSymbol = s;
}
public String getSymbolName() {
return symbolName;
}
public String getPrepareFnSymbol() {
return prepareFnSymbol;
}
public String getCloseFnSymbol() {
return closeFnSymbol;
}
public void setDictFunction(TDictFunction dictFunction) {
this.dictFunction = dictFunction;
}
@Override
public String toSql(boolean ifNotExists) {
StringBuilder sb = new StringBuilder("CREATE ");
if (this.isGlobal) {
sb.append("GLOBAL ");
}
sb.append("FUNCTION ");
if (ifNotExists) {
sb.append("IF NOT EXISTS ");
}
sb.append(signatureString())
.append(" RETURNS " + getReturnType())
.append(" PROPERTIES (");
sb.append("\n \"SYMBOL\"=").append("\"" + getSymbolName() + "\"");
if (getPrepareFnSymbol() != null) {
sb.append(",\n \"PREPARE_FN\"=").append("\"" + getPrepareFnSymbol() + "\"");
}
if (getCloseFnSymbol() != null) {
sb.append(",\n \"CLOSE_FN\"=").append("\"" + getCloseFnSymbol() + "\"");
}
if (getBinaryType() == TFunctionBinaryType.JAVA_UDF) {
sb.append(",\n \"FILE\"=")
.append("\"" + (getLocation() == null ? "" : getLocation().toString()) + "\"");
boolean isReturnNull = this.getNullableMode() == NullableMode.ALWAYS_NULLABLE;
sb.append(",\n \"ALWAYS_NULLABLE\"=").append("\"" + isReturnNull + "\"");
} else {
sb.append(",\n \"OBJECT_FILE\"=")
.append("\"" + (getLocation() == null ? "" : getLocation().toString()) + "\"");
}
sb.append(",\n \"TYPE\"=").append("\"" + this.getBinaryType() + "\"");
sb.append("\n);");
return sb.toString();
}
@Override
public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) {
TFunction fn = super.toThrift(realReturnType, realArgTypes, realArgTypeNullables);
fn.setScalarFn(new TScalarFunction());
if (getBinaryType() == TFunctionBinaryType.JAVA_UDF || getBinaryType() == TFunctionBinaryType.RPC) {
fn.getScalarFn().setSymbol(symbolName);
} else {
fn.getScalarFn().setSymbol("");
}
if (dictFunction != null) {
fn.setDictFunction(dictFunction);
}
return fn;
}
public void readFields(DataInput input) throws IOException {
super.readFields(input);
symbolName = Text.readString(input);
if (input.readBoolean()) {
prepareFnSymbol = Text.readString(input);
}
if (input.readBoolean()) {
closeFnSymbol = Text.readString(input);
}
}
@Override
public String getProperties() {
Map<String, String> properties = Maps.newHashMap();
properties.put(CreateFunctionStmt.OBJECT_FILE_KEY, getLocation() == null ? "" : getLocation().toString());
properties.put(CreateFunctionStmt.MD5_CHECKSUM, checksum);
properties.put(CreateFunctionStmt.SYMBOL_KEY, symbolName);
return new Gson().toJson(properties);
}
}