Coverage Report

Created: 2026-07-01 06:56

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/vectorized_fn_call.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#include "exprs/vectorized_fn_call.h"
19
20
#include <fmt/compile.h>
21
#include <fmt/format.h>
22
#include <fmt/ranges.h> // IWYU pragma: keep
23
#include <gen_cpp/Opcodes_types.h>
24
#include <gen_cpp/Types_types.h>
25
26
#include <memory>
27
#include <ostream>
28
29
#include "common/config.h"
30
#include "common/exception.h"
31
#include "common/logging.h"
32
#include "common/status.h"
33
#include "common/utils.h"
34
#include "core/assert_cast.h"
35
#include "core/block/block.h"
36
#include "core/block/column_numbers.h"
37
#include "core/column/column.h"
38
#include "core/column/column_array.h"
39
#include "core/column/column_nullable.h"
40
#include "core/column/column_vector.h"
41
#include "core/data_type/data_type.h"
42
#include "core/data_type/data_type_agg_state.h"
43
#include "core/types.h"
44
#include "exec/common/util.hpp"
45
#include "exec/pipeline/pipeline_task.h"
46
#include "exprs/function/array/function_array_distance.h"
47
#include "exprs/function/function_agg_state.h"
48
#include "exprs/function/function_fake.h"
49
#include "exprs/function/function_java_udf.h"
50
#include "exprs/function/function_python_udf.h"
51
#include "exprs/function/function_rpc.h"
52
#include "exprs/function/simple_function_factory.h"
53
#include "exprs/function_context.h"
54
#include "exprs/varray_literal.h"
55
#include "exprs/vcast_expr.h"
56
#include "exprs/vexpr_context.h"
57
#include "exprs/virtual_slot_ref.h"
58
#include "exprs/vliteral.h"
59
#include "runtime/runtime_state.h"
60
#include "storage/index/ann/ann_index.h"
61
#include "storage/index/ann/ann_index_iterator.h"
62
#include "storage/index/ann/ann_search_params.h"
63
#include "storage/index/index_reader.h"
64
#include "storage/index/zone_map/zonemap_eval_context.h"
65
#include "storage/segment/column_reader.h"
66
#include "storage/segment/virtual_column_iterator.h"
67
68
namespace doris {
69
class RowDescriptor;
70
class RuntimeState;
71
class TExprNode;
72
} // namespace doris
73
74
namespace doris {
75
76
const std::string AGG_STATE_SUFFIX = "_state";
77
78
// Now left child is a function call, we need to check if it is a distance function
79
const static std::set<std::string> DISTANCE_FUNCS = {L2DistanceApproximate::name,
80
                                                     InnerProductApproximate::name};
81
const static std::set<TExprOpcode::type> OPS_FOR_ANN_RANGE_SEARCH = {
82
        TExprOpcode::GE, TExprOpcode::LE, TExprOpcode::LE, TExprOpcode::GT, TExprOpcode::LT};
83
84
744k
VectorizedFnCall::VectorizedFnCall(const TExprNode& node) : VExpr(node) {}
85
86
Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc,
87
738k
                                 VExprContext* context) {
88
738k
    RETURN_IF_ERROR_OR_PREPARED(VExpr::prepare(state, desc, context));
89
738k
    ColumnsWithTypeAndName argument_template;
90
738k
    argument_template.reserve(_children.size());
91
1.45M
    for (auto child : _children) {
92
1.45M
        if (child->is_literal()) {
93
            // For some functions, he needs some literal columns to derive the return type.
94
695k
            auto literal_node = std::dynamic_pointer_cast<VLiteral>(child);
95
695k
            argument_template.emplace_back(literal_node->get_column_ptr(), child->data_type(),
96
695k
                                           child->expr_name());
97
756k
        } else {
98
756k
            argument_template.emplace_back(nullptr, child->data_type(), child->expr_name());
99
756k
        }
100
1.45M
    }
101
102
738k
    _expr_name = fmt::format("VectorizedFnCall[{}](arguments={},return={})", _fn.name.function_name,
103
738k
                             get_child_names(), _data_type->get_name());
104
738k
    if (_fn.binary_type == TFunctionBinaryType::RPC) {
105
0
        _function = FunctionRPC::create(_fn, argument_template, _data_type);
106
738k
    } else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
107
529
        if (config::enable_java_support) {
108
529
            if (_fn.is_udtf_function) {
109
                // fake function. it's no use and can't execute.
110
56
                auto builder =
111
56
                        std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create());
112
56
                _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>());
113
473
            } else {
114
473
                _function = JavaFunctionCall::create(_fn, argument_template, _data_type);
115
473
            }
