Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "exprs/vectorized_fn_call.h" |
19 | | |
20 | | #include <fmt/compile.h> |
21 | | #include <fmt/format.h> |
22 | | #include <fmt/ranges.h> // IWYU pragma: keep |
23 | | #include <gen_cpp/Opcodes_types.h> |
24 | | #include <gen_cpp/Types_types.h> |
25 | | |
26 | | #include <memory> |
27 | | #include <ostream> |
28 | | |
29 | | #include "common/config.h" |
30 | | #include "common/exception.h" |
31 | | #include "common/logging.h" |
32 | | #include "common/status.h" |
33 | | #include "common/utils.h" |
34 | | #include "core/assert_cast.h" |
35 | | #include "core/block/block.h" |
36 | | #include "core/block/column_numbers.h" |
37 | | #include "core/column/column.h" |
38 | | #include "core/column/column_array.h" |
39 | | #include "core/column/column_nullable.h" |
40 | | #include "core/column/column_vector.h" |
41 | | #include "core/data_type/data_type.h" |
42 | | #include "core/data_type/data_type_agg_state.h" |
43 | | #include "core/types.h" |
44 | | #include "exec/common/util.hpp" |
45 | | #include "exec/pipeline/pipeline_task.h" |
46 | | #include "exprs/function/array/function_array_distance.h" |
47 | | #include "exprs/function/function_agg_state.h" |
48 | | #include "exprs/function/function_fake.h" |
49 | | #include "exprs/function/function_java_udf.h" |
50 | | #include "exprs/function/function_python_udf.h" |
51 | | #include "exprs/function/function_rpc.h" |
52 | | #include "exprs/function/simple_function_factory.h" |
53 | | #include "exprs/function_context.h" |
54 | | #include "exprs/varray_literal.h" |
55 | | #include "exprs/vcast_expr.h" |
56 | | #include "exprs/vexpr_context.h" |
57 | | #include "exprs/virtual_slot_ref.h" |
58 | | #include "exprs/vliteral.h" |
59 | | #include "runtime/runtime_state.h" |
60 | | #include "storage/index/ann/ann_index.h" |
61 | | #include "storage/index/ann/ann_index_iterator.h" |
62 | | #include "storage/index/ann/ann_search_params.h" |
63 | | #include "storage/index/index_reader.h" |
64 | | #include "storage/segment/column_reader.h" |
65 | | #include "storage/segment/virtual_column_iterator.h" |
66 | | |
67 | | namespace doris { |
68 | | class RowDescriptor; |
69 | | class RuntimeState; |
70 | | class TExprNode; |
71 | | } // namespace doris |
72 | | |
73 | | namespace doris { |
74 | | |
75 | | const std::string AGG_STATE_SUFFIX = "_state"; |
76 | | |
77 | | // Now left child is a function call, we need to check if it is a distance function |
78 | | const static std::set<std::string> DISTANCE_FUNCS = {L2DistanceApproximate::name, |
79 | | InnerProductApproximate::name}; |
80 | | const static std::set<TExprOpcode::type> OPS_FOR_ANN_RANGE_SEARCH = { |
81 | | TExprOpcode::GE, TExprOpcode::LE, TExprOpcode::LE, TExprOpcode::GT, TExprOpcode::LT}; |
82 | | |
83 | 159 | VectorizedFnCall::VectorizedFnCall(const TExprNode& node) : VExpr(node) { |
84 | 159 | _function_name = _fn.name.function_name; |
85 | 159 | } |
86 | | |
87 | | Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, |
88 | 136 | VExprContext* context) { |
89 | 136 | RETURN_IF_ERROR_OR_PREPARED(VExpr::prepare(state, desc, context)); |
90 | 136 | ColumnsWithTypeAndName argument_template; |
91 | 136 | argument_template.reserve(_children.size()); |
92 | 260 | for (auto child : _children) { |
93 | 260 | if (child->is_literal()) { |
94 | | // For some functions, he needs some literal columns to derive the return type. |
95 | 99 | auto literal_node = std::dynamic_pointer_cast<VLiteral>(child); |
96 | 99 | argument_template.emplace_back(literal_node->get_column_ptr(), child->data_type(), |
97 | 99 | child->expr_name()); |
98 | 161 | } else { |
99 | 161 | argument_template.emplace_back(nullptr, child->data_type(), child->expr_name()); |
100 | 161 | } |
101 | 260 | } |
102 | | |
103 | 136 | _expr_name = fmt::format("VectorizedFnCall[{}](arguments={},return={})", _fn.name.function_name, |
104 | 136 | get_child_names(), _data_type->get_name()); |
105 | 136 | if (_fn.binary_type == TFunctionBinaryType::RPC) { |
106 | 0 | _function = FunctionRPC::create(_fn, argument_template, _data_type); |
107 | 136 | } else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { |
108 | 0 | if (config::enable_java_support) { |
109 | 0 | if (_fn.is_udtf_function) { |
110 | | // fake function. it's no use and can't execute. |
111 | 0 | auto builder = |
112 | 0 | std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create()); |
113 | 0 | _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>()); |
114 | 0 | } else { |
115 | 0 | _function = JavaFunctionCall::create(_fn, argument_template, _data_type); |
116 | 0 | } |
117 | 0 | } else { |
118 | 0 | return Status::InternalError( |
119 | 0 | "Java UDF is not enabled, you can change be config enable_java_support to true " |
120 | 0 | "and restart be."); |
121 | 0 | } |
122 | 136 | } else if (_fn.binary_type == TFunctionBinaryType::PYTHON_UDF) { |
123 | 0 | if (config::enable_python_udf_support) { |
124 | 0 | if (_fn.is_udtf_function) { |
125 | | // fake function. it's no use and can't execute. |
126 | | // Python UDTF is executed via PythonUDTFFunction in table function path |
127 | 0 | auto builder = |
128 | 0 | std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create()); |
129 | 0 | _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>()); |
130 | 0 | } else { |
131 | 0 | _function = PythonFunctionCall::create(_fn, argument_template, _data_type); |
132 | 0 | LOG(INFO) << fmt::format( |
133 | 0 | "create python function call: {}, runtime version: {}, function code: {}", |
134 | 0 | _fn.name.function_name, _fn.runtime_version, _fn.function_code); |
135 | 0 | } |
136 | 0 | } else { |
137 | 0 | return Status::InternalError( |
138 | 0 | "Python UDF is not enabled, you can change be config enable_python_udf_support " |
139 | 0 | "to true and restart be."); |
140 | 0 | } |
141 | 136 | } else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) { |
142 | 0 | DataTypes argument_types; |
143 | 0 | for (auto column : argument_template) { |
144 | 0 | argument_types.emplace_back(column.type); |
145 | 0 | } |
146 | |
|
147 | 0 | if (match_suffix(_fn.name.function_name, AGG_STATE_SUFFIX)) { |
148 | 0 | if (_data_type->is_nullable()) { |
149 | 0 | return Status::InternalError("State function's return type must be not nullable"); |
150 | 0 | } |
151 | 0 | if (_data_type->get_primitive_type() != PrimitiveType::TYPE_AGG_STATE) { |
152 | 0 | return Status::InternalError( |
153 | 0 | "State function's return type must be agg_state but get {}", |
154 | 0 | _data_type->get_family_name()); |
155 | 0 | } |
156 | 0 | _function = FunctionAggState::create( |
157 | 0 | argument_types, _data_type, |
158 | 0 | assert_cast<const DataTypeAggState*>(_data_type.get())->get_nested_function()); |
159 | 0 | } else { |
160 | 0 | return Status::InternalError("Function {} is not endwith '_state'", _fn.signature); |
161 | 0 | } |
162 | 136 | } else { |
163 | | // get the function. won't prepare function. |
164 | 136 | _function = SimpleFunctionFactory::instance().get_function( |
165 | 136 | _fn.name.function_name, argument_template, _data_type, |
166 | 136 | {.new_version_unix_timestamp = state->query_options().new_version_unix_timestamp}, |
167 | 136 | state->be_exec_version()); |
168 | 136 | } |
169 | 136 | if (_function == nullptr) { |
170 | 0 | return Status::InternalError("Could not find function {}, arg {} return {} ", |
171 | 0 | _fn.name.function_name, get_child_type_names(), |
172 | 0 | _data_type->get_name()); |
173 | 0 | } |
174 | 136 | VExpr::register_function_context(state, context); |
175 | 136 | _function_name = _fn.name.function_name; |
176 | 136 | _prepare_finished = true; |
177 | | |
178 | 136 | FunctionContext* fn_ctx = context->fn_context(_fn_context_index); |
179 | 136 | if (fn().__isset.dict_function) { |
180 | 0 | fn_ctx->set_dict_function(fn().dict_function); |
181 | 0 | } |
182 | 136 | return Status::OK(); |
183 | 136 | } |
184 | | |
185 | | Status VectorizedFnCall::open(RuntimeState* state, VExprContext* context, |
186 | 113 | FunctionContext::FunctionStateScope scope) { |
187 | 113 | DCHECK(_prepare_finished); |
188 | 214 | for (auto& i : _children) { |
189 | 214 | RETURN_IF_ERROR(i->open(state, context, scope)); |
190 | 214 | } |
191 | 113 | RETURN_IF_ERROR(VExpr::init_function_context(state, context, scope, _function)); |
192 | 113 | if (scope == FunctionContext::FRAGMENT_LOCAL) { |
193 | 90 | RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr)); |
194 | 90 | } |
195 | 113 | _open_finished = true; |
196 | 113 | return Status::OK(); |
197 | 113 | } |
198 | | |
199 | 208 | void VectorizedFnCall::close(VExprContext* context, FunctionContext::FunctionStateScope scope) { |
200 | 208 | VExpr::close_function_context(context, scope, _function); |
201 | 208 | VExpr::close(context, scope); |
202 | 208 | } |
203 | | |
204 | 0 | Status VectorizedFnCall::evaluate_inverted_index(VExprContext* context, uint32_t segment_num_rows) { |
205 | 0 | if (get_num_children() < 1) { |
206 | | // score() and similar 0-children virtual column functions don't need |
207 | | // inverted index evaluation; return OK to skip gracefully. |
208 | 0 | return Status::OK(); |
209 | 0 | } |
210 | 0 | return _evaluate_inverted_index(context, _function, segment_num_rows); |
211 | 0 | } |
212 | | |
213 | | Status VectorizedFnCall::_do_execute(VExprContext* context, const Block* block, |
214 | | const Selector* selector, size_t count, |
215 | 106 | ColumnPtr& result_column, ColumnPtr* arg_column) const { |
216 | 106 | if (is_const_and_have_executed()) { // const have executed in open function |
217 | 0 | result_column = get_result_from_const(count); |
218 | 0 | return Status::OK(); |
219 | 0 | } |
220 | 106 | if (fast_execute(context, selector, count, result_column)) { |
221 | 0 | return Status::OK(); |
222 | 0 | } |
223 | 106 | DBUG_EXECUTE_IF("VectorizedFnCall.must_in_slow_path", { |
224 | 106 | if (get_child(0)->is_slot_ref()) { |
225 | 106 | auto debug_col_name = DebugPoints::instance()->get_debug_param_or_default<std::string>( |
226 | 106 | "VectorizedFnCall.must_in_slow_path", "column_name", ""); |
227 | | |
228 | 106 | std::vector<std::string> column_names; |
229 | 106 | boost::split(column_names, debug_col_name, boost::algorithm::is_any_of(",")); |
230 | | |
231 | 106 | auto* column_slot_ref = assert_cast<VSlotRef*>(get_child(0).get()); |
232 | 106 | std::string column_name = column_slot_ref->expr_name(); |
233 | 106 | auto it = std::find(column_names.begin(), column_names.end(), column_name); |
234 | 106 | if (it == column_names.end()) { |
235 | 106 | return Status::Error<ErrorCode::INTERNAL_ERROR>( |
236 | 106 | "column {} should in slow path while VectorizedFnCall::execute.", |
237 | 106 | column_name); |
238 | 106 | } |
239 | 106 | } |
240 | 106 | }) |
241 | 106 | DCHECK(_open_finished || block == nullptr) << debug_string(); |
242 | | |
243 | 106 | Block temp_block; |
244 | 106 | ColumnNumbers args(_children.size()); |
245 | | |
246 | 318 | for (int i = 0; i < _children.size(); ++i) { |
247 | 212 | ColumnPtr tmp_arg_column; |
248 | 212 | RETURN_IF_ERROR( |
249 | 212 | _children[i]->execute_column(context, block, selector, count, tmp_arg_column)); |
250 | 212 | auto arg_type = _children[i]->execute_type(block); |
251 | 212 | temp_block.insert({tmp_arg_column, arg_type, _children[i]->expr_name()}); |
252 | 212 | args[i] = i; |
253 | | |
254 | 212 | if (arg_column != nullptr && i == 0) { |
255 | 0 | *arg_column = tmp_arg_column; |
256 | 0 | } |
257 | 212 | } |
258 | | |
259 | 106 | uint32_t num_columns_without_result = temp_block.columns(); |
260 | | // prepare a column to save result |
261 | 106 | temp_block.insert({nullptr, _data_type, _expr_name}); |
262 | | |
263 | 106 | DBUG_EXECUTE_IF("VectorizedFnCall.wait_before_execute", { |
264 | 106 | auto possibility = DebugPoints::instance()->get_debug_param_or_default<double>( |
265 | 106 | "VectorizedFnCall.wait_before_execute", "possibility", 0); |
266 | 106 | if (random_bool_slow(possibility)) { |
267 | 106 | LOG(WARNING) << "VectorizedFnCall::execute sleep 30s"; |
268 | 106 | sleep(30); |
269 | 106 | } |
270 | 106 | }); |
271 | | |
272 | 106 | RETURN_IF_ERROR(_function->execute(context->fn_context(_fn_context_index), temp_block, args, |
273 | 106 | num_columns_without_result, count)); |
274 | 106 | result_column = temp_block.get_by_position(num_columns_without_result).column; |
275 | 106 | DCHECK_EQ(result_column->size(), count); |
276 | 106 | RETURN_IF_ERROR(result_column->column_self_check()); |
277 | 106 | return Status::OK(); |
278 | 106 | } |
279 | | |
280 | 0 | size_t VectorizedFnCall::estimate_memory(const size_t rows) { |
281 | 0 | if (is_const_and_have_executed()) { // const have execute in open function |
282 | 0 | return 0; |
283 | 0 | } |
284 | | |
285 | 0 | size_t estimate_size = 0; |
286 | 0 | for (auto& child : _children) { |
287 | 0 | estimate_size += child->estimate_memory(rows); |
288 | 0 | } |
289 | |
|
290 | 0 | if (_data_type->have_maximum_size_of_value()) { |
291 | 0 | estimate_size += rows * _data_type->get_size_of_value_in_memory(); |
292 | 0 | } else { |
293 | 0 | estimate_size += rows * 512; /// FIXME: estimated value... |
294 | 0 | } |
295 | 0 | return estimate_size; |
296 | 0 | } |
297 | | |
298 | | Status VectorizedFnCall::execute_runtime_filter(VExprContext* context, const Block* block, |
299 | | const uint8_t* __restrict filter, size_t count, |
300 | | ColumnPtr& result_column, |
301 | 0 | ColumnPtr* arg_column) const { |
302 | 0 | return _do_execute(context, block, nullptr, count, result_column, arg_column); |
303 | 0 | } |
304 | | |
305 | | Status VectorizedFnCall::execute_column_impl(VExprContext* context, const Block* block, |
306 | | const Selector* selector, size_t count, |
307 | 106 | ColumnPtr& result_column) const { |
308 | 106 | return _do_execute(context, block, selector, count, result_column, nullptr); |
309 | 106 | } |
310 | | |
311 | 60 | const std::string& VectorizedFnCall::expr_name() const { |
312 | 60 | return _expr_name; |
313 | 60 | } |
314 | | |
315 | 6 | std::string VectorizedFnCall::function_name() const { |
316 | 6 | return _function_name; |
317 | 6 | } |
318 | | |
319 | 2 | std::string VectorizedFnCall::debug_string() const { |
320 | 2 | std::stringstream out; |
321 | 2 | out << "VectorizedFn["; |
322 | 2 | out << _expr_name; |
323 | 2 | out << "]{"; |
324 | 2 | bool first = true; |
325 | 2 | for (const auto& input_expr : children()) { |
326 | 0 | if (first) { |
327 | 0 | first = false; |
328 | 0 | } else { |
329 | 0 | out << ","; |
330 | 0 | } |
331 | 0 | out << "\n" << input_expr->debug_string(); |
332 | 0 | } |
333 | 2 | out << "}"; |
334 | 2 | return out.str(); |
335 | 2 | } |
336 | | |
337 | 0 | std::string VectorizedFnCall::debug_string(const std::vector<VectorizedFnCall*>& agg_fns) { |
338 | 0 | std::stringstream out; |
339 | 0 | out << "["; |
340 | 0 | for (int i = 0; i < agg_fns.size(); ++i) { |
341 | 0 | out << (i == 0 ? "" : " ") << agg_fns[i]->debug_string(); |
342 | 0 | } |
343 | 0 | out << "]"; |
344 | 0 | return out.str(); |
345 | 0 | } |
346 | | |
347 | 0 | bool VectorizedFnCall::can_push_down_to_index() const { |
348 | 0 | return _function->can_push_down_to_index(); |
349 | 0 | } |
350 | | |
351 | 0 | bool VectorizedFnCall::equals(const VExpr& other) { |
352 | 0 | const auto* other_ptr = dynamic_cast<const VectorizedFnCall*>(&other); |
353 | 0 | if (!other_ptr) { |
354 | 0 | return false; |
355 | 0 | } |
356 | 0 | if (this->_function_name != other_ptr->_function_name) { |
357 | 0 | return false; |
358 | 0 | } |
359 | 0 | if (get_num_children() != other_ptr->get_num_children()) { |
360 | 0 | return false; |
361 | 0 | } |
362 | 0 | for (uint16_t i = 0; i < get_num_children(); i++) { |
363 | 0 | if (!this->get_child(i)->equals(*other_ptr->get_child(i))) { |
364 | 0 | return false; |
365 | 0 | } |
366 | 0 | } |
367 | 0 | return true; |
368 | 0 | } |
369 | | |
370 | | /* |
371 | | * For ANN range search we expect a comparison expression (LE/LT/GE/GT) whose left side is either: |
372 | | * 1) a vector distance function call, or |
373 | | * 2) a cast/virtual slot that unwraps to the function call when the planner promotes float to |
374 | | * double literals. |
375 | | * |
376 | | * Visually the logical tree looks like: |
377 | | * |
378 | | * FunctionCall(LE/LT/GE/GT) |
379 | | * |---------------- |
380 | | * | | |
381 | | * | | |
382 | | * VirtualSlotRef* Float32Literal/Float64Literal |
383 | | * | |
384 | | * | |
385 | | * Cast(Float -> Double)* |
386 | | * | |
387 | | * FunctionCall(distance) |
388 | | * |---------------- |
389 | | * | | |
390 | | * | | |
391 | | * SlotRef ArrayLiteral/Cast(String as Array<FLOAT>) |
392 | | * |
393 | | * Items marked with * are optional and depend on literal types/virtual column usage. The helper |
394 | | * below normalizes the shape and validates distance function, slot, and constant vector inputs. |
395 | | */ |
396 | | |
397 | | void VectorizedFnCall::prepare_ann_range_search( |
398 | | const doris::VectorSearchUserParams& user_params, |
399 | 7 | segment_v2::AnnRangeSearchRuntime& range_search_runtime, bool& suitable_for_ann_index) { |
400 | 7 | if (!suitable_for_ann_index) { |
401 | 0 | return; |
402 | 0 | } |
403 | | |
404 | 7 | if (OPS_FOR_ANN_RANGE_SEARCH.find(this->op()) == OPS_FOR_ANN_RANGE_SEARCH.end()) { |
405 | 0 | suitable_for_ann_index = false; |
406 | 0 | return; |
407 | 0 | } |
408 | | |
409 | 7 | auto mark_unsuitable = [&](const std::string& reason) { |
410 | 1 | suitable_for_ann_index = false; |
411 | 1 | VLOG_DEBUG << "ANN range search skipped: " << reason; |
412 | 1 | }; |
413 | | |
414 | 7 | range_search_runtime.is_le_or_lt = |
415 | 7 | (this->op() == TExprOpcode::LE || this->op() == TExprOpcode::LT); |
416 | | |
417 | 7 | DCHECK(_children.size() == 2); |
418 | | |
419 | 7 | auto left_child = get_child(0); |
420 | 7 | auto right_child = get_child(1); |
421 | | |
422 | | // ========== Step 1: Check left child - must be a distance function ========== |
423 | 7 | auto get_virtual_expr = [&](const VExprSPtr& expr, |
424 | 7 | std::shared_ptr<VirtualSlotRef>& slot_ref) -> VExprSPtr { |
425 | 7 | auto virtual_ref = std::dynamic_pointer_cast<VirtualSlotRef>(expr); |
426 | 7 | if (virtual_ref != nullptr) { |
427 | 7 | DCHECK(virtual_ref->get_virtual_column_expr() != nullptr); |
428 | 7 | slot_ref = virtual_ref; |
429 | 7 | return virtual_ref->get_virtual_column_expr(); |
430 | 7 | } |
431 | 0 | return expr; |
432 | 7 | }; |
433 | | |
434 | 7 | std::shared_ptr<VirtualSlotRef> vir_slot_ref; |
435 | 7 | auto normalized_left = get_virtual_expr(left_child, vir_slot_ref); |
436 | | |
437 | | // Try to find the distance function call, it may be wrapped in a Cast(Float->Double) |
438 | 7 | std::shared_ptr<VectorizedFnCall> function_call = |
439 | 7 | std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left); |
440 | 7 | bool has_float_to_double_cast = false; |
441 | | |
442 | 7 | if (function_call == nullptr) { |
443 | | // Check if it's a Cast expression wrapping a function call |
444 | 0 | auto cast_expr = std::dynamic_pointer_cast<VCastExpr>(normalized_left); |
445 | 0 | if (cast_expr == nullptr) { |
446 | 0 | mark_unsuitable("Left child is neither a function call nor a cast expression."); |
447 | 0 | return; |
448 | 0 | } |
449 | 0 | has_float_to_double_cast = true; |
450 | 0 | auto normalized_cast_child = get_virtual_expr(cast_expr->get_child(0), vir_slot_ref); |
451 | 0 | function_call = std::dynamic_pointer_cast<VectorizedFnCall>(normalized_cast_child); |
452 | 0 | if (function_call == nullptr) { |
453 | 0 | mark_unsuitable("Left child of cast is not a function call."); |
454 | 0 | return; |
455 | 0 | } |
456 | 0 | } |
457 | | |
458 | | // Check if it's a supported distance function |
459 | 7 | if (DISTANCE_FUNCS.find(function_call->_function_name) == DISTANCE_FUNCS.end()) { |
460 | 0 | mark_unsuitable(fmt::format("Left child is not a supported distance function: {}", |
461 | 0 | function_call->_function_name)); |
462 | 0 | return; |
463 | 0 | } |
464 | | |
465 | | // Strip the _approximate suffix to get metric type |
466 | 7 | std::string metric_name = function_call->_function_name; |
467 | 7 | metric_name = metric_name.substr(0, metric_name.size() - 12); |
468 | 7 | range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name); |
469 | | |
470 | | // ========== Step 2: Validate distance function arguments ========== |
471 | | // Identify the slot ref child and the constant query array child (ArrayLiteral or CAST to array) |
472 | 7 | Int32 idx_of_slot_ref = -1; |
473 | 7 | Int32 idx_of_array_expr = -1; |
474 | 14 | auto classify_child = [&](const VExprSPtr& child, UInt16 index) { |
475 | 14 | if (idx_of_slot_ref == -1 && std::dynamic_pointer_cast<VSlotRef>(child) != nullptr) { |
476 | 7 | idx_of_slot_ref = index; |
477 | 7 | return; |
478 | 7 | } |
479 | 7 | if (idx_of_array_expr == -1 && |
480 | 7 | (std::dynamic_pointer_cast<VArrayLiteral>(child) != nullptr || |
481 | 7 | std::dynamic_pointer_cast<VCastExpr>(child) != nullptr)) { |
482 | 7 | idx_of_array_expr = index; |
483 | 7 | } |
484 | 7 | }; |
485 | | |
486 | 21 | for (UInt16 i = 0; i < function_call->get_num_children(); ++i) { |
487 | 14 | classify_child(function_call->get_child(i), i); |
488 | 14 | } |
489 | | |
490 | 7 | if (idx_of_slot_ref == -1 || idx_of_array_expr == -1) { |
491 | 0 | mark_unsuitable("slot ref or array literal/cast is missing."); |
492 | 0 | return; |
493 | 0 | } |
494 | | |
495 | 7 | auto slot_ref = std::dynamic_pointer_cast<VSlotRef>( |
496 | 7 | function_call->get_child(static_cast<UInt16>(idx_of_slot_ref))); |
497 | 7 | range_search_runtime.src_col_idx = slot_ref->column_id(); |
498 | 7 | range_search_runtime.dst_col_idx = vir_slot_ref == nullptr ? -1 : vir_slot_ref->column_id(); |
499 | | |
500 | | // Materialize the constant array expression and validate its shape and types |
501 | 7 | auto array_expr = function_call->get_child(static_cast<UInt16>(idx_of_array_expr)); |
502 | 7 | auto extract_result = extract_query_vector(array_expr); |
503 | 7 | if (!extract_result.has_value()) { |
504 | 0 | mark_unsuitable("Failed to extract query vector from constant array expression."); |
505 | 0 | return; |
506 | 0 | } |
507 | 7 | range_search_runtime.query_value = extract_result.value(); |
508 | 7 | range_search_runtime.dim = range_search_runtime.query_value->size(); |
509 | | |
510 | | // ========== Step 3: Check right child - must be a float/double literal ========== |
511 | 7 | auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child); |
512 | 7 | if (right_literal == nullptr) { |
513 | 1 | mark_unsuitable("Right child is not a literal."); |
514 | 1 | return; |
515 | 1 | } |
516 | | |
517 | | // Handle nullable literal gracefully - just mark as unsuitable instead of crash |
518 | 6 | if (right_literal->is_nullable()) { |
519 | 0 | mark_unsuitable("Right literal is nullable, not supported for ANN range search."); |
520 | 0 | return; |
521 | 0 | } |
522 | | |
523 | 6 | auto right_type = right_literal->get_data_type(); |
524 | 6 | PrimitiveType right_primitive = right_type->get_primitive_type(); |
525 | 6 | const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT; |
526 | 6 | const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE; |
527 | | |
528 | 6 | if (!float32_literal && !float64_literal) { |
529 | 0 | mark_unsuitable("Right child is not a Float32Literal or Float64Literal."); |
530 | 0 | return; |
531 | 0 | } |
532 | | |
533 | | // Validate consistency: if we have Cast(Float->Double), right must be double literal |
534 | 6 | if (has_float_to_double_cast && !float64_literal) { |
535 | 0 | mark_unsuitable("Cast expression expects double literal on right side."); |
536 | 0 | return; |
537 | 0 | } |
538 | | |
539 | | // Extract radius value |
540 | 6 | auto right_col = right_literal->get_column_ptr()->convert_to_full_column_if_const(); |
541 | 6 | if (float32_literal) { |
542 | 6 | const ColumnFloat32* cf32_right = assert_cast<const ColumnFloat32*>(right_col.get()); |
543 | 6 | range_search_runtime.radius = cf32_right->get_data()[0]; |
544 | 6 | } else { |
545 | 0 | const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get()); |
546 | 0 | range_search_runtime.radius = static_cast<float>(cf64_right->get_data()[0]); |
547 | 0 | } |
548 | | |
549 | | // ========== Done: Mark as suitable for ANN range search ========== |
550 | 6 | range_search_runtime.is_ann_range_search = true; |
551 | 6 | range_search_runtime.user_params = user_params; |
552 | 6 | VLOG_DEBUG << fmt::format("Ann range search params: {}", range_search_runtime.to_string()); |
553 | 6 | return; |
554 | 6 | } |
555 | | |
556 | | Status VectorizedFnCall::evaluate_ann_range_search( |
557 | | const segment_v2::AnnRangeSearchRuntime& range_search_runtime, |
558 | | const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, |
559 | | const std::vector<ColumnId>& idx_to_cid, |
560 | | const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, |
561 | | size_t rows_of_segment, roaring::Roaring& row_bitmap, |
562 | | segment_v2::AnnIndexStats& ann_index_stats, bool enable_result_cache, |
563 | 6 | AnnRangeSearchEvaluationResult& evaluation_result) { |
564 | 6 | evaluation_result = {}; |
565 | 6 | if (range_search_runtime.is_ann_range_search == false) { |
566 | 0 | return Status::OK(); |
567 | 0 | } |
568 | | |
569 | 6 | VLOG_DEBUG << fmt::format("Try apply ann range search. Local search params: {}", |
570 | 0 | range_search_runtime.to_string()); |
571 | 6 | size_t origin_num = row_bitmap.cardinality(); |
572 | | |
573 | 6 | const auto idx_in_block = range_search_runtime.src_col_idx; |
574 | 6 | DCHECK_LT(idx_in_block, idx_to_cid.size()) |
575 | 0 | << "idx_in_block: " << idx_in_block << ", idx_to_cid.size(): " << idx_to_cid.size(); |
576 | | |
577 | 6 | ColumnId src_col_cid = idx_to_cid[idx_in_block]; |
578 | 6 | DCHECK(src_col_cid < cid_to_index_iterators.size()); |
579 | 6 | segment_v2::IndexIterator* index_iterator = cid_to_index_iterators[src_col_cid].get(); |
580 | 6 | if (index_iterator == nullptr) { |
581 | 1 | VLOG_DEBUG << "ANN range search skipped: " |
582 | 0 | << fmt::format("No index iterator for column cid {}", src_col_cid); |
583 | 1 | ; |
584 | 1 | return Status::OK(); |
585 | 1 | } |
586 | | |
587 | 5 | segment_v2::AnnIndexIterator* ann_index_iterator = |
588 | 5 | dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterator); |
589 | 5 | if (ann_index_iterator == nullptr) { |
590 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
591 | 0 | << fmt::format("Column cid {} has no ANN index iterator", src_col_cid); |
592 | 0 | return Status::OK(); |
593 | 0 | } |
594 | 5 | DCHECK(ann_index_iterator->get_reader(AnnIndexReaderType::ANN) != nullptr) |
595 | 0 | << "Ann index iterator should have reader. Column cid: " << src_col_cid; |
596 | 5 | std::shared_ptr<AnnIndexReader> ann_index_reader = std::dynamic_pointer_cast<AnnIndexReader>( |
597 | 5 | ann_index_iterator->get_reader(segment_v2::AnnIndexReaderType::ANN)); |
598 | 5 | DCHECK(ann_index_reader != nullptr) |
599 | 0 | << "Ann index reader should not be null. Column cid: " << src_col_cid; |
600 | | // Check if metrics type is match. |
601 | 5 | if (ann_index_reader->get_metric_type() != range_search_runtime.metric_type) { |
602 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
603 | 0 | << fmt::format("Metric type mismatch. Index={} Query={}", |
604 | 0 | segment_v2::metric_to_string(ann_index_reader->get_metric_type()), |
605 | 0 | segment_v2::metric_to_string(range_search_runtime.metric_type)); |
606 | 0 | return Status::OK(); |
607 | 0 | } |
608 | | |
609 | | // Check dimension if available (>0) |
610 | 5 | const size_t index_dim = ann_index_reader->get_dimension(); |
611 | 5 | if (index_dim > 0 && index_dim != range_search_runtime.dim) { |
612 | 1 | return Status::InvalidArgument( |
613 | 1 | "Ann range search query dimension {} does not match index dimension {}", |
614 | 1 | range_search_runtime.dim, index_dim); |
615 | 1 | } |
616 | | |
617 | 4 | const auto& user_params = range_search_runtime.user_params; |
618 | 4 | if (user_params.should_fallback_ann_index_by_small_candidate(origin_num, rows_of_segment)) { |
619 | 0 | VLOG_DEBUG << fmt::format( |
620 | 0 | "Ann range search input rows {} reach small candidate threshold, " |
621 | 0 | "rows_of_segment: {}, absolute_threshold: {}, percent_threshold: {}, " |
622 | 0 | "will not use ann index to filter", |
623 | 0 | origin_num, rows_of_segment, user_params.ann_index_candidate_rows_threshold, |
624 | 0 | user_params.ann_index_candidate_rows_percent_threshold); |
625 | 0 | ann_index_stats.fall_back_brute_force_cnt += 1; |
626 | 0 | ann_index_stats.range_fallback_by_small_candidate_cnt += 1; |
627 | 0 | ann_index_stats.range_fallback_small_candidate_rows += origin_num; |
628 | 0 | return Status::OK(); |
629 | 0 | } |
630 | | |
631 | 4 | auto stats = std::make_unique<segment_v2::AnnIndexStats>(); |
632 | | // Track load index timing |
633 | 4 | { |
634 | 4 | SCOPED_TIMER(&(stats->load_index_costs_ns)); |
635 | 4 | if (!ann_index_iterator->try_load_index()) { |
636 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
637 | 0 | << fmt::format("Failed to load ANN index for column cid {}", src_col_cid); |
638 | 0 | ann_index_stats.fall_back_brute_force_cnt += 1; |
639 | 0 | return Status::OK(); |
640 | 0 | } |
641 | 4 | double load_costs_ms = static_cast<double>(stats->load_index_costs_ns.value()) / 1000000.0; |
642 | 4 | DorisMetrics::instance()->ann_index_load_costs_ms->increment( |
643 | 4 | static_cast<int64_t>(load_costs_ms)); |
644 | 4 | } |
645 | | |
646 | 0 | AnnRangeSearchParams params = range_search_runtime.to_range_search_params(); |
647 | | |
648 | 4 | params.roaring = &row_bitmap; |
649 | 4 | params.enable_result_cache = enable_result_cache; |
650 | 4 | DCHECK(params.roaring != nullptr); |
651 | 4 | DCHECK(params.query_value != nullptr); |
652 | 4 | segment_v2::AnnRangeSearchResult result; |
653 | 4 | RETURN_IF_ERROR(ann_index_iterator->range_search(params, range_search_runtime.user_params, |
654 | 4 | &result, stats.get())); |
655 | | |
656 | 4 | #ifndef NDEBUG |
657 | 4 | if (range_search_runtime.is_le_or_lt == false && |
658 | 4 | ann_index_reader->get_metric_type() == AnnIndexMetric::L2) { |
659 | 2 | DCHECK(result.distance == nullptr) << "Should not have distance"; |
660 | 2 | } |
661 | 4 | if (range_search_runtime.is_le_or_lt == true && |
662 | 4 | ann_index_reader->get_metric_type() == AnnIndexMetric::IP) { |
663 | 0 | DCHECK(result.distance == nullptr); |
664 | 0 | } |
665 | 4 | #endif |
666 | 4 | DCHECK(result.roaring != nullptr); |
667 | 4 | row_bitmap = *result.roaring; |
668 | | |
669 | | // Process virtual column |
670 | 4 | bool dist_fulfilled = false; |
671 | 4 | if (range_search_runtime.dst_col_idx >= 0) { |
672 | | // Prepare materialization if we can use result from index. |
673 | | // Typical situation: range search and operator is LE or LT. |
674 | 4 | if (result.distance != nullptr) { |
675 | 2 | DCHECK(result.row_ids != nullptr); |
676 | 2 | ColumnId dst_col_cid = idx_to_cid[range_search_runtime.dst_col_idx]; |
677 | 2 | DCHECK(dst_col_cid < column_iterators.size()); |
678 | 2 | DCHECK(column_iterators[dst_col_cid] != nullptr); |
679 | 2 | segment_v2::ColumnIterator* column_iterator = column_iterators[dst_col_cid].get(); |
680 | 2 | DCHECK(column_iterator != nullptr); |
681 | 2 | segment_v2::VirtualColumnIterator* virtual_column_iterator = |
682 | 2 | dynamic_cast<segment_v2::VirtualColumnIterator*>(column_iterator); |
683 | 2 | DCHECK(virtual_column_iterator != nullptr); |
684 | | // Now convert distance to column |
685 | 2 | size_t size = result.roaring->cardinality(); |
686 | 2 | auto distance_col = ColumnFloat32::create(size); |
687 | 2 | const float* src = result.distance.get(); |
688 | 2 | float* dst = distance_col->get_data().data(); |
689 | 15 | for (size_t i = 0; i < size; ++i) { |
690 | 13 | dst[i] = src[i]; |
691 | 13 | } |
692 | 2 | virtual_column_iterator->prepare_materialization(std::move(distance_col), |
693 | 2 | std::move(result.row_ids)); |
694 | 2 | dist_fulfilled = true; |
695 | 2 | } else { |
696 | | // Whether the ANN index should have produced distance depends on metric and operator: |
697 | | // - L2: distance is produced for LE/LT; not produced for GE/GT |
698 | | // - IP: distance is produced for GE/GT; not produced for LE/LT |
699 | 2 | #ifndef NDEBUG |
700 | 2 | const bool should_have_distance = |
701 | 2 | (range_search_runtime.is_le_or_lt && |
702 | 2 | range_search_runtime.metric_type == AnnIndexMetric::L2) || |
703 | 2 | (!range_search_runtime.is_le_or_lt && |
704 | 2 | range_search_runtime.metric_type == AnnIndexMetric::IP); |
705 | | // If we expected distance but didn't get it, assert in debug to catch logic errors. |
706 | 2 | DCHECK(!should_have_distance) << "Expected distance from ANN index but got none"; |
707 | 2 | #endif |
708 | 2 | } |
709 | 4 | } else { |
710 | | // Dest is not virtual column. |
711 | 0 | dist_fulfilled = true; |
712 | 0 | } |
713 | | |
714 | 4 | evaluation_result.executed = true; |
715 | 4 | evaluation_result.dist_fulfilled = dist_fulfilled; |
716 | 4 | VLOG_DEBUG << fmt::format( |
717 | 0 | "Ann range search filtered {} rows, origin {} rows, virtual column is full-filled: {}", |
718 | 0 | origin_num - row_bitmap.cardinality(), origin_num, dist_fulfilled); |
719 | | |
720 | 4 | ann_index_stats = *stats; |
721 | 4 | return Status::OK(); |
722 | 4 | } |
723 | | |
724 | 8 | double VectorizedFnCall::execute_cost() const { |
725 | 8 | if (!_function) { |
726 | 0 | throw Exception( |
727 | 0 | Status::InternalError("Function is null in expression: {}", this->debug_string())); |
728 | 0 | } |
729 | 8 | double cost = _function->execute_cost(); |
730 | 16 | for (const auto& child : _children) { |
731 | 16 | cost += child->execute_cost(); |
732 | 16 | } |
733 | 8 | return cost; |
734 | 8 | } |
735 | | |
736 | | } // namespace doris |