Coverage Report

Created: 2025-12-04 16:12

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/impl/AuxIndexStructures.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <algorithm>
11
#include <cstring>
12
13
#include <faiss/impl/AuxIndexStructures.h>
14
15
#include <faiss/impl/FaissAssert.h>
16
17
namespace faiss {
18
19
/***********************************************************************
20
 * RangeSearchResult
21
 ***********************************************************************/
22
23
115
RangeSearchResult::RangeSearchResult(size_t nq, bool alloc_lims) : nq(nq) {
24
115
    if (alloc_lims) {
25
115
        lims = new size_t[nq + 1];
26
115
        memset(lims, 0, sizeof(*lims) * (nq + 1));
27
115
    } else {
28
0
        lims = nullptr;
29
0
    }
30
115
    labels = nullptr;
31
115
    distances = nullptr;
32
115
    buffer_size = 1024 * 256;
33
115
}
34
35
/// called when lims contains the nb of elements result entries
36
/// for each query
37
115
void RangeSearchResult::do_allocation() {
38
    // works only if all the partial results are aggregated
39
    // simulatenously
40
115
    FAISS_THROW_IF_NOT(labels == nullptr && distances == nullptr);
41
115
    size_t ofs = 0;
42
230
    for (int i = 0; i < nq; i++) {
43
115
        size_t n = lims[i];
44
115
        lims[i] = ofs;
45
115
        ofs += n;
46
115
    }
47
115
    lims[nq] = ofs;
48
115
    labels = new idx_t[ofs];
49
115
    distances = new float[ofs];
50
115
}
51
52
115
RangeSearchResult::~RangeSearchResult() {
53
115
    delete[] labels;
54
115
    delete[] distances;
55
115
    delete[] lims;
56
115
}
57
58
/***********************************************************************
59
 * BufferList
60
 ***********************************************************************/
61
62
115
BufferList::BufferList(size_t buffer_size) : buffer_size(buffer_size) {
63
115
    wp = buffer_size;
64
115
}
65
66
115
BufferList::~BufferList() {
67
214
    for (int i = 0; i < buffers.size(); i++) {
68
99
        delete[] buffers[i].ids;
69
99
        delete[] buffers[i].dis;
70
99
    }
71
115
}
72
73
6.72k
void BufferList::add(idx_t id, float dis) {
74
6.72k
    if (wp == buffer_size) { // need new buffer
75
99
        append_buffer();
76
99
    }
77
6.72k
    Buffer& buf = buffers.back();
78
6.72k
    buf.ids[wp] = id;
79
6.72k
    buf.dis[wp] = dis;
80
6.72k
    wp++;
81
6.72k
}
82
83
99
void BufferList::append_buffer() {
84
99
    Buffer buf = {new idx_t[buffer_size], new float[buffer_size]};
85
99
    buffers.push_back(buf);
86
99
    wp = 0;
87
99
}
88
89
/// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to
90
/// tables dest_ids, dest_dis
91
void BufferList::copy_range(
92
        size_t ofs,
93
        size_t n,
94
        idx_t* dest_ids,
95
115
        float* dest_dis) {
96
115
    size_t bno = ofs / buffer_size;
97
115
    ofs -= bno * buffer_size;
98
214
    while (n > 0) {
99
99
        size_t ncopy = ofs + n < buffer_size ? n : buffer_size - ofs;
100
99
        Buffer buf = buffers[bno];
101
99
        memcpy(dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids));
102
99
        memcpy(dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis));
103
99
        dest_ids += ncopy;
104
99
        dest_dis += ncopy;
105
99
        ofs = 0;
106
99
        bno++;
107
99
        n -= ncopy;
108
99
    }
109
115
}
110
111
/***********************************************************************
112
 * RangeSearchPartialResult
113
 ***********************************************************************/
114
115
6.72k
void RangeQueryResult::add(float dis, idx_t id) {
116
6.72k
    nres++;
117
6.72k
    pres->add(id, dis);
118
6.72k
}
119
120
RangeSearchPartialResult::RangeSearchPartialResult(RangeSearchResult* res_in)
121
115
        : BufferList(res_in->buffer_size), res(res_in) {}
