Coverage Report

Created: 2025-11-21 14:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/impl/PolysemousTraining.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/impl/PolysemousTraining.h>
11
12
#include <omp.h>
13
#include <stdint.h>
14
15
#include <algorithm>
16
#include <cmath>
17
#include <cstdlib>
18
#include <cstring>
19
#include <memory>
20
21
#include <faiss/utils/distances.h>
22
#include <faiss/utils/hamming.h>
23
#include <faiss/utils/random.h>
24
#include <faiss/utils/utils.h>
25
26
#include <faiss/impl/FaissAssert.h>
27
28
/*****************************************
29
 * Mixed PQ / Hamming
30
 ******************************************/
31
32
namespace faiss {
33
34
/****************************************************
35
 * Optimization code
36
 ****************************************************/
37
38
// what would the cost update be if iw and jw were swapped?
39
// default implementation just computes both and computes the difference
40
double PermutationObjective::cost_update(const int* perm, int iw, int jw)
41
0
        const {
42
0
    double orig_cost = compute_cost(perm);
43
44
0
    std::vector<int> perm2(n);
45
0
    for (int i = 0; i < n; i++)
46
0
        perm2[i] = perm[i];
47
0
    perm2[iw] = perm[jw];
48
0
    perm2[jw] = perm[iw];
49
50
0
    double new_cost = compute_cost(perm2.data());
51
0
    return new_cost - orig_cost;
52
0
}
53
54
SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer(
55
        PermutationObjective* obj,
56
        const SimulatedAnnealingParameters& p)
57
0
        : SimulatedAnnealingParameters(p),
58
0
          obj(obj),
59
0
          n(obj->n),
60
0
          logfile(nullptr) {
61
0
    rnd = new RandomGenerator(p.seed);
62
0
    FAISS_THROW_IF_NOT(n < 100000 && n >= 0);
63
0
}
64
65
0
SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer() {
66
0
    delete rnd;
67
0
}
68
69
// run the optimization and return the best result in best_perm
70
0
double SimulatedAnnealingOptimizer::run_optimization(int* best_perm) {
71
0
    double min_cost = 1e30;
72
73
    // just do a few runs of the annealing and keep the lowest output cost
74
0
    for (int it = 0; it < n_redo; it++) {
75
0
        std::vector<int> perm(n);
76
0
        for (int i = 0; i < n; i++)
77
0
            perm[i] = i;
78
0
        if (init_random) {
79
0
            for (int i = 0; i < n; i++) {
80
0
                int j = i + rnd->rand_int(n - i);
81
0
                std::swap(perm[i], perm[j]);
82
0
            }
83
0
        }
84
0
        float cost = optimize(perm.data());
85
0
        if (logfile)
86
0
            fprintf(logfile, "\n");
87
0
        if (verbose > 1) {
88
0
            printf("    optimization run %d: cost=%g %s\n",
89
0
                   it,
90
0
                   cost,
91
0
                   cost < min_cost ? "keep" : "");
92
0
        }
93
0
        if (cost < min_cost) {
94
0
            memcpy(best_perm, perm.data(), sizeof(perm[0]) * n);
95
0
            min_cost = cost;
96
0
        }
97
0
    }
98
0
    return min_cost;
99
0
}
100
101
// perform the optimization loop, starting from and modifying
102
// permutation in-place
103
0
double SimulatedAnnealingOptimizer::optimize(int* perm) {
104
0
    double cost = init_cost = obj->compute_cost(perm);
105
0
    int log2n = 0;
106
0
    while (!(n <= (1 << log2n)))
107
0
        log2n++;
108
0
    double temperature = init_temperature;
109
0
    int n_swap = 0, n_hot = 0;
110
0
    for (int it = 0; it < n_iter; it++) {
111
0
        temperature = temperature * temperature_decay;
112
0
        int iw, jw;
113
0
        if (only_bit_flips) {
114
0
            iw = rnd->rand_int(n);
115
0
            jw = iw ^ (1 << rnd->rand_int(log2n));
116
0
        } else {
117
0
            iw = rnd->rand_int(n);
118
0
            jw = rnd->rand_int(n - 1);
119
0
            if (jw == iw)
120
0
                jw++;
121
0
        }
122
0
        double delta_cost = obj->cost_update(perm, iw, jw);
123
0
        if (delta_cost < 0 || rnd->rand_float() < temperature) {
124
0
            std::swap(perm[iw], perm[jw]);
125
0
            cost += delta_cost;
126
0
            n_swap++;
127
0
            if (delta_cost >= 0)
128
0
                n_hot++;
129
0
        }
130
0
        if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
131
0
            printf("      iteration %d cost %g temp %g n_swap %d "
132
0
                   "(%d hot)     \r",
133
0
                   it,
134
0
                   cost,
135
0
                   temperature,
136
0
                   n_swap,
137
0
                   n_hot);
138
0
            fflush(stdout);
139
0
        }
140
0
        if (logfile) {
141
0
            fprintf(logfile,
142
0
                    "%d %g %g %d %d\n",
143
0
                    it,
144
0
                    cost,
145
0
                    temperature,
146
0
                    n_swap,
147
0
                    n_hot);
148
0
        }
149
0
    }
150
0
    if (verbose > 1)
151
0
        printf("\n");
152
0
    return cost;
153
0
}
154
155
/****************************************************
156
 * Cost functions: ReproduceDistanceTable
157
 ****************************************************/
