BaseExecutor.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.udf;
import org.apache.doris.catalog.ArrayType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.classloader.ScannerLoader;
import org.apache.doris.common.exception.InternalException;
import org.apache.doris.common.exception.UdfRuntimeException;
import org.apache.doris.common.jni.utils.JavaUdfDataType;
import org.apache.doris.common.jni.utils.JavaUdfStructType;
import org.apache.doris.common.jni.utils.UdfClassCache;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.ColumnValueConverter;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.thrift.TFunction;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.thrift.TPrimitiveType;
import com.esotericsoftware.reflectasm.MethodAccess;
import com.google.common.base.Strings;
import org.apache.log4j.Logger;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.net.MalformedURLException;
import java.net.URLClassLoader;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
public abstract class BaseExecutor {
// Object to deserialize ctor params from BE.
protected static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new TBinaryProtocol.Factory();
private static final Logger LOG = Logger.getLogger(BaseExecutor.class);
protected Object udf;
// setup by init() and cleared by close()
protected URLClassLoader classLoader;
protected UdfClassCache objCache;
protected TFunction fn;
protected boolean isStaticLoad = false;
protected VectorTable outputTable = null;
String className;
/**
* Create a UdfExecutor, using parameters from a serialized thrift object. Used
* by
* the backend.
*/
public BaseExecutor(byte[] thriftParams) throws Exception {
TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams();
TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY);
try {
deserializer.deserialize(request, thriftParams);
} catch (TException e) {
throw new InternalException(e.getMessage());
}
Type[] parameterTypes = new Type[request.fn.arg_types.size()];
for (int i = 0; i < request.fn.arg_types.size(); ++i) {
parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i));
}
fn = request.fn;
String jarFile = request.location;
Type funcRetType = Type.fromThrift(request.fn.ret_type);
if (request.fn.is_udtf_function) {
funcRetType = ArrayType.create(funcRetType, true);
}
init(request, jarFile, funcRetType, parameterTypes);
}
public String debugString() {
StringBuilder res = new StringBuilder();
for (JavaUdfDataType type : objCache.argTypes) {
res.append(type.toString());
}
res.append(" return type: ").append(objCache.retType.toString());
res.append(" methodAccess: ").append(objCache.methodAccess.toString());
res.append(" fn.toString(): ").append(fn.toString());
return res.toString();
}
protected void init(TJavaUdfExecutorCtorParams request, String jarPath,
Type funcRetType, Type... parameterTypes) throws UdfRuntimeException {
try {
isStaticLoad = request.getFn().isSetIsStaticLoad() && request.getFn().is_static_load;
long expirationTime = 360L; // default is 6 hours
if (request.getFn().isSetExpirationTime()) {
expirationTime = request.getFn().getExpirationTime();
}
objCache = getClassCache(jarPath, request.getFn().getSignature(), expirationTime,
funcRetType, parameterTypes);
Constructor<?> ctor = objCache.udfClass.getConstructor();
udf = ctor.newInstance();
} catch (MalformedURLException e) {
throw new UdfRuntimeException("Unable to load jar.", e);
} catch (SecurityException e) {
throw new UdfRuntimeException("Unable to load function.", e);
} catch (ClassNotFoundException e) {
throw new UdfRuntimeException("Unable to find class.", e);
} catch (NoSuchMethodException e) {
throw new UdfRuntimeException(
"Unable to find constructor with no arguments.", e);
} catch (IllegalArgumentException e) {
throw new UdfRuntimeException(
"Unable to call UDF constructor with no arguments.", e);
} catch (Exception e) {
throw new UdfRuntimeException("Unable to call create UDF instance.", e);
}
}
public UdfClassCache getClassCache(String jarPath, String signature, long expirationTime,
Type funcRetType, Type... parameterTypes)
throws MalformedURLException, FileNotFoundException, ClassNotFoundException, InternalException,
UdfRuntimeException {
UdfClassCache cache = null;
if (isStaticLoad) {
cache = ScannerLoader.getUdfClassLoader(signature);
}
if (cache == null) {
ClassLoader loader;
if (Strings.isNullOrEmpty(jarPath)) {
// if jarPath is empty, which means the UDF jar is located in custom_lib
// and already be loaded when BE start.
// so here we use system class loader to load UDF class.
loader = ClassLoader.getSystemClassLoader();
} else {
ClassLoader parent = getClass().getClassLoader();
classLoader = UdfUtils.getClassLoader(jarPath, parent);
loader = classLoader;
}
cache = new UdfClassCache();
cache.allMethods = new HashMap<>();
cache.udfClass = Class.forName(className, true, loader);
cache.methodAccess = MethodAccess.get(cache.udfClass);
checkAndCacheUdfClass(cache, funcRetType, parameterTypes);
if (isStaticLoad) {
ScannerLoader.cacheClassLoader(signature, cache, expirationTime);
}
}
return cache;
}
protected abstract void checkAndCacheUdfClass(UdfClassCache cache, Type funcRetType, Type... parameterTypes)
throws InternalException, UdfRuntimeException;
/**
* Close the class loader we may have created.
*/
public void close() {
if (classLoader != null) {
try {
classLoader.close();
} catch (IOException e) {
// Log and ignore.
if (LOG.isDebugEnabled()) {
LOG.debug("Error closing the URLClassloader.", e);
}
}
}
// Close the output table if it exists.
if (outputTable != null) {
outputTable.close();
}
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
classLoader = null;
objCache.methodAccess = null;
}
protected ColumnValueConverter getInputConverter(TPrimitiveType primitiveType, Class clz)
throws UdfRuntimeException {
switch (primitiveType) {
case DATE:
case DATEV2: {
if (java.util.Date.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new java.util.Date[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDate v = (LocalDate) columnData[i];
result[i] = new java.util.Date(v.getYear() - 1900, v.getMonthValue() - 1,
v.getDayOfMonth());
}
}
return result;
};
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDate v = (LocalDate) columnData[i];
result[i] = new org.joda.time.LocalDate(v.getYear(), v.getMonthValue(),
v.getDayOfMonth());
}
}
return result;
};
} else if (!LocalDate.class.equals(clz)) {
throw new UdfRuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
case DATETIME:
case DATETIMEV2: {
if (org.joda.time.DateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.DateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDateTime v = (LocalDateTime) columnData[i];
result[i] = new org.joda.time.DateTime(v.getYear(), v.getMonthValue(),
v.getDayOfMonth(), v.getHour(),
v.getMinute(), v.getSecond(), v.getNano() / 1000000);
}
}
return result;
};
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new org.joda.time.LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
LocalDateTime v = (LocalDateTime) columnData[i];
result[i] = new org.joda.time.LocalDateTime(v.getYear(), v.getMonthValue(),
v.getDayOfMonth(), v.getHour(),
v.getMinute(), v.getSecond(), v.getNano() / 1000000);
}
}
return result;
};
} else if (!LocalDateTime.class.equals(clz)) {
throw new UdfRuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
case STRUCT: {
return (Object[] columnData) -> {
Object[] result = new ArrayList[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
HashMap<String, Object> value = (HashMap<String, Object>) columnData[i];
ArrayList<Object> elements = new ArrayList<>();
for (Entry<String, Object> entry : value.entrySet()) {
elements.add(entry.getValue());
}
result[i] = elements;
}
}
return result;
};
}
default:
break;
}
return null;
}
protected ColumnValueConverter getOutputConverter(JavaUdfDataType returnType, Class clz)
throws UdfRuntimeException {
switch (returnType.getPrimitiveType()) {
case DATE:
case DATEV2: {
if (java.util.Date.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
java.util.Date v = (java.util.Date) columnData[i];
result[i] = LocalDate.of(v.getYear() + 1900, v.getMonth() + 1, v.getDate());
}
}
return result;
};
} else if (org.joda.time.LocalDate.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDate[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.LocalDate v = (org.joda.time.LocalDate) columnData[i];
result[i] = LocalDate.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth());
}
}
return result;
};
} else if (!LocalDate.class.equals(clz)) {
throw new UdfRuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
case DATETIME:
case DATETIMEV2: {
if (org.joda.time.DateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.DateTime v = (org.joda.time.DateTime) columnData[i];
result[i] = LocalDateTime.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth(),
v.getHourOfDay(),
v.getMinuteOfHour(), v.getSecondOfMinute(), v.getMillisOfSecond() * 1000000);
}
}
return result;
};
} else if (org.joda.time.LocalDateTime.class.equals(clz)) {
return (Object[] columnData) -> {
Object[] result = new LocalDateTime[columnData.length];
for (int i = 0; i < columnData.length; ++i) {
if (columnData[i] != null) {
org.joda.time.LocalDateTime v = (org.joda.time.LocalDateTime) columnData[i];
result[i] = LocalDateTime.of(v.getYear(), v.getMonthOfYear(), v.getDayOfMonth(),
v.getHourOfDay(),
v.getMinuteOfHour(), v.getSecondOfMinute(), v.getMillisOfSecond() * 1000000);
}
}
return result;
};
} else if (!LocalDateTime.class.equals(clz)) {
throw new RuntimeException("Unsupported date type: " + clz.getCanonicalName());
}
break;
}
case STRUCT: {
return (Object[] columnData) -> {
Object[] result = (HashMap<String, Object>[]) new HashMap<?, ?>[columnData.length];
ArrayList<String> names = ((JavaUdfStructType) returnType).getFieldNames();
for (int i = 0; i < columnData.length; ++i) {
HashMap<String, Object> elements = new HashMap<String, Object>();
if (columnData[i] != null) {
ArrayList<Object> v = (ArrayList<Object>) columnData[i];
for (int k = 0; k < v.size(); ++k) {
elements.put(names.get(k), v.get(k));
}
result[i] = elements;
}
}
return result;
};
}
default:
break;
}
return null;
}
// Add unified converter methods
protected Map<Integer, ColumnValueConverter> getInputConverters(int numColumns, boolean isUdaf)
throws UdfRuntimeException {
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
for (int j = 0; j < numColumns; ++j) {
// For UDAF, we need to offset by 1 since first arg is state
int argIndex = isUdaf ? j + 1 : j;
ColumnValueConverter converter = getInputConverter(objCache.argTypes[j].getPrimitiveType(),
objCache.argClass[argIndex]);
if (converter != null) {
converters.put(j, converter);
}
}
return converters;
}
protected ColumnValueConverter getOutputConverter() throws UdfRuntimeException {
return getOutputConverter(objCache.retType, objCache.retClass);
}
}