116
529
        } else {
117
0
            return Status::InternalError(
118
0
                    "Java UDF is not enabled, you can change be config enable_java_support to true "
119
0
                    "and restart be.");
120
0
        }
121
737k
    } else if (_fn.binary_type == TFunctionBinaryType::PYTHON_UDF) {
122
670
        if (config::enable_python_udf_support) {
123
670
            if (_fn.is_udtf_function) {
124
                // fake function. it's no use and can't execute.
125
                // Python UDTF is executed via PythonUDTFFunction in table function path
126
286
                auto builder =
127
286
                        std::make_shared<DefaultFunctionBuilder>(FunctionFake<UDTFImpl>::create());
128
286
                _function = builder->build(argument_template, std::make_shared<DataTypeUInt8>());
129
384
            } else {
130
384
                _function = PythonFunctionCall::create(_fn, argument_template, _data_type);
131
384
                LOG(INFO) << fmt::format(
132
384
                        "create python function call: {}, runtime version: {}, function code: {}",
133
384
                        _fn.name.function_name, _fn.runtime_version, _fn.function_code);
134
384
            }
135
670
        } else {
136
0
            return Status::InternalError(
137
0
                    "Python UDF is not enabled, you can change be config enable_python_udf_support "
138
0
                    "to true and restart be.");
139
0
        }
140
736k
    } else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) {
141
751
        DataTypes argument_types;
142
1.09k
        for (auto column : argument_template) {
143
1.09k
            argument_types.emplace_back(column.type);
144
1.09k
        }
145
146
751
        if (match_suffix(_fn.name.function_name, AGG_STATE_SUFFIX)) {
147
751
            if (_data_type->is_nullable()) {
148
0
                return Status::InternalError("State function's return type must be not nullable");
149
0
            }
150
751
            if (_data_type->get_primitive_type() != PrimitiveType::TYPE_AGG_STATE) {
151
0
                return Status::InternalError(
152
0
                        "State function's return type must be agg_state but get {}",
153
0
                        _data_type->get_family_name());
154
0
            }
155
751
            _function = FunctionAggState::create(
156
751
                    argument_types, _data_type,
157
751
                    assert_cast<const DataTypeAggState*>(_data_type.get())->get_nested_function());
158
751
        } else {
159
0
            return Status::InternalError("Function {} is not endwith '_state'", _fn.signature);
160
0
        }
161
736k
    } else {
162
        // get the function. won't prepare function.
163
736k
        _function = SimpleFunctionFactory::instance().get_function(
164
736k
                _fn.name.function_name, argument_template, _data_type,
165
736k
                {.new_version_unix_timestamp = state->query_options().new_version_unix_timestamp},
166
736k
                state->be_exec_version());
167
736k
    }
168
738k
    if (_function == nullptr) {
169
2
        return Status::InternalError("Could not find function {}, arg {} return {} ",
170
2
                                     _fn.name.function_name, get_child_type_names(),
171
2
                                     _data_type->get_name());
172
2
    }
173
738k
    VExpr::register_function_context(state, context);
174
738k
    _function_name = _fn.name.function_name;
175
738k
    _prepare_finished = true;
176
177
738k
    FunctionContext* fn_ctx = context->fn_context(_fn_context_index);
178
738k
    if (fn().__isset.dict_function) {
179
95
        fn_ctx->set_dict_function(fn().dict_function);
180
95
    }
181
738k
    return Status::OK();