158
159
0
static inline int hamming_dis(uint64_t a, uint64_t b) {
160
0
    return __builtin_popcountl(a ^ b);
161
0
}
162
163
namespace {
164
165
/// optimize permutation to reproduce a distance table with Hamming distances
166
struct ReproduceWithHammingObjective : PermutationObjective {
167
    int nbits;
168
    double dis_weight_factor;
169
170
0
    static double sqr(double x) {
171
0
        return x * x;
172
0
    }
173
174
    // weihgting of distances: it is more important to reproduce small
175
    // distances well
176
0
    double dis_weight(double x) const {
177
0
        return exp(-dis_weight_factor * x);
178
0
    }
179
180
    std::vector<double> target_dis; // wanted distances (size n^2)
181
    std::vector<double> weights;    // weights for each distance (size n^2)
182
183
    // cost = quadratic difference between actual distance and Hamming distance
184
0
    double compute_cost(const int* perm) const override {
185
0
        double cost = 0;
186
0
        for (int i = 0; i < n; i++) {
187
0
            for (int j = 0; j < n; j++) {
188
0
                double wanted = target_dis[i * n + j];
189
0
                double w = weights[i * n + j];
190
0
                double actual = hamming_dis(perm[i], perm[j]);
191
0
                cost += w * sqr(wanted - actual);
192
0
            }
193
0
        }
194
0
        return cost;
195
0
    }
196
197
    // what would the cost update be if iw and jw were swapped?
198
    // computed in O(n) instead of O(n^2) for the full re-computation
199
0
    double cost_update(const int* perm, int iw, int jw) const override {
200
0
        double delta_cost = 0;
201
202
0
        for (int i = 0; i < n; i++) {
203
0
            if (i == iw) {
204
0
                for (int j = 0; j < n; j++) {
205
0
                    double wanted = target_dis[i * n + j],
206
0
                           w = weights[i * n + j];
207
0
                    double actual = hamming_dis(perm[i], perm[j]);
208
0
                    delta_cost -= w * sqr(wanted - actual);
209
0
                    double new_actual = hamming_dis(
210
0
                            perm[jw],
211
0
                            perm[j == iw           ? jw
212
0
                                         : j == jw ? iw
213
0
                                                   : j]);
214
0
                    delta_cost += w * sqr(wanted - new_actual);
215
0
                }
216
0
            } else if (i == jw) {
217
0
                for (int j = 0; j < n; j++) {
218
0
                    double wanted = target_dis[i * n + j],
219
0
                           w = weights[i * n + j];
220
0
                    double actual = hamming_dis(perm[i], perm[j]);
221
0
                    delta_cost -= w * sqr(wanted - actual);
222
0
                    double new_actual = hamming_dis(
223
0
                            perm[iw],
224
0
                            perm[j == iw           ? jw
225
0
                                         : j == jw ? iw
226
0
                                                   : j]);
227
0
                    delta_cost += w * sqr(wanted - new_actual);
228
0
                }
229
0
            } else {
230
0
                int j = iw;
231
0
                {
232
0
                    double wanted = target_dis[i * n + j],
233
0
                           w = weights[i * n + j];
234
0
                    double actual = hamming_dis(perm[i], perm[j]);
235
0
                    delta_cost -= w * sqr(wanted - actual);
236
0
                    double new_actual = hamming_dis(perm[i], perm[jw]);
237
0
                    delta_cost += w * sqr(wanted - new_actual);
238
0
                }
239
0
                j = jw;
240
0
                {
241
0
                    double wanted = target_dis[i * n + j],
242
0
                           w = weights[i * n + j];
243
0
                    double actual = hamming_dis(perm[i], perm[j]);
244
0
                    delta_cost -= w * sqr(wanted - actual);
245
0
                    double new_actual = hamming_dis(perm[i], perm[iw]);
246
0
                    delta_cost += w * sqr(wanted - new_actual);
247
0
                }
248
0
            }
249
0
        }
250
251
0
        return delta_cost;
252
0
    }
253
254
    ReproduceWithHammingObjective(
255
            int nbits,
256
            const std::vector<double>& dis_table,
257
            double dis_weight_factor)
258
0
            : nbits(nbits), dis_weight_factor(dis_weight_factor) {
259
0
        n = 1 << nbits;
260
0
        FAISS_THROW_IF_NOT(dis_table.size() == n * n);
261
0
        set_affine_target_dis(dis_table);
262
0
    }
263
264
0
    void set_affine_target_dis(const std::vector<double>& dis_table) {
265
0
        double sum = 0, sum2 = 0;
266
0
        int n2 = n * n;
267
0
        for (int i = 0; i < n2; i++) {
268
0
            sum += dis_table[i];
269
0
            sum2 += dis_table[i] * dis_table[i];
270
0
        }
271
0
        double mean = sum / n2;
272
0
        double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
273
274
0
        target_dis.resize(n2);
275
276
0
        for (int i = 0; i < n2; i++) {
277
            // the mapping function
278
0
            double td = (dis_table[i] - mean) / stddev * sqrt(nbits / 4) +
279
0
                    nbits / 2;
280
0
            target_dis[i] = td;
281
            // compute a weight
282
0
            weights.push_back(dis_weight(td));
283
0
        }
284
0
    }
285
286
0
    ~ReproduceWithHammingObjective() override {}
287
};
288
289
} // anonymous namespace
290
291
// weihgting of distances: it is more important to reproduce small
292
// distances well
293
0
double ReproduceDistancesObjective::dis_weight(double x) const {
294
0
    return exp(-dis_weight_factor * x);
295
0
}
296
297
0
double ReproduceDistancesObjective::get_source_dis(int i, int j) const {
298
0
    return source_dis[i * n + j];
299
0
}
300
301
// cost = quadratic difference between actual distance and Hamming distance
302
0
double ReproduceDistancesObjective::compute_cost(const int* perm) const {
303
0
    double cost = 0;
304
0
    for (int i = 0; i < n; i++) {
305
0
        for (int j = 0; j < n; j++) {
306
0
            double wanted = target_dis[i * n + j];
307
0
            double w = weights[i * n + j];
308
0
            double actual = get_source_dis(perm[i], perm[j]);
309
0
            cost += w * sqr(wanted - actual);
310
0
        }
311
0
    }
312
0
    return cost;
313
0
}
314
315
// what would the cost update be if iw and jw were swapped?
316
// computed in O(n) instead of O(n^2) for the full re-computation
317
double ReproduceDistancesObjective::cost_update(const int* perm, int iw, int jw)
318
0
        const {
319
0
    double delta_cost = 0;
320
0
    for (int i = 0; i < n; i++) {
321
0
        if (i == iw) {
322
0
            for (int j = 0; j < n; j++) {
323
0
                double wanted = target_dis[i * n + j], w = weights[i * n + j];
324
0
                double actual = get_source_dis(perm[i], perm[j]);
325
0
                delta_cost -= w * sqr(wanted - actual);
326
0
                double new_actual = get_source_dis(
327
0
                        perm[jw],
328
0
                        perm[j == iw           ? jw
329
0
                                     : j == jw ? iw
330
0
                                               : j]);
331
0
                delta_cost += w * sqr(wanted - new_actual);
332
0
            }
333
0
        } else if (i == jw) {
334
0
            for (int j = 0; j < n; j++) {
335
0
                double wanted = target_dis[i * n + j], w = weights[i * n + j];
336
0
                double actual = get_source_dis(perm[i], perm[j]);
337
0
                delta_cost -= w * sqr(wanted - actual);
338
0
                double new_actual = get_source_dis(
339
0
                        perm[iw],
340
0
                        perm[j == iw           ? jw
341
0
                                     : j == jw ? iw
342
0
                                               : j]);
343
0
                delta_cost += w * sqr(wanted - new_actual);
344
0
            }
345
0
        } else {
346
0
            int j = iw;
347
0
            {
348
0
                double wanted = target_dis[i * n + j], w = weights[i * n + j];
349
0
                double actual = get_source_dis(perm[i], perm[j]);
350
0
                delta_cost -= w * sqr(wanted - actual);
351
0
                double new_actual = get_source_dis(perm[i], perm[jw]);
352
0
                delta_cost += w * sqr(wanted - new_actual);
353
0
            }
354
0
            j = jw;
355
0
            {
356
0
                double wanted = target_dis[i * n + j], w = weights[i * n + j];
357
0
                double actual = get_source_dis(perm[i], perm[j]);
358
0
                delta_cost -= w * sqr(wanted - actual);
359
0
                double new_actual = get_source_dis(perm[i], perm[iw]);
360
0
                delta_cost += w * sqr(wanted - new_actual);
361
0
            }
362
0
        }
363
0
    }
364
0
    return delta_cost;
365
0
}
366
367
ReproduceDistancesObjective::ReproduceDistancesObjective(
368
        int n,
369
        const double* source_dis_in,
370
        const double* target_dis_in,
371
        double dis_weight_factor)
