contrib/faiss/faiss/utils/sorting.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/utils/sorting.h> |
11 | | |
12 | | #include <omp.h> |
13 | | #include <algorithm> |
14 | | |
15 | | #include <faiss/impl/FaissAssert.h> |
16 | | #include <faiss/utils/utils.h> |
17 | | |
18 | | namespace faiss { |
19 | | |
20 | | /***************************************************************************** |
21 | | * Argsort |
22 | | ****************************************************************************/ |
23 | | |
24 | | namespace { |
25 | | struct ArgsortComparator { |
26 | | const float* vals; |
27 | 0 | bool operator()(const size_t a, const size_t b) const { |
28 | 0 | return vals[a] < vals[b]; |
29 | 0 | } |
30 | | }; |
31 | | |
32 | | struct SegmentS { |
33 | | size_t i0; // begin pointer in the permutation array |
34 | | size_t i1; // end |
35 | 0 | size_t len() const { |
36 | 0 | return i1 - i0; |
37 | 0 | } |
38 | | }; |
39 | | |
40 | | // see https://en.wikipedia.org/wiki/Merge_algorithm#Parallel_merge |
41 | | // extended to > 1 merge thread |
42 | | |
43 | | // merges 2 ranges that should be consecutive on the source into |
44 | | // the union of the two on the destination |
45 | | template <typename T> |
46 | | void parallel_merge( |
47 | | const T* src, |
48 | | T* dst, |
49 | | SegmentS& s1, |
50 | | SegmentS& s2, |
51 | | int nt, |
52 | 0 | const ArgsortComparator& comp) { |
53 | 0 | if (s2.len() > s1.len()) { // make sure that s1 larger than s2 |
54 | 0 | std::swap(s1, s2); |
55 | 0 | } |
56 | | |
57 | | // compute sub-ranges for each thread |
58 | 0 | std::vector<SegmentS> s1s(nt), s2s(nt), sws(nt); |
59 | 0 | s2s[0].i0 = s2.i0; |
60 | 0 | s2s[nt - 1].i1 = s2.i1; |
61 | | |
62 | | // not sure parallel actually helps here |
63 | 0 | #pragma omp parallel for num_threads(nt) |
64 | 0 | for (int t = 0; t < nt; t++) { |
65 | 0 | s1s[t].i0 = s1.i0 + s1.len() * t / nt; |
66 | 0 | s1s[t].i1 = s1.i0 + s1.len() * (t + 1) / nt; |
67 | |
|
68 | 0 | if (t + 1 < nt) { |
69 | 0 | T pivot = src[s1s[t].i1]; |
70 | 0 | size_t i0 = s2.i0, i1 = s2.i1; |
71 | 0 | while (i0 + 1 < i1) { |
72 | 0 | size_t imed = (i1 + i0) / 2; |
73 | 0 | if (comp(pivot, src[imed])) { |
74 | 0 | i1 = imed; |
75 | 0 | } else { |
76 | 0 | i0 = imed; |
77 | 0 | } |
78 | 0 | } |
79 | 0 | s2s[t].i1 = s2s[t + 1].i0 = i1; |
80 | 0 | } |
81 | 0 | } |
82 | 0 | s1.i0 = std::min(s1.i0, s2.i0); |
83 | 0 | s1.i1 = std::max(s1.i1, s2.i1); |
84 | 0 | s2 = s1; |
85 | 0 | sws[0].i0 = s1.i0; |
86 | 0 | for (int t = 0; t < nt; t++) { |
87 | 0 | sws[t].i1 = sws[t].i0 + s1s[t].len() + s2s[t].len(); |
88 | 0 | if (t + 1 < nt) { |
89 | 0 | sws[t + 1].i0 = sws[t].i1; |
90 | 0 | } |
91 | 0 | } |
92 | 0 | assert(sws[nt - 1].i1 == s1.i1); |
93 | | |
94 | | // do the actual merging |
95 | 0 | #pragma omp parallel for num_threads(nt) |
96 | 0 | for (int t = 0; t < nt; t++) { |
97 | 0 | SegmentS sw = sws[t]; |
98 | 0 | SegmentS s1t = s1s[t]; |
99 | 0 | SegmentS s2t = s2s[t]; |
100 | 0 | if (s1t.i0 < s1t.i1 && s2t.i0 < s2t.i1) { |
101 | 0 | for (;;) { |
102 | | // assert (sw.len() == s1t.len() + s2t.len()); |
103 | 0 | if (comp(src[s1t.i0], src[s2t.i0])) { |
104 | 0 | dst[sw.i0++] = src[s1t.i0++]; |
105 | 0 | if (s1t.i0 == s1t.i1) { |
106 | 0 | break; |
107 | 0 | } |
108 | 0 | } else { |
109 | 0 | dst[sw.i0++] = src[s2t.i0++]; |
110 | 0 | if (s2t.i0 == s2t.i1) { |
111 | 0 | break; |
112 | 0 | } |
113 | 0 | } |
114 | 0 | } |
115 | 0 | } |
116 | 0 | if (s1t.len() > 0) { |
117 | 0 | assert(s1t.len() == sw.len()); |
118 | 0 | memcpy(dst + sw.i0, src + s1t.i0, s1t.len() * sizeof(dst[0])); |
119 | 0 | } else if (s2t.len() > 0) { |
120 | 0 | assert(s2t.len() == sw.len()); |
121 | 0 | memcpy(dst + sw.i0, src + s2t.i0, s2t.len() * sizeof(dst[0])); |
122 | 0 | } |
123 | 0 | } |
124 | 0 | } |
125 | | |
126 | | } // namespace |
127 | | |
128 | 0 | void fvec_argsort(size_t n, const float* vals, size_t* perm) { |
129 | 0 | for (size_t i = 0; i < n; i++) { |
130 | 0 | perm[i] = i; |
131 | 0 | } |
132 | 0 | ArgsortComparator comp = {vals}; |
133 | 0 | std::sort(perm, perm + n, comp); |
134 | 0 | } |
135 | | |
136 | 0 | void fvec_argsort_parallel(size_t n, const float* vals, size_t* perm) { |
137 | 0 | size_t* perm2 = new size_t[n]; |
138 | | // 2 result tables, during merging, flip between them |
139 | 0 | size_t *permB = perm2, *permA = perm; |
140 | |
|
141 | 0 | int nt = omp_get_max_threads(); |
142 | 0 | { // prepare correct permutation so that the result ends in perm |
143 | | // at final iteration |
144 | 0 | int nseg = nt; |
145 | 0 | while (nseg > 1) { |
146 | 0 | nseg = (nseg + 1) / 2; |
147 | 0 | std::swap(permA, permB); |
148 | 0 | } |
149 | 0 | } |
150 | |
|
151 | 0 | #pragma omp parallel |
152 | 0 | for (size_t i = 0; i < n; i++) { |
153 | 0 | permA[i] = i; |
154 | 0 | } |
155 | |
|
156 | 0 | ArgsortComparator comp = {vals}; |
157 | |
|
158 | 0 | std::vector<SegmentS> segs(nt); |
159 | | |
160 | | // independent sorts |
161 | 0 | #pragma omp parallel for |
162 | 0 | for (int t = 0; t < nt; t++) { |
163 | 0 | size_t i0 = t * n / nt; |
164 | 0 | size_t i1 = (t + 1) * n / nt; |
165 | 0 | SegmentS seg = {i0, i1}; |
166 | 0 | std::sort(permA + seg.i0, permA + seg.i1, comp); |
167 | 0 | segs[t] = seg; |
168 | 0 | } |
169 | 0 | int prev_nested = omp_get_nested(); |
170 | 0 | omp_set_nested(1); |
171 | |
|
172 | 0 | int nseg = nt; |
173 | 0 | while (nseg > 1) { |
174 | 0 | int nseg1 = (nseg + 1) / 2; |
175 | 0 | int sub_nt = nseg % 2 == 0 ? nt : nt - 1; |
176 | 0 | int sub_nseg1 = nseg / 2; |
177 | |
|
178 | 0 | #pragma omp parallel for num_threads(nseg1) |
179 | 0 | for (int s = 0; s < nseg; s += 2) { |
180 | 0 | if (s + 1 == nseg) { // otherwise isolated segment |
181 | 0 | memcpy(permB + segs[s].i0, |
182 | 0 | permA + segs[s].i0, |
183 | 0 | segs[s].len() * sizeof(size_t)); |
184 | 0 | } else { |
185 | 0 | int t0 = s * sub_nt / sub_nseg1; |
186 | 0 | int t1 = (s + 1) * sub_nt / sub_nseg1; |
187 | 0 | printf("merge %d %d, %d threads\n", s, s + 1, t1 - t0); |
188 | 0 | parallel_merge( |
189 | 0 | permA, permB, segs[s], segs[s + 1], t1 - t0, comp); |
190 | 0 | } |
191 | 0 | } |
192 | 0 | for (int s = 0; s < nseg; s += 2) { |
193 | 0 | segs[s / 2] = segs[s]; |
194 | 0 | } |
195 | 0 | nseg = nseg1; |
196 | 0 | std::swap(permA, permB); |
197 | 0 | } |
198 | 0 | assert(permA == perm); |
199 | 0 | omp_set_nested(prev_nested); |
200 | 0 | delete[] perm2; |
201 | 0 | } |
202 | | |
203 | | /***************************************************************************** |
204 | | * Bucket sort |
205 | | ****************************************************************************/ |
206 | | |
207 | | // extern symbol in the .h |
208 | | int bucket_sort_verbose = 0; |
209 | | |
210 | | namespace { |
211 | | |
212 | | void bucket_sort_ref( |
213 | | size_t nval, |
214 | | const uint64_t* vals, |
215 | | uint64_t vmax, |
216 | | int64_t* lims, |
217 | 0 | int64_t* perm) { |
218 | 0 | double t0 = getmillisecs(); |
219 | 0 | memset(lims, 0, sizeof(*lims) * (vmax + 1)); |
220 | 0 | for (size_t i = 0; i < nval; i++) { |
221 | 0 | FAISS_THROW_IF_NOT(vals[i] < vmax); |
222 | 0 | lims[vals[i] + 1]++; |
223 | 0 | } |
224 | 0 | double t1 = getmillisecs(); |
225 | | // compute cumulative sum |
226 | 0 | for (size_t i = 0; i < vmax; i++) { |
227 | 0 | lims[i + 1] += lims[i]; |
228 | 0 | } |
229 | 0 | FAISS_THROW_IF_NOT(lims[vmax] == nval); |
230 | 0 | double t2 = getmillisecs(); |
231 | | // populate buckets |
232 | 0 | for (size_t i = 0; i < nval; i++) { |
233 | 0 | perm[lims[vals[i]]++] = i; |
234 | 0 | } |
235 | 0 | double t3 = getmillisecs(); |
236 | | // reset pointers |
237 | 0 | for (size_t i = vmax; i > 0; i--) { |
238 | 0 | lims[i] = lims[i - 1]; |
239 | 0 | } |
240 | 0 | lims[0] = 0; |
241 | 0 | double t4 = getmillisecs(); |
242 | 0 | if (bucket_sort_verbose) { |
243 | 0 | printf("times %.3f %.3f %.3f %.3f\n", |
244 | 0 | t1 - t0, |
245 | 0 | t2 - t1, |
246 | 0 | t3 - t2, |
247 | 0 | t4 - t3); |
248 | 0 | } |
249 | 0 | } |
250 | | |
251 | | void bucket_sort_parallel( |
252 | | size_t nval, |
253 | | const uint64_t* vals, |
254 | | uint64_t vmax, |
255 | | int64_t* lims, |
256 | | int64_t* perm, |
257 | 0 | int nt_in) { |
258 | 0 | memset(lims, 0, sizeof(*lims) * (vmax + 1)); |
259 | 0 | #pragma omp parallel num_threads(nt_in) |
260 | 0 | { |
261 | 0 | int nt = omp_get_num_threads(); // might be different from nt_in |
262 | 0 | int rank = omp_get_thread_num(); |
263 | 0 | std::vector<int64_t> local_lims(vmax + 1); |
264 | | |
265 | | // range of indices handled by this thread |
266 | 0 | size_t i0 = nval * rank / nt; |
267 | 0 | size_t i1 = nval * (rank + 1) / nt; |
268 | | |
269 | | // build histogram in local lims |
270 | 0 | double t0 = getmillisecs(); |
271 | 0 | for (size_t i = i0; i < i1; i++) { |
272 | 0 | local_lims[vals[i]]++; |
273 | 0 | } |
274 | 0 | #pragma omp critical |
275 | 0 | { // accumulate histograms (not shifted indices to prepare cumsum) |
276 | 0 | for (size_t i = 0; i < vmax; i++) { |
277 | 0 | lims[i + 1] += local_lims[i]; |
278 | 0 | } |
279 | 0 | } |
280 | 0 | #pragma omp barrier |
281 | |
|
282 | 0 | double t1 = getmillisecs(); |
283 | 0 | #pragma omp master |
284 | 0 | { |
285 | | // compute cumulative sum |
286 | 0 | for (size_t i = 0; i < vmax; i++) { |
287 | 0 | lims[i + 1] += lims[i]; |
288 | 0 | } |
289 | 0 | FAISS_THROW_IF_NOT(lims[vmax] == nval); |
290 | 0 | } |
291 | 0 | #pragma omp barrier |
292 | |
|
293 | 0 | #pragma omp critical |
294 | 0 | { // current thread grabs a slot in the buckets |
295 | 0 | for (size_t i = 0; i < vmax; i++) { |
296 | 0 | size_t nv = local_lims[i]; |
297 | 0 | local_lims[i] = lims[i]; // where we should start writing |
298 | 0 | lims[i] += nv; |
299 | 0 | } |
300 | 0 | } |
301 | |
|
302 | 0 | double t2 = getmillisecs(); |
303 | 0 | #pragma omp barrier |
304 | 0 | { // populate buckets, this is the slowest operation |
305 | 0 | for (size_t i = i0; i < i1; i++) { |
306 | 0 | perm[local_lims[vals[i]]++] = i; |
307 | 0 | } |
308 | 0 | } |
309 | 0 | #pragma omp barrier |
310 | 0 | double t3 = getmillisecs(); |
311 | |
|
312 | 0 | #pragma omp master |
313 | 0 | { // shift back lims |
314 | 0 | for (size_t i = vmax; i > 0; i--) { |
315 | 0 | lims[i] = lims[i - 1]; |
316 | 0 | } |
317 | 0 | lims[0] = 0; |
318 | 0 | double t4 = getmillisecs(); |
319 | 0 | if (bucket_sort_verbose) { |
320 | 0 | printf("times %.3f %.3f %.3f %.3f\n", |
321 | 0 | t1 - t0, |
322 | 0 | t2 - t1, |
323 | 0 | t3 - t2, |
324 | 0 | t4 - t3); |
325 | 0 | } |
326 | 0 | } |
327 | 0 | } |
328 | 0 | } |
329 | | |
330 | | /*********************************************** |
331 | | * in-place bucket sort |
332 | | */ |
333 | | |
334 | | template <class TI> |
335 | | void bucket_sort_inplace_ref( |
336 | | size_t nrow, |
337 | | size_t ncol, |
338 | | TI* vals, |
339 | | TI nbucket, |
340 | 0 | int64_t* lims) { |
341 | 0 | double t0 = getmillisecs(); |
342 | 0 | size_t nval = nrow * ncol; |
343 | 0 | FAISS_THROW_IF_NOT( |
344 | 0 | nbucket < nval); // unclear what would happen in this case... |
345 | | |
346 | 0 | memset(lims, 0, sizeof(*lims) * (nbucket + 1)); |
347 | 0 | for (size_t i = 0; i < nval; i++) { |
348 | 0 | FAISS_THROW_IF_NOT(vals[i] < nbucket); |
349 | 0 | lims[vals[i] + 1]++; |
350 | 0 | } |
351 | 0 | double t1 = getmillisecs(); |
352 | | // compute cumulative sum |
353 | 0 | for (size_t i = 0; i < nbucket; i++) { |
354 | 0 | lims[i + 1] += lims[i]; |
355 | 0 | } |
356 | 0 | FAISS_THROW_IF_NOT(lims[nbucket] == nval); |
357 | 0 | double t2 = getmillisecs(); |
358 | |
|
359 | 0 | std::vector<size_t> ptrs(nbucket); |
360 | 0 | for (size_t i = 0; i < nbucket; i++) { |
361 | 0 | ptrs[i] = lims[i]; |
362 | 0 | } |
363 | | |
364 | | // find loops in the permutation and follow them |
365 | 0 | TI row = -1; |
366 | 0 | TI init_bucket_no = 0, bucket_no = 0; |
367 | 0 | for (;;) { |
368 | 0 | size_t idx = ptrs[bucket_no]; |
369 | 0 | if (row >= 0) { |
370 | 0 | ptrs[bucket_no] += 1; |
371 | 0 | } |
372 | 0 | assert(idx < lims[bucket_no + 1]); |
373 | 0 | TI next_bucket_no = vals[idx]; |
374 | 0 | vals[idx] = row; |
375 | 0 | if (next_bucket_no != -1) { |
376 | 0 | row = idx / ncol; |
377 | 0 | bucket_no = next_bucket_no; |
378 | 0 | } else { |
379 | | // start new loop |
380 | 0 | for (; init_bucket_no < nbucket; init_bucket_no++) { |
381 | 0 | if (ptrs[init_bucket_no] < lims[init_bucket_no + 1]) { |
382 | 0 | break; |
383 | 0 | } |
384 | 0 | } |
385 | 0 | if (init_bucket_no == nbucket) { // we're done |
386 | 0 | break; |
387 | 0 | } |
388 | 0 | bucket_no = init_bucket_no; |
389 | 0 | row = -1; |
390 | 0 | } |
391 | 0 | } |
392 | | |
393 | 0 | for (size_t i = 0; i < nbucket; i++) { |
394 | 0 | assert(ptrs[i] == lims[i + 1]); |
395 | 0 | } |
396 | 0 | double t3 = getmillisecs(); |
397 | 0 | if (bucket_sort_verbose) { |
398 | 0 | printf("times %.3f %.3f %.3f\n", t1 - t0, t2 - t1, t3 - t2); |
399 | 0 | } |
400 | 0 | } Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_123bucket_sort_inplace_refIiEEvmmPT_S2_Pl Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_123bucket_sort_inplace_refIlEEvmmPT_S2_Pl |
401 | | |
402 | | // collects row numbers to write into buckets |
403 | | template <class TI> |
404 | | struct ToWrite { |
405 | | TI nbucket; |
406 | | std::vector<TI> buckets; |
407 | | std::vector<TI> rows; |
408 | | std::vector<size_t> lims; |
409 | | |
410 | 0 | explicit ToWrite(TI nbucket) : nbucket(nbucket) { |
411 | 0 | lims.resize(nbucket + 1); |
412 | 0 | } Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIiEC2Ei Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIlEC2El |
413 | | |
414 | | /// add one element (row) to write in bucket b |
415 | 0 | void add(TI row, TI b) { |
416 | 0 | assert(b >= 0 && b < nbucket); |
417 | 0 | rows.push_back(row); |
418 | 0 | buckets.push_back(b); |
419 | 0 | } Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIiE3addEii Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIlE3addEll |
420 | | |
421 | 0 | void bucket_sort() { |
422 | 0 | FAISS_THROW_IF_NOT(buckets.size() == rows.size()); |
423 | 0 | lims.resize(nbucket + 1); |
424 | 0 | memset(lims.data(), 0, sizeof(lims[0]) * (nbucket + 1)); |
425 | |
|
426 | 0 | for (size_t i = 0; i < buckets.size(); i++) { |
427 | 0 | assert(buckets[i] >= 0 && buckets[i] < nbucket); |
428 | 0 | lims[buckets[i] + 1]++; |
429 | 0 | } |
430 | | // compute cumulative sum |
431 | 0 | for (size_t i = 0; i < nbucket; i++) { |
432 | 0 | lims[i + 1] += lims[i]; |
433 | 0 | } |
434 | 0 | FAISS_THROW_IF_NOT(lims[nbucket] == buckets.size()); |
435 | | |
436 | | // could also do a circular perm... |
437 | 0 | std::vector<TI> new_rows(rows.size()); |
438 | 0 | std::vector<size_t> ptrs = lims; |
439 | 0 | for (size_t i = 0; i < buckets.size(); i++) { |
440 | 0 | TI b = buckets[i]; |
441 | 0 | assert(ptrs[b] < lims[b + 1]); |
442 | 0 | new_rows[ptrs[b]++] = rows[i]; |
443 | 0 | } |
444 | 0 | buckets.resize(0); |
445 | 0 | std::swap(rows, new_rows); |
446 | 0 | } Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIiE11bucket_sortEv Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIlE11bucket_sortEv |
447 | | |
448 | 0 | void swap(ToWrite& other) { |
449 | 0 | assert(nbucket == other.nbucket); |
450 | 0 | buckets.swap(other.buckets); |
451 | 0 | rows.swap(other.rows); |
452 | 0 | lims.swap(other.lims); |
453 | 0 | } Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIiE4swapERS2_ Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIlE4swapERS2_ |
454 | | }; |
455 | | |
456 | | template <class TI> |
457 | | void bucket_sort_inplace_parallel( |
458 | | size_t nrow, |
459 | | size_t ncol, |
460 | | TI* vals, |
461 | | TI nbucket, |
462 | | int64_t* lims, |
463 | 0 | int nt_in) { |
464 | 0 | int verbose = bucket_sort_verbose; |
465 | 0 | memset(lims, 0, sizeof(*lims) * (nbucket + 1)); |
466 | 0 | std::vector<ToWrite<TI>> all_to_write; |
467 | 0 | size_t nval = nrow * ncol; |
468 | 0 | FAISS_THROW_IF_NOT( |
469 | 0 | nbucket < nval); // unclear what would happen in this case... |
470 | | |
471 | | // try to keep size of all_to_write < 5GiB |
472 | | // but we need at least one element per bucket |
473 | 0 | size_t init_to_write = std::max( |
474 | 0 | size_t(nbucket), |
475 | 0 | std::min(nval / 10, ((size_t)5 << 30) / (sizeof(TI) * 3 * nt_in))); |
476 | 0 | if (verbose > 0) { |
477 | 0 | printf("init_to_write=%zd\n", init_to_write); |
478 | 0 | } |
479 | |
|
480 | 0 | std::vector<size_t> ptrs(nbucket); // ptrs is shared across all threads |
481 | 0 | std::vector<char> did_wrap( |
482 | 0 | nbucket); // DON'T use std::vector<bool> that cannot be accessed |
483 | | // safely from multiple threads!!! |
484 | |
|
485 | 0 | #pragma omp parallel num_threads(nt_in) |
486 | 0 | { |
487 | 0 | int nt = omp_get_num_threads(); // might be different from nt_in (?) |
488 | 0 | int rank = omp_get_thread_num(); |
489 | 0 | std::vector<int64_t> local_lims(nbucket + 1); |
490 | | |
491 | | // range of indices handled by this thread |
492 | 0 | size_t i0 = nval * rank / nt; |
493 | 0 | size_t i1 = nval * (rank + 1) / nt; |
494 | | |
495 | | // build histogram in local lims |
496 | 0 | for (size_t i = i0; i < i1; i++) { |
497 | 0 | local_lims[vals[i]]++; |
498 | 0 | } |
499 | 0 | #pragma omp critical |
500 | 0 | { // accumulate histograms (not shifted indices to prepare cumsum) |
501 | 0 | for (size_t i = 0; i < nbucket; i++) { |
502 | 0 | lims[i + 1] += local_lims[i]; |
503 | 0 | } |
504 | 0 | all_to_write.push_back(ToWrite<TI>(nbucket)); |
505 | 0 | } |
506 | |
|
507 | 0 | #pragma omp barrier |
508 | | // this thread's things to write |
509 | 0 | ToWrite<TI>& to_write = all_to_write[rank]; |
510 | |
|
511 | 0 | #pragma omp master |
512 | 0 | { |
513 | | // compute cumulative sum |
514 | 0 | for (size_t i = 0; i < nbucket; i++) { |
515 | 0 | lims[i + 1] += lims[i]; |
516 | 0 | } |
517 | 0 | FAISS_THROW_IF_NOT(lims[nbucket] == nval); |
518 | | // at this point lims is final (read only!) |
519 | | |
520 | 0 | memcpy(ptrs.data(), lims, sizeof(lims[0]) * nbucket); |
521 | | |
522 | | // initial values to write (we write -1s to get the process running) |
523 | | // make sure at least one element per bucket |
524 | 0 | size_t written = 0; |
525 | 0 | for (TI b = 0; b < nbucket; b++) { |
526 | 0 | size_t l0 = lims[b], l1 = lims[b + 1]; |
527 | 0 | size_t target_to_write = l1 * init_to_write / nval; |
528 | 0 | do { |
529 | 0 | if (l0 == l1) { |
530 | 0 | break; |
531 | 0 | } |
532 | 0 | to_write.add(-1, b); |
533 | 0 | l0++; |
534 | 0 | written++; |
535 | 0 | } while (written < target_to_write); |
536 | 0 | } |
537 | |
|
538 | 0 | to_write.bucket_sort(); |
539 | 0 | } |
540 | | |
541 | | // this thread writes only buckets b0:b1 |
542 | 0 | size_t b0 = (rank * nbucket + nt - 1) / nt; |
543 | 0 | size_t b1 = ((rank + 1) * nbucket + nt - 1) / nt; |
544 | | |
545 | | // in this loop, we write elements collected in the previous round |
546 | | // and collect the elements that are overwritten for the next round |
547 | 0 | int round = 0; |
548 | 0 | for (;;) { |
549 | 0 | #pragma omp barrier |
550 | |
|
551 | 0 | size_t n_to_write = 0; |
552 | 0 | for (const ToWrite<TI>& to_write_2 : all_to_write) { |
553 | 0 | n_to_write += to_write_2.lims.back(); |
554 | 0 | } |
555 | |
|
556 | 0 | #pragma omp master |
557 | 0 | { |
558 | 0 | if (verbose >= 1) { |
559 | 0 | printf("ROUND %d n_to_write=%zd\n", round, n_to_write); |
560 | 0 | } |
561 | 0 | if (verbose > 2) { |
562 | 0 | for (size_t b = 0; b < nbucket; b++) { |
563 | 0 | printf(" b=%zd [", b); |
564 | 0 | for (size_t i = lims[b]; i < lims[b + 1]; i++) { |
565 | 0 | printf(" %s%d", |
566 | 0 | ptrs[b] == i ? ">" : "", |
567 | 0 | int(vals[i])); |
568 | 0 | } |
569 | 0 | printf(" %s] %s\n", |
570 | 0 | ptrs[b] == lims[b + 1] ? ">" : "", |
571 | 0 | did_wrap[b] ? "w" : ""); |
572 | 0 | } |
573 | 0 | printf("To write\n"); |
574 | 0 | for (size_t b = 0; b < nbucket; b++) { |
575 | 0 | printf(" b=%zd ", b); |
576 | 0 | const char* sep = "["; |
577 | 0 | for (const ToWrite<TI>& to_write_2 : all_to_write) { |
578 | 0 | printf("%s", sep); |
579 | 0 | sep = " |"; |
580 | 0 | size_t l0 = to_write_2.lims[b]; |
581 | 0 | size_t l1 = to_write_2.lims[b + 1]; |
582 | 0 | for (size_t i = l0; i < l1; i++) { |
583 | 0 | printf(" %d", int(to_write_2.rows[i])); |
584 | 0 | } |
585 | 0 | } |
586 | 0 | printf(" ]\n"); |
587 | 0 | } |
588 | 0 | } |
589 | 0 | } |
590 | 0 | if (n_to_write == 0) { |
591 | 0 | break; |
592 | 0 | } |
593 | 0 | round++; |
594 | |
|
595 | 0 | #pragma omp barrier |
596 | |
|
597 | 0 | ToWrite<TI> next_to_write(nbucket); |
598 | |
|
599 | 0 | for (size_t b = b0; b < b1; b++) { |
600 | 0 | for (const ToWrite<TI>& to_write_2 : all_to_write) { |
601 | 0 | size_t l0 = to_write_2.lims[b]; |
602 | 0 | size_t l1 = to_write_2.lims[b + 1]; |
603 | 0 | for (size_t i = l0; i < l1; i++) { |
604 | 0 | TI row = to_write_2.rows[i]; |
605 | 0 | size_t idx = ptrs[b]; |
606 | 0 | if (verbose > 2) { |
607 | 0 | printf(" bucket %d (rank %d) idx %zd\n", |
608 | 0 | int(row), |
609 | 0 | rank, |
610 | 0 | idx); |
611 | 0 | } |
612 | 0 | if (idx < lims[b + 1]) { |
613 | 0 | ptrs[b]++; |
614 | 0 | } else { |
615 | | // wrapping around |
616 | 0 | assert(!did_wrap[b]); |
617 | 0 | did_wrap[b] = true; |
618 | 0 | idx = lims[b]; |
619 | 0 | ptrs[b] = idx + 1; |
620 | 0 | } |
621 | | |
622 | | // check if we need to remember the overwritten number |
623 | 0 | if (vals[idx] >= 0) { |
624 | 0 | TI new_row = idx / ncol; |
625 | 0 | next_to_write.add(new_row, vals[idx]); |
626 | 0 | if (verbose > 2) { |
627 | 0 | printf(" new_row=%d\n", int(new_row)); |
628 | 0 | } |
629 | 0 | } else { |
630 | 0 | assert(did_wrap[b]); |
631 | 0 | } |
632 | | |
633 | 0 | vals[idx] = row; |
634 | 0 | } |
635 | 0 | } |
636 | 0 | } |
637 | 0 | next_to_write.bucket_sort(); |
638 | 0 | #pragma omp barrier |
639 | 0 | all_to_write[rank].swap(next_to_write); |
640 | 0 | } |
641 | 0 | } Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_128bucket_sort_inplace_parallelIiEEvmmPT_S2_Pli.omp_outlined_debug__ Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_128bucket_sort_inplace_parallelIlEEvmmPT_S2_Pli.omp_outlined_debug__ |
642 | 0 | } Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_128bucket_sort_inplace_parallelIiEEvmmPT_S2_Pli Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_128bucket_sort_inplace_parallelIlEEvmmPT_S2_Pli |
643 | | |
644 | | } // anonymous namespace |
645 | | |
646 | | void bucket_sort( |
647 | | size_t nval, |
648 | | const uint64_t* vals, |
649 | | uint64_t vmax, |
650 | | int64_t* lims, |
651 | | int64_t* perm, |
652 | 0 | int nt) { |
653 | 0 | if (nt == 0) { |
654 | 0 | bucket_sort_ref(nval, vals, vmax, lims, perm); |
655 | 0 | } else { |
656 | 0 | bucket_sort_parallel(nval, vals, vmax, lims, perm, nt); |
657 | 0 | } |
658 | 0 | } |
659 | | |
660 | | void matrix_bucket_sort_inplace( |
661 | | size_t nrow, |
662 | | size_t ncol, |
663 | | int32_t* vals, |
664 | | int32_t vmax, |
665 | | int64_t* lims, |
666 | 0 | int nt) { |
667 | 0 | if (nt == 0) { |
668 | 0 | bucket_sort_inplace_ref(nrow, ncol, vals, vmax, lims); |
669 | 0 | } else { |
670 | 0 | bucket_sort_inplace_parallel(nrow, ncol, vals, vmax, lims, nt); |
671 | 0 | } |
672 | 0 | } |
673 | | |
674 | | void matrix_bucket_sort_inplace( |
675 | | size_t nrow, |
676 | | size_t ncol, |
677 | | int64_t* vals, |
678 | | int64_t vmax, |
679 | | int64_t* lims, |
680 | 0 | int nt) { |
681 | 0 | if (nt == 0) { |
682 | 0 | bucket_sort_inplace_ref(nrow, ncol, vals, vmax, lims); |
683 | 0 | } else { |
684 | 0 | bucket_sort_inplace_parallel(nrow, ncol, vals, vmax, lims, nt); |
685 | 0 | } |
686 | 0 | } |
687 | | |
688 | | /** Hashtable implementation for int64 -> int64 with external storage |
689 | | * implemented for speed and parallel processing. |
690 | | */ |
691 | | |
692 | | namespace { |
693 | | |
694 | 0 | int log2_capacity_to_log2_nbucket(int log2_capacity) { |
695 | 0 | return log2_capacity < 12 ? 0 |
696 | 0 | : log2_capacity < 20 ? log2_capacity - 12 |
697 | 0 | : 10; |
698 | 0 | } |
699 | | |
700 | | // https://bigprimes.org/ |
701 | | int64_t bigprime = 8955327411143; |
702 | | |
703 | 0 | inline int64_t hash_function(int64_t x) { |
704 | 0 | return (x * 1000003) % bigprime; |
705 | 0 | } |
706 | | |
707 | | } // anonymous namespace |
708 | | |
709 | 0 | void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab) { |
710 | 0 | size_t capacity = (size_t)1 << log2_capacity; |
711 | 0 | #pragma omp parallel for |
712 | 0 | for (int64_t i = 0; i < capacity; i++) { |
713 | 0 | tab[2 * i] = -1; |
714 | 0 | tab[2 * i + 1] = -1; |
715 | 0 | } |
716 | 0 | } |
717 | | |
718 | | void hashtable_int64_to_int64_add( |
719 | | int log2_capacity, |
720 | | int64_t* tab, |
721 | | size_t n, |
722 | | const int64_t* keys, |
723 | 0 | const int64_t* vals) { |
724 | 0 | size_t capacity = (size_t)1 << log2_capacity; |
725 | 0 | std::vector<int64_t> hk(n); |
726 | 0 | std::vector<uint64_t> bucket_no(n); |
727 | 0 | int64_t mask = capacity - 1; |
728 | 0 | int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity); |
729 | 0 | size_t nbucket = (size_t)1 << log2_nbucket; |
730 | |
|
731 | 0 | #pragma omp parallel for |
732 | 0 | for (int64_t i = 0; i < n; i++) { |
733 | 0 | hk[i] = hash_function(keys[i]) & mask; |
734 | 0 | bucket_no[i] = hk[i] >> (log2_capacity - log2_nbucket); |
735 | 0 | } |
736 | |
|
737 | 0 | std::vector<int64_t> lims(nbucket + 1); |
738 | 0 | std::vector<int64_t> perm(n); |
739 | 0 | bucket_sort( |
740 | 0 | n, |
741 | 0 | bucket_no.data(), |
742 | 0 | nbucket, |
743 | 0 | lims.data(), |
744 | 0 | perm.data(), |
745 | 0 | omp_get_max_threads()); |
746 | |
|
747 | 0 | int num_errors = 0; |
748 | 0 | #pragma omp parallel for reduction(+ : num_errors) |
749 | 0 | for (int64_t bucket = 0; bucket < nbucket; bucket++) { |
750 | 0 | size_t k0 = bucket << (log2_capacity - log2_nbucket); |
751 | 0 | size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket); |
752 | |
|
753 | 0 | for (size_t i = lims[bucket]; i < lims[bucket + 1]; i++) { |
754 | 0 | int64_t j = perm[i]; |
755 | 0 | assert(bucket_no[j] == bucket); |
756 | 0 | assert(hk[j] >= k0 && hk[j] < k1); |
757 | 0 | size_t slot = hk[j]; |
758 | 0 | for (;;) { |
759 | 0 | if (tab[slot * 2] == -1) { // found! |
760 | 0 | tab[slot * 2] = keys[j]; |
761 | 0 | tab[slot * 2 + 1] = vals[j]; |
762 | 0 | break; |
763 | 0 | } else if (tab[slot * 2] == keys[j]) { // overwrite! |
764 | 0 | tab[slot * 2 + 1] = vals[j]; |
765 | 0 | break; |
766 | 0 | } |
767 | 0 | slot++; |
768 | 0 | if (slot == k1) { |
769 | 0 | slot = k0; |
770 | 0 | } |
771 | 0 | if (slot == hk[j]) { // no free slot left in bucket |
772 | 0 | num_errors++; |
773 | 0 | break; |
774 | 0 | } |
775 | 0 | } |
776 | 0 | if (num_errors > 0) { |
777 | 0 | break; |
778 | 0 | } |
779 | 0 | } |
780 | 0 | } |
781 | 0 | FAISS_THROW_IF_NOT_MSG(num_errors == 0, "hashtable capacity exhausted"); |
782 | 0 | } |
783 | | |
784 | | void hashtable_int64_to_int64_lookup( |
785 | | int log2_capacity, |
786 | | const int64_t* tab, |
787 | | size_t n, |
788 | | const int64_t* keys, |
789 | 0 | int64_t* vals) { |
790 | 0 | size_t capacity = (size_t)1 << log2_capacity; |
791 | 0 | std::vector<int64_t> hk(n), bucket_no(n); |
792 | 0 | int64_t mask = capacity - 1; |
793 | 0 | int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity); |
794 | |
|
795 | 0 | #pragma omp parallel for |
796 | 0 | for (int64_t i = 0; i < n; i++) { |
797 | 0 | int64_t k = keys[i]; |
798 | 0 | int64_t hk = hash_function(k) & mask; |
799 | 0 | size_t slot = hk; |
800 | |
|
801 | 0 | if (tab[2 * slot] == -1) { // not in table |
802 | 0 | vals[i] = -1; |
803 | 0 | } else if (tab[2 * slot] == k) { // found! |
804 | 0 | vals[i] = tab[2 * slot + 1]; |
805 | 0 | } else { // need to search in [k0, k1) |
806 | 0 | size_t bucket = hk >> (log2_capacity - log2_nbucket); |
807 | 0 | size_t k0 = bucket << (log2_capacity - log2_nbucket); |
808 | 0 | size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket); |
809 | 0 | for (;;) { |
810 | 0 | if (tab[slot * 2] == k) { // found! |
811 | 0 | vals[i] = tab[2 * slot + 1]; |
812 | 0 | break; |
813 | 0 | } |
814 | 0 | slot++; |
815 | 0 | if (slot == k1) { |
816 | 0 | slot = k0; |
817 | 0 | } |
818 | 0 | if (slot == hk) { // bucket is full and not found |
819 | 0 | vals[i] = -1; |
820 | 0 | break; |
821 | 0 | } |
822 | 0 | } |
823 | 0 | } |
824 | 0 | } |
825 | 0 | } |
826 | | |
827 | | } // namespace faiss |