Coverage Report

Created: 2026-05-22 08:16

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/impl/AuxIndexStructures.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <algorithm>
11
#include <cstring>
12
13
#include <faiss/impl/AuxIndexStructures.h>
14
15
#include <faiss/impl/FaissAssert.h>
16
17
namespace faiss {
18
19
/***********************************************************************
20
 * RangeSearchResult
21
 ***********************************************************************/
22
23
149
RangeSearchResult::RangeSearchResult(size_t nq, bool alloc_lims) : nq(nq) {
24
149
    if (alloc_lims) {
25
149
        lims = new size_t[nq + 1];
26
149
        memset(lims, 0, sizeof(*lims) * (nq + 1));
27
149
    } else {
28
0
        lims = nullptr;
29
0
    }
30
149
    labels = nullptr;
31
149
    distances = nullptr;
32
149
    buffer_size = 1024 * 256;
33
149
}
34
35
/// called when lims contains the nb of elements result entries
36
/// for each query
37
149
void RangeSearchResult::do_allocation() {
38
    // works only if all the partial results are aggregated
39
    // simulatenously
40
149
    FAISS_THROW_IF_NOT(labels == nullptr && distances == nullptr);
41
149
    size_t ofs = 0;
42
298
    for (int i = 0; i < nq; i++) {
43
149
        size_t n = lims[i];
44
149
        lims[i] = ofs;
45
149
        ofs += n;
46
149
    }
47
149
    lims[nq] = ofs;
48
149
    labels = new idx_t[ofs];
49
149
    distances = new float[ofs];
50
149
}
51
52
149
RangeSearchResult::~RangeSearchResult() {
53
149
    delete[] labels;
54
149
    delete[] distances;
55
149
    delete[] lims;
56
149
}
57
58
/***********************************************************************
59
 * BufferList
60
 ***********************************************************************/
61
62
149
BufferList::BufferList(size_t buffer_size) : buffer_size(buffer_size) {
63
149
    wp = buffer_size;
64
149
}
65
66
149
BufferList::~BufferList() {
67
275
    for (int i = 0; i < buffers.size(); i++) {
68
126
        delete[] buffers[i].ids;
69
126
        delete[] buffers[i].dis;
70
126
    }
71
149
}
72
73
7.47k
void BufferList::add(idx_t id, float dis) {
74
7.47k
    if (wp == buffer_size) { // need new buffer
75
126
        append_buffer();
76
126
    }
77
7.47k
    Buffer& buf = buffers.back();
78
7.47k
    buf.ids[wp] = id;
79
7.47k
    buf.dis[wp] = dis;
80
7.47k
    wp++;
81
7.47k
}
82
83
126
void BufferList::append_buffer() {
84
126
    Buffer buf = {new idx_t[buffer_size], new float[buffer_size]};
85
126
    buffers.push_back(buf);
86
126
    wp = 0;
87
126
}
88
89
/// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to
90
/// tables dest_ids, dest_dis
91
void BufferList::copy_range(
92
        size_t ofs,
93
        size_t n,
94
        idx_t* dest_ids,
95
149
        float* dest_dis) {
96
149
    size_t bno = ofs / buffer_size;
97
149
    ofs -= bno * buffer_size;
98
275
    while (n > 0) {
99
126
        size_t ncopy = ofs + n < buffer_size ? n : buffer_size - ofs;
100
126
        Buffer buf = buffers[bno];
101
126
        memcpy(dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids));
102
126
        memcpy(dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis));
103
126
        dest_ids += ncopy;
104
126
        dest_dis += ncopy;
105
126
        ofs = 0;
106
126
        bno++;
107
126
        n -= ncopy;
108
126
    }
109
149
}
110
111
/***********************************************************************
112
 * RangeSearchPartialResult
113
 ***********************************************************************/
114
115
7.47k
void RangeQueryResult::add(float dis, idx_t id) {
116
7.47k
    nres++;
117
7.47k
    pres->add(id, dis);
118
7.47k
}
119
120
RangeSearchPartialResult::RangeSearchPartialResult(RangeSearchResult* res_in)
121
149
        : BufferList(res_in->buffer_size), res(res_in) {}
