/root/doris/contrib/faiss/faiss/impl/AuxIndexStructures.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <algorithm> |
11 | | #include <cstring> |
12 | | |
13 | | #include <faiss/impl/AuxIndexStructures.h> |
14 | | |
15 | | #include <faiss/impl/FaissAssert.h> |
16 | | |
17 | | namespace faiss { |
18 | | |
19 | | /*********************************************************************** |
20 | | * RangeSearchResult |
21 | | ***********************************************************************/ |
22 | | |
23 | 115 | RangeSearchResult::RangeSearchResult(size_t nq, bool alloc_lims) : nq(nq) { |
24 | 115 | if (alloc_lims) { |
25 | 115 | lims = new size_t[nq + 1]; |
26 | 115 | memset(lims, 0, sizeof(*lims) * (nq + 1)); |
27 | 115 | } else { |
28 | 0 | lims = nullptr; |
29 | 0 | } |
30 | 115 | labels = nullptr; |
31 | 115 | distances = nullptr; |
32 | 115 | buffer_size = 1024 * 256; |
33 | 115 | } |
34 | | |
35 | | /// called when lims contains the nb of elements result entries |
36 | | /// for each query |
37 | 115 | void RangeSearchResult::do_allocation() { |
38 | | // works only if all the partial results are aggregated |
39 | | // simulatenously |
40 | 115 | FAISS_THROW_IF_NOT(labels == nullptr && distances == nullptr); |
41 | 115 | size_t ofs = 0; |
42 | 230 | for (int i = 0; i < nq; i++) { |
43 | 115 | size_t n = lims[i]; |
44 | 115 | lims[i] = ofs; |
45 | 115 | ofs += n; |
46 | 115 | } |
47 | 115 | lims[nq] = ofs; |
48 | 115 | labels = new idx_t[ofs]; |
49 | 115 | distances = new float[ofs]; |
50 | 115 | } |
51 | | |
52 | 115 | RangeSearchResult::~RangeSearchResult() { |
53 | 115 | delete[] labels; |
54 | 115 | delete[] distances; |
55 | 115 | delete[] lims; |
56 | 115 | } |
57 | | |
58 | | /*********************************************************************** |
59 | | * BufferList |
60 | | ***********************************************************************/ |
61 | | |
62 | 115 | BufferList::BufferList(size_t buffer_size) : buffer_size(buffer_size) { |
63 | 115 | wp = buffer_size; |
64 | 115 | } |
65 | | |
66 | 115 | BufferList::~BufferList() { |
67 | 214 | for (int i = 0; i < buffers.size(); i++) { |
68 | 99 | delete[] buffers[i].ids; |
69 | 99 | delete[] buffers[i].dis; |
70 | 99 | } |
71 | 115 | } |
72 | | |
73 | 6.72k | void BufferList::add(idx_t id, float dis) { |
74 | 6.72k | if (wp == buffer_size) { // need new buffer |
75 | 99 | append_buffer(); |
76 | 99 | } |
77 | 6.72k | Buffer& buf = buffers.back(); |
78 | 6.72k | buf.ids[wp] = id; |
79 | 6.72k | buf.dis[wp] = dis; |
80 | 6.72k | wp++; |
81 | 6.72k | } |
82 | | |
83 | 99 | void BufferList::append_buffer() { |
84 | 99 | Buffer buf = {new idx_t[buffer_size], new float[buffer_size]}; |
85 | 99 | buffers.push_back(buf); |
86 | 99 | wp = 0; |
87 | 99 | } |
88 | | |
89 | | /// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to |
90 | | /// tables dest_ids, dest_dis |
91 | | void BufferList::copy_range( |
92 | | size_t ofs, |
93 | | size_t n, |
94 | | idx_t* dest_ids, |
95 | 115 | float* dest_dis) { |
96 | 115 | size_t bno = ofs / buffer_size; |
97 | 115 | ofs -= bno * buffer_size; |
98 | 214 | while (n > 0) { |
99 | 99 | size_t ncopy = ofs + n < buffer_size ? n : buffer_size - ofs; |
100 | 99 | Buffer buf = buffers[bno]; |
101 | 99 | memcpy(dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids)); |
102 | 99 | memcpy(dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis)); |
103 | 99 | dest_ids += ncopy; |
104 | 99 | dest_dis += ncopy; |
105 | 99 | ofs = 0; |
106 | 99 | bno++; |
107 | 99 | n -= ncopy; |
108 | 99 | } |
109 | 115 | } |
110 | | |
111 | | /*********************************************************************** |
112 | | * RangeSearchPartialResult |
113 | | ***********************************************************************/ |
114 | | |
115 | 6.72k | void RangeQueryResult::add(float dis, idx_t id) { |
116 | 6.72k | nres++; |
117 | 6.72k | pres->add(id, dis); |
118 | 6.72k | } |
119 | | |
120 | | RangeSearchPartialResult::RangeSearchPartialResult(RangeSearchResult* res_in) |
121 | 115 | : BufferList(res_in->buffer_size), res(res_in) {} |
122 | | |
123 | | /// begin a new result |
124 | 115 | RangeQueryResult& RangeSearchPartialResult::new_result(idx_t qno) { |
125 | 115 | RangeQueryResult qres = {qno, 0, this}; |
126 | 115 | queries.push_back(qres); |
127 | 115 | return queries.back(); |
128 | 115 | } |
129 | | |
130 | 115 | void RangeSearchPartialResult::finalize() { |
131 | 115 | set_lims(); |
132 | 115 | #pragma omp barrier |
133 | | |
134 | 115 | #pragma omp single |
135 | 115 | res->do_allocation(); |
136 | | |
137 | 115 | #pragma omp barrier |
138 | 115 | copy_result(); |
139 | 115 | } |
140 | | |
141 | | /// called by range_search before do_allocation |
142 | 115 | void RangeSearchPartialResult::set_lims() { |
143 | 230 | for (int i = 0; i < queries.size(); i++) { |
144 | 115 | RangeQueryResult& qres = queries[i]; |
145 | 115 | res->lims[qres.qno] = qres.nres; |
146 | 115 | } |
147 | 115 | } |
148 | | |
149 | | /// called by range_search after do_allocation |
150 | 115 | void RangeSearchPartialResult::copy_result(bool incremental) { |
151 | 115 | size_t ofs = 0; |
152 | 230 | for (int i = 0; i < queries.size(); i++) { |
153 | 115 | RangeQueryResult& qres = queries[i]; |
154 | | |
155 | 115 | copy_range( |
156 | 115 | ofs, |
157 | 115 | qres.nres, |
158 | 115 | res->labels + res->lims[qres.qno], |
159 | 115 | res->distances + res->lims[qres.qno]); |
160 | 115 | if (incremental) { |
161 | 0 | res->lims[qres.qno] += qres.nres; |
162 | 0 | } |
163 | 115 | ofs += qres.nres; |
164 | 115 | } |
165 | 115 | } |
166 | | |
167 | | void RangeSearchPartialResult::merge( |
168 | | std::vector<RangeSearchPartialResult*>& partial_results, |
169 | 0 | bool do_delete) { |
170 | 0 | int npres = partial_results.size(); |
171 | 0 | if (npres == 0) |
172 | 0 | return; |
173 | 0 | RangeSearchResult* result = partial_results[0]->res; |
174 | 0 | size_t nx = result->nq; |
175 | | |
176 | | // count |
177 | 0 | for (const RangeSearchPartialResult* pres : partial_results) { |
178 | 0 | if (!pres) |
179 | 0 | continue; |
180 | 0 | for (const RangeQueryResult& qres : pres->queries) { |
181 | 0 | result->lims[qres.qno] += qres.nres; |
182 | 0 | } |
183 | 0 | } |
184 | 0 | result->do_allocation(); |
185 | 0 | for (int j = 0; j < npres; j++) { |
186 | 0 | if (!partial_results[j]) |
187 | 0 | continue; |
188 | 0 | partial_results[j]->copy_result(true); |
189 | 0 | if (do_delete) { |
190 | 0 | delete partial_results[j]; |
191 | 0 | partial_results[j] = nullptr; |
192 | 0 | } |
193 | 0 | } |
194 | | |
195 | | // reset the limits |
196 | 0 | for (size_t i = nx; i > 0; i--) { |
197 | 0 | result->lims[i] = result->lims[i - 1]; |
198 | 0 | } |
199 | 0 | result->lims[0] = 0; |
200 | 0 | } |
201 | | |
202 | | /*********************************************************** |
203 | | * Interrupt callback |
204 | | ***********************************************************/ |
205 | | |
206 | | std::unique_ptr<InterruptCallback> InterruptCallback::instance; |
207 | | |
208 | | std::mutex InterruptCallback::lock; |
209 | | |
210 | 0 | void InterruptCallback::clear_instance() { |
211 | 0 | delete instance.release(); |
212 | 0 | } |
213 | | |
214 | 963 | void InterruptCallback::check() { |
215 | 963 | if (!instance.get()) { |
216 | 963 | return; |
217 | 963 | } |
218 | 0 | if (instance->want_interrupt()) { |
219 | 0 | FAISS_THROW_MSG("computation interrupted"); |
220 | 0 | } |
221 | 0 | } |
222 | | |
223 | 17.3k | bool InterruptCallback::is_interrupted() { |
224 | 17.3k | if (!instance.get()) { |
225 | 17.3k | return false; |
226 | 17.3k | } |
227 | 0 | std::lock_guard<std::mutex> guard(lock); |
228 | 0 | return instance->want_interrupt(); |
229 | 17.3k | } |
230 | | |
231 | 17.3k | size_t InterruptCallback::get_period_hint(size_t flops) { |
232 | 17.3k | if (!instance.get()) { |
233 | 17.3k | return (size_t)1 << 30; // never check |
234 | 17.3k | } |
235 | | // for 10M flops, it is reasonable to check once every 10 iterations |
236 | 0 | return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1); |
237 | 17.3k | } |
238 | | |
239 | 0 | void TimeoutCallback::set_timeout(double timeout_in_seconds) { |
240 | 0 | timeout = timeout_in_seconds; |
241 | 0 | start = std::chrono::steady_clock::now(); |
242 | 0 | } |
243 | | |
244 | 0 | bool TimeoutCallback::want_interrupt() { |
245 | 0 | if (timeout == 0) { |
246 | 0 | return false; |
247 | 0 | } |
248 | 0 | auto end = std::chrono::steady_clock::now(); |
249 | 0 | std::chrono::duration<float, std::milli> duration = end - start; |
250 | 0 | float elapsed_in_seconds = duration.count() / 1000.0; |
251 | 0 | if (elapsed_in_seconds > timeout) { |
252 | 0 | timeout = 0; |
253 | 0 | return true; |
254 | 0 | } |
255 | 0 | return false; |
256 | 0 | } |
257 | | |
258 | 0 | void TimeoutCallback::reset(double timeout_in_seconds) { |
259 | 0 | auto tc(new faiss::TimeoutCallback()); |
260 | 0 | faiss::InterruptCallback::instance.reset(tc); |
261 | 0 | tc->set_timeout(timeout_in_seconds); |
262 | 0 | } |
263 | | |
264 | | } // namespace faiss |