372
0
        : dis_weight_factor(dis_weight_factor), target_dis(target_dis_in) {
373
0
    this->n = n;
374
0
    set_affine_target_dis(source_dis_in);
375
0
}
376
377
void ReproduceDistancesObjective::compute_mean_stdev(
378
        const double* tab,
379
        size_t n2,
380
        double* mean_out,
381
0
        double* stddev_out) {
382
0
    double sum = 0, sum2 = 0;
383
0
    for (int i = 0; i < n2; i++) {
384
0
        sum += tab[i];
385
0
        sum2 += tab[i] * tab[i];
386
0
    }
387
0
    double mean = sum / n2;
388
0
    double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
389
0
    *mean_out = mean;
390
0
    *stddev_out = stddev;
391
0
}
392
393
void ReproduceDistancesObjective::set_affine_target_dis(
394
0
        const double* source_dis_in) {
395
0
    int n2 = n * n;
396
397
0
    double mean_src, stddev_src;
398
0
    compute_mean_stdev(source_dis_in, n2, &mean_src, &stddev_src);
399
400
0
    double mean_target, stddev_target;
401
0
    compute_mean_stdev(target_dis, n2, &mean_target, &stddev_target);
402
403
0
    printf("map mean %g std %g -> mean %g std %g\n",
404
0
           mean_src,
405
0
           stddev_src,
406
0
           mean_target,
407
0
           stddev_target);
408
409
0
    source_dis.resize(n2);
410
0
    weights.resize(n2);
411
412
0
    for (int i = 0; i < n2; i++) {
413
        // the mapping function
414
0
        source_dis[i] =
415
0
                (source_dis_in[i] - mean_src) / stddev_src * stddev_target +
416
0
                mean_target;
417
418
        // compute a weight
419
0
        weights[i] = dis_weight(target_dis[i]);
420
0
    }
421
0
}
422
423
/****************************************************
424
 * Cost functions: RankingScore
425
 ****************************************************/