182
738k
}
183
184
Status VectorizedFnCall::open(RuntimeState* state, VExprContext* context,
185
2.03M
                              FunctionContext::FunctionStateScope scope) {
186
2.03M
    DCHECK(_prepare_finished);
187
3.88M
    for (auto& i : _children) {
188
3.88M
        RETURN_IF_ERROR(i->open(state, context, scope));
189
3.88M
    }
190
2.03M
    RETURN_IF_ERROR(VExpr::init_function_context(state, context, scope, _function));
191
2.03M
    if (scope == FunctionContext::FRAGMENT_LOCAL) {
192
737k
        RETURN_IF_ERROR(VExpr::get_const_col(context, nullptr));
193
737k
    }
194
2.03M
    _open_finished = true;
195
2.03M
    return Status::OK();
196
2.03M
}
197
198
2.04M
void VectorizedFnCall::close(VExprContext* context, FunctionContext::FunctionStateScope scope) {
199
2.04M
    VExpr::close_function_context(context, scope, _function);
200
2.04M
    VExpr::close(context, scope);
201
2.04M
}
202
203
12.8k
Status VectorizedFnCall::evaluate_inverted_index(VExprContext* context, uint32_t segment_num_rows) {
204
12.8k
    if (get_num_children() < 1) {
205
        // score() and similar 0-children virtual column functions don't need
206
        // inverted index evaluation; return OK to skip gracefully.
207
28
        return Status::OK();
208
28
    }
209
12.8k
    return _evaluate_inverted_index(context, _function, segment_num_rows);
210
12.8k
}
211
212
38.3k
ZoneMapFilterResult VectorizedFnCall::evaluate_zonemap_filter(const ZoneMapEvalContext& ctx) const {
213
38.3k
    return _function->evaluate_zonemap_filter(ctx, _children);
214
38.3k
}
215
216
101k
bool VectorizedFnCall::can_evaluate_zonemap_filter() const {
217
101k
    return _function != nullptr && !_function->is_blockable() &&
218
101k
           _function->can_evaluate_zonemap_filter(_children);
219
101k
}
220
221
Status VectorizedFnCall::_do_execute(VExprContext* context, const Block* block,
222
                                     const Selector* selector, size_t count,
223
832k
                                     ColumnPtr& result_column, ColumnPtr* arg_column) const {
224
832k
    if (is_const_and_have_executed()) { // const have executed in open function
225
27.8k
        result_column = get_result_from_const(count);
226
27.8k
        return Status::OK();
227
27.8k
    }
228
804k
    if (fast_execute(context, selector, count, result_column)) {
229
649
        return Status::OK();
230
649
    }
231
803k
    DBUG_EXECUTE_IF("VectorizedFnCall.must_in_slow_path", {
232
803k
        if (get_child(0)->is_slot_ref()) {
233
803k
            auto debug_col_name = DebugPoints::instance()->get_debug_param_or_default<std::string>(
234
803k
                    "VectorizedFnCall.must_in_slow_path", "column_name", "");
235
236
803k
            std::vector<std::string> column_names;
237
803k
            boost::split(column_names, debug_col_name, boost::algorithm::is_any_of(","));
238
239
803k
            auto* column_slot_ref = assert_cast<VSlotRef*>(get_child(0).get());
240
803k
            std::string column_name = column_slot_ref->expr_name();
241
803k
            auto it = std::find(column_names.begin(), column_names.end(), column_name);
242
803k
            if (it == column_names.end()) {
243
803k
                return Status::Error<ErrorCode::INTERNAL_ERROR>(
244
803k
                        "column {} should in slow path while VectorizedFnCall::execute.",
245
803k
                        column_name);
246
803k
            }
247
803k
        }
248
803k
    })
249
803k
    DCHECK(_open_finished || block == nullptr) << debug_string();
250
251
803k
    Block temp_block;
252
803k
    ColumnNumbers args(_children.size());
253
254
2.19M
    for (int i = 0; i < _children.size(); ++i) {
255
1.38M
        ColumnPtr tmp_arg_column;
256
1.38M
        RETURN_IF_ERROR(
257
1.38M
                _children[i]->execute_column(context, block, selector, count, tmp_arg_column));
258
1.38M
        auto arg_type = _children[i]->execute_type(block);
259
1.38M
        temp_block.insert({tmp_arg_column, arg_type, _children[i]->expr_name()});
260
1.38M
        args[i] = i;
261
262
1.38M
        if (arg_column != nullptr && i == 0) {
263
25.9k
            *arg_column = tmp_arg_column;
264
25.9k
        }
265
1.38M
    }
266
267
803k
    uint32_t num_columns_without_result = temp_block.columns();
268
    // prepare a column to save result
269
803k
    temp_block.insert({nullptr, _data_type, _expr_name});
270
271
803k
    DBUG_EXECUTE_IF("VectorizedFnCall.wait_before_execute", {
272
803k
        auto possibility = DebugPoints::instance()->get_debug_param_or_default<double>(
273
803k
                "VectorizedFnCall.wait_before_execute", "possibility", 0);
274
803k
        if (random_bool_slow(possibility)) {
275
803k
            LOG(WARNING) << "VectorizedFnCall::execute sleep 30s";
276
803k
            sleep(30);
277
803k
        }
278
803k
    });
279
280
803k
    RETURN_IF_ERROR(_function->execute(context->fn_context(_fn_context_index), temp_block, args,
281
803k
                                       num_columns_without_result, count));
282
803k
    result_column = temp_block.get_by_position(num_columns_without_result).column;
283
803k
    DCHECK_EQ(result_column->size(), count);
284
803k
    RETURN_IF_ERROR(result_column->column_self_check());
285
803k
    return Status::OK();
286
803k
}
287
288
0
size_t VectorizedFnCall::estimate_memory(const size_t rows) {
289
0
    if (is_const_and_have_executed()) { // const have execute in open function
290
0
        return 0;
291
0
    }
292
293
0
    size_t estimate_size = 0;
294
0
    for (auto& child : _children) {
295
0
        estimate_size += child->estimate_memory(rows);
296
0
    }
297
298
0
    if (_data_type->have_maximum_size_of_value()) {
299
0
        estimate_size += rows * _data_type->get_size_of_value_in_memory();
300
0
    } else {
301
0
        estimate_size += rows * 512; /// FIXME: estimated value...
302
0
    }
303
0
    return estimate_size;
304
0
}
305
306
Status VectorizedFnCall::execute_runtime_filter(VExprContext* context, const Block* block,
307
                                                const uint8_t* __restrict filter, size_t count,
308
                                                ColumnPtr& result_column,
309
25.9k
                                                ColumnPtr* arg_column) const {
310
25.9k
    return _do_execute(context, block, nullptr, count, result_column, arg_column);
311
25.9k
}
312
313
Status VectorizedFnCall::execute_column_impl(VExprContext* context, const Block* block,
314
                                             const Selector* selector, size_t count,
315
806k
                                             ColumnPtr& result_column) const {
316
806k
    return _do_execute(context, block, selector, count, result_column, nullptr);
317
806k
}
318
319
370k
const std::string& VectorizedFnCall::expr_name() const {
320
370k
    return _expr_name;
321
370k
}
322
323
78
std::string VectorizedFnCall::function_name() const {
324
78
    return _function_name;
325
78
}
326
327
521
std::string VectorizedFnCall::debug_string() const {
328
521
    std::stringstream out;
329
521
    out << "VectorizedFn[";
330
521
    out << _expr_name;
331
521
    out << "]{";
332
521
    bool first = true;
333
1.03k
    for (const auto& input_expr : children()) {
334
1.03k
        if (first) {
335
519
            first = false;
336
519
        } else {
337
517
            out << ",";
338
517
        }
339
1.03k
        out << "\n" << input_expr->debug_string();
340
1.03k
    }
341
521
    out << "}";
342
521
    return out.str();
343
521
}
344
345
0
std::string VectorizedFnCall::debug_string(const std::vector<VectorizedFnCall*>& agg_fns) {
346
0
    std::stringstream out;
347
0
    out << "[";
348
0
    for (int i = 0; i < agg_fns.size(); ++i) {
349
0
        out << (i == 0 ? "" : " ") << agg_fns[i]->debug_string();
350
0
    }
351
0
    out << "]";
352
0
    return out.str();
353
0
}
354
355
2.59k
bool VectorizedFnCall::can_push_down_to_index() const {
356
2.59k
    return _function->can_push_down_to_index();
357
2.59k
}
358
359
0
bool VectorizedFnCall::equals(const VExpr& other) {
360
0
    const auto* other_ptr = dynamic_cast<const VectorizedFnCall*>(&other);
361
0
    if (!other_ptr) {
362
0
        return false;
363
0
    }
364
0
    if (this->_function_name != other_ptr->_function_name) {
365
0
        return false;
366
0
    }
367
0
    if (get_num_children() != other_ptr->get_num_children()) {
368
0
        return false;
369
0
    }
370
0
    for (uint16_t i = 0; i < get_num_children(); i++) {
371
0
        if (!this->get_child(i)->equals(*other_ptr->get_child(i))) {
372
0
            return false;
373
0
        }
374
0
    }
375
0
    return true;
376
0
}
377
378
/*
379
 * For ANN range search we expect a comparison expression (LE/LT/GE/GT) whose left side is either:
380
 *   1) a vector distance function call, or
381
 *   2) a cast/virtual slot that unwraps to the function call when the planner promotes float to
382
 *      double literals.
383
 *
384
 * Visually the logical tree looks like:
385
 *
386
 *   FunctionCall(LE/LT/GE/GT)
387
 *   |----------------
388
 *   |               |
389
 *   |               |
390
 *   VirtualSlotRef* Float32Literal/Float64Literal
391
 *   |
392
 *   |
393
 *   Cast(Float -> Double)*
394
 *   |
395
 *   FunctionCall(distance)
396
 *   |----------------
397
 *   |               |
398
 *   |               |
399
 *   SlotRef         ArrayLiteral/Cast(String as Array<FLOAT>)
400
 *
401
 * Items marked with * are optional and depend on literal types/virtual column usage. The helper
402
 * below normalizes the shape and validates distance function, slot, and constant vector inputs.
403
 */
