Coverage Report

Created: 2026-04-01 04:10

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/impl/AuxIndexStructures.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <algorithm>
11
#include <cstring>
12
13
#include <faiss/impl/AuxIndexStructures.h>
14
15
#include <faiss/impl/FaissAssert.h>
16
17
namespace faiss {
18
19
/***********************************************************************
20
 * RangeSearchResult
21
 ***********************************************************************/
22
23
168
RangeSearchResult::RangeSearchResult(size_t nq, bool alloc_lims) : nq(nq) {
24
168
    if (alloc_lims) {
25
168
        lims = new size_t[nq + 1];
26
168
        memset(lims, 0, sizeof(*lims) * (nq + 1));
27
168
    } else {
28
0
        lims = nullptr;
29
0
    }
30
168
    labels = nullptr;
31
168
    distances = nullptr;
32
168
    buffer_size = 1024 * 256;
33
168
}
34
35
/// called when lims contains the nb of elements result entries
36
/// for each query
37
168
void RangeSearchResult::do_allocation() {
38
    // works only if all the partial results are aggregated
39
    // simulatenously
40
168
    FAISS_THROW_IF_NOT(labels == nullptr && distances == nullptr);
41
168
    size_t ofs = 0;
42
336
    for (int i = 0; i < nq; i++) {
43
168
        size_t n = lims[i];
44
168
        lims[i] = ofs;
45
168
        ofs += n;
46
168
    }
47
168
    lims[nq] = ofs;
48
168
    labels = new idx_t[ofs];
49
168
    distances = new float[ofs];
50
168
}
51
52
168
RangeSearchResult::~RangeSearchResult() {
53
168
    delete[] labels;
54
168
    delete[] distances;
55
168
    delete[] lims;
56
168
}
57
58
/***********************************************************************
59
 * BufferList
60
 ***********************************************************************/
61
62
168
BufferList::BufferList(size_t buffer_size) : buffer_size(buffer_size) {
63
168
    wp = buffer_size;
64
168
}
65
66
168
BufferList::~BufferList() {
67
313
    for (int i = 0; i < buffers.size(); i++) {
68
145
        delete[] buffers[i].ids;
69
145
        delete[] buffers[i].dis;
70
145
    }
71
168
}
72
73
7.07k
void BufferList::add(idx_t id, float dis) {
74
7.07k
    if (wp == buffer_size) { // need new buffer
75
145
        append_buffer();
76
145
    }
77
7.07k
    Buffer& buf = buffers.back();
78
7.07k
    buf.ids[wp] = id;
79
7.07k
    buf.dis[wp] = dis;
80
7.07k
    wp++;
81
7.07k
}
82
83
145
void BufferList::append_buffer() {
84
145
    Buffer buf = {new idx_t[buffer_size], new float[buffer_size]};
85
145
    buffers.push_back(buf);
86
145
    wp = 0;
87
145
}
88
89
/// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to
90
/// tables dest_ids, dest_dis
91
void BufferList::copy_range(
92
        size_t ofs,
93
        size_t n,
94
        idx_t* dest_ids,
95
168
        float* dest_dis) {
96
168
    size_t bno = ofs / buffer_size;
97
168
    ofs -= bno * buffer_size;
98
313
    while (n > 0) {
99
145
        size_t ncopy = ofs + n < buffer_size ? n : buffer_size - ofs;
100
145
        Buffer buf = buffers[bno];
101
145
        memcpy(dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids));
102
145
        memcpy(dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis));
103
145
        dest_ids += ncopy;
104
145
        dest_dis += ncopy;
105
145
        ofs = 0;
106
145
        bno++;
107
145
        n -= ncopy;
108
145
    }
109
168
}
110
111
/***********************************************************************
112
 * RangeSearchPartialResult
113
 ***********************************************************************/
114
115
7.07k
void RangeQueryResult::add(float dis, idx_t id) {
116
7.07k
    nres++;
117
7.07k
    pres->add(id, dis);
118
7.07k
}
119
120
RangeSearchPartialResult::RangeSearchPartialResult(RangeSearchResult* res_in)
121
168
        : BufferList(res_in->buffer_size), res(res_in) {}