426
427
/// Maintains a 3D table of elementary costs.
428
/// Accumulates elements based on Hamming distance comparisons
429
template <typename Ttab, typename Taccu>
430
struct Score3Computer : PermutationObjective {
431
    int nc;
432
433
    // cost matrix of size nc * nc *nc
434
    // n_gt (i,j,k) = count of d_gt(x, y-) < d_gt(x, y+)
435
    // where x has PQ code i, y- PQ code j and y+ PQ code k
436
    std::vector<Ttab> n_gt;
437
438
    /// the cost is a triple loop on the nc * nc * nc matrix of entries.
439
    ///
440
0
    Taccu compute(const int* perm) const {
441
0
        Taccu accu = 0;
442
0
        const Ttab* p = n_gt.data();
443
0
        for (int i = 0; i < nc; i++) {
444
0
            int ip = perm[i];
445
0
            for (int j = 0; j < nc; j++) {
446
0
                int jp = perm[j];
447
0
                for (int k = 0; k < nc; k++) {
448
0
                    int kp = perm[k];
449
0
                    if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
450
0
                        accu += *p; // n_gt [ ( i * nc + j) * nc + k];
451
0
                    }
452
0
                    p++;
453
0
                }
454
0
            }
455
0
        }
456
0
        return accu;
457
0
    }
458
459
    /** cost update if entries iw and jw of the permutation would be
460
     * swapped.
461
     *
462
     * The computation is optimized by avoiding elements in the
463
     * nc*nc*nc cube that are known not to change. For nc=256, this
464
     * reduces the nb of cells to visit to about 6/256 th of the
465
     * cells. Practical speedup is about 8x, and the code is quite
466
     * complex :-/
467
     */