404
405
void VectorizedFnCall::prepare_ann_range_search(
406
        const doris::VectorSearchUserParams& user_params,
407
15.0k
        segment_v2::AnnRangeSearchRuntime& range_search_runtime, bool& suitable_for_ann_index) {
408
15.0k
    if (!suitable_for_ann_index) {
409
0
        return;
410
0
    }
411
412
15.0k
    if (OPS_FOR_ANN_RANGE_SEARCH.find(this->op()) == OPS_FOR_ANN_RANGE_SEARCH.end()) {
413
11.1k
        suitable_for_ann_index = false;
414
11.1k
        return;
415
11.1k
    }
416
417
3.89k
    auto mark_unsuitable = [&](const std::string& reason) {
418
3.85k
        suitable_for_ann_index = false;
419
18.4E
        VLOG_DEBUG << "ANN range search skipped: " << reason;
420
3.85k
    };
421
422
3.89k
    range_search_runtime.is_le_or_lt =
423
3.89k
            (this->op() == TExprOpcode::LE || this->op() == TExprOpcode::LT);
424
425
3.89k
    DCHECK(_children.size() == 2);
426
427
3.89k
    auto left_child = get_child(0);
428
3.89k
    auto right_child = get_child(1);
429
430
    // ========== Step 1: Check left child - must be a distance function ==========
431
3.89k
    auto get_virtual_expr = [&](const VExprSPtr& expr,
432
4.96k
                                std::shared_ptr<VirtualSlotRef>& slot_ref) -> VExprSPtr {
433
4.96k
        auto virtual_ref = std::dynamic_pointer_cast<VirtualSlotRef>(expr);
434
4.96k
        if (virtual_ref != nullptr) {
435
234
            DCHECK(virtual_ref->get_virtual_column_expr() != nullptr);
436
234
            slot_ref = virtual_ref;
437
234
            return virtual_ref->get_virtual_column_expr();
438
234
        }
439
4.73k
        return expr;
440
4.96k
    };
441
442
3.89k
    std::shared_ptr<VirtualSlotRef> vir_slot_ref;
443
3.89k
    auto normalized_left = get_virtual_expr(left_child, vir_slot_ref);
444
445
    // Try to find the distance function call, it may be wrapped in a Cast(Float->Double)
446
3.89k
    std::shared_ptr<VectorizedFnCall> function_call =
447
3.89k
            std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left);
