UdfUtils.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.common.jni.utils;

import org.apache.doris.catalog.ArrayType;
import org.apache.doris.catalog.MapType;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.StructField;
import org.apache.doris.catalog.StructType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.common.exception.InternalException;
import org.apache.doris.thrift.TPrimitiveType;

import org.apache.log4j.Logger;
import sun.misc.Unsafe;

import java.io.File;
import java.io.FileNotFoundException;
import java.lang.reflect.Field;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Set;

public class UdfUtils {
    public static final Logger LOG = Logger.getLogger(UdfUtils.class);
    public static final Unsafe UNSAFE;
    private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L;
    public static final long BYTE_ARRAY_OFFSET;
    public static final long INT_ARRAY_OFFSET;

    static {
        UNSAFE = (Unsafe) AccessController.doPrivileged(
                (PrivilegedAction<Object>) () -> {
                    try {
                        Field f = Unsafe.class.getDeclaredField("theUnsafe");
                        f.setAccessible(true);
                        return f.get(null);
                    } catch (NoSuchFieldException | IllegalAccessException e) {
                        throw new Error();
                    }
                });
        BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class);
        INT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(int[].class);
    }

    public static void copyMemory(
            Object src, long srcOffset, Object dst, long dstOffset, long length) {
        // Check if dstOffset is before or after srcOffset to determine if we should copy
        // forward or backwards. This is necessary in case src and dst overlap.
        if (dstOffset < srcOffset) {
            while (length > 0) {
                long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
                UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
                length -= size;
                srcOffset += size;
                dstOffset += size;
            }
        } else {
            srcOffset += length;
            dstOffset += length;
            while (length > 0) {
                long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
                srcOffset -= size;
                dstOffset -= size;
                UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
                length -= size;
            }

        }
    }

    public static URLClassLoader getClassLoader(String jarPath, ClassLoader parent)
            throws MalformedURLException, FileNotFoundException {
        File file = new File(jarPath);
        if (!file.exists()) {
            throw new FileNotFoundException("Can not find local file: " + jarPath);
        }
        URL url = file.toURI().toURL();
        return URLClassLoader.newInstance(new URL[] {url}, parent);
    }

    /**
     * Sets the return type of a Java UDF. Returns true if the return type is compatible
     * with the return type from the function definition. Throws an UdfRuntimeException
     * if the return type is not supported.
     */
    public static Pair<Boolean, JavaUdfDataType> setReturnType(Type retType, Class<?> udfReturnType)
            throws InternalException {
        if (!JavaUdfDataType.isSupported(retType)) {
            throw new InternalException("Unsupported return type: " + retType.toSql());
        }
        Set<JavaUdfDataType> javaTypes = JavaUdfDataType.getCandidateTypes(udfReturnType);
        // Check if the evaluate method return type is compatible with the return type from
        // the function definition. This happens when both of them map to the same primitive
        // type.
        Object[] res = javaTypes.stream().filter(t -> {
            TPrimitiveType t1 = t.getPrimitiveType();
            TPrimitiveType ret = retType.getPrimitiveType().toThrift();
            return (t1 == ret) || (t1 == TPrimitiveType.STRING && ret == TPrimitiveType.VARCHAR);
        }).toArray();

        JavaUdfDataType result = new JavaUdfDataType(
                res.length == 0 ? javaTypes.iterator().next() : (JavaUdfDataType) res[0]);
        if (retType.isDecimalV3() || retType.isDatetimeV2()) {
            result.setPrecision(retType.getPrecision());
            result.setScale(((ScalarType) retType).getScalarScale());
        } else if (retType.isArrayType()) {
            ArrayType arrType = (ArrayType) retType;
            result = new JavaUdfArrayType(arrType.getItemType());
            if (arrType.getItemType().isDatetimeV2() || arrType.getItemType().isDecimalV3()) {
                result.setPrecision(arrType.getItemType().getPrecision());
                result.setScale(((ScalarType) arrType.getItemType()).getScalarScale());
            }
        } else if (retType.isMapType()) {
            MapType mapType = (MapType) retType;
            Type keyType = mapType.getKeyType();
            Type valuType = mapType.getValueType();
            result = new JavaUdfMapType(keyType, valuType);
            JavaUdfMapType udfMapType = ((JavaUdfMapType) result);
            if (keyType.isDatetimeV2() || keyType.isDecimalV3()) {
                udfMapType.setKeyScale(((ScalarType) keyType).getScalarScale());
            }
            if (valuType.isDatetimeV2() || valuType.isDecimalV3()) {
                udfMapType.setValueScale(((ScalarType) valuType).getScalarScale());
            }
        } else if (retType.isStructType()) {
            StructType structType = (StructType) retType;
            result = new JavaUdfStructType(structType.getFields());
        }
        return Pair.of(res.length != 0, result);
    }

    /**
     * Sets the argument types of a Java UDF or UDAF. Returns true if the argument types specified
     * in the UDF are compatible with the argument types of the evaluate() function loaded
     * from the associated JAR file.
     *
     * @throws InternalException
     */
    public static Pair<Boolean, JavaUdfDataType[]> setArgTypes(Type[] parameterTypes, Class<?>[] udfArgTypes,
            boolean isUdaf) throws InternalException {
        JavaUdfDataType[] inputArgTypes = new JavaUdfDataType[parameterTypes.length];
        int firstPos = isUdaf ? 1 : 0;
        for (int i = 0; i < parameterTypes.length; ++i) {
            Set<JavaUdfDataType> javaTypes = JavaUdfDataType.getCandidateTypes(udfArgTypes[i + firstPos]);
            int finalI = i;
            Object[] res = javaTypes.stream().filter(t -> {
                TPrimitiveType t1 = t.getPrimitiveType();
                TPrimitiveType param = parameterTypes[finalI].getPrimitiveType().toThrift();
                return (t1 == param) || (t1 == TPrimitiveType.STRING && param == TPrimitiveType.VARCHAR);
            }).toArray();
            inputArgTypes[i] = new JavaUdfDataType(
                    res.length == 0 ? javaTypes.iterator().next() : (JavaUdfDataType) res[0]);
            if (parameterTypes[finalI].isDecimalV3() || parameterTypes[finalI].isDatetimeV2()) {
                inputArgTypes[i].setPrecision(parameterTypes[finalI].getPrecision());
                inputArgTypes[i].setScale(((ScalarType) parameterTypes[finalI]).getScalarScale());
            } else if (parameterTypes[finalI].isArrayType()) {
                ArrayType arrType = (ArrayType) parameterTypes[finalI];
                inputArgTypes[i] = new JavaUdfArrayType(arrType.getItemType());
                if (arrType.getItemType().isDatetimeV2() || arrType.getItemType().isDecimalV3()) {
                    inputArgTypes[i].setPrecision(arrType.getItemType().getPrecision());
                    inputArgTypes[i].setScale(((ScalarType) arrType.getItemType()).getScalarScale());
                }
            } else if (parameterTypes[finalI].isMapType()) {
                MapType mapType = (MapType) parameterTypes[finalI];
                Type keyType = mapType.getKeyType();
                Type valuType = mapType.getValueType();
                inputArgTypes[i] = new JavaUdfMapType(keyType, valuType);
                JavaUdfMapType udfMapType = ((JavaUdfMapType) inputArgTypes[i]);
                if (keyType.isDatetimeV2() || keyType.isDecimalV3()) {
                    udfMapType.setKeyScale(((ScalarType) keyType).getScalarScale());
                }
                if (valuType.isDatetimeV2() || valuType.isDecimalV3()) {
                    udfMapType.setValueScale(((ScalarType) valuType).getScalarScale());
                }
            } else if (parameterTypes[finalI].isStructType()) {
                StructType structType = (StructType) parameterTypes[finalI];
                ArrayList<StructField> fields = structType.getFields();
                inputArgTypes[i] = new JavaUdfStructType(fields);
            } else if (parameterTypes[finalI].isIP()) {
                if (parameterTypes[finalI].isIPv4()) {
                    inputArgTypes[i] = new JavaUdfDataType(JavaUdfDataType.IPV4);
                } else {
                    inputArgTypes[i] = new JavaUdfDataType(JavaUdfDataType.IPV6);
                }
            }
            if (res.length == 0) {
                return Pair.of(false, inputArgTypes);
            }
        }
        return Pair.of(true, inputArgTypes);
    }

    public static long convertToDateTime(int year, int month, int day, int hour, int minute, int second,
            boolean isDate) {
        long time = 0;
        time = time + year;
        time = (time << 8) + month;
        time = (time << 8) + day;
        time = (time << 8) + hour;
        time = (time << 8) + minute;
        time = (time << 12) + second;
        int type = isDate ? 2 : 3;
        time = (time << 3) + type;
        //this bit is int neg = 0;
        time = (time << 1);
        return time;
    }

    public static long convertToDateTimeV2(
            int year, int month, int day, int hour, int minute, int second, int microsecond) {
        return (long) microsecond | (long) second << 20 | (long) minute << 26 | (long) hour << 32
                | (long) day << 37 | (long) month << 42 | (long) year << 46;
    }

    public static int convertToDateV2(int year, int month, int day) {
        return (int) (day | (long) month << 5 | (long) year << 9);
    }

    /**
     * Change the order of the bytes, Because JVM is Big-Endian , x86 is Little-Endian.
     */
    public static byte[] convertByteOrder(byte[] bytes) {
        int length = bytes.length;
        for (int i = 0; i < length / 2; ++i) {
            byte temp = bytes[i];
            bytes[i] = bytes[length - 1 - i];
            bytes[length - 1 - i] = temp;
        }
        return bytes;
    }
}