UdafExecutor.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.udf;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.common.exception.InternalException;
import org.apache.doris.common.exception.UdfRuntimeException;
import org.apache.doris.common.jni.utils.JavaUdfDataType;
import org.apache.doris.common.jni.utils.OffHeap;
import org.apache.doris.common.jni.utils.UdfClassCache;
import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.lang.reflect.Array;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/**
* udaf executor.
*/
public class UdafExecutor extends BaseExecutor {
private static final Logger LOG = Logger.getLogger(UdafExecutor.class);
private static final String UDAF_CREATE_FUNCTION = "create";
private static final String UDAF_DESTROY_FUNCTION = "destroy";
private static final String UDAF_ADD_FUNCTION = "add";
private static final String UDAF_RESET_FUNCTION = "reset";
private static final String UDAF_SERIALIZE_FUNCTION = "serialize";
private static final String UDAF_DESERIALIZE_FUNCTION = "deserialize";
private static final String UDAF_MERGE_FUNCTION = "merge";
private static final String UDAF_RESULT_FUNCTION = "getValue";
private HashMap<Long, Object> stateObjMap;
/**
* Constructor to create an object.
*/
public UdafExecutor(byte[] thriftParams) throws Exception {
super(thriftParams);
}
/**
* close and invoke destroy function.
*/
@Override
public void close() {
if (!isStaticLoad) {
super.close();
}
stateObjMap = null;
}
public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long placeAddr, int offset,
Map<String, String> inputParams) throws UdfRuntimeException {
try {
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
Object[][] inputs = inputTable.getMaterializedData(rowStart, rowEnd,
getInputConverters(inputTable.getNumColumns(), true));
if (isSinglePlace) {
addBatchSingle(rowStart, rowEnd, placeAddr, inputs);
} else {
addBatchPlaces(rowStart, rowEnd, placeAddr, offset, inputs);
}
} catch (Exception e) {
LOG.warn("evaluate exception: " + debugString(), e);
throw new UdfRuntimeException("UDAF failed to evaluate", e);
}
}
public void addBatchSingle(int rowStart, int rowEnd, long placeAddr, Object[][] inputs) throws UdfRuntimeException {
Long curPlace = placeAddr;
Object[] inputArgs = new Object[objCache.argTypes.length + 1];
Object state = stateObjMap.get(curPlace);
if (state != null) {
inputArgs[0] = state;
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
inputArgs[0] = newState;
}
int numColumns = inputs.length;
int numRows = rowEnd - rowStart;
for (int i = 0; i < numRows; ++i) {
for (int j = 0; j < numColumns; ++j) {
inputArgs[j + 1] = inputs[j][i];
}
objCache.methodAccess.invoke(udf, objCache.methodIndex, inputArgs);
}
}
public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int offset, Object[][] inputs)
throws UdfRuntimeException {
int numColumns = inputs.length;
int numRows = rowEnd - rowStart;
Object[] placeState = new Object[numRows];
for (int row = rowStart; row < rowEnd; ++row) {
Long curPlace = OffHeap.UNSAFE.getLong(null, placeAddr + (8L * row)) + offset;
Object state = stateObjMap.get(curPlace);
if (state != null) {
placeState[row - rowStart] = state;
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
placeState[row - rowStart] = newState;
}
}
// spilt into two for loop
Object[] inputArgs = new Object[objCache.argTypes.length + 1];
for (int row = 0; row < numRows; ++row) {
inputArgs[0] = placeState[row];
for (int j = 0; j < numColumns; ++j) {
inputArgs[j + 1] = inputs[j][row];
}
objCache.methodAccess.invoke(udf, objCache.methodIndex, inputArgs);
}
}
/**
* invoke user create function to get obj.
*/
public Object createAggState() throws UdfRuntimeException {
try {
return objCache.allMethods.get(UDAF_CREATE_FUNCTION).invoke(udf, null);
} catch (Exception e) {
LOG.warn("invoke createAggState function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to create: ", e);
}
}
/**
* invoke destroy before colse. Here we destroy all data at once
*/
public void destroy() throws UdfRuntimeException {
try {
for (Object obj : stateObjMap.values()) {
objCache.allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udf, obj);
}
stateObjMap.clear();
} catch (Exception e) {
LOG.warn("invoke destroy function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to destroy: ", e);
}
}
/**
* invoke serialize function and return byte[] to backends.
*/
public byte[] serialize(long place) throws UdfRuntimeException {
try {
Object[] args = new Object[2];
ByteArrayOutputStream baos = new ByteArrayOutputStream();
args[0] = stateObjMap.get(place);
args[1] = new DataOutputStream(baos);
objCache.allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udf, args);
return baos.toByteArray();
} catch (Exception e) {
LOG.info("evaluate exception debug: " + debugString());
LOG.warn("invoke serialize function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to serialize: ", e);
}
}
/**
* invoke reset function and reset the state to init.
*/
public void reset(long place) throws UdfRuntimeException {
try {
Object[] args = new Object[1];
args[0] = stateObjMap.get(place);
if (args[0] == null) {
return;
}
objCache.allMethods.get(UDAF_RESET_FUNCTION).invoke(udf, args);
} catch (Exception e) {
LOG.info("evaluate exception debug: " + debugString());
LOG.warn("invoke reset function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to reset: ", e);
}
}
/**
* invoke merge function and it's have done deserialze.
* here call deserialize first, and call merge.
*/
public void merge(long place, byte[] data) throws UdfRuntimeException {
try {
Object[] args = new Object[2];
ByteArrayInputStream bins = new ByteArrayInputStream(data);
args[0] = createAggState();
args[1] = new DataInputStream(bins);
objCache.allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udf, args);
args[1] = args[0];
Long curPlace = place;
Object state = stateObjMap.get(curPlace);
if (state != null) {
args[0] = state;
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
args[0] = newState;
}
objCache.allMethods.get(UDAF_MERGE_FUNCTION).invoke(udf, args);
} catch (Exception e) {
LOG.info("evaluate exception debug: " + debugString());
LOG.warn("invoke merge function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to merge: ", e);
}
}
/**
* invoke getValue to return finally result.
*/
public long getValue(long place, Map<String, String> outputParams) throws UdfRuntimeException {
try {
if (outputTable != null) {
outputTable.close();
}
outputTable = VectorTable.createWritableTable(outputParams, 1);
if (stateObjMap.get(place) == null) {
stateObjMap.put(place, createAggState());
}
Object value = objCache.allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get(place));
// If the return type is primitive, we can't cast the array of primitive type as array of Object,
// so we have to new its wrapped Object.
Object[] result = outputTable.getColumnType(0).isPrimitive()
? outputTable.getColumn(0).newObjectContainerArray(1)
: (Object[]) Array.newInstance(objCache.retClass, 1);
result[0] = value;
boolean isNullable = Boolean.parseBoolean(outputParams.getOrDefault("is_nullable", "true"));
outputTable.appendData(0, result, getOutputConverter(), isNullable);
return outputTable.getMetaAddress();
} catch (Exception e) {
LOG.info("evaluate exception debug: " + debugString());
LOG.warn("invoke getValue function meet some error: ", e);
throw new UdfRuntimeException("UDAF failed to result", e);
}
}
protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type funcRetType,
Type... parameterTypes) throws UdfRuntimeException {
className = fn.aggregate_fn.symbol;
super.init(request, jarPath, funcRetType, parameterTypes);
stateObjMap = new HashMap<>();
}
@Override
protected void checkAndCacheUdfClass(UdfClassCache cache, Type funcRetType,
Type... parameterTypes) throws InternalException, UdfRuntimeException {
ArrayList<String> signatures = Lists.newArrayList();
Class<?> c = cache.udfClass;
Method[] methods = c.getMethods();
int idx = 0;
for (idx = 0; idx < methods.length; ++idx) {
signatures.add(methods[idx].toGenericString());
switch (methods[idx].getName()) {
case UDAF_DESTROY_FUNCTION:
case UDAF_CREATE_FUNCTION:
case UDAF_MERGE_FUNCTION:
case UDAF_SERIALIZE_FUNCTION:
case UDAF_RESET_FUNCTION:
case UDAF_DESERIALIZE_FUNCTION: {
cache.allMethods.put(methods[idx].getName(), methods[idx]);
break;
}
case UDAF_RESULT_FUNCTION: {
cache.allMethods.put(methods[idx].getName(), methods[idx]);
Pair<Boolean, JavaUdfDataType> returnType = UdfUtils.setReturnType(funcRetType,
methods[idx].getReturnType());
if (!returnType.first) {
if (LOG.isDebugEnabled()) {
LOG.debug("result function set return parameterTypes has error");
}
} else {
cache.retType = returnType.second;
cache.retClass = methods[idx].getReturnType();
}
break;
}
case UDAF_ADD_FUNCTION: {
cache.allMethods.put(methods[idx].getName(), methods[idx]);
cache.methodIndex = cache.methodAccess.getIndex(UDAF_ADD_FUNCTION);
cache.argClass = methods[idx].getParameterTypes();
if (cache.argClass.length != parameterTypes.length + 1) {
if (LOG.isDebugEnabled()) {
LOG.debug("add function parameterTypes length not equal " + cache.argClass.length + " "
+ parameterTypes.length + " " + methods[idx].getName());
}
}
if (!(parameterTypes.length == 0)) {
Pair<Boolean, JavaUdfDataType[]> inputType = UdfUtils.setArgTypes(parameterTypes,
cache.argClass, true);
if (!inputType.first) {
if (LOG.isDebugEnabled()) {
LOG.debug("add function set arg parameterTypes has error");
}
} else {
cache.argTypes = inputType.second;
}
} else {
// Special case where the UDF doesn't take any input args
cache.argTypes = new JavaUdfDataType[0];
}
break;
}
default:
break;
}
}
if (idx == methods.length) {
return;
}
StringBuilder sb = new StringBuilder();
sb.append("Unable to find evaluate function with the correct signature: ")
.append(className)
.append(".evaluate(")
.append(Joiner.on(", ").join(parameterTypes)).append(")\n").append("UDF contains: \n ")
.append(Joiner.on("\n ").join(signatures));
throw new UdfRuntimeException(sb.toString());
}
}