ExpressionEvaluator.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.nereids.exceptions.NotSupportedException;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeAcquire;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeExtractAndTransform;
import org.apache.doris.nereids.trees.expressions.functions.executable.NumericArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.executable.StringArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.executable.TimeRoundSeries;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;

import java.lang.reflect.Array;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.List;

/**
 * An expression evaluator that evaluates the value of an expression.
 */
public enum ExpressionEvaluator {

    INSTANCE;

    private ImmutableMultimap<String, Method> functions;

    ExpressionEvaluator() {
        registerFunctions();
    }

    /**
     * Evaluate the value of the expression.
     */
    public Expression eval(Expression expression) {
        if (!(expression.isConstant() || expression.foldable()) || expression instanceof AggregateFunction) {
            return expression;
        }

        String fnName = null;
        if (expression instanceof BinaryArithmetic) {
            BinaryArithmetic arithmetic = (BinaryArithmetic) expression;
            fnName = arithmetic.getLegacyOperator().getName();
        } else if (expression instanceof TimestampArithmetic) {
            TimestampArithmetic arithmetic = (TimestampArithmetic) expression;
            fnName = arithmetic.getFuncName();
        } else if (expression instanceof BoundFunction) {
            BoundFunction function = ((BoundFunction) expression);
            fnName = function.getName();
        }

        return invoke(expression, fnName);
    }

    private Expression invoke(Expression expression, String fnName) {
        Method method = getFunction(fnName, expression.children());
        if (method != null) {
            try {
                int varSize = method.getParameterTypes().length;
                if (varSize == 0) {
                    return (Literal) method.invoke(null, expression.children().toArray());
                }
                boolean hasVarArgs = method.getParameterTypes()[varSize - 1].isArray();
                if (hasVarArgs) {
                    int fixedArgsSize = varSize - 1;
                    int inputSize = expression.children().size();
                    Class<?>[] parameterTypes = method.getParameterTypes();
                    Class<?> parameterType = parameterTypes[varSize - 1];
                    Class<?> componentType = parameterType.getComponentType();
                    Object varArgs = Array.newInstance(componentType, inputSize - fixedArgsSize);
                    for (int i = fixedArgsSize; i < inputSize; i++) {
                        Array.set(varArgs, i - fixedArgsSize, expression.children().get(i));
                    }
                    Object[] objects = new Object[fixedArgsSize + 1];
                    for (int i = 0; i < fixedArgsSize; i++) {
                        objects[i] = expression.children().get(i);
                    }
                    objects[fixedArgsSize] = varArgs;

                    return (Literal) method.invoke(null, objects);
                }
                return (Literal) method.invoke(null, expression.children().toArray());
            } catch (InvocationTargetException e) {
                if (e.getTargetException() instanceof NotSupportedException) {
                    throw new NotSupportedException(e.getTargetException().getMessage());
                } else {
                    return expression;
                }
            } catch (IllegalAccessException | IllegalArgumentException e) {
                return expression;
            }
        }
        return expression;
    }

    private boolean canDownCastTo(Class<?> expect, Class<?> input) {
        if (DateLiteral.class.isAssignableFrom(expect)
                || DateTimeLiteral.class.isAssignableFrom(expect)) {
            return expect.equals(input);
        }
        return expect.isAssignableFrom(input);
    }

    private Method getFunction(String fnName, List<Expression> inputs) {
        Collection<Method> expectMethods = functions.get(fnName);
        for (Method expect : expectMethods) {
            boolean match = true;
            int varSize = expect.getParameterTypes().length;
            if (varSize == 0) {
                if (inputs.size() == 0) {
                    return expect;
                } else {
                    continue;
                }
            }
            boolean hasVarArgs = expect.getParameterTypes()[varSize - 1].isArray();
            if (hasVarArgs) {
                int fixedArgsSize = varSize - 1;
                int inputSize = inputs.size();
                if (inputSize <= fixedArgsSize) {
                    continue;
                }
                Class<?>[] expectVarTypes = expect.getParameterTypes();
                for (int i = 0; i < fixedArgsSize; i++) {
                    if (!canDownCastTo(expectVarTypes[i], inputs.get(i).getClass())) {
                        match = false;
                    }
                }
                Class<?> varArgsType = expectVarTypes[varSize - 1];
                Class<?> varArgType = varArgsType.getComponentType();
                for (int i = fixedArgsSize; i < inputSize; i++) {
                    if (!canDownCastTo(varArgType, inputs.get(i).getClass())) {
                        match = false;
                    }
                }
            } else {
                int inputSize = inputs.size();
                if (inputSize != varSize) {
                    continue;
                }
                Class<?>[] expectVarTypes = expect.getParameterTypes();
                for (int i = 0; i < varSize; i++) {
                    if (!canDownCastTo(expectVarTypes[i], inputs.get(i).getClass())) {
                        match = false;
                    }
                }
            }
            if (match) {
                return expect;
            }
        }
        return null;
    }

    private void registerFunctions() {
        if (functions != null) {
            return;
        }
        ImmutableMultimap.Builder<String, Method> mapBuilder = new ImmutableMultimap.Builder<>();
        List<Class<?>> classes = ImmutableList.of(
                DateTimeAcquire.class,
                DateTimeExtractAndTransform.class,
                DateLiteral.class,
                DateTimeArithmetic.class,
                NumericArithmetic.class,
                StringArithmetic.class,
                TimeRoundSeries.class
        );
        for (Class<?> cls : classes) {
            for (Method method : cls.getDeclaredMethods()) {
                ExecFunctionList annotationList = method.getAnnotation(ExecFunctionList.class);
                if (annotationList != null) {
                    for (ExecFunction f : annotationList.value()) {
                        registerFEFunction(mapBuilder, method, f);
                    }
                }
                registerFEFunction(mapBuilder, method, method.getAnnotation(ExecFunction.class));
            }
        }
        this.functions = mapBuilder.build();
    }

    private void registerFEFunction(ImmutableMultimap.Builder<String, Method> mapBuilder,
            Method method, ExecFunction annotation) {
        if (annotation != null) {
            mapBuilder.put(annotation.name(), method);
        }
    }
}