468
0
    Taccu compute_update(const int* perm, int iw, int jw) const {
469
0
        assert(iw != jw);
470
0
        if (iw > jw)
471
0
            std::swap(iw, jw);
472
473
0
        Taccu accu = 0;
474
0
        const Ttab* n_gt_i = n_gt.data();
475
0
        for (int i = 0; i < nc; i++) {
476
0
            int ip0 = perm[i];
477
0
            int ip = perm[i == iw ? jw : i == jw ? iw : i];
478
479
            // accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);
480
481
0
            accu += update_i_cross(perm, iw, jw, ip0, ip, n_gt_i);
482
483
0
            if (ip != ip0)
484
0
                accu += update_i_plane(perm, iw, jw, ip0, ip, n_gt_i);
485
486
0
            n_gt_i += nc * nc;
487
0
        }
488
489
0
        return accu;
490
0
    }
491
492
    Taccu update_i(
493
            const int* perm,
494
            int iw,
495
            int jw,
496
            int ip0,
497
            int ip,
498
            const Ttab* n_gt_i) const {
499
        Taccu accu = 0;
500
        const Ttab* n_gt_ij = n_gt_i;
501
        for (int j = 0; j < nc; j++) {
502
            int jp0 = perm[j];
503
            int jp = perm[j == iw ? jw : j == jw ? iw : j];
504
            for (int k = 0; k < nc; k++) {
505
                int kp0 = perm[k];
506
                int kp = perm[k == iw ? jw : k == jw ? iw : k];
507
                int ng = n_gt_ij[k];
508
                if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
509
                    accu += ng;
510
                }
511
                if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp0)) {
512
                    accu -= ng;
513
                }
514
            }
515
            n_gt_ij += nc;
516
        }
517
        return accu;
518
    }
519
520
    // 2 inner loops for the case ip0 != ip
521
    Taccu update_i_plane(
522
            const int* perm,
523
            int iw,
524
            int jw,
525
            int ip0,
526
            int ip,
527
0
            const Ttab* n_gt_i) const {
528
0
        Taccu accu = 0;
529
0
        const Ttab* n_gt_ij = n_gt_i;
530
531
0
        for (int j = 0; j < nc; j++) {
532
0
            if (j != iw && j != jw) {
533
0
                int jp = perm[j];
534
0
                for (int k = 0; k < nc; k++) {
535
0
                    if (k != iw && k != jw) {
536
0
                        int kp = perm[k];
537
0
                        Ttab ng = n_gt_ij[k];
538
0
                        if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
539
0
                            accu += ng;
540
0
                        }
541
0
                        if (hamming_dis(ip0, jp) < hamming_dis(ip0, kp)) {
542
0
                            accu -= ng;
543
0
                        }
544
0
                    }
545
0
                }
546
0
            }
547
0
            n_gt_ij += nc;
548
0
        }
549
0
        return accu;
550
0
    }
551
552
    /// used for the 8 cells were the 3 indices are swapped
553
    inline Taccu update_k(
554
            const int* perm,
555
            int iw,
556
            int jw,
557
            int ip0,
558
            int ip,
559
            int jp0,
560
            int jp,
561
            int k,
562
0
            const Ttab* n_gt_ij) const {
563
0
        Taccu accu = 0;
564
0
        int kp0 = perm[k];
565
0
        int kp = perm[k == iw ? jw : k == jw ? iw : k];
566
0
        Ttab ng = n_gt_ij[k];
567
0
        if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
568
0
            accu += ng;
569
0
        }
570
0
        if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp0)) {
571
0
            accu -= ng;
572
0
        }
573
0
        return accu;
574
0
    }
575
576
    /// compute update on a line of k's, where i and j are swapped
577
    Taccu update_j_line(
578
            const int* perm,
579
            int iw,
580
            int jw,
581
            int ip0,
582
            int ip,
583
            int jp0,
584
            int jp,
585
0
            const Ttab* n_gt_ij) const {
586
0
        Taccu accu = 0;
587
0
        for (int k = 0; k < nc; k++) {
588
0
            if (k == iw || k == jw)
589
0
                continue;
590
0
            int kp = perm[k];
591
0
            Ttab ng = n_gt_ij[k];
592
0
            if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
593
0
                accu += ng;
594
0
            }
595
0
            if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp)) {
596
0
                accu -= ng;
597
0
            }
598
0
        }
599
0
        return accu;
600
0
    }
601
602
    /// considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw
603
    Taccu update_i_cross(
604
            const int* perm,
605
            int iw,
606
            int jw,
607
            int ip0,
608
            int ip,
609
0
            const Ttab* n_gt_i) const {
610
0
        Taccu accu = 0;
611
0
        const Ttab* n_gt_ij = n_gt_i;
612
613
0
        for (int j = 0; j < nc; j++) {
614
0
            int jp0 = perm[j];
615
0
            int jp = perm[j == iw ? jw : j == jw ? iw : j];
616
617
0
            accu += update_k(perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
618
0
            accu += update_k(perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
619
620
0
            if (jp != jp0)
621
0
                accu += update_j_line(perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
622
623
0
            n_gt_ij += nc;
624
0
        }
625
0
        return accu;
626
0
    }
627
628
    /// PermutationObjective implementeation (just negates the scores
629
    /// for minimization)
630
631
0
    double compute_cost(const int* perm) const override {
632
0
        return -compute(perm);
633
0
    }
634
635
0
    double cost_update(const int* perm, int iw, int jw) const override {
636
0
        double ret = -compute_update(perm, iw, jw);
637
0
        return ret;
638
0
    }
639
640
0
    ~Score3Computer() override {}
641
};
642
643
struct IndirectSort {
644
    const float* tab;
645
0
    bool operator()(int a, int b) {
646
0
        return tab[a] < tab[b];
647
0
    }
648
};
649
650
struct RankingScore2 : Score3Computer<float, double> {
651
    int nbits;
652
    int nq, nb;
653
    const uint32_t *qcodes, *bcodes;
654
    const float* gt_distances;
655
656
    RankingScore2(
657
            int nbits,
658
            int nq,
659
            int nb,
660
            const uint32_t* qcodes,
661
            const uint32_t* bcodes,
662
            const float* gt_distances)
663
0
            : nbits(nbits),
664
0
              nq(nq),
665
0
              nb(nb),
666
0
              qcodes(qcodes),
667
0
              bcodes(bcodes),
668
0
              gt_distances(gt_distances) {
669
0
        n = nc = 1 << nbits;
670
0
        n_gt.resize(nc * nc * nc);
671
0
        init_n_gt();
672
0
    }
673
674
0
    double rank_weight(int r) {
675
0
        return 1.0 / (r + 1);
676
0
    }
677
678
    /// count nb of i, j in a x b st. i < j
679
    /// a and b should be sorted on input
680
    /// they are the ranks of j and k respectively.
681
    /// specific version for diff-of-rank weighting, cannot optimized
682
    /// with a cumulative table
683
    double accum_gt_weight_diff(
684
            const std::vector<int>& a,
685
0
            const std::vector<int>& b) {
686
0
        const auto nb_2 = b.size();
687
0
        const auto na = a.size();
688
689
0
        double accu = 0;
690
0
        size_t j = 0;
691
0
        for (size_t i = 0; i < na; i++) {
692
0
            const auto ai = a[i];
693
0
            while (j < nb_2 && ai >= b[j]) {
694
0
                j++;
695
0
            }
696
697
0
            double accu_i = 0;
698
0
            for (auto k = j; k < b.size(); k++) {
699
0
                accu_i += rank_weight(b[k] - ai);
700
0
            }
701
702
0
            accu += rank_weight(ai) * accu_i;
703
0
        }
704
0
        return accu;
705
0
    }
706
707
0
    void init_n_gt() {
708
0
        for (int q = 0; q < nq; q++) {
709
0
            const float* gtd = gt_distances + q * nb;
710
0
            const uint32_t* cb = bcodes; // all same codes
711
0
            float* n_gt_q = &n_gt[qcodes[q] * nc * nc];
712
713
0
            printf("init gt for q=%d/%d    \r", q, nq);
714
0
            fflush(stdout);
715
716
0
            std::vector<int> rankv(nb);
717
0
            int* ranks = rankv.data();
718
719
            // elements in each code bin, ordered by rank within each bin
720
0
            std::vector<std::vector<int>> tab(nc);
721
722
0
            { // build rank table
723
0
                IndirectSort s = {gtd};
724
0
                for (int j = 0; j < nb; j++)
725
0
                    ranks[j] = j;
726
0
                std::sort(ranks, ranks + nb, s);
727
0
            }
728
729
0
            for (int rank = 0; rank < nb; rank++) {
730
0
                int i = ranks[rank];
731
0
                tab[cb[i]].push_back(rank);
732
0
            }
733
734
            // this is very expensive. Any suggestion for improvement
735
            // welcome.
736
0
            for (int i = 0; i < nc; i++) {
737
0
                std::vector<int>& di = tab[i];
738
0
                for (int j = 0; j < nc; j++) {
739
0
                    std::vector<int>& dj = tab[j];
740
0
                    n_gt_q[i * nc + j] += accum_gt_weight_diff(di, dj);
741
0
                }
742
0
            }
743
0
        }
744
0
    }
745
};
746
747
/*****************************************
748
 * PolysemousTraining
749
 ******************************************/
750
751
0
PolysemousTraining::PolysemousTraining() {
752
0
    optimization_type = OT_ReproduceDistances_affine;
753
0
    ntrain_permutation = 0;
754
0
    dis_weight_factor = log(2);
755
    // max 20 G RAM
756
0
    max_memory = (size_t)(20) * 1024 * 1024 * 1024;
757
0
}
758
759
void PolysemousTraining::optimize_reproduce_distances(
760
0
        ProductQuantizer& pq) const {
761
0
    int dsub = pq.dsub;
762
763
0
    int n = pq.ksub;
764
0
    int nbits = pq.nbits;
765
766
0
    size_t mem1 = memory_usage_per_thread(pq);
767
0
    int nt = std::min(omp_get_max_threads(), int(pq.M));
768
0
    FAISS_THROW_IF_NOT_FMT(
769
0
            mem1 < max_memory,
770
0
            "Polysemous training will use %zd bytes per thread, while the max is set to %zd",
771
0
            mem1,
772
0
            max_memory);
773
774
0
    if (mem1 * nt > max_memory) {
775
0
        nt = max_memory / mem1;
776
0
        fprintf(stderr,
777
0
                "Polysemous training: WARN, reducing number of threads to %d to save memory",
778
0
                nt);
779
0
    }
780
781
0
#pragma omp parallel for num_threads(nt)
782
0
    for (int m = 0; m < pq.M; m++) {
783
0
        std::vector<double> dis_table;
784
785
        // printf ("Optimizing quantizer %d\n", m);
786
787
0
        float* centroids = pq.get_centroids(m, 0);
788
789
0
        for (int i = 0; i < n; i++) {
790
0
            for (int j = 0; j < n; j++) {
791
0
                dis_table.push_back(fvec_L2sqr(
792
0
                        centroids + i * dsub, centroids + j * dsub, dsub));
793
0
            }
794
0
        }
795
796
0
        std::vector<int> perm(n);
797
0
        ReproduceWithHammingObjective obj(nbits, dis_table, dis_weight_factor);
798
799
0
        SimulatedAnnealingOptimizer optim(&obj, *this);
800
801
0
        if (log_pattern.size()) {
802
0
            char fname[256];
803
0
            snprintf(fname, 256, log_pattern.c_str(), m);
804
0
            printf("opening log file %s\n", fname);
805
0
            optim.logfile = fopen(fname, "w");
806
0
            FAISS_THROW_IF_NOT_MSG(optim.logfile, "could not open logfile");
807
0
        }
808
0
        double final_cost = optim.run_optimization(perm.data());
809
810
0
        if (verbose > 0) {
811
0
            printf("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
812
0
                   m,
813
0
                   optim.init_cost,
814
0
                   final_cost);
815
0
        }
816
817
0
        if (log_pattern.size())
818
0
            fclose(optim.logfile);
819
820
0
        std::vector<float> centroids_copy;
821
0
        for (int i = 0; i < dsub * n; i++)
822
0
            centroids_copy.push_back(centroids[i]);
823
824
0
        for (int i = 0; i < n; i++)
825
0
            memcpy(centroids + perm[i] * dsub,
826
0
                   centroids_copy.data() + i * dsub,
827
0
                   dsub * sizeof(centroids[0]));
828
0
    }
829
0
}
830
831
void PolysemousTraining::optimize_ranking(
832
        ProductQuantizer& pq,
833
        size_t n,
834
0
        const float* x) const {
835
0
    int dsub = pq.dsub;
836
0
    int nbits = pq.nbits;
837
838
0
    std::vector<uint8_t> all_codes(pq.code_size * n);
839
840
0
    pq.compute_codes(x, all_codes.data(), n);
841
842
0
    FAISS_THROW_IF_NOT(pq.nbits == 8);
843
844
0
    if (n == 0) {
845
0
        pq.compute_sdc_table();
846
0
    }
847
848
0
#pragma omp parallel for
849
0
    for (int m = 0; m < pq.M; m++) {
850
0
        size_t nq, nb;
851
0
        std::vector<uint32_t> codes;     // query codes, then db codes
852
0
        std::vector<float> gt_distances; // nq * nb matrix of distances
853
854
0
        if (n > 0) {
855
0
            std::vector<float> xtrain(n * dsub);
856
0
            for (int i = 0; i < n; i++)
857
0
                memcpy(xtrain.data() + i * dsub,
858
0
                       x + i * pq.d + m * dsub,
859
0
                       sizeof(float) * dsub);
860
861
0
            codes.resize(n);
862
0
            for (int i = 0; i < n; i++)
863
0
                codes[i] = all_codes[i * pq.code_size + m];
864
865
0
            nq = n / 4;
866
0
            nb = n - nq;
867
0
            const float* xq = xtrain.data();
868
0
            const float* xb = xq + nq * dsub;
869
870
0
            gt_distances.resize(nq * nb);
871
872
0
            pairwise_L2sqr(dsub, nq, xq, nb, xb, gt_distances.data());
873
0
        } else {
874
0
            nq = nb = pq.ksub;
875
0
            codes.resize(2 * nq);
876
0
            for (int i = 0; i < nq; i++)
877
0
                codes[i] = codes[i + nq] = i;
878
879
0
            gt_distances.resize(nq * nb);
880
881
0
            memcpy(gt_distances.data(),
882
0
                   pq.sdc_table.data() + m * nq * nb,
883
0
                   sizeof(float) * nq * nb);
884
0
        }
885
886
0
        double t0 = getmillisecs();
887
888
0
        std::unique_ptr<PermutationObjective> obj(new RankingScore2(
889
0
                nbits,
890
0
                nq,
891
0
                nb,
892
0
                codes.data(),
893
0
                codes.data() + nq,
894
0
                gt_distances.data()));
895
896
0
        if (verbose > 0) {
897
0
            printf("   m=%d, nq=%zd, nb=%zd, initialize RankingScore "
898
0
                   "in %.3f ms\n",
899
0
                   m,
900
0
                   nq,
901
0
                   nb,
902
0
                   getmillisecs() - t0);
903
0
        }
904
905
0
        SimulatedAnnealingOptimizer optim(obj.get(), *this);
906
907
0
        if (log_pattern.size()) {
908
0
            char fname[256];
909
0
            snprintf(fname, 256, log_pattern.c_str(), m);
910
0
            printf("opening log file %s\n", fname);
911
0
            optim.logfile = fopen(fname, "w");
912
0
            FAISS_THROW_IF_NOT_FMT(
913
0
                    optim.logfile, "could not open logfile %s", fname);
914
0
        }
915
916
0
        std::vector<int> perm(pq.ksub);
917
918
0
        double final_cost = optim.run_optimization(perm.data());
919
0
        printf("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
920
0
               m,
921
0
               optim.init_cost,
922
0
               final_cost);
923
924
0
        if (log_pattern.size())
925
0
            fclose(optim.logfile);
926
927
0
        float* centroids = pq.get_centroids(m, 0);
928
929
0
        std::vector<float> centroids_copy;
930
0
        for (int i = 0; i < dsub * pq.ksub; i++)
931
0
            centroids_copy.push_back(centroids[i]);
932
933
0
        for (int i = 0; i < pq.ksub; i++)
934
0
            memcpy(centroids + perm[i] * dsub,
935
0
                   centroids_copy.data() + i * dsub,
936
0
                   dsub * sizeof(centroids[0]));
937
0
    }
938
0
}
939
940
void PolysemousTraining::optimize_pq_for_hamming(
941
        ProductQuantizer& pq,
942
        size_t n,
943
0
        const float* x) const {
944
0
    if (optimization_type == OT_None) {
945
0
    } else if (optimization_type == OT_ReproduceDistances_affine) {
946
0
        optimize_reproduce_distances(pq);
947
0
    } else {
948
0
        optimize_ranking(pq, n, x);
949
0
    }
950
951
0
    pq.compute_sdc_table();
952
0
}
953
954
size_t PolysemousTraining::memory_usage_per_thread(
955
0
        const ProductQuantizer& pq) const {
956
0
    size_t n = pq.ksub;
957
958
0
    switch (optimization_type) {
959
0
        case OT_None:
960
0
            return 0;
961
0
        case OT_ReproduceDistances_affine:
962
0
            return n * n * sizeof(double) * 3;
963
0
        case OT_Ranking_weighted_diff:
964
0
            return n * n * n * sizeof(float);
965
0
    }
966
967
0
    FAISS_THROW_MSG("Invalid optmization type");
968
0
    return 0;
969
0
}
970
971
} // namespace faiss