122
123
/// begin a new result
124
168
RangeQueryResult& RangeSearchPartialResult::new_result(idx_t qno) {
125
168
    RangeQueryResult qres = {qno, 0, this};
126
168
    queries.push_back(qres);
127
168
    return queries.back();
128
168
}
129
130
168
void RangeSearchPartialResult::finalize() {
131
168
    set_lims();
132
168
#pragma omp barrier
133
134
168
#pragma omp single
135
168
    res->do_allocation();
136
137
168
#pragma omp barrier
138
168
    copy_result();
139
168
}
140
141
/// called by range_search before do_allocation
142
168
void RangeSearchPartialResult::set_lims() {
143
336
    for (int i = 0; i < queries.size(); i++) {
144
168
        RangeQueryResult& qres = queries[i];
145
168
        res->lims[qres.qno] = qres.nres;
146
168
    }
147
168
}
148
149
/// called by range_search after do_allocation
150
168
void RangeSearchPartialResult::copy_result(bool incremental) {
151
168
    size_t ofs = 0;
152
336
    for (int i = 0; i < queries.size(); i++) {
153
168
        RangeQueryResult& qres = queries[i];
154
155
168
        copy_range(
156
168
                ofs,
157
168
                qres.nres,
158
168
                res->labels + res->lims[qres.qno],
159
168
                res->distances + res->lims[qres.qno]);
160
168
        if (incremental) {
161
0
            res->lims[qres.qno] += qres.nres;
162
0
        }
163
168
        ofs += qres.nres;
164
168
    }
165
168
}
166
167
void RangeSearchPartialResult::merge(
168
        std::vector<RangeSearchPartialResult*>& partial_results,
169
0
        bool do_delete) {
170
0
    int npres = partial_results.size();
171
0
    if (npres == 0)
172
0
        return;
173
0
    RangeSearchResult* result = partial_results[0]->res;
174
0
    size_t nx = result->nq;
175
176
    // count
177
0
    for (const RangeSearchPartialResult* pres : partial_results) {
178
0
        if (!pres)
179
0
            continue;
180
0
        for (const RangeQueryResult& qres : pres->queries) {
181
0
            result->lims[qres.qno] += qres.nres;
182
0
        }
183
0
    }
184
0
    result->do_allocation();
185
0
    for (int j = 0; j < npres; j++) {
186
0
        if (!partial_results[j])
187
0
            continue;
188
0
        partial_results[j]->copy_result(true);
189
0
        if (do_delete) {
190
0
            delete partial_results[j];
191
0
            partial_results[j] = nullptr;
192
0
        }
193
0
    }
194
195
    // reset the limits
196
0
    for (size_t i = nx; i > 0; i--) {
197
0
        result->lims[i] = result->lims[i - 1];
198
0
    }
199
0
    result->lims[0] = 0;
200
0
}
201
202
/***********************************************************
203
 * Interrupt callback
204
 ***********************************************************/
205
206
std::unique_ptr<InterruptCallback> InterruptCallback::instance;
207
208
std::mutex InterruptCallback::lock;
209
210
0
void InterruptCallback::clear_instance() {
211
0
    delete instance.release();
212
0
}
213
214
1.44k
void InterruptCallback::check() {
215
1.44k
    if (!instance.get()) {
216
1.44k
        return;
217
1.44k
    }
218
0
    if (instance->want_interrupt()) {
219
0
        FAISS_THROW_MSG("computation interrupted");
220
0
    }
221
0
}
222
223
17.4k
bool InterruptCallback::is_interrupted() {
224
17.4k
    if (!instance.get()) {
225
17.4k
        return false;
226
17.4k
    }
227
18.4E
    std::lock_guard<std::mutex> guard(lock);
228
18.4E
    return instance->want_interrupt();
229
17.4k
}
230
231
17.5k
size_t InterruptCallback::get_period_hint(size_t flops) {
232
17.5k
    if (!instance.get()) {
233
17.5k
        return (size_t)1 << 30; // never check
234
17.5k
    }
235
    // for 10M flops, it is reasonable to check once every 10 iterations
236
0
    return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
237
17.5k
}
238
239
0
void TimeoutCallback::set_timeout(double timeout_in_seconds) {
240
0
    timeout = timeout_in_seconds;
241
0
    start = std::chrono::steady_clock::now();
242
0
}
243
244
0
bool TimeoutCallback::want_interrupt() {
245
0
    if (timeout == 0) {
246
0
        return false;
247
0
    }
248
0
    auto end = std::chrono::steady_clock::now();
249
0
    std::chrono::duration<float, std::milli> duration = end - start;
250
0
    float elapsed_in_seconds = duration.count() / 1000.0;
251
0
    if (elapsed_in_seconds > timeout) {
252
0
        timeout = 0;
253
0
        return true;
254
0
    }
255
0
    return false;
256
0
}
257
258
0
void TimeoutCallback::reset(double timeout_in_seconds) {
259
0
    auto tc(new faiss::TimeoutCallback());
260
0
    faiss::InterruptCallback::instance.reset(tc);
261
0
    tc->set_timeout(timeout_in_seconds);
262
0
}
263
264
} // namespace faiss