448
3.89k
    bool has_float_to_double_cast = false;
449
450
3.89k
    if (function_call == nullptr) {
451
        // Check if it's a Cast expression wrapping a function call
452
1.52k
        auto cast_expr = std::dynamic_pointer_cast<VCastExpr>(normalized_left);
453
1.52k
        if (cast_expr == nullptr) {
454
443
            mark_unsuitable("Left child is neither a function call nor a cast expression.");
455
443
            return;
456
443
        }
457
1.08k
        has_float_to_double_cast = true;
458
1.08k
        auto normalized_cast_child = get_virtual_expr(cast_expr->get_child(0), vir_slot_ref);
459
1.08k
        function_call = std::dynamic_pointer_cast<VectorizedFnCall>(normalized_cast_child);
460
1.08k
        if (function_call == nullptr) {
461
1.03k
            mark_unsuitable("Left child of cast is not a function call.");
462
1.03k
            return;
463
1.03k
        }
464
1.08k
    }
465
466
    // Check if it's a supported distance function
467
2.41k
    if (DISTANCE_FUNCS.find(function_call->_function_name) == DISTANCE_FUNCS.end()) {
468
2.36k
        mark_unsuitable(fmt::format("Left child is not a supported distance function: {}",
469
2.36k
                                    function_call->_function_name));
470
2.36k
        return;
471
2.36k
    }
472
473
    // Strip the _approximate suffix to get metric type
474
47
    std::string metric_name = function_call->_function_name;
475
47
    metric_name = metric_name.substr(0, metric_name.size() - 12);
476
47
    range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name);
477
478
    // ========== Step 2: Validate distance function arguments ==========
479
    // Identify the slot ref child and the constant query array child (ArrayLiteral or CAST to array)
480
47
    Int32 idx_of_slot_ref = -1;
481
47
    Int32 idx_of_array_expr = -1;
482
92
    auto classify_child = [&](const VExprSPtr& child, UInt16 index) {
483
92
        if (idx_of_slot_ref == -1 && std::dynamic_pointer_cast<VSlotRef>(child) != nullptr) {
484
46
            idx_of_slot_ref = index;
485
46
            return;
486
46
        }
487
46
        if (idx_of_array_expr == -1 &&
488
46
            (std::dynamic_pointer_cast<VArrayLiteral>(child) != nullptr ||
489
46
             std::dynamic_pointer_cast<VCastExpr>(child) != nullptr)) {
490
39
            idx_of_array_expr = index;
491
39
        }
492
46
    };