122
123
/// begin a new result
124
115
RangeQueryResult& RangeSearchPartialResult::new_result(idx_t qno) {
125
115
    RangeQueryResult qres = {qno, 0, this};
126
115
    queries.push_back(qres);
127
115
    return queries.back();
128
115
}
129
130
115
void RangeSearchPartialResult::finalize() {
131
115
    set_lims();
132
115
#pragma omp barrier
133
134
115
#pragma omp single
135
115
    res->do_allocation();
136
137
115
#pragma omp barrier
138
115
    copy_result();
139
115
}
140
141
/// called by range_search before do_allocation
142
115
void RangeSearchPartialResult::set_lims() {
143
230
    for (int i = 0; i < queries.size(); i++) {
144
115
        RangeQueryResult& qres = queries[i];
145
115
        res->lims[qres.qno] = qres.nres;
146
115
    }
147
115
}
148
149
/// called by range_search after do_allocation
150
115
void RangeSearchPartialResult::copy_result(bool incremental) {
151
115
    size_t ofs = 0;
152
230
    for (int i = 0; i < queries.size(); i++) {
153
115
        RangeQueryResult& qres = queries[i];
154
155
115
        copy_range(
156
115
                ofs,
157
115
                qres.nres,
158
115
                res->labels + res->lims[qres.qno],
159
115
                res->distances + res->lims[qres.qno]);
160
115
        if (incremental) {
161
0
            res->lims[qres.qno] += qres.nres;
162
0
        }
163
115
        ofs += qres.nres;
164
115
    }
165
115
}
166
167
void RangeSearchPartialResult::merge(
168
        std::vector<RangeSearchPartialResult*>& partial_results,
169
0
        bool do_delete) {
170
0
    int npres = partial_results.size();
171
0
    if (npres == 0)
172
0
        return;
173
0
    RangeSearchResult* result = partial_results[0]->res;
174
0
    size_t nx = result->nq;
175
176
    // count
177
0
    for (const RangeSearchPartialResult* pres : partial_results) {
178
0
        if (!pres)
179
0
            continue;
180
0
        for (const RangeQueryResult& qres : pres->queries) {
181
0
            result->lims[qres.qno] += qres.nres;
182
0
        }
183
0
    }
184
0
    result->do_allocation();
185
0
    for (int j = 0; j < npres; j++) {
186
0
        if (!partial_results[j])
187
0
            continue;
188
0
        partial_results[j]->copy_result(true);
189
0
        if (do_delete) {
190
0
            delete partial_results[j];
191
0
            partial_results[j] = nullptr;
192
0
        }
193
0
    }
194
195
    // reset the limits
196
0
    for (size_t i = nx; i > 0; i--) {
197
0
        result->lims[i] = result->lims[i - 1];
198
0
    }
199
0
    result->lims[0] = 0;
200
0
}
201
202
/***********************************************************
203
 * Interrupt callback
204
 ***********************************************************/
205
206
std::unique_ptr<InterruptCallback> InterruptCallback::instance;
207
208
std::mutex InterruptCallback::lock;
209
210
0
void InterruptCallback::clear_instance() {
211
0
    delete instance.release();
212
0
}
213
214
963
void InterruptCallback::check() {
215
963
    if (!instance.get()) {
216
963
        return;
217
963
    }
218
0
    if (instance->want_interrupt()) {
219
0
        FAISS_THROW_MSG("computation interrupted");
220
0
    }
221
0
}
222
223
17.3k
bool InterruptCallback::is_interrupted() {
224
17.3k
    if (!instance.get()) {
225
17.3k
        return false;
226
17.3k
    }
227
0
    std::lock_guard<std::mutex> guard(lock);
228
0
    return instance->want_interrupt();
229
17.3k
}
230
231
17.3k
size_t InterruptCallback::get_period_hint(size_t flops) {
232
17.3k
    if (!instance.get()) {
233
17.3k
        return (size_t)1 << 30; // never check
234
17.3k
    }
235
    // for 10M flops, it is reasonable to check once every 10 iterations
236
0
    return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
237
17.3k
}
238
239
0
void TimeoutCallback::set_timeout(double timeout_in_seconds) {
240
0
    timeout = timeout_in_seconds;
241
0
    start = std::chrono::steady_clock::now();
242
0
}
243
244
0
bool TimeoutCallback::want_interrupt() {
245
0
    if (timeout == 0) {
246
0
        return false;
247
0
    }
248
0
    auto end = std::chrono::steady_clock::now();
249
0
    std::chrono::duration<float, std::milli> duration = end - start;
250
0
    float elapsed_in_seconds = duration.count() / 1000.0;
251
0
    if (elapsed_in_seconds > timeout) {
252
0
        timeout = 0;
253
0
        return true;
254
0
    }
255
0
    return false;
256
0
}
257
258
0
void TimeoutCallback::reset(double timeout_in_seconds) {
259
0
    auto tc(new faiss::TimeoutCallback());
260
0
    faiss::InterruptCallback::instance.reset(tc);
261
0
    tc->set_timeout(timeout_in_seconds);
262
0
}
263
264
} // namespace faiss