Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "exprs/vectorized_fn_call.h" |
19 | | |
20 | | #include <fmt/compile.h> |
21 | | #include <fmt/format.h> |
22 | | #include <fmt/ranges.h> // IWYU pragma: keep |
23 | | #include <gen_cpp/Opcodes_types.h> |
24 | | #include <gen_cpp/Types_types.h> |
25 | | |
26 | | #include <memory> |
27 | | #include <ostream> |
28 | | |
29 | | #include "common/config.h" |
30 | | #include "common/exception.h" |
31 | | #include "common/logging.h" |
32 | | #include "common/status.h" |
33 | | #include "common/utils.h" |
34 | | #include "core/assert_cast.h" |
35 | | #include "core/block/block.h" |
36 | | #include "core/block/column_numbers.h" |
37 | | #include "core/column/column.h" |
38 | | #include "core/column/column_array.h" |
39 | | #include "core/column/column_nullable.h" |
40 | | #include "core/column/column_vector.h" |
41 | | #include "core/data_type/data_type.h" |
42 | | #include "core/data_type/data_type_agg_state.h" |
43 | | #include "core/types.h" |
44 | | #include "exec/common/util.hpp" |
45 | | #include "exec/pipeline/pipeline_task.h" |
46 | | #include "exprs/function/array/function_array_distance.h" |
47 | | #include "exprs/function/function_agg_state.h" |
48 | | #include "exprs/function/function_fake.h" |
49 | | #include "exprs/function/function_java_udf.h" |
50 | | #include "exprs/function/function_python_udf.h" |
51 | | #include "exprs/function/function_rpc.h" |
52 | | #include "exprs/function/simple_function_factory.h" |
53 | | #include "exprs/function_context.h" |
54 | | #include "exprs/varray_literal.h" |
55 | | #include "exprs/vcast_expr.h" |
56 | | #include "exprs/vexpr_context.h" |
57 | | #include "exprs/virtual_slot_ref.h" |
58 | | #include "exprs/vliteral.h" |
59 | | #include "runtime/runtime_state.h" |
60 | | #include "storage/index/ann/ann_index.h" |
61 | | #include "storage/index/ann/ann_index_iterator.h" |
62 | | #include "storage/index/ann/ann_search_params.h" |
63 | | #include "storage/index/index_reader.h" |
64 | | #include "storage/index/zone_map/zonemap_eval_context.h" |
65 | | #include "storage/segment/column_reader.h" |
66 | | #include "storage/segment/virtual_column_iterator.h" |
67 | | |
68 | | namespace doris { |
69 | | class RowDescriptor; |
70 | | class RuntimeState; |
71 | | class TExprNode; |
72 | | } // namespace doris |
73 | | |
74 | | namespace doris { |
75 | | |
76 | | const std::string AGG_STATE_SUFFIX = "_state"; |
77 | | |
78 | | // Now left child is a function call, we need to check if it is a distance function |
79 | | const static std::set<std::string> DISTANCE_FUNCS = {L2DistanceApproximate::name, |
80 | | InnerProductApproximate::name}; |
81 | | const static std::set<TExprOpcode::type> OPS_FOR_ANN_RANGE_SEARCH = { |
82 | | TExprOpcode::GE, TExprOpcode::LE, TExprOpcode::LE, TExprOpcode::GT, TExprOpcode::LT}; |
83 | | |
84 | 140 | VectorizedFnCall::VectorizedFnCall(const TExprNode& node) : VExpr(node) {} |
85 | | |
86 | | Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, |
87 | 101 | VExprContext* context) { |
88 | 101 | RETURN_IF_ERROR_OR_PREPARED(VExpr::prepare(state, desc, context)); |
89 | 101 | ColumnsWithTypeAndName argument_template; |
90 | 101 | argument_template.reserve(_children.size()); |
91 | 190 | for (auto child : _children) { |
92 | 190 | if (child->is_literal()) { |
93 | | // For some functions, he needs some literal columns to derive the return type. |
94 | 66 | auto literal_node = std::dynamic_pointer_cast<VLiteral>(child); |
95 | 66 | argument_template.emplace_back(literal_node->get_column_ptr(), child->data_type(), |
96 | 66 | child->expr_name()); |
97 | 124 | } else { |
98 | 124 | argument_template.emplace_back(nullptr, child->data_type(), child->expr_name()); |
99 | 124 | } |
100 | 190 | } |
101 | | |
102 | 101 | _expr_name = fmt::format("VectorizedFnCall[{}](arguments={},return={})", _fn.name.function_name, |
103 | 101 | get_child_names(), _data_type->get_name()); |
104 | 101 | if (_fn.binary_type == TFunctionBinaryType::RPC) { |
105 | 0 | _function = FunctionRPC::create(_fn, argument_template, _data_type); |
106 | 101 | } else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { |
107 | 0 | if (config::enable_java_support) { |
108 | 0 | if (_fn.is_udtf_function) { |
109 | | // fake function. it's no use and can't execute. |
110 | 0 | auto builder = |
111 | 0 | std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create()); |
112 | 0 | _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>()); |
113 | 0 | } else { |
114 | 0 | _function = JavaFunctionCall::create(_fn, argument_template, _data_type); |
115 | 0 | } |
116 | 0 | } else { |
117 | 0 | return Status::InternalError( |
118 | 0 | "Java UDF is not enabled, you can change be config enable_java_support to true " |
119 | 0 | "and restart be."); |
120 | 0 | } |
121 | 101 | } else if (_fn.binary_type == TFunctionBinaryType::PYTHON_UDF) { |
122 | 0 | if (config::enable_python_udf_support) { |
123 | 0 | if (_fn.is_udtf_function) { |
124 | | // fake function. it's no use and can't execute. |
125 | | // Python UDTF is executed via PythonUDTFFunction in table function path |
126 | 0 | auto builder = |
127 | 0 | std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create()); |
128 | 0 | _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>()); |
129 | 0 | } else { |
130 | 0 | _function = PythonFunctionCall::create(_fn, argument_template, _data_type); |
131 | 0 | LOG(INFO) << fmt::format( |
132 | 0 | "create python function call: {}, runtime version: {}, function code: {}", |
133 | 0 | _fn.name.function_name, _fn.runtime_version, _fn.function_code); |
134 | 0 | } |
135 | 0 | } else { |
136 | 0 | return Status::InternalError( |
137 | 0 | "Python UDF is not enabled, you can change be config enable_python_udf_support " |
138 | 0 | "to true and restart be."); |
139 | 0 | } |
140 | 101 | } else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) { |
141 | 0 | DataTypes argument_types; |
142 | 0 | for (auto column : argument_template) { |
143 | 0 | argument_types.emplace_back(column.type); |
144 | 0 | } |
145 | |
|
146 | 0 | if (match_suffix(_fn.name.function_name, AGG_STATE_SUFFIX)) { |
147 | 0 | if (_data_type->is_nullable()) { |
148 | 0 | return Status::InternalError("State function's return type must be not nullable"); |
149 | 0 | } |
150 | 0 | if (_data_type->get_primitive_type() != PrimitiveType::TYPE_AGG_STATE) { |
151 | 0 | return Status::InternalError( |
152 | 0 | "State function's return type must be agg_state but get {}", |
153 | 0 | _data_type->get_family_name()); |
154 | 0 | } |
155 | 0 | _function = FunctionAggState::create( |
156 | 0 | argument_types, _data_type, |
157 | 0 | assert_cast<const DataTypeAggState*>(_data_type.get())->get_nested_function()); |
158 | 0 | } else { |
159 | 0 | return Status::InternalError("Function {} is not endwith '_state'", _fn.signature); |
160 | 0 | } |
161 | 101 | } else { |
162 | | // get the function. won't prepare function. |
163 | 101 | _function = SimpleFunctionFactory::instance().get_function( |
164 | 101 | _fn.name.function_name, argument_template, _data_type, |
165 | 101 | {.new_version_unix_timestamp = state->query_options().new_version_unix_timestamp}, |
166 | 101 | state->be_exec_version()); |
167 | 101 | } |
168 | 101 | if (_function == nullptr) { |
169 | 0 | return Status::InternalError("Could not find function {}, arg {} return {} ", |
170 | 0 | _fn.name.function_name, get_child_type_names(), |
171 | 0 | _data_type->get_name()); |
172 | 0 | } |
173 | 101 | VExpr::register_function_context(state, context); |
174 | 101 | _function_name = _fn.name.function_name; |
175 | 101 | _prepare_finished = true; |
176 | | |
177 | 101 | FunctionContext* fn_ctx = context->fn_context(_fn_context_index); |
178 | 101 | if (fn().__isset.dict_function) { |
179 | 0 | fn_ctx->set_dict_function(fn().dict_function); |
180 | 0 | } |
181 | 101 | return Status::OK(); |
182 | 101 | } |
183 | | |
184 | | Status VectorizedFnCall::open(RuntimeState* state, VExprContext* context, |
185 | 78 | FunctionContext::FunctionStateScope scope) { |
186 | 78 | DCHECK(_prepare_finished); |
187 | 144 | for (auto& i : _children) { |
188 | 144 | RETURN_IF_ERROR(i->open(state, context, scope)); |
189 | 144 | } |
190 | 78 | RETURN_IF_ERROR(VExpr::init_function_context(state, context, scope, _function)); |
191 | 78 | if (scope == FunctionContext::FRAGMENT_LOCAL) { |
192 | 55 | RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr)); |
193 | 55 | } |
194 | 78 | _open_finished = true; |
195 | 78 | return Status::OK(); |
196 | 78 | } |
197 | | |
198 | 173 | void VectorizedFnCall::close(VExprContext* context, FunctionContext::FunctionStateScope scope) { |
199 | 173 | VExpr::close_function_context(context, scope, _function); |
200 | 173 | VExpr::close(context, scope); |
201 | 173 | } |
202 | | |
203 | 0 | Status VectorizedFnCall::evaluate_inverted_index(VExprContext* context, uint32_t segment_num_rows) { |
204 | 0 | if (get_num_children() < 1) { |
205 | | // score() and similar 0-children virtual column functions don't need |
206 | | // inverted index evaluation; return OK to skip gracefully. |
207 | 0 | return Status::OK(); |
208 | 0 | } |
209 | 0 | return _evaluate_inverted_index(context, _function, segment_num_rows); |
210 | 0 | } |
211 | | |
212 | 9 | ZoneMapFilterResult VectorizedFnCall::evaluate_zonemap_filter(const ZoneMapEvalContext& ctx) const { |
213 | 9 | return _function->evaluate_zonemap_filter(ctx, _children); |
214 | 9 | } |
215 | | |
216 | 22 | bool VectorizedFnCall::can_evaluate_zonemap_filter() const { |
217 | 22 | return _function != nullptr && !_function->is_blockable() && |
218 | 22 | _function->can_evaluate_zonemap_filter(_children); |
219 | 22 | } |
220 | | |
221 | | Status VectorizedFnCall::_do_execute(VExprContext* context, const Block* block, |
222 | | const Selector* selector, size_t count, |
223 | 67 | ColumnPtr& result_column, ColumnPtr* arg_column) const { |
224 | 67 | if (is_const_and_have_executed()) { // const have executed in open function |
225 | 0 | result_column = get_result_from_const(count); |
226 | 0 | return Status::OK(); |
227 | 0 | } |
228 | 67 | if (fast_execute(context, selector, count, result_column)) { |
229 | 0 | return Status::OK(); |
230 | 0 | } |
231 | 67 | DBUG_EXECUTE_IF("VectorizedFnCall.must_in_slow_path", { |
232 | 67 | if (get_child(0)->is_slot_ref()) { |
233 | 67 | auto debug_col_name = DebugPoints::instance()->get_debug_param_or_default<std::string>( |
234 | 67 | "VectorizedFnCall.must_in_slow_path", "column_name", ""); |
235 | | |
236 | 67 | std::vector<std::string> column_names; |
237 | 67 | boost::split(column_names, debug_col_name, boost::algorithm::is_any_of(",")); |
238 | | |
239 | 67 | auto* column_slot_ref = assert_cast<VSlotRef*>(get_child(0).get()); |
240 | 67 | std::string column_name = column_slot_ref->expr_name(); |
241 | 67 | auto it = std::find(column_names.begin(), column_names.end(), column_name); |
242 | 67 | if (it == column_names.end()) { |
243 | 67 | return Status::Error<ErrorCode::INTERNAL_ERROR>( |
244 | 67 | "column {} should in slow path while VectorizedFnCall::execute.", |
245 | 67 | column_name); |
246 | 67 | } |
247 | 67 | } |
248 | 67 | }) |
249 | 67 | DCHECK(_open_finished || block == nullptr) << debug_string(); |
250 | | |
251 | 67 | Block temp_block; |
252 | 67 | ColumnNumbers args(_children.size()); |
253 | | |
254 | 201 | for (int i = 0; i < _children.size(); ++i) { |
255 | 134 | ColumnPtr tmp_arg_column; |
256 | 134 | RETURN_IF_ERROR( |
257 | 134 | _children[i]->execute_column(context, block, selector, count, tmp_arg_column)); |
258 | 134 | auto arg_type = _children[i]->execute_type(block); |
259 | 134 | temp_block.insert({tmp_arg_column, arg_type, _children[i]->expr_name()}); |
260 | 134 | args[i] = i; |
261 | | |
262 | 134 | if (arg_column != nullptr && i == 0) { |
263 | 0 | *arg_column = tmp_arg_column; |
264 | 0 | } |
265 | 134 | } |
266 | | |
267 | 67 | uint32_t num_columns_without_result = temp_block.columns(); |
268 | | // prepare a column to save result |
269 | 67 | temp_block.insert({nullptr, _data_type, _expr_name}); |
270 | | |
271 | 67 | DBUG_EXECUTE_IF("VectorizedFnCall.wait_before_execute", { |
272 | 67 | auto possibility = DebugPoints::instance()->get_debug_param_or_default<double>( |
273 | 67 | "VectorizedFnCall.wait_before_execute", "possibility", 0); |
274 | 67 | if (random_bool_slow(possibility)) { |
275 | 67 | LOG(WARNING) << "VectorizedFnCall::execute sleep 30s"; |
276 | 67 | sleep(30); |
277 | 67 | } |
278 | 67 | }); |
279 | | |
280 | 67 | RETURN_IF_ERROR(_function->execute(context->fn_context(_fn_context_index), temp_block, args, |
281 | 67 | num_columns_without_result, count)); |
282 | 67 | result_column = temp_block.get_by_position(num_columns_without_result).column; |
283 | 67 | DCHECK_EQ(result_column->size(), count); |
284 | 67 | RETURN_IF_ERROR(result_column->column_self_check()); |
285 | 67 | return Status::OK(); |
286 | 67 | } |
287 | | |
288 | 0 | size_t VectorizedFnCall::estimate_memory(const size_t rows) { |
289 | 0 | if (is_const_and_have_executed()) { // const have execute in open function |
290 | 0 | return 0; |
291 | 0 | } |
292 | | |
293 | 0 | size_t estimate_size = 0; |
294 | 0 | for (auto& child : _children) { |
295 | 0 | estimate_size += child->estimate_memory(rows); |
296 | 0 | } |
297 | |
|
298 | 0 | if (_data_type->have_maximum_size_of_value()) { |
299 | 0 | estimate_size += rows * _data_type->get_size_of_value_in_memory(); |
300 | 0 | } else { |
301 | 0 | estimate_size += rows * 512; /// FIXME: estimated value... |
302 | 0 | } |
303 | 0 | return estimate_size; |
304 | 0 | } |
305 | | |
306 | | Status VectorizedFnCall::execute_runtime_filter(VExprContext* context, const Block* block, |
307 | | const uint8_t* __restrict filter, size_t count, |
308 | | ColumnPtr& result_column, |
309 | 0 | ColumnPtr* arg_column) const { |
310 | 0 | return _do_execute(context, block, nullptr, count, result_column, arg_column); |
311 | 0 | } |
312 | | |
313 | | Status VectorizedFnCall::execute_column_impl(VExprContext* context, const Block* block, |
314 | | const Selector* selector, size_t count, |
315 | 67 | ColumnPtr& result_column) const { |
316 | 67 | return _do_execute(context, block, selector, count, result_column, nullptr); |
317 | 67 | } |
318 | | |
319 | 48 | const std::string& VectorizedFnCall::expr_name() const { |
320 | 48 | return _expr_name; |
321 | 48 | } |
322 | | |
323 | 6 | std::string VectorizedFnCall::function_name() const { |
324 | 6 | return _function_name; |
325 | 6 | } |
326 | | |
327 | 2 | std::string VectorizedFnCall::debug_string() const { |
328 | 2 | std::stringstream out; |
329 | 2 | out << "VectorizedFn["; |
330 | 2 | out << _expr_name; |
331 | 2 | out << "]{"; |
332 | 2 | bool first = true; |
333 | 2 | for (const auto& input_expr : children()) { |
334 | 0 | if (first) { |
335 | 0 | first = false; |
336 | 0 | } else { |
337 | 0 | out << ","; |
338 | 0 | } |
339 | 0 | out << "\n" << input_expr->debug_string(); |
340 | 0 | } |
341 | 2 | out << "}"; |
342 | 2 | return out.str(); |
343 | 2 | } |
344 | | |
345 | 0 | std::string VectorizedFnCall::debug_string(const std::vector<VectorizedFnCall*>& agg_fns) { |
346 | 0 | std::stringstream out; |
347 | 0 | out << "["; |
348 | 0 | for (int i = 0; i < agg_fns.size(); ++i) { |
349 | 0 | out << (i == 0 ? "" : " ") << agg_fns[i]->debug_string(); |
350 | 0 | } |
351 | 0 | out << "]"; |
352 | 0 | return out.str(); |
353 | 0 | } |
354 | | |
355 | 0 | bool VectorizedFnCall::can_push_down_to_index() const { |
356 | 0 | return _function->can_push_down_to_index(); |
357 | 0 | } |
358 | | |
359 | 0 | bool VectorizedFnCall::equals(const VExpr& other) { |
360 | 0 | const auto* other_ptr = dynamic_cast<const VectorizedFnCall*>(&other); |
361 | 0 | if (!other_ptr) { |
362 | 0 | return false; |
363 | 0 | } |
364 | 0 | if (this->_function_name != other_ptr->_function_name) { |
365 | 0 | return false; |
366 | 0 | } |
367 | 0 | if (get_num_children() != other_ptr->get_num_children()) { |
368 | 0 | return false; |
369 | 0 | } |
370 | 0 | for (uint16_t i = 0; i < get_num_children(); i++) { |
371 | 0 | if (!this->get_child(i)->equals(*other_ptr->get_child(i))) { |
372 | 0 | return false; |
373 | 0 | } |
374 | 0 | } |
375 | 0 | return true; |
376 | 0 | } |
377 | | |
378 | | /* |
379 | | * For ANN range search we expect a comparison expression (LE/LT/GE/GT) whose left side is either: |
380 | | * 1) a vector distance function call, or |
381 | | * 2) a cast/virtual slot that unwraps to the function call when the planner promotes float to |
382 | | * double literals. |
383 | | * |
384 | | * Visually the logical tree looks like: |
385 | | * |
386 | | * FunctionCall(LE/LT/GE/GT) |
387 | | * |---------------- |
388 | | * | | |
389 | | * | | |
390 | | * VirtualSlotRef* Float32Literal/Float64Literal |
391 | | * | |
392 | | * | |
393 | | * Cast(Float -> Double)* |
394 | | * | |
395 | | * FunctionCall(distance) |
396 | | * |---------------- |
397 | | * | | |
398 | | * | | |
399 | | * SlotRef ArrayLiteral/Cast(String as Array<FLOAT>) |
400 | | * |
401 | | * Items marked with * are optional and depend on literal types/virtual column usage. The helper |
402 | | * below normalizes the shape and validates distance function, slot, and constant vector inputs. |
403 | | */ |
404 | | |
405 | | void VectorizedFnCall::prepare_ann_range_search( |
406 | | const doris::VectorSearchUserParams& user_params, |
407 | 7 | segment_v2::AnnRangeSearchRuntime& range_search_runtime, bool& suitable_for_ann_index) { |
408 | 7 | if (!suitable_for_ann_index) { |
409 | 0 | return; |
410 | 0 | } |
411 | | |
412 | 7 | if (OPS_FOR_ANN_RANGE_SEARCH.find(this->op()) == OPS_FOR_ANN_RANGE_SEARCH.end()) { |
413 | 0 | suitable_for_ann_index = false; |
414 | 0 | return; |
415 | 0 | } |
416 | | |
417 | 7 | auto mark_unsuitable = [&](const std::string& reason) { |
418 | 1 | suitable_for_ann_index = false; |
419 | 1 | VLOG_DEBUG << "ANN range search skipped: " << reason; |
420 | 1 | }; |
421 | | |
422 | 7 | range_search_runtime.is_le_or_lt = |
423 | 7 | (this->op() == TExprOpcode::LE || this->op() == TExprOpcode::LT); |
424 | | |
425 | 7 | DCHECK(_children.size() == 2); |
426 | | |
427 | 7 | auto left_child = get_child(0); |
428 | 7 | auto right_child = get_child(1); |
429 | | |
430 | | // ========== Step 1: Check left child - must be a distance function ========== |
431 | 7 | auto get_virtual_expr = [&](const VExprSPtr& expr, |
432 | 7 | std::shared_ptr<VirtualSlotRef>& slot_ref) -> VExprSPtr { |
433 | 7 | auto virtual_ref = std::dynamic_pointer_cast<VirtualSlotRef>(expr); |
434 | 7 | if (virtual_ref != nullptr) { |
435 | 7 | DCHECK(virtual_ref->get_virtual_column_expr() != nullptr); |
436 | 7 | slot_ref = virtual_ref; |
437 | 7 | return virtual_ref->get_virtual_column_expr(); |
438 | 7 | } |
439 | 0 | return expr; |
440 | 7 | }; |
441 | | |
442 | 7 | std::shared_ptr<VirtualSlotRef> vir_slot_ref; |
443 | 7 | auto normalized_left = get_virtual_expr(left_child, vir_slot_ref); |
444 | | |
445 | | // Try to find the distance function call, it may be wrapped in a Cast(Float->Double) |
446 | 7 | std::shared_ptr<VectorizedFnCall> function_call = |
447 | 7 | std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left); |
448 | 7 | bool has_float_to_double_cast = false; |
449 | | |
450 | 7 | if (function_call == nullptr) { |
451 | | // Check if it's a Cast expression wrapping a function call |
452 | 0 | auto cast_expr = std::dynamic_pointer_cast<VCastExpr>(normalized_left); |
453 | 0 | if (cast_expr == nullptr) { |
454 | 0 | mark_unsuitable("Left child is neither a function call nor a cast expression."); |
455 | 0 | return; |
456 | 0 | } |
457 | 0 | has_float_to_double_cast = true; |
458 | 0 | auto normalized_cast_child = get_virtual_expr(cast_expr->get_child(0), vir_slot_ref); |
459 | 0 | function_call = std::dynamic_pointer_cast<VectorizedFnCall>(normalized_cast_child); |
460 | 0 | if (function_call == nullptr) { |
461 | 0 | mark_unsuitable("Left child of cast is not a function call."); |
462 | 0 | return; |
463 | 0 | } |
464 | 0 | } |
465 | | |
466 | | // Check if it's a supported distance function |
467 | 7 | if (DISTANCE_FUNCS.find(function_call->_function_name) == DISTANCE_FUNCS.end()) { |
468 | 0 | mark_unsuitable(fmt::format("Left child is not a supported distance function: {}", |
469 | 0 | function_call->_function_name)); |
470 | 0 | return; |
471 | 0 | } |
472 | | |
473 | | // Strip the _approximate suffix to get metric type |
474 | 7 | std::string metric_name = function_call->_function_name; |
475 | 7 | metric_name = metric_name.substr(0, metric_name.size() - 12); |
476 | 7 | range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name); |
477 | | |
478 | | // ========== Step 2: Validate distance function arguments ========== |
479 | | // Identify the slot ref child and the constant query array child (ArrayLiteral or CAST to array) |
480 | 7 | Int32 idx_of_slot_ref = -1; |
481 | 7 | Int32 idx_of_array_expr = -1; |
482 | 14 | auto classify_child = [&](const VExprSPtr& child, UInt16 index) { |
483 | 14 | if (idx_of_slot_ref == -1 && std::dynamic_pointer_cast<VSlotRef>(child) != nullptr) { |
484 | 7 | idx_of_slot_ref = index; |
485 | 7 | return; |
486 | 7 | } |
487 | 7 | if (idx_of_array_expr == -1 && |
488 | 7 | (std::dynamic_pointer_cast<VArrayLiteral>(child) != nullptr || |
489 | 7 | std::dynamic_pointer_cast<VCastExpr>(child) != nullptr)) { |
490 | 7 | idx_of_array_expr = index; |
491 | 7 | } |
492 | 7 | }; |
493 | | |
494 | 21 | for (UInt16 i = 0; i < function_call->get_num_children(); ++i) { |
495 | 14 | classify_child(function_call->get_child(i), i); |
496 | 14 | } |
497 | | |
498 | 7 | if (idx_of_slot_ref == -1 || idx_of_array_expr == -1) { |
499 | 0 | mark_unsuitable("slot ref or array literal/cast is missing."); |
500 | 0 | return; |
501 | 0 | } |
502 | | |
503 | 7 | auto slot_ref = std::dynamic_pointer_cast<VSlotRef>( |
504 | 7 | function_call->get_child(static_cast<UInt16>(idx_of_slot_ref))); |
505 | 7 | range_search_runtime.src_col_idx = slot_ref->column_id(); |
506 | 7 | range_search_runtime.dst_col_idx = vir_slot_ref == nullptr ? -1 : vir_slot_ref->column_id(); |
507 | | |
508 | | // Materialize the constant array expression and validate its shape and types |
509 | 7 | auto array_expr = function_call->get_child(static_cast<UInt16>(idx_of_array_expr)); |
510 | 7 | auto extract_result = extract_query_vector(array_expr); |
511 | 7 | if (!extract_result.has_value()) { |
512 | 0 | mark_unsuitable("Failed to extract query vector from constant array expression."); |
513 | 0 | return; |
514 | 0 | } |
515 | 7 | range_search_runtime.query_value = extract_result.value(); |
516 | 7 | range_search_runtime.dim = range_search_runtime.query_value->size(); |
517 | | |
518 | | // ========== Step 3: Check right child - must be a float/double literal ========== |
519 | 7 | auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child); |
520 | 7 | if (right_literal == nullptr) { |
521 | 1 | mark_unsuitable("Right child is not a literal."); |
522 | 1 | return; |
523 | 1 | } |
524 | | |
525 | | // Handle nullable literal gracefully - just mark as unsuitable instead of crash |
526 | 6 | if (right_literal->is_nullable()) { |
527 | 0 | mark_unsuitable("Right literal is nullable, not supported for ANN range search."); |
528 | 0 | return; |
529 | 0 | } |
530 | | |
531 | 6 | auto right_type = right_literal->get_data_type(); |
532 | 6 | PrimitiveType right_primitive = right_type->get_primitive_type(); |
533 | 6 | const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT; |
534 | 6 | const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE; |
535 | | |
536 | 6 | if (!float32_literal && !float64_literal) { |
537 | 0 | mark_unsuitable("Right child is not a Float32Literal or Float64Literal."); |
538 | 0 | return; |
539 | 0 | } |
540 | | |
541 | | // Validate consistency: if we have Cast(Float->Double), right must be double literal |
542 | 6 | if (has_float_to_double_cast && !float64_literal) { |
543 | 0 | mark_unsuitable("Cast expression expects double literal on right side."); |
544 | 0 | return; |
545 | 0 | } |
546 | | |
547 | | // Extract radius value |
548 | 6 | auto right_col = right_literal->get_column_ptr()->convert_to_full_column_if_const(); |
549 | 6 | if (float32_literal) { |
550 | 6 | const ColumnFloat32* cf32_right = assert_cast<const ColumnFloat32*>(right_col.get()); |
551 | 6 | range_search_runtime.radius = cf32_right->get_data()[0]; |
552 | 6 | } else { |
553 | 0 | const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get()); |
554 | 0 | range_search_runtime.radius = static_cast<float>(cf64_right->get_data()[0]); |
555 | 0 | } |
556 | | |
557 | | // ========== Done: Mark as suitable for ANN range search ========== |
558 | 6 | range_search_runtime.is_ann_range_search = true; |
559 | 6 | range_search_runtime.user_params = user_params; |
560 | 6 | VLOG_DEBUG << fmt::format("Ann range search params: {}", range_search_runtime.to_string()); |
561 | 6 | return; |
562 | 6 | } |
563 | | |
564 | | Status VectorizedFnCall::evaluate_ann_range_search( |
565 | | const segment_v2::AnnRangeSearchRuntime& range_search_runtime, |
566 | | const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, |
567 | | const std::vector<ColumnId>& idx_to_cid, |
568 | | const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, |
569 | | roaring::Roaring& row_bitmap, segment_v2::AnnIndexStats& ann_index_stats, |
570 | 6 | bool enable_result_cache, AnnRangeSearchEvaluationResult& evaluation_result) { |
571 | 6 | evaluation_result = {}; |
572 | 6 | if (range_search_runtime.is_ann_range_search == false) { |
573 | 0 | return Status::OK(); |
574 | 0 | } |
575 | | |
576 | 6 | VLOG_DEBUG << fmt::format("Try apply ann range search. Local search params: {}", |
577 | 0 | range_search_runtime.to_string()); |
578 | 6 | size_t origin_num = row_bitmap.cardinality(); |
579 | | |
580 | 6 | const auto idx_in_block = range_search_runtime.src_col_idx; |
581 | 6 | DCHECK_LT(idx_in_block, idx_to_cid.size()) |
582 | 0 | << "idx_in_block: " << idx_in_block << ", idx_to_cid.size(): " << idx_to_cid.size(); |
583 | | |
584 | 6 | ColumnId src_col_cid = idx_to_cid[idx_in_block]; |
585 | 6 | DCHECK(src_col_cid < cid_to_index_iterators.size()); |
586 | 6 | segment_v2::IndexIterator* index_iterator = cid_to_index_iterators[src_col_cid].get(); |
587 | 6 | if (index_iterator == nullptr) { |
588 | 1 | VLOG_DEBUG << "ANN range search skipped: " |
589 | 0 | << fmt::format("No index iterator for column cid {}", src_col_cid); |
590 | 1 | ; |
591 | 1 | return Status::OK(); |
592 | 1 | } |
593 | | |
594 | 5 | segment_v2::AnnIndexIterator* ann_index_iterator = |
595 | 5 | dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterator); |
596 | 5 | if (ann_index_iterator == nullptr) { |
597 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
598 | 0 | << fmt::format("Column cid {} has no ANN index iterator", src_col_cid); |
599 | 0 | return Status::OK(); |
600 | 0 | } |
601 | 5 | DCHECK(ann_index_iterator->get_reader(AnnIndexReaderType::ANN) != nullptr) |
602 | 0 | << "Ann index iterator should have reader. Column cid: " << src_col_cid; |
603 | 5 | std::shared_ptr<AnnIndexReader> ann_index_reader = std::dynamic_pointer_cast<AnnIndexReader>( |
604 | 5 | ann_index_iterator->get_reader(segment_v2::AnnIndexReaderType::ANN)); |
605 | 5 | DCHECK(ann_index_reader != nullptr) |
606 | 0 | << "Ann index reader should not be null. Column cid: " << src_col_cid; |
607 | | // Check if metrics type is match. |
608 | 5 | if (ann_index_reader->get_metric_type() != range_search_runtime.metric_type) { |
609 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
610 | 0 | << fmt::format("Metric type mismatch. Index={} Query={}", |
611 | 0 | segment_v2::metric_to_string(ann_index_reader->get_metric_type()), |
612 | 0 | segment_v2::metric_to_string(range_search_runtime.metric_type)); |
613 | 0 | return Status::OK(); |
614 | 0 | } |
615 | | |
616 | | // Check dimension if available (>0) |
617 | 5 | const size_t index_dim = ann_index_reader->get_dimension(); |
618 | 5 | if (index_dim > 0 && index_dim != range_search_runtime.dim) { |
619 | 1 | return Status::InvalidArgument( |
620 | 1 | "Ann range search query dimension {} does not match index dimension {}", |
621 | 1 | range_search_runtime.dim, index_dim); |
622 | 1 | } |
623 | | |
624 | 4 | auto stats = std::make_unique<segment_v2::AnnIndexStats>(); |
625 | | // Track load index timing |
626 | 4 | { |
627 | 4 | SCOPED_TIMER(&(stats->load_index_costs_ns)); |
628 | 4 | if (!ann_index_iterator->try_load_index()) { |
629 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
630 | 0 | << fmt::format("Failed to load ANN index for column cid {}", src_col_cid); |
631 | 0 | ann_index_stats.fall_back_brute_force_cnt += 1; |
632 | 0 | return Status::OK(); |
633 | 0 | } |
634 | 4 | double load_costs_ms = static_cast<double>(stats->load_index_costs_ns.value()) / 1000000.0; |
635 | 4 | DorisMetrics::instance()->ann_index_load_costs_ms->increment( |
636 | 4 | static_cast<int64_t>(load_costs_ms)); |
637 | 4 | } |
638 | | |
639 | 0 | AnnRangeSearchParams params = range_search_runtime.to_range_search_params(); |
640 | | |
641 | 4 | params.roaring = &row_bitmap; |
642 | 4 | params.enable_result_cache = enable_result_cache; |
643 | 4 | DCHECK(params.roaring != nullptr); |
644 | 4 | DCHECK(params.query_value != nullptr); |
645 | 4 | segment_v2::AnnRangeSearchResult result; |
646 | 4 | RETURN_IF_ERROR(ann_index_iterator->range_search(params, range_search_runtime.user_params, |
647 | 4 | &result, stats.get())); |
648 | | |
649 | 4 | #ifndef NDEBUG |
650 | 4 | if (range_search_runtime.is_le_or_lt == false && |
651 | 4 | ann_index_reader->get_metric_type() == AnnIndexMetric::L2) { |
652 | 2 | DCHECK(result.distance == nullptr) << "Should not have distance"; |
653 | 2 | } |
654 | 4 | if (range_search_runtime.is_le_or_lt == true && |
655 | 4 | ann_index_reader->get_metric_type() == AnnIndexMetric::IP) { |
656 | 0 | DCHECK(result.distance == nullptr); |
657 | 0 | } |
658 | 4 | #endif |
659 | 4 | DCHECK(result.roaring != nullptr); |
660 | 4 | row_bitmap = *result.roaring; |
661 | | |
662 | | // Process virtual column |
663 | 4 | bool dist_fulfilled = false; |
664 | 4 | if (range_search_runtime.dst_col_idx >= 0) { |
665 | | // Prepare materialization if we can use result from index. |
666 | | // Typical situation: range search and operator is LE or LT. |
667 | 4 | if (result.distance != nullptr) { |
668 | 2 | DCHECK(result.row_ids != nullptr); |
669 | 2 | ColumnId dst_col_cid = idx_to_cid[range_search_runtime.dst_col_idx]; |
670 | 2 | DCHECK(dst_col_cid < column_iterators.size()); |
671 | 2 | DCHECK(column_iterators[dst_col_cid] != nullptr); |
672 | 2 | segment_v2::ColumnIterator* column_iterator = column_iterators[dst_col_cid].get(); |
673 | 2 | DCHECK(column_iterator != nullptr); |
674 | 2 | segment_v2::VirtualColumnIterator* virtual_column_iterator = |
675 | 2 | dynamic_cast<segment_v2::VirtualColumnIterator*>(column_iterator); |
676 | 2 | DCHECK(virtual_column_iterator != nullptr); |
677 | | // Now convert distance to column |
678 | 2 | size_t size = result.roaring->cardinality(); |
679 | 2 | auto distance_col = ColumnFloat32::create(size); |
680 | 2 | const float* src = result.distance.get(); |
681 | 2 | float* dst = distance_col->get_data().data(); |
682 | 15 | for (size_t i = 0; i < size; ++i) { |
683 | 13 | dst[i] = src[i]; |
684 | 13 | } |
685 | 2 | virtual_column_iterator->prepare_materialization(std::move(distance_col), |
686 | 2 | std::move(result.row_ids)); |
687 | 2 | dist_fulfilled = true; |
688 | 2 | } else { |
689 | | // Whether the ANN index should have produced distance depends on metric and operator: |
690 | | // - L2: distance is produced for LE/LT; not produced for GE/GT |
691 | | // - IP: distance is produced for GE/GT; not produced for LE/LT |
692 | 2 | #ifndef NDEBUG |
693 | 2 | const bool should_have_distance = |
694 | 2 | (range_search_runtime.is_le_or_lt && |
695 | 2 | range_search_runtime.metric_type == AnnIndexMetric::L2) || |
696 | 2 | (!range_search_runtime.is_le_or_lt && |
697 | 2 | range_search_runtime.metric_type == AnnIndexMetric::IP); |
698 | | // If we expected distance but didn't get it, assert in debug to catch logic errors. |
699 | 2 | DCHECK(!should_have_distance) << "Expected distance from ANN index but got none"; |
700 | 2 | #endif |
701 | 2 | } |
702 | 4 | } else { |
703 | | // Dest is not virtual column. |
704 | 0 | dist_fulfilled = true; |
705 | 0 | } |
706 | | |
707 | 4 | evaluation_result.executed = true; |
708 | 4 | evaluation_result.dist_fulfilled = dist_fulfilled; |
709 | 4 | VLOG_DEBUG << fmt::format( |
710 | 0 | "Ann range search filtered {} rows, origin {} rows, virtual column is full-filled: {}", |
711 | 0 | origin_num - row_bitmap.cardinality(), origin_num, dist_fulfilled); |
712 | | |
713 | 4 | ann_index_stats = *stats; |
714 | 4 | return Status::OK(); |
715 | 4 | } |
716 | | |
717 | 8 | double VectorizedFnCall::execute_cost() const { |
718 | 8 | if (!_function) { |
719 | 0 | throw Exception( |
720 | 0 | Status::InternalError("Function is null in expression: {}", this->debug_string())); |
721 | 0 | } |
722 | 8 | double cost = _function->execute_cost(); |
723 | 16 | for (const auto& child : _children) { |
724 | 16 | cost += child->execute_cost(); |
725 | 16 | } |
726 | 8 | return cost; |
727 | 8 | } |
728 | | |
729 | | } // namespace doris |