493
494
139
    for (UInt16 i = 0; i < function_call->get_num_children(); ++i) {
495
92
        classify_child(function_call->get_child(i), i);
496
92
    }
497
498
47
    if (idx_of_slot_ref == -1 || idx_of_array_expr == -1) {
499
7
        mark_unsuitable("slot ref or array literal/cast is missing.");
500
7
        return;
501
7
    }
502
503
40
    auto slot_ref = std::dynamic_pointer_cast<VSlotRef>(
504
40
            function_call->get_child(static_cast<UInt16>(idx_of_slot_ref)));
505
40
    range_search_runtime.src_col_idx = slot_ref->column_id();
506
40
    range_search_runtime.dst_col_idx = vir_slot_ref == nullptr ? -1 : vir_slot_ref->column_id();
507
508
    // Materialize the constant array expression and validate its shape and types
509
40
    auto array_expr = function_call->get_child(static_cast<UInt16>(idx_of_array_expr));
510
40
    auto extract_result = extract_query_vector(array_expr);
511
40
    if (!extract_result.has_value()) {
512
0
        mark_unsuitable("Failed to extract query vector from constant array expression.");
513
0
        return;
514
0
    }
515
40
    range_search_runtime.query_value = extract_result.value();
516
40
    range_search_runtime.dim = range_search_runtime.query_value->size();
517
518
    // ========== Step 3: Check right child - must be a float/double literal ==========
519
40
    auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child);
520
40
    if (right_literal == nullptr) {
521
1
        mark_unsuitable("Right child is not a literal.");
522
1
        return;
523
1
    }
524
525
    // Handle nullable literal gracefully - just mark as unsuitable instead of crash
526
39
    if (right_literal->is_nullable()) {
527
0
        mark_unsuitable("Right literal is nullable, not supported for ANN range search.");
528
0
        return;
529
0
    }
530
531
39
    auto right_type = right_literal->get_data_type();
532
39
    PrimitiveType right_primitive = right_type->get_primitive_type();
533
39
    const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT;
534
39
    const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE;
535
536
39
    if (!float32_literal && !float64_literal) {
537
0
        mark_unsuitable("Right child is not a Float32Literal or Float64Literal.");
538
0
        return;
539
0
    }
540
541
    // Validate consistency: if we have Cast(Float->Double), right must be double literal
542
39
    if (has_float_to_double_cast && !float64_literal) {
543
0
        mark_unsuitable("Cast expression expects double literal on right side.");
544
0
        return;
545
0
    }
546
547
    // Extract radius value
548
39
    auto right_col = right_literal->get_column_ptr()->convert_to_full_column_if_const();
549
39
    if (float32_literal) {
550
7
        const ColumnFloat32* cf32_right = assert_cast<const ColumnFloat32*>(right_col.get());
551
7
        range_search_runtime.radius = cf32_right->get_data()[0];
552
32
    } else {
553
32
        const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get());
554
32
        range_search_runtime.radius = static_cast<float>(cf64_right->get_data()[0]);
555
32
    }
556
557
    // ========== Done: Mark as suitable for ANN range search ==========
558
39
    range_search_runtime.is_ann_range_search = true;
559
39
    range_search_runtime.user_params = user_params;
560
39
    VLOG_DEBUG << fmt::format("Ann range search params: {}", range_search_runtime.to_string());
561
39
    return;
562
39
}
563
564
Status VectorizedFnCall::evaluate_ann_range_search(
565
        const segment_v2::AnnRangeSearchRuntime& range_search_runtime,
566
        const std::vector<std::unique_ptr<segment_v2::IndexIterator>>& cid_to_index_iterators,
567
        const std::vector<ColumnId>& idx_to_cid,
568
        const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators,
569
        size_t rows_of_segment, roaring::Roaring& row_bitmap,
570
        segment_v2::AnnIndexStats& ann_index_stats, bool enable_result_cache,
571
11.9k
        AnnRangeSearchEvaluationResult& evaluation_result) {
572
11.9k
    evaluation_result = {};
573
11.9k
    if (range_search_runtime.is_ann_range_search == false) {
574
11.8k
        return Status::OK();
575
11.8k
    }
576
577
56
    VLOG_DEBUG << fmt::format("Try apply ann range search. Local search params: {}",
578
20
                              range_search_runtime.to_string());
579
56
    size_t origin_num = row_bitmap.cardinality();
580
581
56
    const auto idx_in_block = range_search_runtime.src_col_idx;
582
56
    DCHECK_LT(idx_in_block, idx_to_cid.size())
583
0
            << "idx_in_block: " << idx_in_block << ", idx_to_cid.size(): " << idx_to_cid.size();
584
585
56
    ColumnId src_col_cid = idx_to_cid[idx_in_block];
586
56
    DCHECK(src_col_cid < cid_to_index_iterators.size());
587
56
    segment_v2::IndexIterator* index_iterator = cid_to_index_iterators[src_col_cid].get();
588
56
    if (index_iterator == nullptr) {
589
1
        VLOG_DEBUG << "ANN range search skipped: "
590
0
                   << fmt::format("No index iterator for column cid {}", src_col_cid);
591
1
        ;
592
1
        return Status::OK();
593
1
    }
594
595
55
    segment_v2::AnnIndexIterator* ann_index_iterator =
596
55
            dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterator);
