be/src/exprs/function/uniform.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include <fmt/format.h> |
19 | | #include <glog/logging.h> |
20 | | |
21 | | #include <boost/iterator/iterator_facade.hpp> |
22 | | #include <cstdint> |
23 | | #include <cstdlib> |
24 | | #include <memory> |
25 | | #include <random> |
26 | | #include <utility> |
27 | | |
28 | | #include "common/status.h" |
29 | | #include "core/assert_cast.h" |
30 | | #include "core/block/block.h" |
31 | | #include "core/block/column_numbers.h" |
32 | | #include "core/column/column.h" |
33 | | #include "core/column/column_vector.h" |
34 | | #include "core/data_type/data_type_number.h" // IWYU pragma: keep |
35 | | #include "core/data_type/primitive_type.h" |
36 | | #include "core/types.h" |
37 | | #include "exprs/aggregate/aggregate_function.h" |
38 | | #include "exprs/function/function.h" |
39 | | #include "exprs/function/simple_function_factory.h" |
40 | | #include "exprs/function_context.h" |
41 | | |
42 | | namespace doris { |
43 | | #include "common/compile_check_begin.h" |
44 | | |
45 | | // Integer uniform implementation |
46 | | struct UniformIntImpl { |
47 | 12 | static DataTypes get_variadic_argument_types() { |
48 | 12 | return {std::make_shared<DataTypeInt64>(), std::make_shared<DataTypeInt64>(), |
49 | 12 | std::make_shared<DataTypeInt64>()}; |
50 | 12 | } |
51 | | |
52 | 4 | static DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& arguments) { |
53 | 4 | return std::make_shared<DataTypeInt64>(); |
54 | 4 | } |
55 | | |
56 | | static Status execute_impl(FunctionContext* context, Block& block, |
57 | | const ColumnNumbers& arguments, uint32_t result, |
58 | 7 | size_t input_rows_count) { |
59 | 7 | auto res_column = ColumnInt64::create(input_rows_count); |
60 | 7 | auto& res_data = static_cast<ColumnInt64&>(*res_column).get_data(); |
61 | | |
62 | | // Get min and max values (constants) |
63 | 7 | const auto& left = |
64 | 7 | assert_cast<const ColumnConst&>(*block.get_by_position(arguments[0]).column) |
65 | 7 | .get_data_column(); |
66 | 7 | const auto& right = |
67 | 7 | assert_cast<const ColumnConst&>(*block.get_by_position(arguments[1]).column) |
68 | 7 | .get_data_column(); |
69 | 7 | Int64 min = assert_cast<const ColumnInt64&>(left).get_element(0); |
70 | 7 | Int64 max = assert_cast<const ColumnInt64&>(right).get_element(0); |
71 | | |
72 | 7 | if (min >= max) { |
73 | 1 | return Status::InvalidArgument( |
74 | 1 | "uniform's min should be less than max, but got [{}, {})", min, max); |
75 | 1 | } |
76 | | |
77 | | // Get gen column (seed values) |
78 | 6 | const auto& gen_column = block.get_by_position(arguments[2]).column; |
79 | | |
80 | 39 | for (int i = 0; i < input_rows_count; i++) { |
81 | | // Use gen value as seed for each row |
82 | 33 | auto seed = (*gen_column)[i].get<TYPE_BIGINT>(); |
83 | 33 | std::mt19937_64 generator(seed); |
84 | 33 | std::uniform_int_distribution<int64_t> distribution(min, max); |
85 | 33 | res_data[i] = distribution(generator); |
86 | 33 | } |
87 | | |
88 | 6 | block.replace_by_position(result, std::move(res_column)); |
89 | 6 | return Status::OK(); |
90 | 7 | } |
91 | | }; |
92 | | |
93 | | // Double uniform implementation |
94 | | struct UniformDoubleImpl { |
95 | 8 | static DataTypes get_variadic_argument_types() { |
96 | 8 | return {std::make_shared<DataTypeFloat64>(), std::make_shared<DataTypeFloat64>(), |
97 | 8 | std::make_shared<DataTypeInt64>()}; |
98 | 8 | } |
99 | | |
100 | 0 | static DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& arguments) { |
101 | 0 | return std::make_shared<DataTypeFloat64>(); |
102 | 0 | } |
103 | | |
104 | | static Status execute_impl(FunctionContext* context, Block& block, |
105 | | const ColumnNumbers& arguments, uint32_t result, |
106 | 0 | size_t input_rows_count) { |
107 | 0 | auto res_column = ColumnFloat64::create(input_rows_count); |
108 | 0 | auto& res_data = static_cast<ColumnFloat64&>(*res_column).get_data(); |
109 | | |
110 | | // Get min and max values (constants) |
111 | 0 | const auto& left = |
112 | 0 | assert_cast<const ColumnConst&>(*block.get_by_position(arguments[0]).column) |
113 | 0 | .get_data_column(); |
114 | 0 | const auto& right = |
115 | 0 | assert_cast<const ColumnConst&>(*block.get_by_position(arguments[1]).column) |
116 | 0 | .get_data_column(); |
117 | 0 | double min = assert_cast<const ColumnFloat64&>(left).get_element(0); |
118 | 0 | double max = assert_cast<const ColumnFloat64&>(right).get_element(0); |
119 | |
|
120 | 0 | if (min >= max) { |
121 | 0 | return Status::InvalidArgument( |
122 | 0 | "uniform's min should be less than max, but got [{}, {})", min, max); |
123 | 0 | } |
124 | | |
125 | | // Get gen column (seed values) |
126 | 0 | const auto& gen_column = block.get_by_position(arguments[2]).column; |
127 | |
|
128 | 0 | for (int i = 0; i < input_rows_count; i++) { |
129 | | // Use gen value as seed for each row |
130 | 0 | auto seed = (*gen_column)[i].get<TYPE_BIGINT>(); |
131 | 0 | std::mt19937_64 generator(seed); |
132 | 0 | std::uniform_real_distribution<double> distribution(min, max); |
133 | 0 | res_data[i] = distribution(generator); |
134 | 0 | } |
135 | |
|
136 | 0 | block.replace_by_position(result, std::move(res_column)); |
137 | 0 | return Status::OK(); |
138 | 0 | } |
139 | | }; |
140 | | |
141 | | template <typename Impl> |
142 | | class FunctionUniform : public IFunction { |
143 | | public: |
144 | | static constexpr auto name = "uniform"; |
145 | | |
146 | 22 | static FunctionPtr create() { return std::make_shared<FunctionUniform<Impl>>(); }_ZN5doris15FunctionUniformINS_14UniformIntImplEE6createEv Line | Count | Source | 146 | 13 | static FunctionPtr create() { return std::make_shared<FunctionUniform<Impl>>(); } |
_ZN5doris15FunctionUniformINS_17UniformDoubleImplEE6createEv Line | Count | Source | 146 | 9 | static FunctionPtr create() { return std::make_shared<FunctionUniform<Impl>>(); } |
|
147 | 2 | String get_name() const override { return name; }_ZNK5doris15FunctionUniformINS_14UniformIntImplEE8get_nameB5cxx11Ev Line | Count | Source | 147 | 1 | String get_name() const override { return name; } |
_ZNK5doris15FunctionUniformINS_17UniformDoubleImplEE8get_nameB5cxx11Ev Line | Count | Source | 147 | 1 | String get_name() const override { return name; } |
|
148 | | |
149 | 4 | size_t get_number_of_arguments() const override { |
150 | 4 | return get_variadic_argument_types_impl().size(); |
151 | 4 | } _ZNK5doris15FunctionUniformINS_14UniformIntImplEE23get_number_of_argumentsEv Line | Count | Source | 149 | 4 | size_t get_number_of_arguments() const override { | 150 | 4 | return get_variadic_argument_types_impl().size(); | 151 | 4 | } |
Unexecuted instantiation: _ZNK5doris15FunctionUniformINS_17UniformDoubleImplEE23get_number_of_argumentsEv |
152 | | |
153 | 4 | DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& arguments) const override { |
154 | 4 | return Impl::get_return_type_impl(arguments); |
155 | 4 | } _ZNK5doris15FunctionUniformINS_14UniformIntImplEE20get_return_type_implERKSt6vectorINS_21ColumnWithTypeAndNameESaIS4_EE Line | Count | Source | 153 | 4 | DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& arguments) const override { | 154 | 4 | return Impl::get_return_type_impl(arguments); | 155 | 4 | } |
Unexecuted instantiation: _ZNK5doris15FunctionUniformINS_17UniformDoubleImplEE20get_return_type_implERKSt6vectorINS_21ColumnWithTypeAndNameESaIS4_EE |
156 | | |
157 | 20 | DataTypes get_variadic_argument_types_impl() const override { |
158 | 20 | return Impl::get_variadic_argument_types(); |
159 | 20 | } _ZNK5doris15FunctionUniformINS_14UniformIntImplEE32get_variadic_argument_types_implEv Line | Count | Source | 157 | 12 | DataTypes get_variadic_argument_types_impl() const override { | 158 | 12 | return Impl::get_variadic_argument_types(); | 159 | 12 | } |
_ZNK5doris15FunctionUniformINS_17UniformDoubleImplEE32get_variadic_argument_types_implEv Line | Count | Source | 157 | 8 | DataTypes get_variadic_argument_types_impl() const override { | 158 | 8 | return Impl::get_variadic_argument_types(); | 159 | 8 | } |
|
160 | | |
161 | 16 | Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { |
162 | | // init_function_context do set_constant_cols for FRAGMENT_LOCAL scope |
163 | 16 | if (scope == FunctionContext::FRAGMENT_LOCAL) { |
164 | 4 | if (!context->is_col_constant(0)) { |
165 | 0 | return Status::InvalidArgument( |
166 | 0 | "The first parameter (min) of uniform function must be literal"); |
167 | 0 | } |
168 | 4 | if (!context->is_col_constant(1)) { |
169 | 0 | return Status::InvalidArgument( |
170 | 0 | "The second parameter (max) of uniform function must be literal"); |
171 | 0 | } |
172 | 4 | } |
173 | 16 | return Status::OK(); |
174 | 16 | } _ZN5doris15FunctionUniformINS_14UniformIntImplEE4openEPNS_15FunctionContextENS3_18FunctionStateScopeE Line | Count | Source | 161 | 16 | Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { | 162 | | // init_function_context do set_constant_cols for FRAGMENT_LOCAL scope | 163 | 16 | if (scope == FunctionContext::FRAGMENT_LOCAL) { | 164 | 4 | if (!context->is_col_constant(0)) { | 165 | 0 | return Status::InvalidArgument( | 166 | 0 | "The first parameter (min) of uniform function must be literal"); | 167 | 0 | } | 168 | 4 | if (!context->is_col_constant(1)) { | 169 | 0 | return Status::InvalidArgument( | 170 | 0 | "The second parameter (max) of uniform function must be literal"); | 171 | 0 | } | 172 | 4 | } | 173 | 16 | return Status::OK(); | 174 | 16 | } |
Unexecuted instantiation: _ZN5doris15FunctionUniformINS_17UniformDoubleImplEE4openEPNS_15FunctionContextENS3_18FunctionStateScopeE |
175 | | |
176 | | Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, |
177 | 7 | uint32_t result, size_t input_rows_count) const override { |
178 | 7 | return Impl::execute_impl(context, block, arguments, result, input_rows_count); |
179 | 7 | } _ZNK5doris15FunctionUniformINS_14UniformIntImplEE12execute_implEPNS_15FunctionContextERNS_5BlockERKSt6vectorIjSaIjEEjm Line | Count | Source | 177 | 7 | uint32_t result, size_t input_rows_count) const override { | 178 | 7 | return Impl::execute_impl(context, block, arguments, result, input_rows_count); | 179 | 7 | } |
Unexecuted instantiation: _ZNK5doris15FunctionUniformINS_17UniformDoubleImplEE12execute_implEPNS_15FunctionContextERNS_5BlockERKSt6vectorIjSaIjEEjm |
180 | | }; |
181 | | |
182 | 8 | void register_function_uniform(SimpleFunctionFactory& factory) { |
183 | 8 | factory.register_function<FunctionUniform<UniformIntImpl>>(); |
184 | 8 | factory.register_function<FunctionUniform<UniformDoubleImpl>>(); |
185 | 8 | } |
186 | | #include "common/compile_check_end.h" |
187 | | } // namespace doris |