be/src/exprs/vectorized_fn_call.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "exprs/vectorized_fn_call.h" |
19 | | |
20 | | #include <fmt/compile.h> |
21 | | #include <fmt/format.h> |
22 | | #include <fmt/ranges.h> // IWYU pragma: keep |
23 | | #include <gen_cpp/Opcodes_types.h> |
24 | | #include <gen_cpp/Types_types.h> |
25 | | |
26 | | #include <memory> |
27 | | #include <ostream> |
28 | | |
29 | | #include "common/config.h" |
30 | | #include "common/exception.h" |
31 | | #include "common/logging.h" |
32 | | #include "common/status.h" |
33 | | #include "common/utils.h" |
34 | | #include "core/assert_cast.h" |
35 | | #include "core/block/block.h" |
36 | | #include "core/block/column_numbers.h" |
37 | | #include "core/column/column.h" |
38 | | #include "core/column/column_array.h" |
39 | | #include "core/column/column_nullable.h" |
40 | | #include "core/column/column_vector.h" |
41 | | #include "core/data_type/data_type.h" |
42 | | #include "core/data_type/data_type_agg_state.h" |
43 | | #include "core/types.h" |
44 | | #include "exec/common/util.hpp" |
45 | | #include "exec/pipeline/pipeline_task.h" |
46 | | #include "exprs/function/array/function_array_distance.h" |
47 | | #include "exprs/function/function_agg_state.h" |
48 | | #include "exprs/function/function_fake.h" |
49 | | #include "exprs/function/function_java_udf.h" |
50 | | #include "exprs/function/function_python_udf.h" |
51 | | #include "exprs/function/function_rpc.h" |
52 | | #include "exprs/function/simple_function_factory.h" |
53 | | #include "exprs/function_context.h" |
54 | | #include "exprs/varray_literal.h" |
55 | | #include "exprs/vcast_expr.h" |
56 | | #include "exprs/vexpr_context.h" |
57 | | #include "exprs/virtual_slot_ref.h" |
58 | | #include "exprs/vliteral.h" |
59 | | #include "runtime/runtime_state.h" |
60 | | #include "storage/index/ann/ann_index.h" |
61 | | #include "storage/index/ann/ann_index_iterator.h" |
62 | | #include "storage/index/ann/ann_search_params.h" |
63 | | #include "storage/index/index_reader.h" |
64 | | #include "storage/index/zone_map/zonemap_eval_context.h" |
65 | | #include "storage/segment/column_reader.h" |
66 | | #include "storage/segment/virtual_column_iterator.h" |
67 | | |
68 | | namespace doris { |
69 | | class RowDescriptor; |
70 | | class RuntimeState; |
71 | | class TExprNode; |
72 | | } // namespace doris |
73 | | |
74 | | namespace doris { |
75 | | |
76 | | const std::string AGG_STATE_SUFFIX = "_state"; |
77 | | |
78 | | // Now left child is a function call, we need to check if it is a distance function |
79 | | const static std::set<std::string> DISTANCE_FUNCS = {L2DistanceApproximate::name, |
80 | | InnerProductApproximate::name}; |
81 | | const static std::set<TExprOpcode::type> OPS_FOR_ANN_RANGE_SEARCH = { |
82 | | TExprOpcode::GE, TExprOpcode::LE, TExprOpcode::LE, TExprOpcode::GT, TExprOpcode::LT}; |
83 | | |
84 | 179 | VectorizedFnCall::VectorizedFnCall(const TExprNode& node) : VExpr(node) { |
85 | 179 | _function_name = _fn.name.function_name; |
86 | 179 | } |
87 | | |
88 | | Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, |
89 | 151 | VExprContext* context) { |
90 | 151 | RETURN_IF_ERROR_OR_PREPARED(VExpr::prepare(state, desc, context)); |
91 | 151 | ColumnsWithTypeAndName argument_template; |
92 | 151 | argument_template.reserve(_children.size()); |
93 | 284 | for (auto child : _children) { |
94 | 284 | if (child->is_literal()) { |
95 | | // For some functions, he needs some literal columns to derive the return type. |
96 | 108 | auto literal_node = std::dynamic_pointer_cast<VLiteral>(child); |
97 | 108 | argument_template.emplace_back(literal_node->get_column_ptr(), child->data_type(), |
98 | 108 | child->expr_name()); |
99 | 176 | } else { |
100 | 176 | argument_template.emplace_back(nullptr, child->data_type(), child->expr_name()); |
101 | 176 | } |
102 | 284 | } |
103 | | |
104 | 151 | _expr_name = fmt::format("VectorizedFnCall[{}](arguments={},return={})", _fn.name.function_name, |
105 | 151 | get_child_names(), _data_type->get_name()); |
106 | 151 | if (_fn.binary_type == TFunctionBinaryType::RPC) { |
107 | 0 | _function = FunctionRPC::create(_fn, argument_template, _data_type); |
108 | 151 | } else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { |
109 | 0 | if (config::enable_java_support) { |
110 | 0 | if (_fn.is_udtf_function) { |
111 | | // fake function. it's no use and can't execute. |
112 | 0 | auto builder = |
113 | 0 | std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create()); |
114 | 0 | _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>()); |
115 | 0 | } else { |
116 | 0 | _function = JavaFunctionCall::create(_fn, argument_template, _data_type); |
117 | 0 | } |
118 | 0 | } else { |
119 | 0 | return Status::InternalError( |
120 | 0 | "Java UDF is not enabled, you can change be config enable_java_support to true " |
121 | 0 | "and restart be."); |
122 | 0 | } |
123 | 151 | } else if (_fn.binary_type == TFunctionBinaryType::PYTHON_UDF) { |
124 | 0 | if (config::enable_python_udf_support) { |
125 | 0 | if (_fn.is_udtf_function) { |
126 | | // fake function. it's no use and can't execute. |
127 | | // Python UDTF is executed via PythonUDTFFunction in table function path |
128 | 0 | auto builder = |
129 | 0 | std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create()); |
130 | 0 | _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>()); |
131 | 0 | } else { |
132 | 0 | _function = PythonFunctionCall::create(_fn, argument_template, _data_type); |
133 | 0 | LOG(INFO) << fmt::format( |
134 | 0 | "create python function call: {}, runtime version: {}, function code: {}", |
135 | 0 | _fn.name.function_name, _fn.runtime_version, _fn.function_code); |
136 | 0 | } |
137 | 0 | } else { |
138 | 0 | return Status::InternalError( |
139 | 0 | "Python UDF is not enabled, you can change be config enable_python_udf_support " |
140 | 0 | "to true and restart be."); |
141 | 0 | } |
142 | 151 | } else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) { |
143 | 0 | DataTypes argument_types; |
144 | 0 | for (auto column : argument_template) { |
145 | 0 | argument_types.emplace_back(column.type); |
146 | 0 | } |
147 | |
|
148 | 0 | if (match_suffix(_fn.name.function_name, AGG_STATE_SUFFIX)) { |
149 | 0 | if (_data_type->is_nullable()) { |
150 | 0 | return Status::InternalError("State function's return type must be not nullable"); |
151 | 0 | } |
152 | 0 | if (_data_type->get_primitive_type() != PrimitiveType::TYPE_AGG_STATE) { |
153 | 0 | return Status::InternalError( |
154 | 0 | "State function's return type must be agg_state but get {}", |
155 | 0 | _data_type->get_family_name()); |
156 | 0 | } |
157 | 0 | _function = FunctionAggState::create( |
158 | 0 | argument_types, _data_type, |
159 | 0 | assert_cast<const DataTypeAggState*>(_data_type.get())->get_nested_function()); |
160 | 0 | } else { |
161 | 0 | return Status::InternalError("Function {} is not endwith '_state'", _fn.signature); |
162 | 0 | } |
163 | 151 | } else { |
164 | | // get the function. won't prepare function. |
165 | 151 | _function = SimpleFunctionFactory::instance().get_function( |
166 | 151 | _fn.name.function_name, argument_template, _data_type, |
167 | 151 | {.new_version_unix_timestamp = state->query_options().new_version_unix_timestamp}, |
168 | 151 | state->be_exec_version()); |
169 | 151 | } |
170 | 151 | if (_function == nullptr) { |
171 | 0 | return Status::InternalError("Could not find function {}, arg {} return {} ", |
172 | 0 | _fn.name.function_name, get_child_type_names(), |
173 | 0 | _data_type->get_name()); |
174 | 0 | } |
175 | 151 | VExpr::register_function_context(state, context); |
176 | 151 | _function_name = _fn.name.function_name; |
177 | 151 | _prepare_finished = true; |
178 | | |
179 | 151 | FunctionContext* fn_ctx = context->fn_context(_fn_context_index); |
180 | 151 | if (fn().__isset.dict_function) { |
181 | 0 | fn_ctx->set_dict_function(fn().dict_function); |
182 | 0 | } |
183 | 151 | return Status::OK(); |
184 | 151 | } |
185 | | |
186 | | Status VectorizedFnCall::open(RuntimeState* state, VExprContext* context, |
187 | 148 | FunctionContext::FunctionStateScope scope) { |
188 | 148 | DCHECK(_prepare_finished); |
189 | 268 | for (auto& i : _children) { |
190 | 268 | RETURN_IF_ERROR(i->open(state, context, scope)); |
191 | 268 | } |
192 | 148 | RETURN_IF_ERROR(VExpr::init_function_context(state, context, scope, _function)); |
193 | 148 | if (scope == FunctionContext::FRAGMENT_LOCAL) { |
194 | 105 | RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr)); |
195 | 105 | } |
196 | 148 | _open_finished = true; |
197 | 148 | return Status::OK(); |
198 | 148 | } |
199 | | |
200 | 243 | void VectorizedFnCall::close(VExprContext* context, FunctionContext::FunctionStateScope scope) { |
201 | 243 | VExpr::close_function_context(context, scope, _function); |
202 | 243 | VExpr::close(context, scope); |
203 | 243 | } |
204 | | |
205 | 18 | Status VectorizedFnCall::evaluate_inverted_index(VExprContext* context, uint32_t segment_num_rows) { |
206 | 18 | if (get_num_children() < 1) { |
207 | | // score() and similar 0-children virtual column functions don't need |
208 | | // inverted index evaluation; return OK to skip gracefully. |
209 | 0 | return Status::OK(); |
210 | 0 | } |
211 | 18 | return _evaluate_inverted_index(context, _function, segment_num_rows); |
212 | 18 | } |
213 | | |
214 | 19 | ZoneMapFilterResult VectorizedFnCall::evaluate_zonemap_filter(const ZoneMapEvalContext& ctx) const { |
215 | 19 | return _function->evaluate_zonemap_filter(ctx, _children); |
216 | 19 | } |
217 | | |
218 | 68 | bool VectorizedFnCall::can_evaluate_zonemap_filter() const { |
219 | 68 | return _function != nullptr && !_function->is_blockable() && |
220 | 68 | _function->can_evaluate_zonemap_filter(_children); |
221 | 68 | } |
222 | | |
223 | | Status VectorizedFnCall::_do_execute(VExprContext* context, const Block* block, |
224 | | const Selector* selector, size_t count, |
225 | 88 | ColumnPtr& result_column, ColumnPtr* arg_column) const { |
226 | 88 | if (is_const_and_have_executed()) { // const have executed in open function |
227 | 0 | result_column = get_result_from_const(count); |
228 | 0 | return Status::OK(); |
229 | 0 | } |
230 | 88 | if (fast_execute(context, selector, count, result_column)) { |
231 | 0 | return Status::OK(); |
232 | 0 | } |
233 | 88 | DBUG_EXECUTE_IF("VectorizedFnCall.must_in_slow_path", { |
234 | 88 | if (get_child(0)->is_slot_ref()) { |
235 | 88 | auto debug_col_name = DebugPoints::instance()->get_debug_param_or_default<std::string>( |
236 | 88 | "VectorizedFnCall.must_in_slow_path", "column_name", ""); |
237 | | |
238 | 88 | std::vector<std::string> column_names; |
239 | 88 | boost::split(column_names, debug_col_name, boost::algorithm::is_any_of(",")); |
240 | | |
241 | 88 | auto* column_slot_ref = assert_cast<VSlotRef*>(get_child(0).get()); |
242 | 88 | std::string column_name = column_slot_ref->expr_name(); |
243 | 88 | auto it = std::find(column_names.begin(), column_names.end(), column_name); |
244 | 88 | if (it == column_names.end()) { |
245 | 88 | return Status::Error<ErrorCode::INTERNAL_ERROR>( |
246 | 88 | "column {} should in slow path while VectorizedFnCall::execute.", |
247 | 88 | column_name); |
248 | 88 | } |
249 | 88 | } |
250 | 88 | }) |
251 | 88 | DCHECK(_open_finished || block == nullptr) << debug_string(); |
252 | | |
253 | 88 | Block temp_block; |
254 | 88 | ColumnNumbers args(_children.size()); |
255 | | |
256 | 260 | for (int i = 0; i < _children.size(); ++i) { |
257 | 172 | ColumnPtr tmp_arg_column; |
258 | 172 | RETURN_IF_ERROR( |
259 | 172 | _children[i]->execute_column(context, block, selector, count, tmp_arg_column)); |
260 | 172 | auto arg_type = _children[i]->execute_type(block); |
261 | 172 | temp_block.insert({tmp_arg_column, arg_type, _children[i]->expr_name()}); |
262 | 172 | args[i] = i; |
263 | | |
264 | 172 | if (arg_column != nullptr && i == 0) { |
265 | 0 | *arg_column = tmp_arg_column; |
266 | 0 | } |
267 | 172 | } |
268 | | |
269 | 88 | uint32_t num_columns_without_result = temp_block.columns(); |
270 | | // prepare a column to save result |
271 | 88 | temp_block.insert({nullptr, _data_type, _expr_name}); |
272 | | |
273 | 88 | DBUG_EXECUTE_IF("VectorizedFnCall.wait_before_execute", { |
274 | 88 | auto possibility = DebugPoints::instance()->get_debug_param_or_default<double>( |
275 | 88 | "VectorizedFnCall.wait_before_execute", "possibility", 0); |
276 | 88 | if (random_bool_slow(possibility)) { |
277 | 88 | LOG(WARNING) << "VectorizedFnCall::execute sleep 30s"; |
278 | 88 | sleep(30); |
279 | 88 | } |
280 | 88 | }); |
281 | | |
282 | 88 | RETURN_IF_ERROR(_function->execute(context->fn_context(_fn_context_index), temp_block, args, |
283 | 88 | num_columns_without_result, count)); |
284 | 88 | result_column = temp_block.get_by_position(num_columns_without_result).column; |
285 | 88 | DCHECK_EQ(result_column->size(), count); |
286 | 88 | RETURN_IF_ERROR(result_column->column_self_check()); |
287 | 88 | return Status::OK(); |
288 | 88 | } |
289 | | |
290 | 0 | size_t VectorizedFnCall::estimate_memory(const size_t rows) { |
291 | 0 | if (is_const_and_have_executed()) { // const have execute in open function |
292 | 0 | return 0; |
293 | 0 | } |
294 | | |
295 | 0 | size_t estimate_size = 0; |
296 | 0 | for (auto& child : _children) { |
297 | 0 | estimate_size += child->estimate_memory(rows); |
298 | 0 | } |
299 | |
|
300 | 0 | if (_data_type->have_maximum_size_of_value()) { |
301 | 0 | estimate_size += rows * _data_type->get_size_of_value_in_memory(); |
302 | 0 | } else { |
303 | 0 | estimate_size += rows * 512; /// FIXME: estimated value... |
304 | 0 | } |
305 | 0 | return estimate_size; |
306 | 0 | } |
307 | | |
308 | | Status VectorizedFnCall::execute_runtime_filter(VExprContext* context, const Block* block, |
309 | | const uint8_t* __restrict filter, size_t count, |
310 | | ColumnPtr& result_column, |
311 | 0 | ColumnPtr* arg_column) const { |
312 | 0 | return _do_execute(context, block, nullptr, count, result_column, arg_column); |
313 | 0 | } |
314 | | |
315 | | Status VectorizedFnCall::execute_column_impl(VExprContext* context, const Block* block, |
316 | | const Selector* selector, size_t count, |
317 | 88 | ColumnPtr& result_column) const { |
318 | 88 | return _do_execute(context, block, selector, count, result_column, nullptr); |
319 | 88 | } |
320 | | |
321 | 62 | const std::string& VectorizedFnCall::expr_name() const { |
322 | 62 | return _expr_name; |
323 | 62 | } |
324 | | |
325 | 6 | std::string VectorizedFnCall::function_name() const { |
326 | 6 | return _function_name; |
327 | 6 | } |
328 | | |
329 | 4 | std::string VectorizedFnCall::debug_string() const { |
330 | 4 | std::stringstream out; |
331 | 4 | out << "VectorizedFn["; |
332 | 4 | out << _expr_name; |
333 | 4 | out << "]{"; |
334 | 4 | bool first = true; |
335 | 4 | for (const auto& input_expr : children()) { |
336 | 4 | if (first) { |
337 | 2 | first = false; |
338 | 2 | } else { |
339 | 2 | out << ","; |
340 | 2 | } |
341 | 4 | out << "\n" << input_expr->debug_string(); |
342 | 4 | } |
343 | 4 | out << "}"; |
344 | 4 | return out.str(); |
345 | 4 | } |
346 | | |
347 | 0 | std::string VectorizedFnCall::debug_string(const std::vector<VectorizedFnCall*>& agg_fns) { |
348 | 0 | std::stringstream out; |
349 | 0 | out << "["; |
350 | 0 | for (int i = 0; i < agg_fns.size(); ++i) { |
351 | 0 | out << (i == 0 ? "" : " ") << agg_fns[i]->debug_string(); |
352 | 0 | } |
353 | 0 | out << "]"; |
354 | 0 | return out.str(); |
355 | 0 | } |
356 | | |
357 | 0 | bool VectorizedFnCall::can_push_down_to_index() const { |
358 | 0 | return _function->can_push_down_to_index(); |
359 | 0 | } |
360 | | |
361 | 0 | bool VectorizedFnCall::equals(const VExpr& other) { |
362 | 0 | const auto* other_ptr = dynamic_cast<const VectorizedFnCall*>(&other); |
363 | 0 | if (!other_ptr) { |
364 | 0 | return false; |
365 | 0 | } |
366 | 0 | if (this->_function_name != other_ptr->_function_name) { |
367 | 0 | return false; |
368 | 0 | } |
369 | 0 | if (get_num_children() != other_ptr->get_num_children()) { |
370 | 0 | return false; |
371 | 0 | } |
372 | 0 | for (uint16_t i = 0; i < get_num_children(); i++) { |
373 | 0 | if (!this->get_child(i)->equals(*other_ptr->get_child(i))) { |
374 | 0 | return false; |
375 | 0 | } |
376 | 0 | } |
377 | 0 | return true; |
378 | 0 | } |
379 | | |
380 | | /* |
381 | | * For ANN range search we expect a comparison expression (LE/LT/GE/GT) whose left side is either: |
382 | | * 1) a vector distance function call, or |
383 | | * 2) a cast/virtual slot that unwraps to the function call when the planner promotes float to |
384 | | * double literals. |
385 | | * |
386 | | * Visually the logical tree looks like: |
387 | | * |
388 | | * FunctionCall(LE/LT/GE/GT) |
389 | | * |---------------- |
390 | | * | | |
391 | | * | | |
392 | | * VirtualSlotRef* Float32Literal/Float64Literal |
393 | | * | |
394 | | * | |
395 | | * Cast(Float -> Double)* |
396 | | * | |
397 | | * FunctionCall(distance) |
398 | | * |---------------- |
399 | | * | | |
400 | | * | | |
401 | | * SlotRef ArrayLiteral/Cast(String as Array<FLOAT>) |
402 | | * |
403 | | * Items marked with * are optional and depend on literal types/virtual column usage. The helper |
404 | | * below normalizes the shape and validates distance function, slot, and constant vector inputs. |
405 | | */ |
406 | | |
407 | | void VectorizedFnCall::prepare_ann_range_search( |
408 | | const doris::VectorSearchUserParams& user_params, |
409 | 7 | segment_v2::AnnRangeSearchRuntime& range_search_runtime, bool& suitable_for_ann_index) { |
410 | 7 | if (!suitable_for_ann_index) { |
411 | 0 | return; |
412 | 0 | } |
413 | | |
414 | 7 | if (OPS_FOR_ANN_RANGE_SEARCH.find(this->op()) == OPS_FOR_ANN_RANGE_SEARCH.end()) { |
415 | 0 | suitable_for_ann_index = false; |
416 | 0 | return; |
417 | 0 | } |
418 | | |
419 | 7 | auto mark_unsuitable = [&](const std::string& reason) { |
420 | 1 | suitable_for_ann_index = false; |
421 | 1 | VLOG_DEBUG << "ANN range search skipped: " << reason; |
422 | 1 | }; |
423 | | |
424 | 7 | range_search_runtime.is_le_or_lt = |
425 | 7 | (this->op() == TExprOpcode::LE || this->op() == TExprOpcode::LT); |
426 | | |
427 | 7 | DCHECK(_children.size() == 2); |
428 | | |
429 | 7 | auto left_child = get_child(0); |
430 | 7 | auto right_child = get_child(1); |
431 | | |
432 | | // ========== Step 1: Check left child - must be a distance function ========== |
433 | 7 | auto get_virtual_expr = [&](const VExprSPtr& expr, |
434 | 7 | std::shared_ptr<VirtualSlotRef>& slot_ref) -> VExprSPtr { |
435 | 7 | auto virtual_ref = std::dynamic_pointer_cast<VirtualSlotRef>(expr); |
436 | 7 | if (virtual_ref != nullptr) { |
437 | 7 | DCHECK(virtual_ref->get_virtual_column_expr() != nullptr); |
438 | 7 | slot_ref = virtual_ref; |
439 | 7 | return virtual_ref->get_virtual_column_expr(); |
440 | 7 | } |
441 | 0 | return expr; |
442 | 7 | }; |
443 | | |
444 | 7 | std::shared_ptr<VirtualSlotRef> vir_slot_ref; |
445 | 7 | auto normalized_left = get_virtual_expr(left_child, vir_slot_ref); |
446 | | |
447 | | // Try to find the distance function call, it may be wrapped in a Cast(Float->Double) |
448 | 7 | std::shared_ptr<VectorizedFnCall> function_call = |
449 | 7 | std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left); |
450 | 7 | bool has_float_to_double_cast = false; |
451 | | |
452 | 7 | if (function_call == nullptr) { |
453 | | // Check if it's a Cast expression wrapping a function call |
454 | 0 | auto cast_expr = std::dynamic_pointer_cast<VCastExpr>(normalized_left); |
455 | 0 | if (cast_expr == nullptr) { |
456 | 0 | mark_unsuitable("Left child is neither a function call nor a cast expression."); |
457 | 0 | return; |
458 | 0 | } |
459 | 0 | has_float_to_double_cast = true; |
460 | 0 | auto normalized_cast_child = get_virtual_expr(cast_expr->get_child(0), vir_slot_ref); |
461 | 0 | function_call = std::dynamic_pointer_cast<VectorizedFnCall>(normalized_cast_child); |
462 | 0 | if (function_call == nullptr) { |
463 | 0 | mark_unsuitable("Left child of cast is not a function call."); |
464 | 0 | return; |
465 | 0 | } |
466 | 0 | } |
467 | | |
468 | | // Check if it's a supported distance function |
469 | 7 | if (DISTANCE_FUNCS.find(function_call->_function_name) == DISTANCE_FUNCS.end()) { |
470 | 0 | mark_unsuitable(fmt::format("Left child is not a supported distance function: {}", |
471 | 0 | function_call->_function_name)); |
472 | 0 | return; |
473 | 0 | } |
474 | | |
475 | | // Strip the _approximate suffix to get metric type |
476 | 7 | std::string metric_name = function_call->_function_name; |
477 | 7 | metric_name = metric_name.substr(0, metric_name.size() - 12); |
478 | 7 | range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name); |
479 | | |
480 | | // ========== Step 2: Validate distance function arguments ========== |
481 | | // Identify the slot ref child and the constant query array child (ArrayLiteral or CAST to array) |
482 | 7 | Int32 idx_of_slot_ref = -1; |
483 | 7 | Int32 idx_of_array_expr = -1; |
484 | 14 | auto classify_child = [&](const VExprSPtr& child, UInt16 index) { |
485 | 14 | if (idx_of_slot_ref == -1 && std::dynamic_pointer_cast<VSlotRef>(child) != nullptr) { |
486 | 7 | idx_of_slot_ref = index; |
487 | 7 | return; |
488 | 7 | } |
489 | 7 | if (idx_of_array_expr == -1 && |
490 | 7 | (std::dynamic_pointer_cast<VArrayLiteral>(child) != nullptr || |
491 | 7 | std::dynamic_pointer_cast<VCastExpr>(child) != nullptr)) { |
492 | 7 | idx_of_array_expr = index; |
493 | 7 | } |
494 | 7 | }; |
495 | | |
496 | 21 | for (UInt16 i = 0; i < function_call->get_num_children(); ++i) { |
497 | 14 | classify_child(function_call->get_child(i), i); |
498 | 14 | } |
499 | | |
500 | 7 | if (idx_of_slot_ref == -1 || idx_of_array_expr == -1) { |
501 | 0 | mark_unsuitable("slot ref or array literal/cast is missing."); |
502 | 0 | return; |
503 | 0 | } |
504 | | |
505 | 7 | auto slot_ref = std::dynamic_pointer_cast<VSlotRef>( |
506 | 7 | function_call->get_child(static_cast<UInt16>(idx_of_slot_ref))); |
507 | 7 | range_search_runtime.src_col_idx = slot_ref->column_id(); |
508 | 7 | range_search_runtime.dst_col_idx = vir_slot_ref == nullptr ? -1 : vir_slot_ref->column_id(); |
509 | | |
510 | | // Materialize the constant array expression and validate its shape and types |
511 | 7 | auto array_expr = function_call->get_child(static_cast<UInt16>(idx_of_array_expr)); |
512 | 7 | auto extract_result = extract_query_vector(array_expr); |
513 | 7 | if (!extract_result.has_value()) { |
514 | 0 | mark_unsuitable("Failed to extract query vector from constant array expression."); |
515 | 0 | return; |
516 | 0 | } |
517 | 7 | range_search_runtime.query_value = extract_result.value(); |
518 | 7 | range_search_runtime.dim = range_search_runtime.query_value->size(); |
519 | | |
520 | | // ========== Step 3: Check right child - must be a float/double literal ========== |
521 | 7 | auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child); |
522 | 7 | if (right_literal == nullptr) { |
523 | 1 | mark_unsuitable("Right child is not a literal."); |
524 | 1 | return; |
525 | 1 | } |
526 | | |
527 | | // Handle nullable literal gracefully - just mark as unsuitable instead of crash |
528 | 6 | if (right_literal->is_nullable()) { |
529 | 0 | mark_unsuitable("Right literal is nullable, not supported for ANN range search."); |
530 | 0 | return; |
531 | 0 | } |
532 | | |
533 | 6 | auto right_type = right_literal->get_data_type(); |
534 | 6 | PrimitiveType right_primitive = right_type->get_primitive_type(); |
535 | 6 | const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT; |
536 | 6 | const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE; |
537 | | |
538 | 6 | if (!float32_literal && !float64_literal) { |
539 | 0 | mark_unsuitable("Right child is not a Float32Literal or Float64Literal."); |
540 | 0 | return; |
541 | 0 | } |
542 | | |
543 | | // Validate consistency: if we have Cast(Float->Double), right must be double literal |
544 | 6 | if (has_float_to_double_cast && !float64_literal) { |
545 | 0 | mark_unsuitable("Cast expression expects double literal on right side."); |
546 | 0 | return; |
547 | 0 | } |
548 | | |
549 | | // Extract radius value |
550 | 6 | auto right_col = right_literal->get_column_ptr()->convert_to_full_column_if_const(); |
551 | 6 | if (float32_literal) { |
552 | 6 | const ColumnFloat32* cf32_right = assert_cast<const ColumnFloat32*>(right_col.get()); |
553 | 6 | range_search_runtime.radius = cf32_right->get_data()[0]; |
554 | 6 | } else { |
555 | 0 | const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get()); |
556 | 0 | range_search_runtime.radius = static_cast<float>(cf64_right->get_data()[0]); |
557 | 0 | } |
558 | | |
559 | | // ========== Done: Mark as suitable for ANN range search ========== |
560 | 6 | range_search_runtime.is_ann_range_search = true; |
561 | 6 | range_search_runtime.user_params = user_params; |
562 | 6 | VLOG_DEBUG << fmt::format("Ann range search params: {}", range_search_runtime.to_string()); |
563 | 6 | return; |
564 | 6 | } |
565 | | |
566 | | Status VectorizedFnCall::evaluate_ann_range_search( |
567 | | const segment_v2::AnnRangeSearchRuntime& range_search_runtime, |
568 | | const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, |
569 | | const std::vector<ColumnId>& idx_to_cid, |
570 | | const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, |
571 | | size_t rows_of_segment, roaring::Roaring& row_bitmap, |
572 | | segment_v2::AnnIndexStats& ann_index_stats, bool enable_result_cache, |
573 | 24 | AnnRangeSearchEvaluationResult& evaluation_result) { |
574 | 24 | evaluation_result = {}; |
575 | 24 | if (range_search_runtime.is_ann_range_search == false) { |
576 | 18 | return Status::OK(); |
577 | 18 | } |
578 | | |
579 | 6 | VLOG_DEBUG << fmt::format("Try apply ann range search. Local search params: {}", |
580 | 0 | range_search_runtime.to_string()); |
581 | 6 | size_t origin_num = row_bitmap.cardinality(); |
582 | | |
583 | 6 | const auto idx_in_block = range_search_runtime.src_col_idx; |
584 | 6 | DCHECK_LT(idx_in_block, idx_to_cid.size()) |
585 | 0 | << "idx_in_block: " << idx_in_block << ", idx_to_cid.size(): " << idx_to_cid.size(); |
586 | | |
587 | 6 | ColumnId src_col_cid = idx_to_cid[idx_in_block]; |
588 | 6 | DCHECK(src_col_cid < cid_to_index_iterators.size()); |
589 | 6 | segment_v2::IndexIterator* index_iterator = cid_to_index_iterators[src_col_cid].get(); |
590 | 6 | if (index_iterator == nullptr) { |
591 | 1 | VLOG_DEBUG << "ANN range search skipped: " |
592 | 0 | << fmt::format("No index iterator for column cid {}", src_col_cid); |
593 | 1 | ; |
594 | 1 | return Status::OK(); |
595 | 1 | } |
596 | | |
597 | 5 | segment_v2::AnnIndexIterator* ann_index_iterator = |
598 | 5 | dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterator); |
599 | 5 | if (ann_index_iterator == nullptr) { |
600 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
601 | 0 | << fmt::format("Column cid {} has no ANN index iterator", src_col_cid); |
602 | 0 | return Status::OK(); |
603 | 0 | } |
604 | 5 | DCHECK(ann_index_iterator->get_reader(AnnIndexReaderType::ANN) != nullptr) |
605 | 0 | << "Ann index iterator should have reader. Column cid: " << src_col_cid; |
606 | 5 | std::shared_ptr<AnnIndexReader> ann_index_reader = std::dynamic_pointer_cast<AnnIndexReader>( |
607 | 5 | ann_index_iterator->get_reader(segment_v2::AnnIndexReaderType::ANN)); |
608 | 5 | DCHECK(ann_index_reader != nullptr) |
609 | 0 | << "Ann index reader should not be null. Column cid: " << src_col_cid; |
610 | | // Check if metrics type is match. |
611 | 5 | if (ann_index_reader->get_metric_type() != range_search_runtime.metric_type) { |
612 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
613 | 0 | << fmt::format("Metric type mismatch. Index={} Query={}", |
614 | 0 | segment_v2::metric_to_string(ann_index_reader->get_metric_type()), |
615 | 0 | segment_v2::metric_to_string(range_search_runtime.metric_type)); |
616 | 0 | return Status::OK(); |
617 | 0 | } |
618 | | |
619 | | // Check dimension if available (>0) |
620 | 5 | const size_t index_dim = ann_index_reader->get_dimension(); |
621 | 5 | if (index_dim > 0 && index_dim != range_search_runtime.dim) { |
622 | 1 | return Status::InvalidArgument( |
623 | 1 | "Ann range search query dimension {} does not match index dimension {}", |
624 | 1 | range_search_runtime.dim, index_dim); |
625 | 1 | } |
626 | | |
627 | 4 | const auto& user_params = range_search_runtime.user_params; |
628 | 4 | if (user_params.should_fallback_ann_index_by_small_candidate(origin_num, rows_of_segment)) { |
629 | 0 | VLOG_DEBUG << fmt::format( |
630 | 0 | "Ann range search input rows {} reach small candidate threshold, " |
631 | 0 | "rows_of_segment: {}, absolute_threshold: {}, percent_threshold: {}, " |
632 | 0 | "will not use ann index to filter", |
633 | 0 | origin_num, rows_of_segment, user_params.ann_index_candidate_rows_threshold, |
634 | 0 | user_params.ann_index_candidate_rows_percent_threshold); |
635 | 0 | ann_index_stats.fall_back_brute_force_cnt += 1; |
636 | 0 | ann_index_stats.range_fallback_by_small_candidate_cnt += 1; |
637 | 0 | ann_index_stats.range_fallback_small_candidate_rows += origin_num; |
638 | 0 | return Status::OK(); |
639 | 0 | } |
640 | | |
641 | 4 | auto stats = std::make_unique<segment_v2::AnnIndexStats>(); |
642 | | // Track load index timing |
643 | 4 | { |
644 | 4 | SCOPED_TIMER(&(stats->load_index_costs_ns)); |
645 | 4 | if (!ann_index_iterator->try_load_index()) { |
646 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
647 | 0 | << fmt::format("Failed to load ANN index for column cid {}", src_col_cid); |
648 | 0 | ann_index_stats.fall_back_brute_force_cnt += 1; |
649 | 0 | return Status::OK(); |
650 | 0 | } |
651 | 4 | double load_costs_ms = static_cast<double>(stats->load_index_costs_ns.value()) / 1000000.0; |
652 | 4 | DorisMetrics::instance()->ann_index_load_costs_ms->increment( |
653 | 4 | static_cast<int64_t>(load_costs_ms)); |
654 | 4 | } |
655 | | |
656 | 0 | AnnRangeSearchParams params = range_search_runtime.to_range_search_params(); |
657 | | |
658 | 4 | params.roaring = &row_bitmap; |
659 | 4 | params.enable_result_cache = enable_result_cache; |
660 | 4 | DCHECK(params.roaring != nullptr); |
661 | 4 | DCHECK(params.query_value != nullptr); |
662 | 4 | segment_v2::AnnRangeSearchResult result; |
663 | 4 | RETURN_IF_ERROR(ann_index_iterator->range_search(params, range_search_runtime.user_params, |
664 | 4 | &result, stats.get())); |
665 | | |
666 | 4 | #ifndef NDEBUG |
667 | 4 | if (range_search_runtime.is_le_or_lt == false && |
668 | 4 | ann_index_reader->get_metric_type() == AnnIndexMetric::L2) { |
669 | 2 | DCHECK(result.distance == nullptr) << "Should not have distance"; |
670 | 2 | } |
671 | 4 | if (range_search_runtime.is_le_or_lt == true && |
672 | 4 | ann_index_reader->get_metric_type() == AnnIndexMetric::IP) { |
673 | 0 | DCHECK(result.distance == nullptr); |
674 | 0 | } |
675 | 4 | #endif |
676 | 4 | DCHECK(result.roaring != nullptr); |
677 | 4 | row_bitmap = *result.roaring; |
678 | | |
679 | | // Process virtual column |
680 | 4 | bool dist_fulfilled = false; |
681 | 4 | if (range_search_runtime.dst_col_idx >= 0) { |
682 | | // Prepare materialization if we can use result from index. |
683 | | // Typical situation: range search and operator is LE or LT. |
684 | 4 | if (result.distance != nullptr) { |
685 | 2 | DCHECK(result.row_ids != nullptr); |
686 | 2 | ColumnId dst_col_cid = idx_to_cid[range_search_runtime.dst_col_idx]; |
687 | 2 | DCHECK(dst_col_cid < column_iterators.size()); |
688 | 2 | DCHECK(column_iterators[dst_col_cid] != nullptr); |
689 | 2 | segment_v2::ColumnIterator* column_iterator = column_iterators[dst_col_cid].get(); |
690 | 2 | DCHECK(column_iterator != nullptr); |
691 | 2 | segment_v2::VirtualColumnIterator* virtual_column_iterator = |
692 | 2 | dynamic_cast<segment_v2::VirtualColumnIterator*>(column_iterator); |
693 | 2 | DCHECK(virtual_column_iterator != nullptr); |
694 | | // Now convert distance to column |
695 | 2 | size_t size = result.roaring->cardinality(); |
696 | 2 | auto distance_col = ColumnFloat32::create(size); |
697 | 2 | const float* src = result.distance.get(); |
698 | 2 | float* dst = distance_col->get_data().data(); |
699 | 15 | for (size_t i = 0; i < size; ++i) { |
700 | 13 | dst[i] = src[i]; |
701 | 13 | } |
702 | 2 | virtual_column_iterator->prepare_materialization(std::move(distance_col), |
703 | 2 | std::move(result.row_ids)); |
704 | 2 | dist_fulfilled = true; |
705 | 2 | } else { |
706 | | // Whether the ANN index should have produced distance depends on metric and operator: |
707 | | // - L2: distance is produced for LE/LT; not produced for GE/GT |
708 | | // - IP: distance is produced for GE/GT; not produced for LE/LT |
709 | 2 | #ifndef NDEBUG |
710 | 2 | const bool should_have_distance = |
711 | 2 | (range_search_runtime.is_le_or_lt && |
712 | 2 | range_search_runtime.metric_type == AnnIndexMetric::L2) || |
713 | 2 | (!range_search_runtime.is_le_or_lt && |
714 | 2 | range_search_runtime.metric_type == AnnIndexMetric::IP); |
715 | | // If we expected distance but didn't get it, assert in debug to catch logic errors. |
716 | 2 | DCHECK(!should_have_distance) << "Expected distance from ANN index but got none"; |
717 | 2 | #endif |
718 | 2 | } |
719 | 4 | } else { |
720 | | // Dest is not virtual column. |
721 | 0 | dist_fulfilled = true; |
722 | 0 | } |
723 | | |
724 | 4 | evaluation_result.executed = true; |
725 | 4 | evaluation_result.dist_fulfilled = dist_fulfilled; |
726 | 4 | VLOG_DEBUG << fmt::format( |
727 | 0 | "Ann range search filtered {} rows, origin {} rows, virtual column is full-filled: {}", |
728 | 0 | origin_num - row_bitmap.cardinality(), origin_num, dist_fulfilled); |
729 | | |
730 | 4 | ann_index_stats = *stats; |
731 | 4 | return Status::OK(); |
732 | 4 | } |
733 | | |
734 | 4 | double VectorizedFnCall::execute_cost() const { |
735 | 4 | if (!_function) { |
736 | 0 | throw Exception( |
737 | 0 | Status::InternalError("Function is null in expression: {}", this->debug_string())); |
738 | 0 | } |
739 | 4 | double cost = _function->execute_cost(); |
740 | 8 | for (const auto& child : _children) { |
741 | 8 | cost += child->execute_cost(); |
742 | 8 | } |
743 | 4 | return cost; |
744 | 4 | } |
745 | | |
746 | | } // namespace doris |