AggCombinerFunctionBuilder.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import org.apache.doris.nereids.types.AggStateType;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/**
 * This class used to resolve AggState's combinators
 */
public class AggCombinerFunctionBuilder extends FunctionBuilder {
    public static final String COMBINATOR_LINKER = "_";
    public static final String STATE = "state";
    public static final String MERGE = "merge";
    public static final String UNION = "union";
    public static final String FOREACH = "foreach";

    public static final String STATE_SUFFIX = COMBINATOR_LINKER + STATE;
    public static final String MERGE_SUFFIX = COMBINATOR_LINKER + MERGE;
    public static final String UNION_SUFFIX = COMBINATOR_LINKER + UNION;
    public static final String FOREACH_SUFFIX = COMBINATOR_LINKER + FOREACH;

    private final FunctionBuilder nestedBuilder;

    private final String combinatorSuffix;

    public AggCombinerFunctionBuilder(String combinatorSuffix, FunctionBuilder nestedBuilder) {
        this.combinatorSuffix = Objects.requireNonNull(combinatorSuffix, "combinatorSuffix can not be null");
        this.nestedBuilder = Objects.requireNonNull(nestedBuilder, "nestedBuilder can not be null");
    }

    @Override
    public Class<? extends BoundFunction> functionClass() {
        return nestedBuilder.functionClass();
    }

    @Override
    public boolean canApply(List<?> arguments) {
        if (combinatorSuffix.equalsIgnoreCase(STATE) || combinatorSuffix.equalsIgnoreCase(FOREACH)) {
            return nestedBuilder.canApply(arguments);
        } else {
            if (arguments.size() != 1) {
                return false;
            }
            Expression argument = (Expression) arguments.get(0);
            if (!argument.getDataType().isAggStateType()) {
                return false;
            }

            return nestedBuilder.canApply(((AggStateType) argument.getDataType()).getMockedExpressions());
        }
    }

    private AggregateFunction buildState(String nestedName, List<? extends Object> arguments) {
        return (AggregateFunction) nestedBuilder.build(nestedName, arguments).first;
    }

    private AggregateFunction buildForEach(String nestedName, List<? extends Object> arguments) {
        List<Expression> forEachargs = arguments.stream().map(expr -> {
            if (!(expr instanceof SlotReference)) {
                throw new IllegalStateException(
                        "Can not build foreach nested function: '" + nestedName);
            }
            DataType arrayType = (((Expression) expr).getDataType());
            if (!(arrayType instanceof ArrayType)) {
                throw new IllegalStateException(
                        "foreach must be input array type: '" + nestedName);
            }
            DataType itemType = ((ArrayType) arrayType).getItemType();
            return new SlotReference("mocked", itemType, (((ArrayType) arrayType).containsNull()));
        }).collect(Collectors.toList());
        return (AggregateFunction) nestedBuilder.build(nestedName, forEachargs).first;
    }

    private AggregateFunction buildMergeOrUnion(String nestedName, List<? extends Object> arguments) {
        if (arguments.size() != 1 || !(arguments.get(0) instanceof Expression)
                || !((Expression) arguments.get(0)).getDataType().isAggStateType()) {
            String argString = arguments.stream().map(arg -> {
                if (arg == null) {
                    return "null";
                } else if (arg instanceof Expression) {
                    return ((Expression) arg).toSql();
                } else {
                    return arg.toString();
                }
            }).collect(Collectors.joining(", ", "(", ")"));
            throw new IllegalStateException("Can not build AggState nested function: '" + nestedName + "', expression: "
                    + nestedName + argString);
        }

        Expression arg = (Expression) arguments.get(0);
        AggStateType type = (AggStateType) arg.getDataType();

        return (AggregateFunction) nestedBuilder.build(nestedName, type.getMockedExpressions()).first;
    }

    @Override
    public Pair<BoundFunction, AggregateFunction> build(String name, List<?> arguments) {
        String nestedName = getNestedName(name);
        if (combinatorSuffix.equalsIgnoreCase(STATE)) {
            AggregateFunction nestedFunction = buildState(nestedName, arguments);
            // distinct will be passed as 1st boolean true arg. remove it
            if (!arguments.isEmpty() && arguments.get(0) instanceof Boolean && (Boolean) arguments.get(0)) {
                arguments = arguments.subList(1, arguments.size());
            }
            return Pair.of(new StateCombinator((List<Expression>) arguments, nestedFunction), nestedFunction);
        } else if (combinatorSuffix.equalsIgnoreCase(MERGE)) {
            AggregateFunction nestedFunction = buildMergeOrUnion(nestedName, arguments);
            return Pair.of(new MergeCombinator((List<Expression>) arguments, nestedFunction), nestedFunction);
        } else if (combinatorSuffix.equalsIgnoreCase(UNION)) {
            AggregateFunction nestedFunction = buildMergeOrUnion(nestedName, arguments);
            return Pair.of(new UnionCombinator((List<Expression>) arguments, nestedFunction), nestedFunction);
        } else if (combinatorSuffix.equalsIgnoreCase(FOREACH)) {
            AggregateFunction nestedFunction = buildForEach(nestedName, arguments);
            return Pair.of(new ForEachCombinator((List<Expression>) arguments, nestedFunction), nestedFunction);
        }
        return null;
    }

    @Override
    public String parameterDisplayString() {
        return nestedBuilder.parameterDisplayString();
    }

    public static boolean isAggStateCombinator(String name) {
        return name.toLowerCase().endsWith(STATE_SUFFIX) || name.toLowerCase().endsWith(MERGE_SUFFIX)
                || name.toLowerCase().endsWith(UNION_SUFFIX) || name.toLowerCase().endsWith(FOREACH_SUFFIX);
    }

    public static String getNestedName(String name) {
        return name.substring(0, name.length() - getCombinatorSuffix(name).length() - 1);
    }

    public static String getCombinatorSuffix(String name) {
        if (!name.contains(COMBINATOR_LINKER)) {
            throw new IllegalStateException(name + " call getCombinatorSuffix must contains " + COMBINATOR_LINKER);
        }
        return name.substring(name.lastIndexOf(COMBINATOR_LINKER) + 1);
    }
}