be/src/exprs/vectorized_fn_call.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "exprs/vectorized_fn_call.h" |
19 | | |
20 | | #include <fmt/compile.h> |
21 | | #include <fmt/format.h> |
22 | | #include <fmt/ranges.h> // IWYU pragma: keep |
23 | | #include <gen_cpp/Opcodes_types.h> |
24 | | #include <gen_cpp/Types_types.h> |
25 | | |
26 | | #include <memory> |
27 | | #include <ostream> |
28 | | |
29 | | #include "common/config.h" |
30 | | #include "common/exception.h" |
31 | | #include "common/logging.h" |
32 | | #include "common/status.h" |
33 | | #include "common/utils.h" |
34 | | #include "core/assert_cast.h" |
35 | | #include "core/block/block.h" |
36 | | #include "core/block/column_numbers.h" |
37 | | #include "core/column/column.h" |
38 | | #include "core/column/column_array.h" |
39 | | #include "core/column/column_nullable.h" |
40 | | #include "core/column/column_vector.h" |
41 | | #include "core/data_type/data_type.h" |
42 | | #include "core/data_type/data_type_agg_state.h" |
43 | | #include "core/types.h" |
44 | | #include "exec/common/util.hpp" |
45 | | #include "exec/pipeline/pipeline_task.h" |
46 | | #include "exprs/function/array/function_array_distance.h" |
47 | | #include "exprs/function/function_agg_state.h" |
48 | | #include "exprs/function/function_fake.h" |
49 | | #include "exprs/function/function_java_udf.h" |
50 | | #include "exprs/function/function_python_udf.h" |
51 | | #include "exprs/function/function_rpc.h" |
52 | | #include "exprs/function/simple_function_factory.h" |
53 | | #include "exprs/function_context.h" |
54 | | #include "exprs/varray_literal.h" |
55 | | #include "exprs/vcast_expr.h" |
56 | | #include "exprs/vexpr_context.h" |
57 | | #include "exprs/virtual_slot_ref.h" |
58 | | #include "exprs/vliteral.h" |
59 | | #include "runtime/runtime_state.h" |
60 | | #include "storage/index/ann/ann_index.h" |
61 | | #include "storage/index/ann/ann_index_iterator.h" |
62 | | #include "storage/index/ann/ann_search_params.h" |
63 | | #include "storage/index/index_reader.h" |
64 | | #include "storage/index/zone_map/zonemap_eval_context.h" |
65 | | #include "storage/segment/column_reader.h" |
66 | | #include "storage/segment/virtual_column_iterator.h" |
67 | | |
68 | | namespace doris { |
69 | | class RowDescriptor; |
70 | | class RuntimeState; |
71 | | class TExprNode; |
72 | | } // namespace doris |
73 | | |
74 | | namespace doris { |
75 | | |
76 | | const std::string AGG_STATE_SUFFIX = "_state"; |
77 | | |
78 | | // Now left child is a function call, we need to check if it is a distance function |
79 | | const static std::set<std::string> DISTANCE_FUNCS = {L2DistanceApproximate::name, |
80 | | InnerProductApproximate::name}; |
81 | | const static std::set<TExprOpcode::type> OPS_FOR_ANN_RANGE_SEARCH = { |
82 | | TExprOpcode::GE, TExprOpcode::LE, TExprOpcode::LE, TExprOpcode::GT, TExprOpcode::LT}; |
83 | | |
84 | 744k | VectorizedFnCall::VectorizedFnCall(const TExprNode& node) : VExpr(node) {} |
85 | | |
86 | | Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, |
87 | 738k | VExprContext* context) { |
88 | 738k | RETURN_IF_ERROR_OR_PREPARED(VExpr::prepare(state, desc, context)); |
89 | 738k | ColumnsWithTypeAndName argument_template; |
90 | 738k | argument_template.reserve(_children.size()); |
91 | 1.45M | for (auto child : _children) { |
92 | 1.45M | if (child->is_literal()) { |
93 | | // For some functions, he needs some literal columns to derive the return type. |
94 | 695k | auto literal_node = std::dynamic_pointer_cast<VLiteral>(child); |
95 | 695k | argument_template.emplace_back(literal_node->get_column_ptr(), child->data_type(), |
96 | 695k | child->expr_name()); |
97 | 756k | } else { |
98 | 756k | argument_template.emplace_back(nullptr, child->data_type(), child->expr_name()); |
99 | 756k | } |
100 | 1.45M | } |
101 | | |
102 | 738k | _expr_name = fmt::format("VectorizedFnCall[{}](arguments={},return={})", _fn.name.function_name, |
103 | 738k | get_child_names(), _data_type->get_name()); |
104 | 738k | if (_fn.binary_type == TFunctionBinaryType::RPC) { |
105 | 0 | _function = FunctionRPC::create(_fn, argument_template, _data_type); |
106 | 738k | } else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { |
107 | 529 | if (config::enable_java_support) { |
108 | 529 | if (_fn.is_udtf_function) { |
109 | | // fake function. it's no use and can't execute. |
110 | 56 | auto builder = |
111 | 56 | std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create()); |
112 | 56 | _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>()); |
113 | 473 | } else { |
114 | 473 | _function = JavaFunctionCall::create(_fn, argument_template, _data_type); |
115 | 473 | } |
116 | 529 | } else { |
117 | 0 | return Status::InternalError( |
118 | 0 | "Java UDF is not enabled, you can change be config enable_java_support to true " |
119 | 0 | "and restart be."); |
120 | 0 | } |
121 | 737k | } else if (_fn.binary_type == TFunctionBinaryType::PYTHON_UDF) { |
122 | 670 | if (config::enable_python_udf_support) { |
123 | 670 | if (_fn.is_udtf_function) { |
124 | | // fake function. it's no use and can't execute. |
125 | | // Python UDTF is executed via PythonUDTFFunction in table function path |
126 | 286 | auto builder = |
127 | 286 | std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create()); |
128 | 286 | _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>()); |
129 | 384 | } else { |
130 | 384 | _function = PythonFunctionCall::create(_fn, argument_template, _data_type); |
131 | 384 | LOG(INFO) << fmt::format( |
132 | 384 | "create python function call: {}, runtime version: {}, function code: {}", |
133 | 384 | _fn.name.function_name, _fn.runtime_version, _fn.function_code); |
134 | 384 | } |
135 | 670 | } else { |
136 | 0 | return Status::InternalError( |
137 | 0 | "Python UDF is not enabled, you can change be config enable_python_udf_support " |
138 | 0 | "to true and restart be."); |
139 | 0 | } |
140 | 736k | } else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) { |
141 | 751 | DataTypes argument_types; |
142 | 1.09k | for (auto column : argument_template) { |
143 | 1.09k | argument_types.emplace_back(column.type); |
144 | 1.09k | } |
145 | | |
146 | 751 | if (match_suffix(_fn.name.function_name, AGG_STATE_SUFFIX)) { |
147 | 751 | if (_data_type->is_nullable()) { |
148 | 0 | return Status::InternalError("State function's return type must be not nullable"); |
149 | 0 | } |
150 | 751 | if (_data_type->get_primitive_type() != PrimitiveType::TYPE_AGG_STATE) { |
151 | 0 | return Status::InternalError( |
152 | 0 | "State function's return type must be agg_state but get {}", |
153 | 0 | _data_type->get_family_name()); |
154 | 0 | } |
155 | 751 | _function = FunctionAggState::create( |
156 | 751 | argument_types, _data_type, |
157 | 751 | assert_cast<const DataTypeAggState*>(_data_type.get())->get_nested_function()); |
158 | 751 | } else { |
159 | 0 | return Status::InternalError("Function {} is not endwith '_state'", _fn.signature); |
160 | 0 | } |
161 | 736k | } else { |
162 | | // get the function. won't prepare function. |
163 | 736k | _function = SimpleFunctionFactory::instance().get_function( |
164 | 736k | _fn.name.function_name, argument_template, _data_type, |
165 | 736k | {.new_version_unix_timestamp = state->query_options().new_version_unix_timestamp}, |
166 | 736k | state->be_exec_version()); |
167 | 736k | } |
168 | 738k | if (_function == nullptr) { |
169 | 2 | return Status::InternalError("Could not find function {}, arg {} return {} ", |
170 | 2 | _fn.name.function_name, get_child_type_names(), |
171 | 2 | _data_type->get_name()); |
172 | 2 | } |
173 | 738k | VExpr::register_function_context(state, context); |
174 | 738k | _function_name = _fn.name.function_name; |
175 | 738k | _prepare_finished = true; |
176 | | |
177 | 738k | FunctionContext* fn_ctx = context->fn_context(_fn_context_index); |
178 | 738k | if (fn().__isset.dict_function) { |
179 | 95 | fn_ctx->set_dict_function(fn().dict_function); |
180 | 95 | } |
181 | 738k | return Status::OK(); |
182 | 738k | } |
183 | | |
184 | | Status VectorizedFnCall::open(RuntimeState* state, VExprContext* context, |
185 | 2.03M | FunctionContext::FunctionStateScope scope) { |
186 | 2.03M | DCHECK(_prepare_finished); |
187 | 3.88M | for (auto& i : _children) { |
188 | 3.88M | RETURN_IF_ERROR(i->open(state, context, scope)); |
189 | 3.88M | } |
190 | 2.03M | RETURN_IF_ERROR(VExpr::init_function_context(state, context, scope, _function)); |
191 | 2.03M | if (scope == FunctionContext::FRAGMENT_LOCAL) { |
192 | 737k | RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr)); |
193 | 737k | } |
194 | 2.03M | _open_finished = true; |
195 | 2.03M | return Status::OK(); |
196 | 2.03M | } |
197 | | |
198 | 2.04M | void VectorizedFnCall::close(VExprContext* context, FunctionContext::FunctionStateScope scope) { |
199 | 2.04M | VExpr::close_function_context(context, scope, _function); |
200 | 2.04M | VExpr::close(context, scope); |
201 | 2.04M | } |
202 | | |
203 | 12.8k | Status VectorizedFnCall::evaluate_inverted_index(VExprContext* context, uint32_t segment_num_rows) { |
204 | 12.8k | if (get_num_children() < 1) { |
205 | | // score() and similar 0-children virtual column functions don't need |
206 | | // inverted index evaluation; return OK to skip gracefully. |
207 | 28 | return Status::OK(); |
208 | 28 | } |
209 | 12.8k | return _evaluate_inverted_index(context, _function, segment_num_rows); |
210 | 12.8k | } |
211 | | |
212 | 38.3k | ZoneMapFilterResult VectorizedFnCall::evaluate_zonemap_filter(const ZoneMapEvalContext& ctx) const { |
213 | 38.3k | return _function->evaluate_zonemap_filter(ctx, _children); |
214 | 38.3k | } |
215 | | |
216 | 101k | bool VectorizedFnCall::can_evaluate_zonemap_filter() const { |
217 | 101k | return _function != nullptr && !_function->is_blockable() && |
218 | 101k | _function->can_evaluate_zonemap_filter(_children); |
219 | 101k | } |
220 | | |
221 | | Status VectorizedFnCall::_do_execute(VExprContext* context, const Block* block, |
222 | | const Selector* selector, size_t count, |
223 | 832k | ColumnPtr& result_column, ColumnPtr* arg_column) const { |
224 | 832k | if (is_const_and_have_executed()) { // const have executed in open function |
225 | 27.8k | result_column = get_result_from_const(count); |
226 | 27.8k | return Status::OK(); |
227 | 27.8k | } |
228 | 804k | if (fast_execute(context, selector, count, result_column)) { |
229 | 649 | return Status::OK(); |
230 | 649 | } |
231 | 803k | DBUG_EXECUTE_IF("VectorizedFnCall.must_in_slow_path", { |
232 | 803k | if (get_child(0)->is_slot_ref()) { |
233 | 803k | auto debug_col_name = DebugPoints::instance()->get_debug_param_or_default<std::string>( |
234 | 803k | "VectorizedFnCall.must_in_slow_path", "column_name", ""); |
235 | | |
236 | 803k | std::vector<std::string> column_names; |
237 | 803k | boost::split(column_names, debug_col_name, boost::algorithm::is_any_of(",")); |
238 | | |
239 | 803k | auto* column_slot_ref = assert_cast<VSlotRef*>(get_child(0).get()); |
240 | 803k | std::string column_name = column_slot_ref->expr_name(); |
241 | 803k | auto it = std::find(column_names.begin(), column_names.end(), column_name); |
242 | 803k | if (it == column_names.end()) { |
243 | 803k | return Status::Error<ErrorCode::INTERNAL_ERROR>( |
244 | 803k | "column {} should in slow path while VectorizedFnCall::execute.", |
245 | 803k | column_name); |
246 | 803k | } |
247 | 803k | } |
248 | 803k | }) |
249 | 803k | DCHECK(_open_finished || block == nullptr) << debug_string(); |
250 | | |
251 | 803k | Block temp_block; |
252 | 803k | ColumnNumbers args(_children.size()); |
253 | | |
254 | 2.19M | for (int i = 0; i < _children.size(); ++i) { |
255 | 1.38M | ColumnPtr tmp_arg_column; |
256 | 1.38M | RETURN_IF_ERROR( |
257 | 1.38M | _children[i]->execute_column(context, block, selector, count, tmp_arg_column)); |
258 | 1.38M | auto arg_type = _children[i]->execute_type(block); |
259 | 1.38M | temp_block.insert({tmp_arg_column, arg_type, _children[i]->expr_name()}); |
260 | 1.38M | args[i] = i; |
261 | | |
262 | 1.38M | if (arg_column != nullptr && i == 0) { |
263 | 25.9k | *arg_column = tmp_arg_column; |
264 | 25.9k | } |
265 | 1.38M | } |
266 | | |
267 | 803k | uint32_t num_columns_without_result = temp_block.columns(); |
268 | | // prepare a column to save result |
269 | 803k | temp_block.insert({nullptr, _data_type, _expr_name}); |
270 | | |
271 | 803k | DBUG_EXECUTE_IF("VectorizedFnCall.wait_before_execute", { |
272 | 803k | auto possibility = DebugPoints::instance()->get_debug_param_or_default<double>( |
273 | 803k | "VectorizedFnCall.wait_before_execute", "possibility", 0); |
274 | 803k | if (random_bool_slow(possibility)) { |
275 | 803k | LOG(WARNING) << "VectorizedFnCall::execute sleep 30s"; |
276 | 803k | sleep(30); |
277 | 803k | } |
278 | 803k | }); |
279 | | |
280 | 803k | RETURN_IF_ERROR(_function->execute(context->fn_context(_fn_context_index), temp_block, args, |
281 | 803k | num_columns_without_result, count)); |
282 | 803k | result_column = temp_block.get_by_position(num_columns_without_result).column; |
283 | 803k | DCHECK_EQ(result_column->size(), count); |
284 | 803k | RETURN_IF_ERROR(result_column->column_self_check()); |
285 | 803k | return Status::OK(); |
286 | 803k | } |
287 | | |
288 | 0 | size_t VectorizedFnCall::estimate_memory(const size_t rows) { |
289 | 0 | if (is_const_and_have_executed()) { // const have execute in open function |
290 | 0 | return 0; |
291 | 0 | } |
292 | | |
293 | 0 | size_t estimate_size = 0; |
294 | 0 | for (auto& child : _children) { |
295 | 0 | estimate_size += child->estimate_memory(rows); |
296 | 0 | } |
297 | |
|
298 | 0 | if (_data_type->have_maximum_size_of_value()) { |
299 | 0 | estimate_size += rows * _data_type->get_size_of_value_in_memory(); |
300 | 0 | } else { |
301 | 0 | estimate_size += rows * 512; /// FIXME: estimated value... |
302 | 0 | } |
303 | 0 | return estimate_size; |
304 | 0 | } |
305 | | |
306 | | Status VectorizedFnCall::execute_runtime_filter(VExprContext* context, const Block* block, |
307 | | const uint8_t* __restrict filter, size_t count, |
308 | | ColumnPtr& result_column, |
309 | 25.9k | ColumnPtr* arg_column) const { |
310 | 25.9k | return _do_execute(context, block, nullptr, count, result_column, arg_column); |
311 | 25.9k | } |
312 | | |
313 | | Status VectorizedFnCall::execute_column_impl(VExprContext* context, const Block* block, |
314 | | const Selector* selector, size_t count, |
315 | 806k | ColumnPtr& result_column) const { |
316 | 806k | return _do_execute(context, block, selector, count, result_column, nullptr); |
317 | 806k | } |
318 | | |
319 | 370k | const std::string& VectorizedFnCall::expr_name() const { |
320 | 370k | return _expr_name; |
321 | 370k | } |
322 | | |
323 | 78 | std::string VectorizedFnCall::function_name() const { |
324 | 78 | return _function_name; |
325 | 78 | } |
326 | | |
327 | 521 | std::string VectorizedFnCall::debug_string() const { |
328 | 521 | std::stringstream out; |
329 | 521 | out << "VectorizedFn["; |
330 | 521 | out << _expr_name; |
331 | 521 | out << "]{"; |
332 | 521 | bool first = true; |
333 | 1.03k | for (const auto& input_expr : children()) { |
334 | 1.03k | if (first) { |
335 | 519 | first = false; |
336 | 519 | } else { |
337 | 517 | out << ","; |
338 | 517 | } |
339 | 1.03k | out << "\n" << input_expr->debug_string(); |
340 | 1.03k | } |
341 | 521 | out << "}"; |
342 | 521 | return out.str(); |
343 | 521 | } |
344 | | |
345 | 0 | std::string VectorizedFnCall::debug_string(const std::vector<VectorizedFnCall*>& agg_fns) { |
346 | 0 | std::stringstream out; |
347 | 0 | out << "["; |
348 | 0 | for (int i = 0; i < agg_fns.size(); ++i) { |
349 | 0 | out << (i == 0 ? "" : " ") << agg_fns[i]->debug_string(); |
350 | 0 | } |
351 | 0 | out << "]"; |
352 | 0 | return out.str(); |
353 | 0 | } |
354 | | |
355 | 2.59k | bool VectorizedFnCall::can_push_down_to_index() const { |
356 | 2.59k | return _function->can_push_down_to_index(); |
357 | 2.59k | } |
358 | | |
359 | 0 | bool VectorizedFnCall::equals(const VExpr& other) { |
360 | 0 | const auto* other_ptr = dynamic_cast<const VectorizedFnCall*>(&other); |
361 | 0 | if (!other_ptr) { |
362 | 0 | return false; |
363 | 0 | } |
364 | 0 | if (this->_function_name != other_ptr->_function_name) { |
365 | 0 | return false; |
366 | 0 | } |
367 | 0 | if (get_num_children() != other_ptr->get_num_children()) { |
368 | 0 | return false; |
369 | 0 | } |
370 | 0 | for (uint16_t i = 0; i < get_num_children(); i++) { |
371 | 0 | if (!this->get_child(i)->equals(*other_ptr->get_child(i))) { |
372 | 0 | return false; |
373 | 0 | } |
374 | 0 | } |
375 | 0 | return true; |
376 | 0 | } |
377 | | |
378 | | /* |
379 | | * For ANN range search we expect a comparison expression (LE/LT/GE/GT) whose left side is either: |
380 | | * 1) a vector distance function call, or |
381 | | * 2) a cast/virtual slot that unwraps to the function call when the planner promotes float to |
382 | | * double literals. |
383 | | * |
384 | | * Visually the logical tree looks like: |
385 | | * |
386 | | * FunctionCall(LE/LT/GE/GT) |
387 | | * |---------------- |
388 | | * | | |
389 | | * | | |
390 | | * VirtualSlotRef* Float32Literal/Float64Literal |
391 | | * | |
392 | | * | |
393 | | * Cast(Float -> Double)* |
394 | | * | |
395 | | * FunctionCall(distance) |
396 | | * |---------------- |
397 | | * | | |
398 | | * | | |
399 | | * SlotRef ArrayLiteral/Cast(String as Array<FLOAT>) |
400 | | * |
401 | | * Items marked with * are optional and depend on literal types/virtual column usage. The helper |
402 | | * below normalizes the shape and validates distance function, slot, and constant vector inputs. |
403 | | */ |
404 | | |
405 | | void VectorizedFnCall::prepare_ann_range_search( |
406 | | const doris::VectorSearchUserParams& user_params, |
407 | 15.0k | segment_v2::AnnRangeSearchRuntime& range_search_runtime, bool& suitable_for_ann_index) { |
408 | 15.0k | if (!suitable_for_ann_index) { |
409 | 0 | return; |
410 | 0 | } |
411 | | |
412 | 15.0k | if (OPS_FOR_ANN_RANGE_SEARCH.find(this->op()) == OPS_FOR_ANN_RANGE_SEARCH.end()) { |
413 | 11.1k | suitable_for_ann_index = false; |
414 | 11.1k | return; |
415 | 11.1k | } |
416 | | |
417 | 3.89k | auto mark_unsuitable = [&](const std::string& reason) { |
418 | 3.85k | suitable_for_ann_index = false; |
419 | 18.4E | VLOG_DEBUG << "ANN range search skipped: " << reason; |
420 | 3.85k | }; |
421 | | |
422 | 3.89k | range_search_runtime.is_le_or_lt = |
423 | 3.89k | (this->op() == TExprOpcode::LE || this->op() == TExprOpcode::LT); |
424 | | |
425 | 3.89k | DCHECK(_children.size() == 2); |
426 | | |
427 | 3.89k | auto left_child = get_child(0); |
428 | 3.89k | auto right_child = get_child(1); |
429 | | |
430 | | // ========== Step 1: Check left child - must be a distance function ========== |
431 | 3.89k | auto get_virtual_expr = [&](const VExprSPtr& expr, |
432 | 4.96k | std::shared_ptr<VirtualSlotRef>& slot_ref) -> VExprSPtr { |
433 | 4.96k | auto virtual_ref = std::dynamic_pointer_cast<VirtualSlotRef>(expr); |
434 | 4.96k | if (virtual_ref != nullptr) { |
435 | 234 | DCHECK(virtual_ref->get_virtual_column_expr() != nullptr); |
436 | 234 | slot_ref = virtual_ref; |
437 | 234 | return virtual_ref->get_virtual_column_expr(); |
438 | 234 | } |
439 | 4.73k | return expr; |
440 | 4.96k | }; |
441 | | |
442 | 3.89k | std::shared_ptr<VirtualSlotRef> vir_slot_ref; |
443 | 3.89k | auto normalized_left = get_virtual_expr(left_child, vir_slot_ref); |
444 | | |
445 | | // Try to find the distance function call, it may be wrapped in a Cast(Float->Double) |
446 | 3.89k | std::shared_ptr<VectorizedFnCall> function_call = |
447 | 3.89k | std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left); |
448 | 3.89k | bool has_float_to_double_cast = false; |
449 | | |
450 | 3.89k | if (function_call == nullptr) { |
451 | | // Check if it's a Cast expression wrapping a function call |
452 | 1.52k | auto cast_expr = std::dynamic_pointer_cast<VCastExpr>(normalized_left); |
453 | 1.52k | if (cast_expr == nullptr) { |
454 | 443 | mark_unsuitable("Left child is neither a function call nor a cast expression."); |
455 | 443 | return; |
456 | 443 | } |
457 | 1.08k | has_float_to_double_cast = true; |
458 | 1.08k | auto normalized_cast_child = get_virtual_expr(cast_expr->get_child(0), vir_slot_ref); |
459 | 1.08k | function_call = std::dynamic_pointer_cast<VectorizedFnCall>(normalized_cast_child); |
460 | 1.08k | if (function_call == nullptr) { |
461 | 1.03k | mark_unsuitable("Left child of cast is not a function call."); |
462 | 1.03k | return; |
463 | 1.03k | } |
464 | 1.08k | } |
465 | | |
466 | | // Check if it's a supported distance function |
467 | 2.41k | if (DISTANCE_FUNCS.find(function_call->_function_name) == DISTANCE_FUNCS.end()) { |
468 | 2.36k | mark_unsuitable(fmt::format("Left child is not a supported distance function: {}", |
469 | 2.36k | function_call->_function_name)); |
470 | 2.36k | return; |
471 | 2.36k | } |
472 | | |
473 | | // Strip the _approximate suffix to get metric type |
474 | 47 | std::string metric_name = function_call->_function_name; |
475 | 47 | metric_name = metric_name.substr(0, metric_name.size() - 12); |
476 | 47 | range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name); |
477 | | |
478 | | // ========== Step 2: Validate distance function arguments ========== |
479 | | // Identify the slot ref child and the constant query array child (ArrayLiteral or CAST to array) |
480 | 47 | Int32 idx_of_slot_ref = -1; |
481 | 47 | Int32 idx_of_array_expr = -1; |
482 | 92 | auto classify_child = [&](const VExprSPtr& child, UInt16 index) { |
483 | 92 | if (idx_of_slot_ref == -1 && std::dynamic_pointer_cast<VSlotRef>(child) != nullptr) { |
484 | 46 | idx_of_slot_ref = index; |
485 | 46 | return; |
486 | 46 | } |
487 | 46 | if (idx_of_array_expr == -1 && |
488 | 46 | (std::dynamic_pointer_cast<VArrayLiteral>(child) != nullptr || |
489 | 46 | std::dynamic_pointer_cast<VCastExpr>(child) != nullptr)) { |
490 | 39 | idx_of_array_expr = index; |
491 | 39 | } |
492 | 46 | }; |
493 | | |
494 | 139 | for (UInt16 i = 0; i < function_call->get_num_children(); ++i) { |
495 | 92 | classify_child(function_call->get_child(i), i); |
496 | 92 | } |
497 | | |
498 | 47 | if (idx_of_slot_ref == -1 || idx_of_array_expr == -1) { |
499 | 7 | mark_unsuitable("slot ref or array literal/cast is missing."); |
500 | 7 | return; |
501 | 7 | } |
502 | | |
503 | 40 | auto slot_ref = std::dynamic_pointer_cast<VSlotRef>( |
504 | 40 | function_call->get_child(static_cast<UInt16>(idx_of_slot_ref))); |
505 | 40 | range_search_runtime.src_col_idx = slot_ref->column_id(); |
506 | 40 | range_search_runtime.dst_col_idx = vir_slot_ref == nullptr ? -1 : vir_slot_ref->column_id(); |
507 | | |
508 | | // Materialize the constant array expression and validate its shape and types |
509 | 40 | auto array_expr = function_call->get_child(static_cast<UInt16>(idx_of_array_expr)); |
510 | 40 | auto extract_result = extract_query_vector(array_expr); |
511 | 40 | if (!extract_result.has_value()) { |
512 | 0 | mark_unsuitable("Failed to extract query vector from constant array expression."); |
513 | 0 | return; |
514 | 0 | } |
515 | 40 | range_search_runtime.query_value = extract_result.value(); |
516 | 40 | range_search_runtime.dim = range_search_runtime.query_value->size(); |
517 | | |
518 | | // ========== Step 3: Check right child - must be a float/double literal ========== |
519 | 40 | auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child); |
520 | 40 | if (right_literal == nullptr) { |
521 | 1 | mark_unsuitable("Right child is not a literal."); |
522 | 1 | return; |
523 | 1 | } |
524 | | |
525 | | // Handle nullable literal gracefully - just mark as unsuitable instead of crash |
526 | 39 | if (right_literal->is_nullable()) { |
527 | 0 | mark_unsuitable("Right literal is nullable, not supported for ANN range search."); |
528 | 0 | return; |
529 | 0 | } |
530 | | |
531 | 39 | auto right_type = right_literal->get_data_type(); |
532 | 39 | PrimitiveType right_primitive = right_type->get_primitive_type(); |
533 | 39 | const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT; |
534 | 39 | const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE; |
535 | | |
536 | 39 | if (!float32_literal && !float64_literal) { |
537 | 0 | mark_unsuitable("Right child is not a Float32Literal or Float64Literal."); |
538 | 0 | return; |
539 | 0 | } |
540 | | |
541 | | // Validate consistency: if we have Cast(Float->Double), right must be double literal |
542 | 39 | if (has_float_to_double_cast && !float64_literal) { |
543 | 0 | mark_unsuitable("Cast expression expects double literal on right side."); |
544 | 0 | return; |
545 | 0 | } |
546 | | |
547 | | // Extract radius value |
548 | 39 | auto right_col = right_literal->get_column_ptr()->convert_to_full_column_if_const(); |
549 | 39 | if (float32_literal) { |
550 | 7 | const ColumnFloat32* cf32_right = assert_cast<const ColumnFloat32*>(right_col.get()); |
551 | 7 | range_search_runtime.radius = cf32_right->get_data()[0]; |
552 | 32 | } else { |
553 | 32 | const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get()); |
554 | 32 | range_search_runtime.radius = static_cast<float>(cf64_right->get_data()[0]); |
555 | 32 | } |
556 | | |
557 | | // ========== Done: Mark as suitable for ANN range search ========== |
558 | 39 | range_search_runtime.is_ann_range_search = true; |
559 | 39 | range_search_runtime.user_params = user_params; |
560 | 39 | VLOG_DEBUG << fmt::format("Ann range search params: {}", range_search_runtime.to_string()); |
561 | 39 | return; |
562 | 39 | } |
563 | | |
564 | | Status VectorizedFnCall::evaluate_ann_range_search( |
565 | | const segment_v2::AnnRangeSearchRuntime& range_search_runtime, |
566 | | const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators, |
567 | | const std::vector<ColumnId>& idx_to_cid, |
568 | | const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, |
569 | | size_t rows_of_segment, roaring::Roaring& row_bitmap, |
570 | | segment_v2::AnnIndexStats& ann_index_stats, bool enable_result_cache, |
571 | 11.9k | AnnRangeSearchEvaluationResult& evaluation_result) { |
572 | 11.9k | evaluation_result = {}; |
573 | 11.9k | if (range_search_runtime.is_ann_range_search == false) { |
574 | 11.8k | return Status::OK(); |
575 | 11.8k | } |
576 | | |
577 | 56 | VLOG_DEBUG << fmt::format("Try apply ann range search. Local search params: {}", |
578 | 20 | range_search_runtime.to_string()); |
579 | 56 | size_t origin_num = row_bitmap.cardinality(); |
580 | | |
581 | 56 | const auto idx_in_block = range_search_runtime.src_col_idx; |
582 | 56 | DCHECK_LT(idx_in_block, idx_to_cid.size()) |
583 | 0 | << "idx_in_block: " << idx_in_block << ", idx_to_cid.size(): " << idx_to_cid.size(); |
584 | | |
585 | 56 | ColumnId src_col_cid = idx_to_cid[idx_in_block]; |
586 | 56 | DCHECK(src_col_cid < cid_to_index_iterators.size()); |
587 | 56 | segment_v2::IndexIterator* index_iterator = cid_to_index_iterators[src_col_cid].get(); |
588 | 56 | if (index_iterator == nullptr) { |
589 | 1 | VLOG_DEBUG << "ANN range search skipped: " |
590 | 0 | << fmt::format("No index iterator for column cid {}", src_col_cid); |
591 | 1 | ; |
592 | 1 | return Status::OK(); |
593 | 1 | } |
594 | | |
595 | 55 | segment_v2::AnnIndexIterator* ann_index_iterator = |
596 | 55 | dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterator); |
597 | 55 | if (ann_index_iterator == nullptr) { |
598 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
599 | 0 | << fmt::format("Column cid {} has no ANN index iterator", src_col_cid); |
600 | 0 | return Status::OK(); |
601 | 0 | } |
602 | 55 | DCHECK(ann_index_iterator->get_reader(AnnIndexReaderType::ANN) != nullptr) |
603 | 20 | << "Ann index iterator should have reader. Column cid: " << src_col_cid; |
604 | 55 | std::shared_ptr<AnnIndexReader> ann_index_reader = std::dynamic_pointer_cast<AnnIndexReader>( |
605 | 55 | ann_index_iterator->get_reader(segment_v2::AnnIndexReaderType::ANN)); |
606 | 55 | DCHECK(ann_index_reader != nullptr) |
607 | 20 | << "Ann index reader should not be null. Column cid: " << src_col_cid; |
608 | | // Check if metrics type is match. |
609 | 55 | if (ann_index_reader->get_metric_type() != range_search_runtime.metric_type) { |
610 | 0 | VLOG_DEBUG << "ANN range search skipped: " |
611 | 0 | << fmt::format("Metric type mismatch. Index={} Query={}", |
612 | 0 | segment_v2::metric_to_string(ann_index_reader->get_metric_type()), |
613 | 0 | segment_v2::metric_to_string(range_search_runtime.metric_type)); |
614 | 0 | return Status::OK(); |
615 | 0 | } |
616 | | |
617 | | // Check dimension if available (>0) |
618 | 55 | const size_t index_dim = ann_index_reader->get_dimension(); |
619 | 55 | if (index_dim > 0 && index_dim != range_search_runtime.dim) { |
620 | 6 | return Status::InvalidArgument( |
621 | 6 | "Ann range search query dimension {} does not match index dimension {}", |
622 | 6 | range_search_runtime.dim, index_dim); |
623 | 6 | } |
624 | | |
625 | 49 | const auto& user_params = range_search_runtime.user_params; |
626 | 49 | if (user_params.should_fallback_ann_index_by_small_candidate(origin_num, rows_of_segment)) { |
627 | 0 | VLOG_DEBUG << fmt::format( |
628 | 0 | "Ann range search input rows {} reach small candidate threshold, " |
629 | 0 | "rows_of_segment: {}, absolute_threshold: {}, percent_threshold: {}, " |
630 | 0 | "will not use ann index to filter", |
631 | 0 | origin_num, rows_of_segment, user_params.ann_index_candidate_rows_threshold, |
632 | 0 | user_params.ann_index_candidate_rows_percent_threshold); |
633 | 0 | ann_index_stats.fall_back_brute_force_cnt += 1; |
634 | 0 | ann_index_stats.range_fallback_by_small_candidate_cnt += 1; |
635 | 0 | ann_index_stats.range_fallback_small_candidate_rows += origin_num; |
636 | 0 | return Status::OK(); |
637 | 0 | } |
638 | | |
639 | 49 | auto stats = std::make_unique<segment_v2::AnnIndexStats>(); |
640 | | // Track load index timing |
641 | 49 | { |
642 | 49 | SCOPED_TIMER(&(stats->load_index_costs_ns)); |
643 | 49 | if (!ann_index_iterator->try_load_index()) { |
644 | 2 | VLOG_DEBUG << "ANN range search skipped: " |
645 | 0 | << fmt::format("Failed to load ANN index for column cid {}", src_col_cid); |
646 | 2 | ann_index_stats.fall_back_brute_force_cnt += 1; |
647 | 2 | return Status::OK(); |
648 | 2 | } |
649 | 47 | double load_costs_ms = static_cast<double>(stats->load_index_costs_ns.value()) / 1000000.0; |
650 | 47 | DorisMetrics::instance()->ann_index_load_costs_ms->increment( |
651 | 47 | static_cast<int64_t>(load_costs_ms)); |
652 | 47 | } |
653 | | |
654 | 0 | AnnRangeSearchParams params = range_search_runtime.to_range_search_params(); |
655 | | |
656 | 47 | params.roaring = &row_bitmap; |
657 | 47 | params.enable_result_cache = enable_result_cache; |
658 | 47 | DCHECK(params.roaring != nullptr); |
659 | 47 | DCHECK(params.query_value != nullptr); |
660 | 47 | segment_v2::AnnRangeSearchResult result; |
661 | 47 | RETURN_IF_ERROR(ann_index_iterator->range_search(params, range_search_runtime.user_params, |
662 | 47 | &result, stats.get())); |
663 | | |
664 | 47 | #ifndef NDEBUG |
665 | 47 | if (range_search_runtime.is_le_or_lt == false && |
666 | 47 | ann_index_reader->get_metric_type() == AnnIndexMetric::L2) { |
667 | 7 | DCHECK(result.distance == nullptr) << "Should not have distance"; |
668 | 7 | } |
669 | 47 | if (range_search_runtime.is_le_or_lt == true && |
670 | 47 | ann_index_reader->get_metric_type() == AnnIndexMetric::IP) { |
671 | 4 | DCHECK(result.distance == nullptr); |
672 | 4 | } |
673 | 47 | #endif |
674 | 47 | DCHECK(result.roaring != nullptr); |
675 | 47 | row_bitmap = *result.roaring; |
676 | | |
677 | | // Process virtual column |
678 | 47 | bool dist_fulfilled = false; |
679 | 47 | if (range_search_runtime.dst_col_idx >= 0) { |
680 | | // Prepare materialization if we can use result from index. |
681 | | // Typical situation: range search and operator is LE or LT. |
682 | 4 | if (result.distance != nullptr) { |
683 | 2 | DCHECK(result.row_ids != nullptr); |
684 | 2 | ColumnId dst_col_cid = idx_to_cid[range_search_runtime.dst_col_idx]; |
685 | 2 | DCHECK(dst_col_cid < column_iterators.size()); |
686 | 2 | DCHECK(column_iterators[dst_col_cid] != nullptr); |
687 | 2 | segment_v2::ColumnIterator* column_iterator = column_iterators[dst_col_cid].get(); |
688 | 2 | DCHECK(column_iterator != nullptr); |
689 | 2 | segment_v2::VirtualColumnIterator* virtual_column_iterator = |
690 | 2 | dynamic_cast<segment_v2::VirtualColumnIterator*>(column_iterator); |
691 | 2 | DCHECK(virtual_column_iterator != nullptr); |
692 | | // Now convert distance to column |
693 | 2 | size_t size = result.roaring->cardinality(); |
694 | 2 | auto distance_col = ColumnFloat32::create(size); |
695 | 2 | const float* src = result.distance.get(); |
696 | 2 | float* dst = distance_col->get_data().data(); |
697 | 15 | for (size_t i = 0; i < size; ++i) { |
698 | 13 | dst[i] = src[i]; |
699 | 13 | } |
700 | 2 | virtual_column_iterator->prepare_materialization(std::move(distance_col), |
701 | 2 | std::move(result.row_ids)); |
702 | 2 | dist_fulfilled = true; |
703 | 2 | } else { |
704 | | // Whether the ANN index should have produced distance depends on metric and operator: |
705 | | // - L2: distance is produced for LE/LT; not produced for GE/GT |
706 | | // - IP: distance is produced for GE/GT; not produced for LE/LT |
707 | 2 | #ifndef NDEBUG |
708 | 2 | const bool should_have_distance = |
709 | 2 | (range_search_runtime.is_le_or_lt && |
710 | 2 | range_search_runtime.metric_type == AnnIndexMetric::L2) || |
711 | 2 | (!range_search_runtime.is_le_or_lt && |
712 | 2 | range_search_runtime.metric_type == AnnIndexMetric::IP); |
713 | | // If we expected distance but didn't get it, assert in debug to catch logic errors. |
714 | 2 | DCHECK(!should_have_distance) << "Expected distance from ANN index but got none"; |
715 | 2 | #endif |
716 | 2 | } |
717 | 43 | } else { |
718 | | // Dest is not virtual column. |
719 | 43 | dist_fulfilled = true; |
720 | 43 | } |
721 | | |
722 | 47 | evaluation_result.executed = true; |
723 | 47 | evaluation_result.dist_fulfilled = dist_fulfilled; |
724 | 47 | VLOG_DEBUG << fmt::format( |
725 | 20 | "Ann range search filtered {} rows, origin {} rows, virtual column is full-filled: {}", |
726 | 20 | origin_num - row_bitmap.cardinality(), origin_num, dist_fulfilled); |
727 | | |
728 | 47 | ann_index_stats = *stats; |
729 | 47 | return Status::OK(); |
730 | 47 | } |
731 | | |
732 | 1.22M | double VectorizedFnCall::execute_cost() const { |
733 | 1.22M | if (!_function) { |
734 | 0 | throw Exception( |
735 | 0 | Status::InternalError("Function is null in expression: {}", this->debug_string())); |
736 | 0 | } |
737 | 1.22M | double cost = _function->execute_cost(); |
738 | 2.43M | for (const auto& child : _children) { |
739 | 2.43M | cost += child->execute_cost(); |
740 | 2.43M | } |
741 | 1.22M | return cost; |
742 | 1.22M | } |
743 | | |
744 | | } // namespace doris |