Coverage Report

Created: 2026-03-16 03:15

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/impl/AuxIndexStructures.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <algorithm>
11
#include <cstring>
12
13
#include <faiss/impl/AuxIndexStructures.h>
14
15
#include <faiss/impl/FaissAssert.h>
16
17
namespace faiss {
18
19
/***********************************************************************
20
 * RangeSearchResult
21
 ***********************************************************************/
22
23
170
RangeSearchResult::RangeSearchResult(size_t nq, bool alloc_lims) : nq(nq) {
24
170
    if (alloc_lims) {
25
170
        lims = new size_t[nq + 1];
26
170
        memset(lims, 0, sizeof(*lims) * (nq + 1));
27
170
    } else {
28
0
        lims = nullptr;
29
0
    }
30
170
    labels = nullptr;
31
170
    distances = nullptr;
32
170
    buffer_size = 1024 * 256;
33
170
}
34
35
/// called when lims contains the nb of elements result entries
36
/// for each query
37
170
void RangeSearchResult::do_allocation() {
38
    // works only if all the partial results are aggregated
39
    // simulatenously
40
170
    FAISS_THROW_IF_NOT(labels == nullptr && distances == nullptr);
41
170
    size_t ofs = 0;
42
340
    for (int i = 0; i < nq; i++) {
43
170
        size_t n = lims[i];
44
170
        lims[i] = ofs;
45
170
        ofs += n;
46
170
    }
47
170
    lims[nq] = ofs;
48
170
    labels = new idx_t[ofs];
49
170
    distances = new float[ofs];
50
170
}
51
52
170
RangeSearchResult::~RangeSearchResult() {
53
170
    delete[] labels;
54
170
    delete[] distances;
55
170
    delete[] lims;
56
170
}
57
58
/***********************************************************************
59
 * BufferList
60
 ***********************************************************************/
61
62
170
BufferList::BufferList(size_t buffer_size) : buffer_size(buffer_size) {
63
170
    wp = buffer_size;
64
170
}
65
66
170
BufferList::~BufferList() {
67
317
    for (int i = 0; i < buffers.size(); i++) {
68
147
        delete[] buffers[i].ids;
69
147
        delete[] buffers[i].dis;
70
147
    }
71
170
}
72
73
6.93k
void BufferList::add(idx_t id, float dis) {
74
6.93k
    if (wp == buffer_size) { // need new buffer
75
147
        append_buffer();
76
147
    }
77
6.93k
    Buffer& buf = buffers.back();
78
6.93k
    buf.ids[wp] = id;
79
6.93k
    buf.dis[wp] = dis;
80
6.93k
    wp++;
81
6.93k
}
82
83
147
void BufferList::append_buffer() {
84
147
    Buffer buf = {new idx_t[buffer_size], new float[buffer_size]};
85
147
    buffers.push_back(buf);
86
147
    wp = 0;
87
147
}
88
89
/// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to
90
/// tables dest_ids, dest_dis
91
void BufferList::copy_range(
92
        size_t ofs,
93
        size_t n,
94
        idx_t* dest_ids,
95
170
        float* dest_dis) {
96
170
    size_t bno = ofs / buffer_size;
97
170
    ofs -= bno * buffer_size;
98
317
    while (n > 0) {
99
147
        size_t ncopy = ofs + n < buffer_size ? n : buffer_size - ofs;
100
147
        Buffer buf = buffers[bno];
101
147
        memcpy(dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids));
102
147
        memcpy(dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis));
103
147
        dest_ids += ncopy;
104
147
        dest_dis += ncopy;
105
147
        ofs = 0;
106
147
        bno++;
107
147
        n -= ncopy;
108
147
    }
109
170
}
110
111
/***********************************************************************
112
 * RangeSearchPartialResult
113
 ***********************************************************************/
114
115
6.93k
void RangeQueryResult::add(float dis, idx_t id) {
116
6.93k
    nres++;
117
6.93k
    pres->add(id, dis);
118
6.93k
}
119
120
RangeSearchPartialResult::RangeSearchPartialResult(RangeSearchResult* res_in)
121
170
        : BufferList(res_in->buffer_size), res(res_in) {}