122
123
/// begin a new result
124
149
RangeQueryResult& RangeSearchPartialResult::new_result(idx_t qno) {
125
149
    RangeQueryResult qres = {qno, 0, this};
126
149
    queries.push_back(qres);
127
149
    return queries.back();
128
149
}
129
130
149
void RangeSearchPartialResult::finalize() {
131
149
    set_lims();
132
149
#pragma omp barrier
133
134
149
#pragma omp single
135
149
    res->do_allocation();
136
137
149
#pragma omp barrier
138
149
    copy_result();
139
149
}
140
141
/// called by range_search before do_allocation
142
149
void RangeSearchPartialResult::set_lims() {
143
298
    for (int i = 0; i < queries.size(); i++) {
144
149
        RangeQueryResult& qres = queries[i];
145
149
        res->lims[qres.qno] = qres.nres;
146
149
    }
147
149
}
148
149
/// called by range_search after do_allocation
150
149
void RangeSearchPartialResult::copy_result(bool incremental) {
151
149
    size_t ofs = 0;
152
298
    for (int i = 0; i < queries.size(); i++) {
153
149
        RangeQueryResult& qres = queries[i];
154
155
149
        copy_range(
156
149
                ofs,
157
149
                qres.nres,
158
149
                res->labels + res->lims[qres.qno],
159
149
                res->distances + res->lims[qres.qno]);
160
149
        if (incremental) {
161
0
            res->lims[qres.qno] += qres.nres;
162
0
        }
163
149
        ofs += qres.nres;
164
149
    }
165
149
}
166
167
void RangeSearchPartialResult::merge(
168
        std::vector<RangeSearchPartialResult*>& partial_results,
169
0
        bool do_delete) {
170
0
    int npres = partial_results.size();
171
0
    if (npres == 0)
172
0
        return;
173
0
    RangeSearchResult* result = partial_results[0]->res;
174
0
    size_t nx = result->nq;
175
176
    // count
177
0
    for (const RangeSearchPartialResult* pres : partial_results) {
178
0
        if (!pres)
179
0
            continue;
180
0
        for (const RangeQueryResult& qres : pres->queries) {
181
0
            result->lims[qres.qno] += qres.nres;
182
0
        }
183
0
    }
184
0
    result->do_allocation();
185
0
    for (int j = 0; j < npres; j++) {
186
0
        if (!partial_results[j])
187
0
            continue;
188
0
        partial_results[j]->copy_result(true);
189
0
        if (do_delete) {
190
0
            delete partial_results[j];
191
0
            partial_results[j] = nullptr;
192
0
        }
193
0
    }
194
195
    // reset the limits
196
0
    for (size_t i = nx; i > 0; i--) {
197
0
        result->lims[i] = result->lims[i - 1];
198
0
    }
199
0
    result->lims[0] = 0;
200
0
}
201
202
/***********************************************************
203
 * Interrupt callback
204
 ***********************************************************/
205
206
std::unique_ptr<InterruptCallback> InterruptCallback::instance;
207
208
std::mutex InterruptCallback::lock;
209
210
0
void InterruptCallback::clear_instance() {
211
0
    delete instance.release();
212
0
}
213
214
1.57k
void InterruptCallback::check() {
215
1.57k
    if (!instance.get()) {
216
1.57k
        return;
217
1.57k
    }
218
0
    if (instance->want_interrupt()) {
219
0
        FAISS_THROW_MSG("computation interrupted");
220
0
    }
221
0
}
222
223
18.7k
bool InterruptCallback::is_interrupted() {
224
18.7k
    if (!instance.get()) {
225
18.7k
        return false;
226
18.7k
    }
227
1
    std::lock_guard<std::mutex> guard(lock);
228
1
    return instance->want_interrupt();
229
18.7k
}
230
231
18.8k
size_t InterruptCallback::get_period_hint(size_t flops) {
232
18.8k
    if (!instance.get()) {
233
18.8k
        return (size_t)1 << 30; // never check
234
18.8k
    }
235
    // for 10M flops, it is reasonable to check once every 10 iterations
236
0
    return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
237
18.8k
}
238
239
0
void TimeoutCallback::set_timeout(double timeout_in_seconds) {
240
0
    timeout = timeout_in_seconds;
241
0
    start = std::chrono::steady_clock::now();
242
0
}
243
244
0
bool TimeoutCallback::want_interrupt() {
245
0
    if (timeout == 0) {
246
0
        return false;
247
0
    }
248
0
    auto end = std::chrono::steady_clock::now();
249
0
    std::chrono::duration<float, std::milli> duration = end - start;
250
0
    float elapsed_in_seconds = duration.count() / 1000.0;
251
0
    if (elapsed_in_seconds > timeout) {
252
0
        timeout = 0;
253
0
        return true;
254
0
    }
255
0
    return false;
256
0
}
257
258
0
void TimeoutCallback::reset(double timeout_in_seconds) {
259
0
    auto tc(new faiss::TimeoutCallback());
260
0
    faiss::InterruptCallback::instance.reset(tc);
261
0
    tc->set_timeout(timeout_in_seconds);
262
0
}
263
264
} // namespace faiss