be/src/exprs/vectorized_fn_call.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "exprs/vectorized_fn_call.h" |
19 | | |
20 | | #include <fmt/compile.h> |
21 | | #include <fmt/format.h> |
22 | | #include <fmt/ranges.h> // IWYU pragma: keep |
23 | | #include <gen_cpp/Opcodes_types.h> |
24 | | #include <gen_cpp/Types_types.h> |
25 | | |
26 | | #include <memory> |
27 | | #include <ostream> |
28 | | |
29 | | #include "common/config.h" |
30 | | #include "common/exception.h" |
31 | | #include "common/logging.h" |
32 | | #include "common/status.h" |
33 | | #include "common/utils.h" |
34 | | #include "core/assert_cast.h" |
35 | | #include "core/block/block.h" |
36 | | #include "core/block/column_numbers.h" |
37 | | #include "core/column/column.h" |
38 | | #include "core/column/column_array.h" |
39 | | #include "core/column/column_nullable.h" |
40 | | #include "core/column/column_vector.h" |
41 | | #include "core/data_type/data_type.h" |
42 | | #include "core/data_type/data_type_agg_state.h" |
43 | | #include "core/types.h" |
44 | | #include "exec/common/util.hpp" |
45 | | #include "exec/pipeline/pipeline_task.h" |
46 | | #include "exprs/function/array/function_array_distance.h" |
47 | | #include "exprs/function/function_agg_state.h" |
48 | | #include "exprs/function/function_fake.h" |
49 | | #include "exprs/function/function_java_udf.h" |
50 | | #include "exprs/function/function_python_udf.h" |
51 | | #include "exprs/function/function_rpc.h" |
52 | | #include "exprs/function/simple_function_factory.h" |
53 | | #include "exprs/function_context.h" |
54 | | #include "exprs/varray_literal.h" |
55 | | #include "exprs/vcast_expr.h" |
56 | | #include "exprs/vexpr_context.h" |
57 | | #include "exprs/virtual_slot_ref.h" |
58 | | #include "exprs/vliteral.h" |
59 | | #include "runtime/runtime_state.h" |
60 | | #include "storage/index/ann/ann_index.h" |
61 | | #include "storage/index/ann/ann_index_iterator.h" |
62 | | #include "storage/index/ann/ann_search_params.h" |
63 | | #include "storage/index/index_reader.h" |
64 | | #include "storage/segment/column_reader.h" |
65 | | #include "storage/segment/virtual_column_iterator.h" |
66 | | |
67 | | namespace doris { |
68 | | class RowDescriptor; |
69 | | class RuntimeState; |
70 | | class TExprNode; |
71 | | } // namespace doris |
72 | | |
73 | | namespace doris { |
74 | | #include "common/compile_check_begin.h" |
75 | | |
76 | | const std::string AGG_STATE_SUFFIX = "_state"; |
77 | | |
78 | | // Now left child is a function call, we need to check if it is a distance function |
79 | | const static std::set<std::string> DISTANCE_FUNCS = {L2DistanceApproximate::name, |
80 | | InnerProductApproximate::name}; |
81 | | const static std::set<TExprOpcode::type> OPS_FOR_ANN_RANGE_SEARCH = { |
82 | | TExprOpcode::GE, TExprOpcode::LE, TExprOpcode::LE, TExprOpcode::GT, TExprOpcode::LT}; |
83 | | |
84 | 527k | VectorizedFnCall::VectorizedFnCall(const TExprNode& node) : VExpr(node) {} |
85 | | |
86 | | Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, |
87 | 527k | VExprContext* context) { |
88 | 527k | RETURN_IF_ERROR_OR_PREPARED(VExpr::prepare(state, desc, context)); |
89 | 527k | ColumnsWithTypeAndName argument_template; |
90 | 527k | argument_template.reserve(_children.size()); |
91 | 1.02M | for (auto child : _children) { |
92 | 1.02M | if (child->is_literal()) { |
93 | | // For some functions, he needs some literal columns to derive the return type. |
94 | 473k | auto literal_node = std::dynamic_pointer_cast<VLiteral>(child); |
95 | 473k | argument_template.emplace_back(literal_node->get_column_ptr(), child->data_type(), |
96 | 473k | child->expr_name()); |
97 | 548k | } else { |
98 | 548k | argument_template.emplace_back(nullptr, child->data_type(), child->expr_name()); |
99 | 548k | } |
100 | 1.02M | } |
101 | | |
102 | 527k | _expr_name = fmt::format("VectorizedFnCall[{}](arguments={},return={})", _fn.name.function_name, |
103 | 527k | get_child_names(), _data_type->get_name()); |
104 | 527k | if (_fn.binary_type == TFunctionBinaryType::RPC) { |
105 | 0 | _function = FunctionRPC::create(_fn, argument_template, _data_type); |
106 | 527k | } else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { |
107 | 520 | if (config::enable_java_support) { |
108 | 520 | if (_fn.is_udtf_function) { |
109 | | // fake function. it's no use and can't execute. |
110 | 56 | auto builder = |
111 | 56 | std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create()); |
112 | 56 | _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>()); |
113 | 464 | } else { |
114 | 464 | _function = JavaFunctionCall::create(_fn, argument_template, _data_type); |
115 | 464 | } |
116 | 520 | } else { |
117 | 0 | return Status::InternalError( |
118 | 0 | "Java UDF is not enabled, you can change be config enable_java_support to true " |
119 | 0 | "and restart be."); |
120 | 0 | } |
121 | 526k | } else if (_fn.binary_type == TFunctionBinaryType::PYTHON_UDF) { |
122 | 615 | if (config::enable_python_udf_support) { |
123 | 615 | if (_fn.is_udtf_function) { |
124 | | // fake function. it's no use and can't execute. |
125 | | // Python UDTF is executed via PythonUDTFFunction in table function path |
126 | 271 | auto builder = |
127 | 271 | std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create()); |
128 | 271 | _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>()); |
129 | 344 | } else { |
130 | 344 | _function = PythonFunctionCall::create(_fn, argument_template, _data_type); |
131 | 344 | LOG(INFO) << fmt::format( |
132 | 344 | "create python function call: {}, runtime version: {}, function code: {}", |
133 | 344 | _fn.name.function_name, _fn.runtime_version, _fn.function_code); |
134 | 344 | } |
135 | 615 | } else { |
136 | 0 | return Status::InternalError( |
137 | 0 | "Python UDF is not enabled, you can change be config enable_python_udf_support " |
138 | 0 | "to true and restart be."); |
139 | 0 | } |
140 | 526k | } else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) { |
141 | 750 | DataTypes argument_types; |
142 | 1.10k | for (auto column : argument_template) { |
143 | 1.10k | argument_types.emplace_back(column.type); |
144 | 1.10k | } |
145 | | |
146 | 750 | if (match_suffix(_fn.name.function_name, AGG_STATE_SUFFIX)) { |
147 | 750 | if (_data_type->is_nullable()) { |
148 | 0 | return Status::InternalError("State function's return type must be not nullable"); |
149 | 0 | } |
150 | 750 | if (_data_type->get_primitive_type() != PrimitiveType::TYPE_AGG_STATE) { |
151 | 0 | return Status::InternalError( |
152 | 0 | "State function's return type must be agg_state but get {}", |
153 | 0 | _data_type->get_family_name()); |
154 | 0 | } |
155 | 750 | _function = FunctionAggState::create( |
156 | 750 | argument_types, _data_type, |
157 | 750 | assert_cast<const DataTypeAggState*>(_data_type.get())->get_nested_function()); |
158 | 750 | } else { |
159 | 0 | return Status::InternalError("Function {} is not endwith '_state'", _fn.signature); |
160 | 0 | } |
161 | 525k | } else { |
162 | | // get the function. won't prepare function. |
163 | 525k | _function = SimpleFunctionFactory::instance().get_function( |
164 | 525k | _fn.name.function_name, argument_template, _data_type, |
165 | 525k | {.new_version_unix_timestamp = state->query_options().new_version_unix_timestamp}, |
166 | 525k | state->be_exec_version()); |
167 | 525k | } |
168 | 527k | if (_function == nullptr) { |
169 | 2 | return Status::InternalError("Could not find function {}, arg {} return {} ", |
170 | 2 | _fn.name.function_name, get_child_type_names(), |
171 | 2 | _data_type->get_name()); |
172 | 2 | } |
173 | 527k | VExpr::register_function_context(state, context); |
174 | 527k | _function_name = _fn.name.function_name; |
175 | 527k | _prepare_finished = true; |
176 | | |
177 | 527k | FunctionContext* fn_ctx = context->fn_context(_fn_context_index); |
178 | 527k | if (fn().__isset.dict_function) { |
179 | 95 | fn_ctx->set_dict_function(fn().dict_function); |
180 | 95 | } |
181 | 527k | return Status::OK(); |
182 | 527k | } |
183 | | |
184 | | Status VectorizedFnCall::open(RuntimeState* state, VExprContext* context, |
185 | 1.41M | FunctionContext::FunctionStateScope scope) { |
186 | 1.41M | DCHECK(_prepare_finished); |
187 | 2.68M | for (auto& i : _children) { |
188 | 2.68M | RETURN_IF_ERROR(i->open(state, context, scope)); |
189 | 2.68M | } |
190 | 1.41M | RETURN_IF_ERROR(VExpr::init_function_context(state, context, scope, _function)); |
191 | 1.41M | if (scope == FunctionContext::FRAGMENT_LOCAL) { |
192 | 525k | RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr)); |
193 | 525k | } |
194 | 1.41M | _open_finished = true; |
195 | 1.41M | return Status::OK(); |
196 | 1.41M | } |
197 | | |
198 | 1.41M | void VectorizedFnCall::close(VExprContext* context, FunctionContext::FunctionStateScope scope) { |
199 | 1.41M | VExpr::close_function_context(context, scope, _function); |
200 | 1.41M | VExpr::close(context, scope); |
201 | 1.41M | } |
202 | | |
203 | 7.86k | Status VectorizedFnCall::evaluate_inverted_index(VExprContext* context, uint32_t segment_num_rows) { |
204 | 7.86k | if (get_num_children() < 1) { |
205 | | // score() and similar 0-children virtual column functions don't need |
206 | | // inverted index evaluation; return OK to skip gracefully. |
207 | 35 | return Status::OK(); |
208 | 35 | } |
209 | 7.83k | return _evaluate_inverted_index(context, _function, segment_num_rows); |
210 | 7.86k | } |
211 | | |
212 | | Status VectorizedFnCall::_do_execute(VExprContext* context, const Block* block, Selector* selector, |
213 | | size_t count, ColumnPtr& result_column, |
214 | 299k | ColumnPtr* arg_column) const { |
215 | 299k | if (is_const_and_have_executed()) { // const have executed in open function |
216 | 15.7k | result_column = get_result_from_const(count); |
217 | 15.7k | return Status::OK(); |
218 | 15.7k | } |
219 | 284k | if (fast_execute(context, selector, count, result_column)) { |
220 | 810 | return Status::OK(); |
221 | 810 | } |
222 | 283k | DBUG_EXECUTE_IF("VectorizedFnCall.must_in_slow_path", { |
223 | 283k | if (get_child(0)->is_slot_ref()) { |
224 | 283k | auto debug_col_name = DebugPoints::instance()->get_debug_param_or_default<std::string>( |
225 | 283k | "VectorizedFnCall.must_in_slow_path", "column_name", ""); |
226 | | |
227 | 283k | std::vector<std::string> column_names; |
228 | 283k | boost::split(column_names, debug_col_name, boost::algorithm::is_any_of(",")); |
229 | | |
230 | 283k | auto* column_slot_ref = assert_cast<VSlotRef*>(get_child(0).get()); |
231 | 283k | std::string column_name = column_slot_ref->expr_name(); |
232 | 283k | auto it = std::find(column_names.begin(), column_names.end(), column_name); |
233 | 283k | if (it == column_names.end()) { |
234 | 283k | return Status::Error<ErrorCode::INTERNAL_ERROR>( |
235 | 283k | "column {} should in slow path while VectorizedFnCall::execute.", |
236 | 283k | column_name); |
237 | 283k | } |
238 | 283k | } |
239 | 283k | }) |
240 | 283k | DCHECK(_open_finished || block == nullptr) << debug_string(); |
241 | | |
242 | 283k | Block temp_block; |
243 | 283k | ColumnNumbers args(_children.size()); |
244 | | |
245 | 821k | for (int i = 0; i < _children.size(); ++i) { |
246 | 537k | ColumnPtr tmp_arg_column; |
247 | 537k | RETURN_IF_ERROR( |
248 | 537k | _children[i]->execute_column(context, block, selector, count, tmp_arg_column)); |
249 | 537k | auto arg_type = _children[i]->execute_type(block); |
250 | 537k | temp_block.insert({tmp_arg_column, arg_type, _children[i]->expr_name()}); |
251 | 537k | args[i] = i; |
252 | | |
253 | 537k | if (arg_column != nullptr && i == 0) { |
254 | 1.93k | *arg_column = tmp_arg_column; |
255 | 1.93k | } |
256 | 537k | } |
257 | | |
258 | 283k | uint32_t num_columns_without_result = temp_block.columns(); |
259 | | // prepare a column to save result |
260 | 283k | temp_block.insert({nullptr, _data_type, _expr_name}); |
261 | | |
262 | 283k | DBUG_EXECUTE_IF("VectorizedFnCall.wait_before_execute", { |
263 | 283k | auto possibility = DebugPoints::instance()->get_debug_param_or_default<double>( |
264 | 283k | "VectorizedFnCall.wait_before_execute", "possibility", 0); |
265 | 283k | if (random_bool_slow(possibility)) { |
266 | 283k | LOG(WARNING) << "VectorizedFnCall::execute sleep 30s"; |
267 | 283k | sleep(30); |
268 | 283k | } |
269 | 283k | }); |
270 | | |
271 | 283k | RETURN_IF_ERROR(_function->execute(context->fn_context(_fn_context_index), temp_block, args, |
272 | 283k | num_columns_without_result, count)); |
273 | 282k | result_column = temp_block.get_by_position(num_columns_without_result).column; |
274 | 282k | DCHECK_EQ(result_column->size(), count); |
275 | 282k | RETURN_IF_ERROR(result_column->column_self_check()); |
276 | 282k | return Status::OK(); |
277 | 282k | } |
278 | | |
279 | 0 | size_t VectorizedFnCall::estimate_memory(const size_t rows) { |
280 | 0 | if (is_const_and_have_executed()) { // const have execute in open function |
281 | 0 | return 0; |
282 | 0 | } |
283 | | |
284 | 0 | size_t estimate_size = 0; |
285 | 0 | for (auto& child : _children) { |
286 | 0 | estimate_size += child->estimate_memory(rows); |
287 | 0 | } |
288 | |
|
289 | 0 | if (_data_type->have_maximum_size_of_value()) { |
290 | 0 | estimate_size += rows * _data_type->get_size_of_value_in_memory(); |
291 | 0 | } else { |
292 | 0 | estimate_size += rows * 512; /// FIXME: estimated value... |
293 | 0 | } |
294 | 0 | return estimate_size; |
295 | 0 | } |
296 | | |
297 | | Status VectorizedFnCall::execute_runtime_filter(VExprContext* context, const Block* block, |
298 | | const uint8_t* __restrict filter, size_t count, |
299 | | ColumnPtr& result_column, |
300 | 1.93k | ColumnPtr* arg_column) const { |
301 | 1.93k | return _do_execute(context, block, nullptr, count, result_column, arg_column); |
302 | 1.93k | } |
303 | | |
304 | | Status VectorizedFnCall::execute_column(VExprContext* context, const Block* block, |
305 | | Selector* selector, size_t count, |
306 | 297k | ColumnPtr& result_column) const { |
307 | 297k | return _do_execute(context, block, selector, count, result_column, nullptr); |
308 | 297k | } |
309 | | |
310 | 269k | const std::string& VectorizedFnCall::expr_name() const { |
311 | 269k | return _expr_name; |
312 | 269k | } |
313 | | |
314 | 98 | std::string VectorizedFnCall::function_name() const { |
315 | 98 | return _function_name; |
316 | 98 | } |
317 | | |
318 | 149 | std::string VectorizedFnCall::debug_string() const { |
319 | 149 | std::stringstream out; |
320 | 149 | out << "VectorizedFn["; |
321 | 149 | out << _expr_name; |
322 | 149 | out << "]{"; |
323 | 149 | bool first = true; |
324 | 266 | for (const auto& input_expr : children()) { |
325 | 266 | if (first) { |
326 | 147 | first = false; |
327 | 147 | } else { |
328 | 119 | out << ","; |
329 | 119 | } |
330 | 266 | out << "\n" << input_expr->debug_string(); |
331 | 266 | } |
332 | 149 | out << "}"; |
333 | 149 | return out.str(); |
334 | 149 | } |
335 | | |
336 | 0 | std::string VectorizedFnCall::debug_string(const std::vector<VectorizedFnCall*>& agg_fns) { |
337 | 0 | std::stringstream out; |
338 | 0 | out << "["; |
339 | 0 | for (int i = 0; i < agg_fns.size(); ++i) { |
340 | 0 | out << (i == 0 ? "" : " ") << agg_fns[i]->debug_string(); |
341 | 0 | } |
342 | 0 | out << "]"; |
343 | 0 | return out.str(); |
344 | 0 | } |
345 | | |
346 | 935 | bool VectorizedFnCall::can_push_down_to_index() const { |
347 | 935 | return _function->can_push_down_to_index(); |
348 | 935 | } |
349 | | |
350 | 0 | bool VectorizedFnCall::equals(const VExpr& other) { |
351 | 0 | const auto* other_ptr = dynamic_cast<const VectorizedFnCall*>(&other); |
352 | 0 | if (!other_ptr) { |
353 | 0 | return false; |
354 | 0 | } |
355 | 0 | if (this->_function_name != other_ptr->_function_name) { |
356 | 0 | return false; |
357 | 0 | } |
358 | 0 | if (get_num_children() != other_ptr->get_num_children()) { |
359 | 0 | return false; |
360 | 0 | } |
361 | 0 | for (uint16_t i = 0; i < get_num_children(); i++) { |
362 | 0 | if (!this->get_child(i)->equals(*other_ptr->get_child(i))) { |
363 | 0 | return false; |
364 | 0 | } |
365 | 0 | } |
366 | 0 | return true; |
367 | 0 | } |
368 | | |
369 | | /* |
370 | | * For ANN range search we expect a comparison expression (LE/LT/GE/GT) whose left side is either: |
371 | | * 1) a vector distance function call, or |
372 | | * 2) a cast/virtual slot that unwraps to the function call when the planner promotes float to |
373 | | * double literals. |
374 | | * |
375 | | * Visually the logical tree looks like: |
376 | | * |
377 | | * FunctionCall(LE/LT/GE/GT) |
378 | | * |---------------- |
379 | | * | | |
380 | | * | | |
381 | | * VirtualSlotRef* Float32Literal/Float64Literal |
382 | | * | |
383 | | * | |
384 | | * Cast(Float -> Double)* |
385 | | * | |
386 | | * FunctionCall(distance) |
387 | | * |---------------- |
388 | | * | | |
389 | | * | | |
390 | | * SlotRef ArrayLiteral/Cast(String as Array<FLOAT>) |
391 | | * |
392 | | * Items marked with * are optional and depend on literal types/virtual column usage. The helper |
393 | | * below normalizes the shape and validates distance function, slot, and constant vector inputs. |
394 | | */ |
395 | | |
396 | | void VectorizedFnCall::prepare_ann_range_search( |
397 | | const doris::VectorSearchUserParams& user_params, |
398 | 8.09k | segment_v2::AnnRangeSearchRuntime& range_search_runtime, bool& suitable_for_ann_index) { |
399 | 8.09k | if (!suitable_for_ann_index) { |
400 | 0 | return; |
401 | 0 | } |
402 | | |
403 | 8.09k | if (OPS_FOR_ANN_RANGE_SEARCH.find(this->op()) == OPS_FOR_ANN_RANGE_SEARCH.end()) { |
404 | 6.63k | suitable_for_ann_index = false; |
405 | 6.63k | return; |
406 | 6.63k | } |
407 | | |
408 | 1.45k | auto mark_unsuitable = [&](const std::string& reason) { |
409 | 1.39k | suitable_for_ann_index = false; |
410 | 18.4E | VLOG_DEBUG << "ANN range search skipped: " << reason; |
411 | 1.39k | }; |
412 | | |
413 | 1.45k | range_search_runtime.is_le_or_lt = |
414 | 1.45k | (this->op() == TExprOpcode::LE || this->op() == TExprOpcode::LT); |
415 | | |
416 | 1.45k | DCHECK(_children.size() == 2); |
417 | | |
418 | 1.45k | auto left_child = get_child(0); |
419 | 1.45k | auto right_child = get_child(1); |
420 | | |
421 | | // ========== Step 1: Check left child - must be a distance function ========== |
422 | 1.45k | auto get_virtual_expr = [&](const VExprSPtr& expr, |
423 | 1.94k | std::shared_ptr<VirtualSlotRef>& slot_ref) -> VExprSPtr { |
424 | 1.94k | auto virtual_ref = std::dynamic_pointer_cast<VirtualSlotRef>(expr); |
425 | 1.94k | if (virtual_ref != nullptr) { |
426 | 20 | DCHECK(virtual_ref->get_virtual_column_expr() != nullptr); |
427 | 20 | slot_ref = virtual_ref; |
428 | 20 | return virtual_ref->get_virtual_column_expr(); |
429 | 20 | } |
430 | 1.92k | return expr; |
431 | 1.94k | }; |
432 | | |
433 | 1.45k | std::shared_ptr<VirtualSlotRef> vir_slot_ref; |
434 | 1.45k | auto normalized_left = get_virtual_expr(left_child, vir_slot_ref); |
435 | | |
436 | | // Try to find the distance function call, it may be wrapped in a Cast(Float->Double) |
437 | 1.45k | std::shared_ptr<VectorizedFnCall> function_call = |
438 | 1.45k | std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left); |
439 | 1.45k | bool has_float_to_double_cast = false; |
440 | | |
441 | 1.45k | if (function_call == nullptr) { |
442 | | // Check if it's a Cast expression wrapping a function call |
443 | 670 | auto cast_expr = std::dynamic_pointer_cast<VCastExpr>(normalized_left); |
444 | 670 | if (cast_expr == nullptr) { |
445 | 182 | mark_unsuitable("Left child is neither a function call nor a cast expression."); |
446 | 182 | return; |
447 | 182 | } |
448 | 488 | has_float_to_double_cast = true; |
449 | 488 | auto normalized_cast_child = get_virtual_expr(cast_expr->get_child(0), vir_slot_ref); |
450 | 488 | function_call = std::dynamic_pointer_cast<VectorizedFnCall>(normalized_cast_child); |
451 | 488 | if (function_call == nullptr) { |
452 | 422 | mark_unsuitable("Left child of cast is not a function call."); |
453 | 422 | return; |
454 | 422 | } |
455 | 488 | } |
456 | | |
457 | | // Check if it's a supported distance function |
458 | 852 | if (DISTANCE_FUNCS.find(function_call->_function_name) == DISTANCE_FUNCS.end()) { |
459 | 785 | mark_unsuitable(fmt::format("Left child is not a supported distance function: {}", |
460 | 785 | function_call->_function_name)); |
461 | 785 | return; |
462 | 785 | } |
463 | | |
464 | | // Strip the _approximate suffix to get metric type |
465 | 67 | std::string metric_name = function_call->_function_name; |
466 | 67 | metric_name = metric_name.substr(0, metric_name.size() - 12); |
467 | 67 | range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name); |
468 | | |
469 | | // ========== Step 2: Validate distance function arguments ========== |
470 | | // Identify the slot ref child and the constant query array child (ArrayLiteral or CAST to array) |
471 | 67 | Int32 idx_of_slot_ref = -1; |
472 | 67 | Int32 idx_of_array_expr = -1; |
473 | 145 | auto classify_child = [&](const VExprSPtr& child, UInt16 index) { |
474 | 145 | if (idx_of_slot_ref == -1 && std::dynamic_pointer_cast<VSlotRef>(child) != nullptr) { |
475 | 73 | idx_of_slot_ref = index; |
476 | 73 | return; |
477 | 73 | } |
478 | 72 | if (idx_of_array_expr == -1 && |
479 | 73 | (std::dynamic_pointer_cast<VArrayLiteral>(child) != nullptr || |
480 | 73 | std::dynamic_pointer_cast<VCastExpr>(child) != nullptr)) { |
481 | 66 | idx_of_array_expr = index; |
482 | 66 | } |
483 | 72 | }; |
484 | | |
485 | 211 | for (UInt16 i = 0; i < function_call->get_num_children(); ++i) { |
486 | 144 | classify_child(function_call->get_child(i), i); |
487 | 144 | } |
488 | | |
489 | 72 | if (idx_of_slot_ref == -1 || idx_of_array_expr == -1) { |
490 | 6 | mark_unsuitable("slot ref or array literal/cast is missing."); |
491 | 6 | return; |
492 | 6 | } |
493 | | |
494 | 61 | auto slot_ref = std::dynamic_pointer_cast<VSlotRef>( |
495 | 61 | function_call->get_child(static_cast<UInt16>(idx_of_slot_ref))); |
496 | 61 | range_search_runtime.src_col_idx = slot_ref->column_id(); |
497 | 61 | range_search_runtime.dst_col_idx = vir_slot_ref == nullptr ? -1 : vir_slot_ref->column_id(); |
498 | | |
499 | | // Materialize the constant array expression and validate its shape and types |
500 | 61 | auto array_expr = function_call->get_child(static_cast<UInt16>(idx_of_array_expr)); |
501 | 61 | auto extract_result = extract_query_vector(array_expr); |
502 | 61 | if (!extract_result.has_value()) { |
503 | 0 | mark_unsuitable("Failed to extract query vector from constant array expression."); |
504 | 0 | return; |
505 | 0 | } |
506 | 61 | range_search_runtime.query_value = extract_result.value(); |
507 | 61 | range_search_runtime.dim = range_search_runtime.query_value->size(); |
508 | | |
509 | | // ========== Step 3: Check right child - must be a float/double literal ========== |
510 | 61 | auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child); |
511 | 61 | if (right_literal == nullptr) { |
512 | 1 | mark_unsuitable("Right child is not a literal."); |
513 | 1 | return; |
514 | 1 | } |
515 | | |
516 | | // Handle nullable literal gracefully - just mark as unsuitable instead of crash |
517 | 60 | if (right_literal->is_nullable()) { |
518 | 0 | mark_unsuitable("Right literal is nullable, not supported for ANN range search."); |
519 | 0 | return; |
520 | 0 | } |
521 | | |
522 | 60 | auto right_type = right_literal->get_data_type(); |
523 | 60 | PrimitiveType right_primitive = right_type->get_primitive_type(); |
524 | 60 | const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT; |
525 | 60 | const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE; |
526 | | |
527 | 60 | if (!float32_literal && !float64_literal) { |
528 | 0 | mark_unsuitable("Right child is not a Float32Literal or Float64Literal."); |
529 | 0 | return; |
530 | 0 | } |
531 | | |
532 | | // Validate consistency: if we have Cast(Float->Double), right must be double literal |
533 | 60 | if (has_float_to_double_cast && !float64_literal) { |
534 | 0 | mark_unsuitable("Cast expression expects double literal on right side."); |
535 | 0 | return; |
536 | 0 | } |
537 | | |
538 | | // Extract radius value |
539 | 60 | auto right_col = right_literal->get_column_ptr()->convert_to_full_column_if_const(); |
540 | 60 | if (float32_literal) { |
541 | 7 | const ColumnFloat32* cf32_right = assert_cast<const ColumnFloat32*>(right_col.get()); |
542 | 7 | range_search_runtime.radius = cf32_right->get_data()[0]; |
543 | 53 | } else { |
544 | 53 | const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get()); |
545 | 53 | range_search_runtime.radius = static_cast<float>(cf64_right->get_data()[0]); |
546 | 53 | } |
547 | | |
548 | | // ========== Done: Mark as suitable for ANN range search ========== |
549 | 60 | range_search_runtime.is_ann_range_search = true; |
550 | 60 | range_search_runtime.user_params = user_params; |
551 | 18.4E | VLOG_DEBUG << fmt::format("Ann range search params: {}", range_search_runtime.to_string()); |
552 | 60 | return; |
553 | 60 | } |
554 | | |
555 | | Status VectorizedFnCall::evaluate_ann_range_search( |
556 | | const segment_v2::AnnRangeSearchRuntime& range_search_runtime, |
557 | | const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, |
558 | | const std::vector<ColumnId>& idx_to_cid, |
559 | | const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, |
560 | 7.13k | roaring::Roaring& row_bitmap, segment_v2::AnnIndexStats& ann_index_stats) { |
561 | 7.13k | if (range_search_runtime.is_ann_range_search == false) { |
562 | 7.07k | return Status::OK(); |
563 | 7.07k | } |
564 | | |
565 | 18.4E | VLOG_DEBUG << fmt::format("Try apply ann range search. Local search params: {}", |
566 | 18.4E | range_search_runtime.to_string()); |
567 | 60 | size_t origin_num = row_bitmap.cardinality(); |
568 | | |
569 | 60 | int idx_in_block = static_cast<int>(range_search_runtime.src_col_idx); |
570 | 18.4E | DCHECK(idx_in_block < idx_to_cid.size()) |
571 | 18.4E | << "idx_in_block: " << idx_in_block << ", idx_to_cid.size(): " << idx_to_cid.size(); |
572 | | |
573 | 60 | ColumnId src_col_cid = idx_to_cid[idx_in_block]; |
574 | 60 | DCHECK(src_col_cid < cid_to_index_iterators.size()); |
575 | 60 | segment_v2::IndexIterator* index_iterator = cid_to_index_iterators[src_col_cid].get(); |
576 | 60 | if (index_iterator == nullptr) { |
577 | 1 | VLOG_DEBUG << "ANN range search skipped: " |
578 | 0 | << fmt::format("No index iterator for column cid {}", src_col_cid); |
579 | 1 | ; |
580 | 1 | return Status::OK(); |
581 | 1 | } |
582 | | |
583 | 59 | segment_v2::AnnIndexIterator* ann_index_iterator = |
584 | 59 | dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterator); |
585 | 59 | if (ann_index_iterator == nullptr) { |
586 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
587 | 0 | << fmt::format("Column cid {} has no ANN index iterator", src_col_cid); |
588 | 0 | return Status::OK(); |
589 | 0 | } |
590 | 18.4E | DCHECK(ann_index_iterator->get_reader(AnnIndexReaderType::ANN) != nullptr) |
591 | 18.4E | << "Ann index iterator should have reader. Column cid: " << src_col_cid; |
592 | 59 | std::shared_ptr<AnnIndexReader> ann_index_reader = std::dynamic_pointer_cast<AnnIndexReader>( |
593 | 59 | ann_index_iterator->get_reader(segment_v2::AnnIndexReaderType::ANN)); |
594 | 18.4E | DCHECK(ann_index_reader != nullptr) |
595 | 18.4E | << "Ann index reader should not be null. Column cid: " << src_col_cid; |
596 | | // Check if metrics type is match. |
597 | 59 | if (ann_index_reader->get_metric_type() != range_search_runtime.metric_type) { |
598 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
599 | 0 | << fmt::format("Metric type mismatch. Index={} Query={}", |
600 | 0 | segment_v2::metric_to_string(ann_index_reader->get_metric_type()), |
601 | 0 | segment_v2::metric_to_string(range_search_runtime.metric_type)); |
602 | 0 | return Status::OK(); |
603 | 0 | } |
604 | | |
605 | | // Check dimension if available (>0) |
606 | 59 | const size_t index_dim = ann_index_reader->get_dimension(); |
607 | 61 | if (index_dim > 0 && index_dim != range_search_runtime.dim) { |
608 | 6 | return Status::InvalidArgument( |
609 | 6 | "Ann range search query dimension {} does not match index dimension {}", |
610 | 6 | range_search_runtime.dim, index_dim); |
611 | 6 | } |
612 | | |
613 | 53 | AnnRangeSearchParams params = range_search_runtime.to_range_search_params(); |
614 | | |
615 | 53 | params.roaring = &row_bitmap; |
616 | 53 | DCHECK(params.roaring != nullptr); |
617 | 53 | DCHECK(params.query_value != nullptr); |
618 | 53 | segment_v2::AnnRangeSearchResult result; |
619 | 53 | auto stats = std::make_unique<segment_v2::AnnIndexStats>(); |
620 | 53 | RETURN_IF_ERROR(ann_index_iterator->range_search(params, range_search_runtime.user_params, |
621 | 53 | &result, stats.get())); |
622 | | |
623 | 53 | #ifndef NDEBUG |
624 | 53 | if (range_search_runtime.is_le_or_lt == false && |
625 | 53 | ann_index_reader->get_metric_type() == AnnIndexMetric::L2) { |
626 | 16 | DCHECK(result.distance == nullptr) << "Should not have distance"; |
627 | 16 | } |
628 | 53 | if (range_search_runtime.is_le_or_lt == true && |
629 | 53 | ann_index_reader->get_metric_type() == AnnIndexMetric::IP) { |
630 | 6 | DCHECK(result.distance == nullptr); |
631 | 6 | } |
632 | 53 | #endif |
633 | 53 | DCHECK(result.roaring != nullptr); |
634 | 53 | row_bitmap = *result.roaring; |
635 | | |
636 | | // Process virtual column |
637 | 53 | if (range_search_runtime.dst_col_idx >= 0) { |
638 | | // Prepare materialization if we can use result from index. |
639 | | // Typical situation: range search and operator is LE or LT. |
640 | 22 | if (result.distance != nullptr) { |
641 | 13 | DCHECK(result.row_ids != nullptr); |
642 | 13 | ColumnId dst_col_cid = idx_to_cid[range_search_runtime.dst_col_idx]; |
643 | 13 | DCHECK(dst_col_cid < column_iterators.size()); |
644 | 13 | DCHECK(column_iterators[dst_col_cid] != nullptr); |
645 | 13 | segment_v2::ColumnIterator* column_iterator = column_iterators[dst_col_cid].get(); |
646 | 13 | DCHECK(column_iterator != nullptr); |
647 | 13 | segment_v2::VirtualColumnIterator* virtual_column_iterator = |
648 | 13 | dynamic_cast<segment_v2::VirtualColumnIterator*>(column_iterator); |
649 | 13 | DCHECK(virtual_column_iterator != nullptr); |
650 | | // Now convert distance to column |
651 | 13 | size_t size = result.roaring->cardinality(); |
652 | 13 | auto distance_col = ColumnFloat32::create(size); |
653 | 13 | const float* src = result.distance.get(); |
654 | 13 | float* dst = distance_col->get_data().data(); |
655 | 128 | for (size_t i = 0; i < size; ++i) { |
656 | 115 | dst[i] = src[i]; |
657 | 115 | } |
658 | 13 | virtual_column_iterator->prepare_materialization(std::move(distance_col), |
659 | 13 | std::move(result.row_ids)); |
660 | 13 | _virtual_column_is_fulfilled = true; |
661 | 13 | } else { |
662 | | // Whether the ANN index should have produced distance depends on metric and operator: |
663 | | // - L2: distance is produced for LE/LT; not produced for GE/GT |
664 | | // - IP: distance is produced for GE/GT; not produced for LE/LT |
665 | 9 | #ifndef NDEBUG |
666 | 9 | const bool should_have_distance = |
667 | 9 | (range_search_runtime.is_le_or_lt && |
668 | 9 | range_search_runtime.metric_type == AnnIndexMetric::L2) || |
669 | 9 | (!range_search_runtime.is_le_or_lt && |
670 | 9 | range_search_runtime.metric_type == AnnIndexMetric::IP); |
671 | | // If we expected distance but didn't get it, assert in debug to catch logic errors. |
672 | 9 | DCHECK(!should_have_distance) << "Expected distance from ANN index but got none"; |
673 | 9 | #endif |
674 | 9 | _virtual_column_is_fulfilled = false; |
675 | 9 | } |
676 | 31 | } else { |
677 | | // Dest is not virtual column. |
678 | 31 | _virtual_column_is_fulfilled = true; |
679 | 31 | } |
680 | | |
681 | 53 | _has_been_executed = true; |
682 | 18.4E | VLOG_DEBUG << fmt::format( |
683 | 18.4E | "Ann range search filtered {} rows, origin {} rows, virtual column is full-filled: {}", |
684 | 18.4E | origin_num - row_bitmap.cardinality(), origin_num, _virtual_column_is_fulfilled); |
685 | | |
686 | 53 | ann_index_stats = *stats; |
687 | 53 | return Status::OK(); |
688 | 53 | } |
689 | | |
690 | 683k | double VectorizedFnCall::execute_cost() const { |
691 | 683k | if (!_function) { |
692 | 0 | throw Exception( |
693 | 0 | Status::InternalError("Function is null in expression: {}", this->debug_string())); |
694 | 0 | } |
695 | 683k | double cost = _function->execute_cost(); |
696 | 1.36M | for (const auto& child : _children) { |
697 | 1.36M | cost += child->execute_cost(); |
698 | 1.36M | } |
699 | 683k | return cost; |
700 | 683k | } |
701 | | |
702 | | #include "common/compile_check_end.h" |
703 | | } // namespace doris |