122
123
/// begin a new result
124
170
RangeQueryResult& RangeSearchPartialResult::new_result(idx_t qno) {
125
170
    RangeQueryResult qres = {qno, 0, this};
126
170
    queries.push_back(qres);
127
170
    return queries.back();
128
170
}
129
130
170
void RangeSearchPartialResult::finalize() {
131
170
    set_lims();
132
170
#pragma omp barrier
133
134
170
#pragma omp single
135
170
    res->do_allocation();
136
137
170
#pragma omp barrier
138
170
    copy_result();
139
170
}
140
141
/// called by range_search before do_allocation
142
170
void RangeSearchPartialResult::set_lims() {
143
340
    for (int i = 0; i < queries.size(); i++) {
144
170
        RangeQueryResult& qres = queries[i];
145
170
        res->lims[qres.qno] = qres.nres;
146
170
    }
147
170
}
148
149
/// called by range_search after do_allocation
150
170
void RangeSearchPartialResult::copy_result(bool incremental) {
151
170
    size_t ofs = 0;
152
340
    for (int i = 0; i < queries.size(); i++) {
153
170
        RangeQueryResult& qres = queries[i];
154
155
170
        copy_range(
156
170
                ofs,
157
170
                qres.nres,
158
170
                res->labels + res->lims[qres.qno],
159
170
                res->distances + res->lims[qres.qno]);
160
170
        if (incremental) {
161
0
            res->lims[qres.qno] += qres.nres;
162
0
        }
163
170
        ofs += qres.nres;
164
170
    }
165
170
}
166
167
void RangeSearchPartialResult::merge(
168
        std::vector<RangeSearchPartialResult*>& partial_results,
169
0
        bool do_delete) {
170
0
    int npres = partial_results.size();
171
0
    if (npres == 0)
172
0
        return;
173
0
    RangeSearchResult* result = partial_results[0]->res;
174
0
    size_t nx = result->nq;
175
176
    // count
177
0
    for (const RangeSearchPartialResult* pres : partial_results) {
178
0
        if (!pres)
179
0
            continue;
180
0
        for (const RangeQueryResult& qres : pres->queries) {
181
0
            result->lims[qres.qno] += qres.nres;
182
0
        }
183
0
    }
184
0
    result->do_allocation();
185
0
    for (int j = 0; j < npres; j++) {
186
0
        if (!partial_results[j])
187
0
            continue;
188
0
        partial_results[j]->copy_result(true);
189
0
        if (do_delete) {
190
0
            delete partial_results[j];
191
0
            partial_results[j] = nullptr;
192
0
        }
193
0
    }
194
195
    // reset the limits
196
0
    for (size_t i = nx; i > 0; i--) {
197
0
        result->lims[i] = result->lims[i - 1];
198
0
    }
199
0
    result->lims[0] = 0;
200
0
}
201
202
/***********************************************************
203
 * Interrupt callback
204
 ***********************************************************/
205
206
std::unique_ptr<InterruptCallback> InterruptCallback::instance;
207
208
std::mutex InterruptCallback::lock;
209
210
0
void InterruptCallback::clear_instance() {
211
0
    delete instance.release();
212
0
}
213
214
1.09k
void InterruptCallback::check() {
215
1.09k
    if (!instance.get()) {
216
1.09k
        return;
217
1.09k
    }
218
0
    if (instance->want_interrupt()) {
219
0
        FAISS_THROW_MSG("computation interrupted");
220
0
    }
221
0
}
222
223
18.7k
bool InterruptCallback::is_interrupted() {
224
18.7k
    if (!instance.get()) {
225
18.7k
        return false;
226
18.7k
    }
227
0
    std::lock_guard<std::mutex> guard(lock);
228
0
    return instance->want_interrupt();
229
18.7k
}
230
231
18.9k
size_t InterruptCallback::get_period_hint(size_t flops) {
232
18.9k
    if (!instance.get()) {
233
18.9k
        return (size_t)1 << 30; // never check
234
18.9k
    }
235
    // for 10M flops, it is reasonable to check once every 10 iterations
236
0
    return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
237
18.9k
}
238
239
0
void TimeoutCallback::set_timeout(double timeout_in_seconds) {
240
0
    timeout = timeout_in_seconds;
241
0
    start = std::chrono::steady_clock::now();
242
0
}
243
244
0
bool TimeoutCallback::want_interrupt() {
245
0
    if (timeout == 0) {
246
0
        return false;
247
0
    }
248
0
    auto end = std::chrono::steady_clock::now();
249
0
    std::chrono::duration<float, std::milli> duration = end - start;
250
0
    float elapsed_in_seconds = duration.count() / 1000.0;
251
0
    if (elapsed_in_seconds > timeout) {
252
0
        timeout = 0;
253
0
        return true;
254
0
    }
255
0
    return false;
256
0
}
257
258
0
void TimeoutCallback::reset(double timeout_in_seconds) {
259
0
    auto tc(new faiss::TimeoutCallback());
260
0
    faiss::InterruptCallback::instance.reset(tc);
261
0
    tc->set_timeout(timeout_in_seconds);
262
0
}
263
264
} // namespace faiss