be/src/exprs/function/function_agg_state.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <fmt/format.h> |
21 | | |
22 | | #include "common/status.h" |
23 | | #include "core/arena.h" |
24 | | #include "core/block/block.h" |
25 | | #include "core/column/column.h" |
26 | | #include "core/column/column_nullable.h" |
27 | | #include "core/data_type/data_type.h" |
28 | | #include "core/data_type/data_type_agg_state.h" |
29 | | #include "core/types.h" |
30 | | #include "exprs/aggregate/aggregate_function.h" |
31 | | #include "exprs/function/function.h" |
32 | | |
33 | | namespace doris { |
34 | | |
35 | | class FunctionAggState : public IFunction { |
36 | | public: |
37 | | FunctionAggState(const DataTypes& argument_types, const DataTypePtr& return_type, |
38 | | AggregateFunctionPtr agg_function) |
39 | 750 | : _argument_types(argument_types), |
40 | 750 | _return_type(return_type), |
41 | 750 | _agg_function(agg_function) {} |
42 | | |
43 | | static FunctionBasePtr create(const DataTypes& argument_types, const DataTypePtr& return_type, |
44 | 750 | AggregateFunctionPtr agg_function) { |
45 | 750 | if (agg_function == nullptr) { |
46 | 0 | return nullptr; |
47 | 0 | } |
48 | 750 | return std::make_shared<DefaultFunction>( |
49 | 750 | std::make_shared<FunctionAggState>(argument_types, return_type, agg_function), |
50 | 750 | argument_types, return_type); |
51 | 750 | } |
52 | | |
53 | 0 | size_t get_number_of_arguments() const override { return _argument_types.size(); } |
54 | | |
55 | 1.26k | bool use_default_implementation_for_nulls() const override { return false; } |
56 | | |
57 | 0 | String get_name() const override { return fmt::format("{}_state", _agg_function->get_name()); } |
58 | | |
59 | 0 | DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { |
60 | 0 | return _return_type; |
61 | 0 | } |
62 | | |
63 | | Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, |
64 | 1.26k | uint32_t result, size_t input_rows_count) const override { |
65 | 1.26k | auto col = _agg_function->create_serialize_column(); |
66 | 1.26k | std::vector<const IColumn*> agg_columns; |
67 | 1.26k | std::vector<ColumnPtr> save_columns; |
68 | | |
69 | 3.02k | for (size_t i = 0; i < arguments.size(); i++) { |
70 | 1.76k | DataTypePtr signature = |
71 | 1.76k | assert_cast<const DataTypeAggState*>(_return_type.get())->get_sub_types()[i]; |
72 | 1.76k | ColumnPtr column = |
73 | 1.76k | block.get_by_position(arguments[i]).column->convert_to_full_column_if_const(); |
74 | 1.76k | save_columns.push_back(column); |
75 | | |
76 | 1.76k | if (!signature->is_nullable() && column->is_nullable()) { |
77 | 0 | return Status::InternalError( |
78 | 0 | "State function meet input nullable column, but signature is not nullable"); |
79 | 0 | } |
80 | 1.76k | if (!column->is_nullable() && signature->is_nullable()) { |
81 | 0 | column = make_nullable(column); |
82 | 0 | save_columns.push_back(column); |
83 | 0 | } |
84 | | |
85 | 1.76k | agg_columns.push_back(column.get()); |
86 | 1.76k | } |
87 | 1.26k | _agg_function->streaming_agg_serialize_to_column(agg_columns.data(), col, input_rows_count, |
88 | 1.26k | context->get_arena()); |
89 | 1.26k | block.replace_by_position(result, std::move(col)); |
90 | 1.26k | return Status::OK(); |
91 | 1.26k | } |
92 | | |
93 | | private: |
94 | | DataTypes _argument_types; |
95 | | DataTypePtr _return_type; |
96 | | AggregateFunctionPtr _agg_function; |
97 | | }; |
98 | | |
99 | | } // namespace doris |