597
55
    if (ann_index_iterator == nullptr) {
598
0
        VLOG_DEBUG << "ANN range search skipped: "
599
0
                   << fmt::format("Column cid {} has no ANN index iterator", src_col_cid);
600
0
        return Status::OK();
601
0
    }
602
55
    DCHECK(ann_index_iterator->get_reader(AnnIndexReaderType::ANN) != nullptr)
603
20
            << "Ann index iterator should have reader. Column cid: " << src_col_cid;
604
55
    std::shared_ptr<AnnIndexReader> ann_index_reader = std::dynamic_pointer_cast<AnnIndexReader>(
605
55
            ann_index_iterator->get_reader(segment_v2::AnnIndexReaderType::ANN));
606
55
    DCHECK(ann_index_reader != nullptr)
607
20
            << "Ann index reader should not be null. Column cid: " << src_col_cid;
608
    // Check if metrics type is match.
609
55
    if (ann_index_reader->get_metric_type() != range_search_runtime.metric_type) {
610
0
        VLOG_DEBUG << "ANN range search skipped: "
611
0
                   << fmt::format("Metric type mismatch. Index={} Query={}",
612
0
                                  segment_v2::metric_to_string(ann_index_reader->get_metric_type()),
613
0
                                  segment_v2::metric_to_string(range_search_runtime.metric_type));
614
0
        return Status::OK();
615
0
    }
616
617
    // Check dimension if available (>0)
618
55
    const size_t index_dim = ann_index_reader->get_dimension();
619
55
    if (index_dim > 0 && index_dim != range_search_runtime.dim) {
620
6
        return Status::InvalidArgument(
621
6
                "Ann range search query dimension {} does not match index dimension {}",
622
6
                range_search_runtime.dim, index_dim);
623
6
    }
624
625
49
    const auto& user_params = range_search_runtime.user_params;
626
49
    if (user_params.should_fallback_ann_index_by_small_candidate(origin_num, rows_of_segment)) {
627
0
        VLOG_DEBUG << fmt::format(
628
0
                "Ann range search input rows {} reach small candidate threshold, "
629
0
                "rows_of_segment: {}, absolute_threshold: {}, percent_threshold: {}, "
630
0
                "will not use ann index to filter",
631
0
                origin_num, rows_of_segment, user_params.ann_index_candidate_rows_threshold,
632
0
                user_params.ann_index_candidate_rows_percent_threshold);
633
0
        ann_index_stats.fall_back_brute_force_cnt += 1;
634
0
        ann_index_stats.range_fallback_by_small_candidate_cnt += 1;
635
0
        ann_index_stats.range_fallback_small_candidate_rows += origin_num;
636
0
        return Status::OK();
637
0
    }
638
639
49
    auto stats = std::make_unique<segment_v2::AnnIndexStats>();
640
    // Track load index timing
641
49
    {
642
49
        SCOPED_TIMER(&(stats->load_index_costs_ns));
643
49
        if (!ann_index_iterator->try_load_index()) {
644
2
            VLOG_DEBUG << "ANN range search skipped: "
645
0
                       << fmt::format("Failed to load ANN index for column cid {}", src_col_cid);
646
2
            ann_index_stats.fall_back_brute_force_cnt += 1;
647
2
            return Status::OK();
648
2
        }
649
47
        double load_costs_ms = static_cast<double>(stats->load_index_costs_ns.value()) / 1000000.0;
650
47
        DorisMetrics::instance()->ann_index_load_costs_ms->increment(
651
47
                static_cast<int64_t>(load_costs_ms));
652
47
    }
653
654
0
    AnnRangeSearchParams params = range_search_runtime.to_range_search_params();
655
656
47
    params.roaring = &row_bitmap;
