Coverage Report

Created: 2026-03-20 13:45

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/utils/sorting.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/utils/sorting.h>
11
12
#include <omp.h>
13
#include <algorithm>
14
15
#include <faiss/impl/FaissAssert.h>
16
#include <faiss/utils/utils.h>
17
18
namespace faiss {
19
20
/*****************************************************************************
21
 * Argsort
22
 ****************************************************************************/
23
24
namespace {
25
struct ArgsortComparator {
26
    const float* vals;
27
0
    bool operator()(const size_t a, const size_t b) const {
28
0
        return vals[a] < vals[b];
29
0
    }
30
};
31
32
struct SegmentS {
33
    size_t i0; // begin pointer in the permutation array
34
    size_t i1; // end
35
0
    size_t len() const {
36
0
        return i1 - i0;
37
0
    }
38
};
39
40
// see https://en.wikipedia.org/wiki/Merge_algorithm#Parallel_merge
41
// extended to > 1 merge thread
42
43
// merges 2 ranges that should be consecutive on the source into
44
// the union of the two on the destination
45
template <typename T>
46
void parallel_merge(
47
        const T* src,
48
        T* dst,
49
        SegmentS& s1,
50
        SegmentS& s2,
51
        int nt,
52
0
        const ArgsortComparator& comp) {
53
0
    if (s2.len() > s1.len()) { // make sure that s1 larger than s2
54
0
        std::swap(s1, s2);
55
0
    }
56
57
    // compute sub-ranges for each thread
58
0
    std::vector<SegmentS> s1s(nt), s2s(nt), sws(nt);
59
0
    s2s[0].i0 = s2.i0;
60
0
    s2s[nt - 1].i1 = s2.i1;
61
62
    // not sure parallel actually helps here
63
0
#pragma omp parallel for num_threads(nt)
64
0
    for (int t = 0; t < nt; t++) {
65
0
        s1s[t].i0 = s1.i0 + s1.len() * t / nt;
66
0
        s1s[t].i1 = s1.i0 + s1.len() * (t + 1) / nt;
67
68
0
        if (t + 1 < nt) {
69
0
            T pivot = src[s1s[t].i1];
70
0
            size_t i0 = s2.i0, i1 = s2.i1;
71
0
            while (i0 + 1 < i1) {
72
0
                size_t imed = (i1 + i0) / 2;
73
0
                if (comp(pivot, src[imed])) {
74
0
                    i1 = imed;
75
0
                } else {
76
0
                    i0 = imed;
77
0
                }
78
0
            }
79
0
            s2s[t].i1 = s2s[t + 1].i0 = i1;
80
0
        }
81
0
    }
82
0
    s1.i0 = std::min(s1.i0, s2.i0);
83
0
    s1.i1 = std::max(s1.i1, s2.i1);
84
0
    s2 = s1;
85
0
    sws[0].i0 = s1.i0;
86
0
    for (int t = 0; t < nt; t++) {
87
0
        sws[t].i1 = sws[t].i0 + s1s[t].len() + s2s[t].len();
88
0
        if (t + 1 < nt) {
89
0
            sws[t + 1].i0 = sws[t].i1;
90
0
        }
91
0
    }
92
0
    assert(sws[nt - 1].i1 == s1.i1);
93
94
    // do the actual merging
95
0
#pragma omp parallel for num_threads(nt)
96
0
    for (int t = 0; t < nt; t++) {
97
0
        SegmentS sw = sws[t];
98
0
        SegmentS s1t = s1s[t];
99
0
        SegmentS s2t = s2s[t];
100
0
        if (s1t.i0 < s1t.i1 && s2t.i0 < s2t.i1) {
101
0
            for (;;) {
102
                // assert (sw.len() == s1t.len() + s2t.len());
103
0
                if (comp(src[s1t.i0], src[s2t.i0])) {
104
0
                    dst[sw.i0++] = src[s1t.i0++];
105
0
                    if (s1t.i0 == s1t.i1) {
106
0
                        break;
107
0
                    }
108
0
                } else {
109
0
                    dst[sw.i0++] = src[s2t.i0++];
110
0
                    if (s2t.i0 == s2t.i1) {
111
0
                        break;
112
0
                    }
113
0
                }
114
0
            }
115
0
        }
116
0
        if (s1t.len() > 0) {
117
0
            assert(s1t.len() == sw.len());
118
0
            memcpy(dst + sw.i0, src + s1t.i0, s1t.len() * sizeof(dst[0]));
119
0
        } else if (s2t.len() > 0) {
120
0
            assert(s2t.len() == sw.len());
121
0
            memcpy(dst + sw.i0, src + s2t.i0, s2t.len() * sizeof(dst[0]));
122
0
        }
123
0
    }
124
0
}
125
126
} // namespace
127
128
0
void fvec_argsort(size_t n, const float* vals, size_t* perm) {
129
0
    for (size_t i = 0; i < n; i++) {
130
0
        perm[i] = i;
131
0
    }
132
0
    ArgsortComparator comp = {vals};
133
0
    std::sort(perm, perm + n, comp);
134
0
}
135
136
0
void fvec_argsort_parallel(size_t n, const float* vals, size_t* perm) {
137
0
    size_t* perm2 = new size_t[n];
138
    // 2 result tables, during merging, flip between them
139
0
    size_t *permB = perm2, *permA = perm;
140
141
0
    int nt = omp_get_max_threads();
142
0
    { // prepare correct permutation so that the result ends in perm
143
      // at final iteration
144
0
        int nseg = nt;
145
0
        while (nseg > 1) {
146
0
            nseg = (nseg + 1) / 2;
147
0
            std::swap(permA, permB);
148
0
        }
149
0
    }
150
151
0
#pragma omp parallel
152
0
    for (size_t i = 0; i < n; i++) {
153
0
        permA[i] = i;
154
0
    }
155
156
0
    ArgsortComparator comp = {vals};
157
158
0
    std::vector<SegmentS> segs(nt);
159
160
    // independent sorts
161
0
#pragma omp parallel for
162
0
    for (int t = 0; t < nt; t++) {
163
0
        size_t i0 = t * n / nt;
164
0
        size_t i1 = (t + 1) * n / nt;
165
0
        SegmentS seg = {i0, i1};
166
0
        std::sort(permA + seg.i0, permA + seg.i1, comp);
167
0
        segs[t] = seg;
168
0
    }
169
0
    int prev_nested = omp_get_nested();
170
0
    omp_set_nested(1);
171
172
0
    int nseg = nt;
173
0
    while (nseg > 1) {
174
0
        int nseg1 = (nseg + 1) / 2;
175
0
        int sub_nt = nseg % 2 == 0 ? nt : nt - 1;
176
0
        int sub_nseg1 = nseg / 2;
177
178
0
#pragma omp parallel for num_threads(nseg1)
179
0
        for (int s = 0; s < nseg; s += 2) {
180
0
            if (s + 1 == nseg) { // otherwise isolated segment
181
0
                memcpy(permB + segs[s].i0,
182
0
                       permA + segs[s].i0,
183
0
                       segs[s].len() * sizeof(size_t));
184
0
            } else {
185
0
                int t0 = s * sub_nt / sub_nseg1;
186
0
                int t1 = (s + 1) * sub_nt / sub_nseg1;
187
0
                printf("merge %d %d, %d threads\n", s, s + 1, t1 - t0);
188
0
                parallel_merge(
189
0
                        permA, permB, segs[s], segs[s + 1], t1 - t0, comp);
190
0
            }
191
0
        }
192
0
        for (int s = 0; s < nseg; s += 2) {
193
0
            segs[s / 2] = segs[s];
194
0
        }
195
0
        nseg = nseg1;
196
0
        std::swap(permA, permB);
197
0
    }
198
0
    assert(permA == perm);
199
0
    omp_set_nested(prev_nested);
200
0
    delete[] perm2;
201
0
}
202
203
/*****************************************************************************
204
 * Bucket sort
205
 ****************************************************************************/
206
207
// extern symbol in the .h
208
int bucket_sort_verbose = 0;
209
210
namespace {
211
212
void bucket_sort_ref(
213
        size_t nval,
214
        const uint64_t* vals,
215
        uint64_t vmax,
216
        int64_t* lims,
217
0
        int64_t* perm) {
218
0
    double t0 = getmillisecs();
219
0
    memset(lims, 0, sizeof(*lims) * (vmax + 1));
220
0
    for (size_t i = 0; i < nval; i++) {
221
0
        FAISS_THROW_IF_NOT(vals[i] < vmax);
222
0
        lims[vals[i] + 1]++;
223
0
    }
224
0
    double t1 = getmillisecs();
225
    // compute cumulative sum
226
0
    for (size_t i = 0; i < vmax; i++) {
227
0
        lims[i + 1] += lims[i];
228
0
    }
229
0
    FAISS_THROW_IF_NOT(lims[vmax] == nval);
230
0
    double t2 = getmillisecs();
231
    // populate buckets
232
0
    for (size_t i = 0; i < nval; i++) {
233
0
        perm[lims[vals[i]]++] = i;
234
0
    }
235
0
    double t3 = getmillisecs();
236
    // reset pointers
237
0
    for (size_t i = vmax; i > 0; i--) {
238
0
        lims[i] = lims[i - 1];
239
0
    }
240
0
    lims[0] = 0;
241
0
    double t4 = getmillisecs();
242
0
    if (bucket_sort_verbose) {
243
0
        printf("times %.3f %.3f %.3f %.3f\n",
244
0
               t1 - t0,
245
0
               t2 - t1,
246
0
               t3 - t2,
247
0
               t4 - t3);
248
0
    }
249
0
}
250
251
void bucket_sort_parallel(
252
        size_t nval,
253
        const uint64_t* vals,
254
        uint64_t vmax,
255
        int64_t* lims,
256
        int64_t* perm,
257
0
        int nt_in) {
258
0
    memset(lims, 0, sizeof(*lims) * (vmax + 1));
259
0
#pragma omp parallel num_threads(nt_in)
260
0
    {
261
0
        int nt = omp_get_num_threads(); // might be different from nt_in
262
0
        int rank = omp_get_thread_num();
263
0
        std::vector<int64_t> local_lims(vmax + 1);
264
265
        // range of indices handled by this thread
266
0
        size_t i0 = nval * rank / nt;
267
0
        size_t i1 = nval * (rank + 1) / nt;
268
269
        // build histogram in local lims
270
0
        double t0 = getmillisecs();
271
0
        for (size_t i = i0; i < i1; i++) {
272
0
            local_lims[vals[i]]++;
273
0
        }
274
0
#pragma omp critical
275
0
        { // accumulate histograms (not shifted indices to prepare cumsum)
276
0
            for (size_t i = 0; i < vmax; i++) {
277
0
                lims[i + 1] += local_lims[i];
278
0
            }
279
0
        }
280
0
#pragma omp barrier
281
282
0
        double t1 = getmillisecs();
283
0
#pragma omp master
284
0
        {
285
            // compute cumulative sum
286
0
            for (size_t i = 0; i < vmax; i++) {
287
0
                lims[i + 1] += lims[i];
288
0
            }
289
0
            FAISS_THROW_IF_NOT(lims[vmax] == nval);
290
0
        }
291
0
#pragma omp barrier
292
293
0
#pragma omp critical
294
0
        { // current thread grabs a slot in the buckets
295
0
            for (size_t i = 0; i < vmax; i++) {
296
0
                size_t nv = local_lims[i];
297
0
                local_lims[i] = lims[i]; // where we should start writing
298
0
                lims[i] += nv;
299
0
            }
300
0
        }
301
302
0
        double t2 = getmillisecs();
303
0
#pragma omp barrier
304
0
        { // populate buckets, this is the slowest operation
305
0
            for (size_t i = i0; i < i1; i++) {
306
0
                perm[local_lims[vals[i]]++] = i;
307
0
            }
308
0
        }
309
0
#pragma omp barrier
310
0
        double t3 = getmillisecs();
311
312
0
#pragma omp master
313
0
        { // shift back lims
314
0
            for (size_t i = vmax; i > 0; i--) {
315
0
                lims[i] = lims[i - 1];
316
0
            }
317
0
            lims[0] = 0;
318
0
            double t4 = getmillisecs();
319
0
            if (bucket_sort_verbose) {
320
0
                printf("times %.3f %.3f %.3f %.3f\n",
321
0
                       t1 - t0,
322
0
                       t2 - t1,
323
0
                       t3 - t2,
324
0
                       t4 - t3);
325
0
            }
326
0
        }
327
0
    }
328
0
}
329
330
/***********************************************
331
 * in-place bucket sort
332
 */
333
334
template <class TI>
335
void bucket_sort_inplace_ref(
336
        size_t nrow,
337
        size_t ncol,
338
        TI* vals,
339
        TI nbucket,
340
0
        int64_t* lims) {
341
0
    double t0 = getmillisecs();
342
0
    size_t nval = nrow * ncol;
343
0
    FAISS_THROW_IF_NOT(
344
0
            nbucket < nval); // unclear what would happen in this case...
345
346
0
    memset(lims, 0, sizeof(*lims) * (nbucket + 1));
347
0
    for (size_t i = 0; i < nval; i++) {
348
0
        FAISS_THROW_IF_NOT(vals[i] < nbucket);
349
0
        lims[vals[i] + 1]++;
350
0
    }
351
0
    double t1 = getmillisecs();
352
    // compute cumulative sum
353
0
    for (size_t i = 0; i < nbucket; i++) {
354
0
        lims[i + 1] += lims[i];
355
0
    }
356
0
    FAISS_THROW_IF_NOT(lims[nbucket] == nval);
357
0
    double t2 = getmillisecs();
358
359
0
    std::vector<size_t> ptrs(nbucket);
360
0
    for (size_t i = 0; i < nbucket; i++) {
361
0
        ptrs[i] = lims[i];
362
0
    }
363
364
    // find loops in the permutation and follow them
365
0
    TI row = -1;
366
0
    TI init_bucket_no = 0, bucket_no = 0;
367
0
    for (;;) {
368
0
        size_t idx = ptrs[bucket_no];
369
0
        if (row >= 0) {
370
0
            ptrs[bucket_no] += 1;
371
0
        }
372
0
        assert(idx < lims[bucket_no + 1]);
373
0
        TI next_bucket_no = vals[idx];
374
0
        vals[idx] = row;
375
0
        if (next_bucket_no != -1) {
376
0
            row = idx / ncol;
377
0
            bucket_no = next_bucket_no;
378
0
        } else {
379
            // start new loop
380
0
            for (; init_bucket_no < nbucket; init_bucket_no++) {
381
0
                if (ptrs[init_bucket_no] < lims[init_bucket_no + 1]) {
382
0
                    break;
383
0
                }
384
0
            }
385
0
            if (init_bucket_no == nbucket) { // we're done
386
0
                break;
387
0
            }
388
0
            bucket_no = init_bucket_no;
389
0
            row = -1;
390
0
        }
391
0
    }
392
393
0
    for (size_t i = 0; i < nbucket; i++) {
394
0
        assert(ptrs[i] == lims[i + 1]);
395
0
    }
396
0
    double t3 = getmillisecs();
397
0
    if (bucket_sort_verbose) {
398
0
        printf("times %.3f %.3f %.3f\n", t1 - t0, t2 - t1, t3 - t2);
399
0
    }
400
0
}
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_123bucket_sort_inplace_refIiEEvmmPT_S2_Pl
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_123bucket_sort_inplace_refIlEEvmmPT_S2_Pl
401
402
// collects row numbers to write into buckets
403
template <class TI>
404
struct ToWrite {
405
    TI nbucket;
406
    std::vector<TI> buckets;
407
    std::vector<TI> rows;
408
    std::vector<size_t> lims;
409
410
0
    explicit ToWrite(TI nbucket) : nbucket(nbucket) {
411
0
        lims.resize(nbucket + 1);
412
0
    }
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIiEC2Ei
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIlEC2El
413
414
    /// add one element (row) to write in bucket b
415
0
    void add(TI row, TI b) {
416
0
        assert(b >= 0 && b < nbucket);
417
0
        rows.push_back(row);
418
0
        buckets.push_back(b);
419
0
    }
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIiE3addEii
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIlE3addEll
420
421
0
    void bucket_sort() {
422
0
        FAISS_THROW_IF_NOT(buckets.size() == rows.size());
423
0
        lims.resize(nbucket + 1);
424
0
        memset(lims.data(), 0, sizeof(lims[0]) * (nbucket + 1));
425
426
0
        for (size_t i = 0; i < buckets.size(); i++) {
427
0
            assert(buckets[i] >= 0 && buckets[i] < nbucket);
428
0
            lims[buckets[i] + 1]++;
429
0
        }
430
        // compute cumulative sum
431
0
        for (size_t i = 0; i < nbucket; i++) {
432
0
            lims[i + 1] += lims[i];
433
0
        }
434
0
        FAISS_THROW_IF_NOT(lims[nbucket] == buckets.size());
435
436
        // could also do a circular perm...
437
0
        std::vector<TI> new_rows(rows.size());
438
0
        std::vector<size_t> ptrs = lims;
439
0
        for (size_t i = 0; i < buckets.size(); i++) {
440
0
            TI b = buckets[i];
441
0
            assert(ptrs[b] < lims[b + 1]);
442
0
            new_rows[ptrs[b]++] = rows[i];
443
0
        }
444
0
        buckets.resize(0);
445
0
        std::swap(rows, new_rows);
446
0
    }
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIiE11bucket_sortEv
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIlE11bucket_sortEv
447
448
0
    void swap(ToWrite& other) {
449
0
        assert(nbucket == other.nbucket);
450
0
        buckets.swap(other.buckets);
451
0
        rows.swap(other.rows);
452
0
        lims.swap(other.lims);
453
0
    }
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIiE4swapERS2_
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_17ToWriteIlE4swapERS2_
454
};
455
456
template <class TI>
457
void bucket_sort_inplace_parallel(
458
        size_t nrow,
459
        size_t ncol,
460
        TI* vals,
461
        TI nbucket,
462
        int64_t* lims,
463
0
        int nt_in) {
464
0
    int verbose = bucket_sort_verbose;
465
0
    memset(lims, 0, sizeof(*lims) * (nbucket + 1));
466
0
    std::vector<ToWrite<TI>> all_to_write;
467
0
    size_t nval = nrow * ncol;
468
0
    FAISS_THROW_IF_NOT(
469
0
            nbucket < nval); // unclear what would happen in this case...
470
471
    // try to keep size of all_to_write < 5GiB
472
    // but we need at least one element per bucket
473
0
    size_t init_to_write = std::max(
474
0
            size_t(nbucket),
475
0
            std::min(nval / 10, ((size_t)5 << 30) / (sizeof(TI) * 3 * nt_in)));
476
0
    if (verbose > 0) {
477
0
        printf("init_to_write=%zd\n", init_to_write);
478
0
    }
479
480
0
    std::vector<size_t> ptrs(nbucket); // ptrs is shared across all threads
481
0
    std::vector<char> did_wrap(
482
0
            nbucket); // DON'T use std::vector<bool> that cannot be accessed
483
                      // safely from multiple threads!!!
484
485
0
#pragma omp parallel num_threads(nt_in)
486
0
    {
487
0
        int nt = omp_get_num_threads(); // might be different from nt_in (?)
488
0
        int rank = omp_get_thread_num();
489
0
        std::vector<int64_t> local_lims(nbucket + 1);
490
491
        // range of indices handled by this thread
492
0
        size_t i0 = nval * rank / nt;
493
0
        size_t i1 = nval * (rank + 1) / nt;
494
495
        // build histogram in local lims
496
0
        for (size_t i = i0; i < i1; i++) {
497
0
            local_lims[vals[i]]++;
498
0
        }
499
0
#pragma omp critical
500
0
        { // accumulate histograms (not shifted indices to prepare cumsum)
501
0
            for (size_t i = 0; i < nbucket; i++) {
502
0
                lims[i + 1] += local_lims[i];
503
0
            }
504
0
            all_to_write.push_back(ToWrite<TI>(nbucket));
505
0
        }
506
507
0
#pragma omp barrier
508
        // this thread's things to write
509
0
        ToWrite<TI>& to_write = all_to_write[rank];
510
511
0
#pragma omp master
512
0
        {
513
            // compute cumulative sum
514
0
            for (size_t i = 0; i < nbucket; i++) {
515
0
                lims[i + 1] += lims[i];
516
0
            }
517
0
            FAISS_THROW_IF_NOT(lims[nbucket] == nval);
518
            // at this point lims is final (read only!)
519
520
0
            memcpy(ptrs.data(), lims, sizeof(lims[0]) * nbucket);
521
522
            // initial values to write (we write -1s to get the process running)
523
            // make sure at least one element per bucket
524
0
            size_t written = 0;
525
0
            for (TI b = 0; b < nbucket; b++) {
526
0
                size_t l0 = lims[b], l1 = lims[b + 1];
527
0
                size_t target_to_write = l1 * init_to_write / nval;
528
0
                do {
529
0
                    if (l0 == l1) {
530
0
                        break;
531
0
                    }
532
0
                    to_write.add(-1, b);
533
0
                    l0++;
534
0
                    written++;
535
0
                } while (written < target_to_write);
536
0
            }
537
538
0
            to_write.bucket_sort();
539
0
        }
540
541
        // this thread writes only buckets b0:b1
542
0
        size_t b0 = (rank * nbucket + nt - 1) / nt;
543
0
        size_t b1 = ((rank + 1) * nbucket + nt - 1) / nt;
544
545
        // in this loop, we write elements collected in the previous round
546
        // and collect the elements that are overwritten for the next round
547
0
        int round = 0;
548
0
        for (;;) {
549
0
#pragma omp barrier
550
551
0
            size_t n_to_write = 0;
552
0
            for (const ToWrite<TI>& to_write_2 : all_to_write) {
553
0
                n_to_write += to_write_2.lims.back();
554
0
            }
555
556
0
#pragma omp master
557
0
            {
558
0
                if (verbose >= 1) {
559
0
                    printf("ROUND %d n_to_write=%zd\n", round, n_to_write);
560
0
                }
561
0
                if (verbose > 2) {
562
0
                    for (size_t b = 0; b < nbucket; b++) {
563
0
                        printf("   b=%zd [", b);
564
0
                        for (size_t i = lims[b]; i < lims[b + 1]; i++) {
565
0
                            printf(" %s%d",
566
0
                                   ptrs[b] == i ? ">" : "",
567
0
                                   int(vals[i]));
568
0
                        }
569
0
                        printf(" %s] %s\n",
570
0
                               ptrs[b] == lims[b + 1] ? ">" : "",
571
0
                               did_wrap[b] ? "w" : "");
572
0
                    }
573
0
                    printf("To write\n");
574
0
                    for (size_t b = 0; b < nbucket; b++) {
575
0
                        printf("   b=%zd ", b);
576
0
                        const char* sep = "[";
577
0
                        for (const ToWrite<TI>& to_write_2 : all_to_write) {
578
0
                            printf("%s", sep);
579
0
                            sep = " |";
580
0
                            size_t l0 = to_write_2.lims[b];
581
0
                            size_t l1 = to_write_2.lims[b + 1];
582
0
                            for (size_t i = l0; i < l1; i++) {
583
0
                                printf(" %d", int(to_write_2.rows[i]));
584
0
                            }
585
0
                        }
586
0
                        printf(" ]\n");
587
0
                    }
588
0
                }
589
0
            }
590
0
            if (n_to_write == 0) {
591
0
                break;
592
0
            }
593
0
            round++;
594
595
0
#pragma omp barrier
596
597
0
            ToWrite<TI> next_to_write(nbucket);
598
599
0
            for (size_t b = b0; b < b1; b++) {
600
0
                for (const ToWrite<TI>& to_write_2 : all_to_write) {
601
0
                    size_t l0 = to_write_2.lims[b];
602
0
                    size_t l1 = to_write_2.lims[b + 1];
603
0
                    for (size_t i = l0; i < l1; i++) {
604
0
                        TI row = to_write_2.rows[i];
605
0
                        size_t idx = ptrs[b];
606
0
                        if (verbose > 2) {
607
0
                            printf("    bucket %d (rank %d) idx %zd\n",
608
0
                                   int(row),
609
0
                                   rank,
610
0
                                   idx);
611
0
                        }
612
0
                        if (idx < lims[b + 1]) {
613
0
                            ptrs[b]++;
614
0
                        } else {
615
                            // wrapping around
616
0
                            assert(!did_wrap[b]);
617
0
                            did_wrap[b] = true;
618
0
                            idx = lims[b];
619
0
                            ptrs[b] = idx + 1;
620
0
                        }
621
622
                        // check if we need to remember the overwritten number
623
0
                        if (vals[idx] >= 0) {
624
0
                            TI new_row = idx / ncol;
625
0
                            next_to_write.add(new_row, vals[idx]);
626
0
                            if (verbose > 2) {
627
0
                                printf("       new_row=%d\n", int(new_row));
628
0
                            }
629
0
                        } else {
630
0
                            assert(did_wrap[b]);
631
0
                        }
632
633
0
                        vals[idx] = row;
634
0
                    }
635
0
                }
636
0
            }
637
0
            next_to_write.bucket_sort();
638
0
#pragma omp barrier
639
0
            all_to_write[rank].swap(next_to_write);
640
0
        }
641
0
    }
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_128bucket_sort_inplace_parallelIiEEvmmPT_S2_Pli.omp_outlined_debug__
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_128bucket_sort_inplace_parallelIlEEvmmPT_S2_Pli.omp_outlined_debug__
642
0
}
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_128bucket_sort_inplace_parallelIiEEvmmPT_S2_Pli
Unexecuted instantiation: sorting.cpp:_ZN5faiss12_GLOBAL__N_128bucket_sort_inplace_parallelIlEEvmmPT_S2_Pli
643
644
} // anonymous namespace
645
646
void bucket_sort(
647
        size_t nval,
648
        const uint64_t* vals,
649
        uint64_t vmax,
650
        int64_t* lims,
651
        int64_t* perm,
652
0
        int nt) {
653
0
    if (nt == 0) {
654
0
        bucket_sort_ref(nval, vals, vmax, lims, perm);
655
0
    } else {
656
0
        bucket_sort_parallel(nval, vals, vmax, lims, perm, nt);
657
0
    }
658
0
}
659
660
void matrix_bucket_sort_inplace(
661
        size_t nrow,
662
        size_t ncol,
663
        int32_t* vals,
664
        int32_t vmax,
665
        int64_t* lims,
666
0
        int nt) {
667
0
    if (nt == 0) {
668
0
        bucket_sort_inplace_ref(nrow, ncol, vals, vmax, lims);
669
0
    } else {
670
0
        bucket_sort_inplace_parallel(nrow, ncol, vals, vmax, lims, nt);
671
0
    }
672
0
}
673
674
void matrix_bucket_sort_inplace(
675
        size_t nrow,
676
        size_t ncol,
677
        int64_t* vals,
678
        int64_t vmax,
679
        int64_t* lims,
680
0
        int nt) {
681
0
    if (nt == 0) {
682
0
        bucket_sort_inplace_ref(nrow, ncol, vals, vmax, lims);
683
0
    } else {
684
0
        bucket_sort_inplace_parallel(nrow, ncol, vals, vmax, lims, nt);
685
0
    }
686
0
}
687
688
/** Hashtable implementation for int64 -> int64 with external storage
689
 * implemented for speed and parallel processing.
690
 */
691
692
namespace {
693
694
0
int log2_capacity_to_log2_nbucket(int log2_capacity) {
695
0
    return log2_capacity < 12    ? 0
696
0
            : log2_capacity < 20 ? log2_capacity - 12
697
0
                                 : 10;
698
0
}
699
700
// https://bigprimes.org/
701
int64_t bigprime = 8955327411143;
702
703
0
inline int64_t hash_function(int64_t x) {
704
0
    return (x * 1000003) % bigprime;
705
0
}
706
707
} // anonymous namespace
708
709
0
void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab) {
710
0
    size_t capacity = (size_t)1 << log2_capacity;
711
0
#pragma omp parallel for
712
0
    for (int64_t i = 0; i < capacity; i++) {
713
0
        tab[2 * i] = -1;
714
0
        tab[2 * i + 1] = -1;
715
0
    }
716
0
}
717
718
void hashtable_int64_to_int64_add(
719
        int log2_capacity,
720
        int64_t* tab,
721
        size_t n,
722
        const int64_t* keys,
723
0
        const int64_t* vals) {
724
0
    size_t capacity = (size_t)1 << log2_capacity;
725
0
    std::vector<int64_t> hk(n);
726
0
    std::vector<uint64_t> bucket_no(n);
727
0
    int64_t mask = capacity - 1;
728
0
    int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
729
0
    size_t nbucket = (size_t)1 << log2_nbucket;
730
731
0
#pragma omp parallel for
732
0
    for (int64_t i = 0; i < n; i++) {
733
0
        hk[i] = hash_function(keys[i]) & mask;
734
0
        bucket_no[i] = hk[i] >> (log2_capacity - log2_nbucket);
735
0
    }
736
737
0
    std::vector<int64_t> lims(nbucket + 1);
738
0
    std::vector<int64_t> perm(n);
739
0
    bucket_sort(
740
0
            n,
741
0
            bucket_no.data(),
742
0
            nbucket,
743
0
            lims.data(),
744
0
            perm.data(),
745
0
            omp_get_max_threads());
746
747
0
    int num_errors = 0;
748
0
#pragma omp parallel for reduction(+ : num_errors)
749
0
    for (int64_t bucket = 0; bucket < nbucket; bucket++) {
750
0
        size_t k0 = bucket << (log2_capacity - log2_nbucket);
751
0
        size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
752
753
0
        for (size_t i = lims[bucket]; i < lims[bucket + 1]; i++) {
754
0
            int64_t j = perm[i];
755
0
            assert(bucket_no[j] == bucket);
756
0
            assert(hk[j] >= k0 && hk[j] < k1);
757
0
            size_t slot = hk[j];
758
0
            for (;;) {
759
0
                if (tab[slot * 2] == -1) { // found!
760
0
                    tab[slot * 2] = keys[j];
761
0
                    tab[slot * 2 + 1] = vals[j];
762
0
                    break;
763
0
                } else if (tab[slot * 2] == keys[j]) { // overwrite!
764
0
                    tab[slot * 2 + 1] = vals[j];
765
0
                    break;
766
0
                }
767
0
                slot++;
768
0
                if (slot == k1) {
769
0
                    slot = k0;
770
0
                }
771
0
                if (slot == hk[j]) { // no free slot left in bucket
772
0
                    num_errors++;
773
0
                    break;
774
0
                }
775
0
            }
776
0
            if (num_errors > 0) {
777
0
                break;
778
0
            }
779
0
        }
780
0
    }
781
0
    FAISS_THROW_IF_NOT_MSG(num_errors == 0, "hashtable capacity exhausted");
782
0
}
783
784
void hashtable_int64_to_int64_lookup(
785
        int log2_capacity,
786
        const int64_t* tab,
787
        size_t n,
788
        const int64_t* keys,
789
0
        int64_t* vals) {
790
0
    size_t capacity = (size_t)1 << log2_capacity;
791
0
    std::vector<int64_t> hk(n), bucket_no(n);
792
0
    int64_t mask = capacity - 1;
793
0
    int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
794
795
0
#pragma omp parallel for
796
0
    for (int64_t i = 0; i < n; i++) {
797
0
        int64_t k = keys[i];
798
0
        int64_t hk = hash_function(k) & mask;
799
0
        size_t slot = hk;
800
801
0
        if (tab[2 * slot] == -1) { // not in table
802
0
            vals[i] = -1;
803
0
        } else if (tab[2 * slot] == k) { // found!
804
0
            vals[i] = tab[2 * slot + 1];
805
0
        } else { // need to search in [k0, k1)
806
0
            size_t bucket = hk >> (log2_capacity - log2_nbucket);
807
0
            size_t k0 = bucket << (log2_capacity - log2_nbucket);
808
0
            size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
809
0
            for (;;) {
810
0
                if (tab[slot * 2] == k) { // found!
811
0
                    vals[i] = tab[2 * slot + 1];
812
0
                    break;
813
0
                }
814
0
                slot++;
815
0
                if (slot == k1) {
816
0
                    slot = k0;
817
0
                }
818
0
                if (slot == hk) { // bucket is full and not found
819
0
                    vals[i] = -1;
820
0
                    break;
821
0
                }
822
0
            }
823
0
        }
824
0
    }
825
0
}
826
827
} // namespace faiss