UdfExecutor.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.udf;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.common.exception.InternalException;
import org.apache.doris.common.exception.UdfRuntimeException;
import org.apache.doris.common.jni.utils.JavaUdfDataType;
import org.apache.doris.common.jni.utils.UdfClassCache;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Map;
public class UdfExecutor extends BaseExecutor {
public static final Logger LOG = Logger.getLogger(UdfExecutor.class);
private static final String UDF_PREPARE_FUNCTION_NAME = "prepare";
private static final String UDF_FUNCTION_NAME = "evaluate";
/**
* Create a UdfExecutor, using parameters from a serialized thrift object. Used by
* the backend.
*/
public UdfExecutor(byte[] thriftParams) throws Exception {
super(thriftParams);
}
/**
* Close the class loader we may have created.
*/
@Override
public void close() {
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
if (!isStaticLoad) {
super.close();
} else if (outputTable != null) {
outputTable.close();
}
}
public long evaluate(Map<String, String> inputParams, Map<String, String> outputParams) throws UdfRuntimeException {
try {
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
int numRows = inputTable.getNumRows();
int numColumns = inputTable.getNumColumns();
if (outputTable != null) {
outputTable.close();
}
outputTable = VectorTable.createWritableTable(outputParams, numRows);
// If the return type is primitive, we can't cast the array of primitive type as array of Object,
// so we have to new its wrapped Object.
Object[] result = outputTable.getColumnType(0).isPrimitive()
? outputTable.getColumn(0).newObjectContainerArray(numRows)
: (Object[]) Array.newInstance(objCache.retClass, numRows);
Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns, false));
Object[] parameters = new Object[numColumns];
for (int i = 0; i < numRows; ++i) {
for (int j = 0; j < numColumns; ++j) {
int row = inputTable.isConstColumn(j) ? 0 : i;
parameters[j] = inputs[j][row];
}
result[i] = objCache.methodAccess.invoke(udf, objCache.methodIndex, parameters);
}
boolean isNullable = Boolean.parseBoolean(outputParams.getOrDefault("is_nullable", "true"));
outputTable.appendData(0, result, getOutputConverter(), isNullable);
return outputTable.getMetaAddress();
} catch (Exception e) {
LOG.warn("evaluate exception: " + debugString(), e);
throw new UdfRuntimeException("UDF failed to evaluate", e);
}
}
private Method findPrepareMethod(Method[] methods) {
for (Method method : methods) {
if (method.getName().equals(UDF_PREPARE_FUNCTION_NAME) && method.getReturnType().equals(void.class)
&& method.getParameterCount() == 0) {
return method;
}
}
return null; // Method not found
}
// Preallocate the input objects that will be passed to the underlying UDF.
// These objects are allocated once and reused across calls to evaluate()
@Override
protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type funcRetType,
Type... parameterTypes) throws UdfRuntimeException {
className = fn.scalar_fn.symbol;
super.init(request, jarPath, funcRetType, parameterTypes);
Method prepareMethod = objCache.allMethods.get(UDF_PREPARE_FUNCTION_NAME);
if (prepareMethod != null) {
try {
prepareMethod.invoke(udf);
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
throw new UdfRuntimeException("Unable to call UDF prepare function.", e);
}
}
}
@Override
protected void checkAndCacheUdfClass(UdfClassCache cache, Type funcRetType, Type... parameterTypes)
throws InternalException, UdfRuntimeException {
ArrayList<String> signatures = Lists.newArrayList();
Class<?> c = cache.udfClass;
Method[] methods = c.getMethods();
Method prepareMethod = findPrepareMethod(methods);
if (prepareMethod != null) {
cache.allMethods.put(UDF_PREPARE_FUNCTION_NAME, prepareMethod);
}
for (Method m : methods) {
// By convention, the udf must contain the function "evaluate"
if (!m.getName().equals(UDF_FUNCTION_NAME)) {
continue;
}
signatures.add(m.toGenericString());
cache.argClass = m.getParameterTypes();
// Try to match the arguments
if (cache.argClass.length != parameterTypes.length) {
continue;
}
cache.allMethods.put(UDF_FUNCTION_NAME, m);
cache.methodIndex = cache.methodAccess.getIndex(UDF_FUNCTION_NAME, cache.argClass);
Pair<Boolean, JavaUdfDataType> returnType;
cache.retClass = m.getReturnType();
if (cache.argClass.length == 0 && parameterTypes.length == 0) {
// Special case where the UDF doesn't take any input args
returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType());
if (!returnType.first) {
continue;
} else {
cache.retType = returnType.second;
}
cache.argTypes = new JavaUdfDataType[0];
return;
}
returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType());
if (!returnType.first) {
continue;
} else {
cache.retType = returnType.second;
}
Pair<Boolean, JavaUdfDataType[]> inputType = UdfUtils.setArgTypes(parameterTypes, cache.argClass, false);
if (!inputType.first) {
continue;
} else {
cache.argTypes = inputType.second;
}
return;
}
StringBuilder sb = new StringBuilder();
sb.append("Unable to find evaluate function with the correct signature: ")
.append(className)
.append(".evaluate(")
.append(Joiner.on(", ").join(parameterTypes))
.append(")\n")
.append("UDF contains: \n ")
.append(Joiner.on("\n ").join(signatures));
throw new UdfRuntimeException(sb.toString());
}
}