657
47
    params.enable_result_cache = enable_result_cache;
658
47
    DCHECK(params.roaring != nullptr);
659
47
    DCHECK(params.query_value != nullptr);
660
47
    segment_v2::AnnRangeSearchResult result;
661
47
    RETURN_IF_ERROR(ann_index_iterator->range_search(params, range_search_runtime.user_params,
662
47
                                                     &result, stats.get()));
663
664
47
#ifndef NDEBUG
665
47
    if (range_search_runtime.is_le_or_lt == false &&
666
47
        ann_index_reader->get_metric_type() == AnnIndexMetric::L2) {
667
7
        DCHECK(result.distance == nullptr) << "Should not have distance";
668
7
    }
669
47
    if (range_search_runtime.is_le_or_lt == true &&
670
47
        ann_index_reader->get_metric_type() == AnnIndexMetric::IP) {
671
4
        DCHECK(result.distance == nullptr);
672
4
    }
673
47
#endif
674
47
    DCHECK(result.roaring != nullptr);
675
47
    row_bitmap = *result.roaring;
676
677
    // Process virtual column
678
47
    bool dist_fulfilled = false;
679
47
    if (range_search_runtime.dst_col_idx >= 0) {
680
        // Prepare materialization if we can use result from index.
681
        // Typical situation: range search and operator is LE or LT.
682
4
        if (result.distance != nullptr) {
683
2
            DCHECK(result.row_ids != nullptr);
684
2
            ColumnId dst_col_cid = idx_to_cid[range_search_runtime.dst_col_idx];
685
2
            DCHECK(dst_col_cid < column_iterators.size());
686
2
            DCHECK(column_iterators[dst_col_cid] != nullptr);
687
2
            segment_v2::ColumnIterator* column_iterator = column_iterators[dst_col_cid].get();
688
2
            DCHECK(column_iterator != nullptr);
689
2
            segment_v2::VirtualColumnIterator* virtual_column_iterator =
690
2
                    dynamic_cast<segment_v2::VirtualColumnIterator*>(column_iterator);
691
2
            DCHECK(virtual_column_iterator != nullptr);
692
            // Now convert distance to column
693
2
            size_t size = result.roaring->cardinality();
694
2
            auto distance_col = ColumnFloat32::create(size);
695
2
            const float* src = result.distance.get();
696
2
            float* dst = distance_col->get_data().data();
697
15
            for (size_t i = 0; i < size; ++i) {
698
13
                dst[i] = src[i];
699
13
            }
700
2
            virtual_column_iterator->prepare_materialization(std::move(distance_col),
701
2
                                                             std::move(result.row_ids));
702
2
            dist_fulfilled = true;
703
2
        } else {
704
            // Whether the ANN index should have produced distance depends on metric and operator:
705
            //  - L2: distance is produced for LE/LT; not produced for GE/GT
706
            //  - IP: distance is produced for GE/GT; not produced for LE/LT
707
2
#ifndef NDEBUG
708
2
            const bool should_have_distance =
709
2
                    (range_search_runtime.is_le_or_lt &&
710
2
                     range_search_runtime.metric_type == AnnIndexMetric::L2) ||
711
2
                    (!range_search_runtime.is_le_or_lt &&
712
2
                     range_search_runtime.metric_type == AnnIndexMetric::IP);
713
            // If we expected distance but didn't get it, assert in debug to catch logic errors.
714
2
            DCHECK(!should_have_distance) << "Expected distance from ANN index but got none";
715
2
#endif
716
2
        }
717
43
    } else {
718
        // Dest is not virtual column.
719
43
        dist_fulfilled = true;
720
43
    }
721
722
47
    evaluation_result.executed = true;
723
47
    evaluation_result.dist_fulfilled = dist_fulfilled;
724
47
    VLOG_DEBUG << fmt::format(
725
20
            "Ann range search filtered {} rows, origin {} rows, virtual column is full-filled: {}",
726
20
            origin_num - row_bitmap.cardinality(), origin_num, dist_fulfilled);
727
728
47
    ann_index_stats = *stats;
729
47
    return Status::OK();
730
47
}
731
732
1.22M
double VectorizedFnCall::execute_cost() const {
733
1.22M
    if (!_function) {
734
0
        throw Exception(
735
0
                Status::InternalError("Function is null in expression: {}", this->debug_string()));
736
0
    }
737
1.22M
    double cost = _function->execute_cost();
738
2.43M
    for (const auto& child : _children) {
739
2.43M
        cost += child->execute_cost();
740
2.43M
    }
741
1.22M
    return cost;
742
1.22M
}
743
744
} // namespace doris