/root/doris/contrib/faiss/faiss/impl/ResidualQuantizer.cpp
| Line | Count | Source | 
| 1 |  | /* | 
| 2 |  |  * Copyright (c) Meta Platforms, Inc. and affiliates. | 
| 3 |  |  * | 
| 4 |  |  * This source code is licensed under the MIT license found in the | 
| 5 |  |  * LICENSE file in the root directory of this source tree. | 
| 6 |  |  */ | 
| 7 |  |  | 
| 8 |  | #include <faiss/impl/ResidualQuantizer.h> | 
| 9 |  |  | 
| 10 |  | #include <algorithm> | 
| 11 |  | #include <cmath> | 
| 12 |  | #include <cstddef> | 
| 13 |  | #include <cstdio> | 
| 14 |  | #include <cstring> | 
| 15 |  | #include <memory> | 
| 16 |  |  | 
| 17 |  | #include <faiss/IndexFlat.h> | 
| 18 |  | #include <faiss/VectorTransform.h> | 
| 19 |  | #include <faiss/impl/FaissAssert.h> | 
| 20 |  | #include <faiss/impl/residual_quantizer_encode_steps.h> | 
| 21 |  | #include <faiss/utils/distances.h> | 
| 22 |  | #include <faiss/utils/hamming.h> | 
| 23 |  | #include <faiss/utils/utils.h> | 
| 24 |  |  | 
| 25 |  | extern "C" { | 
| 26 |  |  | 
| 27 |  | // general matrix multiplication | 
| 28 |  | int sgemm_( | 
| 29 |  |         const char* transa, | 
| 30 |  |         const char* transb, | 
| 31 |  |         FINTEGER* m, | 
| 32 |  |         FINTEGER* n, | 
| 33 |  |         FINTEGER* k, | 
| 34 |  |         const float* alpha, | 
| 35 |  |         const float* a, | 
| 36 |  |         FINTEGER* lda, | 
| 37 |  |         const float* b, | 
| 38 |  |         FINTEGER* ldb, | 
| 39 |  |         float* beta, | 
| 40 |  |         float* c, | 
| 41 |  |         FINTEGER* ldc); | 
| 42 |  |  | 
| 43 |  | // http://www.netlib.org/clapack/old/single/sgels.c | 
| 44 |  | // solve least squares | 
| 45 |  |  | 
| 46 |  | int sgelsd_( | 
| 47 |  |         FINTEGER* m, | 
| 48 |  |         FINTEGER* n, | 
| 49 |  |         FINTEGER* nrhs, | 
| 50 |  |         float* a, | 
| 51 |  |         FINTEGER* lda, | 
| 52 |  |         float* b, | 
| 53 |  |         FINTEGER* ldb, | 
| 54 |  |         float* s, | 
| 55 |  |         float* rcond, | 
| 56 |  |         FINTEGER* rank, | 
| 57 |  |         float* work, | 
| 58 |  |         FINTEGER* lwork, | 
| 59 |  |         FINTEGER* iwork, | 
| 60 |  |         FINTEGER* info); | 
| 61 |  | } | 
| 62 |  |  | 
| 63 |  | namespace faiss { | 
| 64 |  |  | 
| 65 | 0 | ResidualQuantizer::ResidualQuantizer() { | 
| 66 | 0 |     d = 0; | 
| 67 | 0 |     M = 0; | 
| 68 | 0 |     verbose = false; | 
| 69 | 0 | } | 
| 70 |  |  | 
| 71 |  | ResidualQuantizer::ResidualQuantizer( | 
| 72 |  |         size_t d, | 
| 73 |  |         const std::vector<size_t>& nbits, | 
| 74 |  |         Search_type_t search_type) | 
| 75 | 0 |         : ResidualQuantizer() { | 
| 76 | 0 |     this->search_type = search_type; | 
| 77 | 0 |     this->d = d; | 
| 78 | 0 |     M = nbits.size(); | 
| 79 | 0 |     this->nbits = nbits; | 
| 80 | 0 |     set_derived_values(); | 
| 81 | 0 | } | 
| 82 |  |  | 
| 83 |  | ResidualQuantizer::ResidualQuantizer( | 
| 84 |  |         size_t d, | 
| 85 |  |         size_t M, | 
| 86 |  |         size_t nbits, | 
| 87 |  |         Search_type_t search_type) | 
| 88 | 0 |         : ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {} | 
| 89 |  |  | 
| 90 |  | void ResidualQuantizer::initialize_from( | 
| 91 |  |         const ResidualQuantizer& other, | 
| 92 | 0 |         int skip_M) { | 
| 93 | 0 |     FAISS_THROW_IF_NOT(M + skip_M <= other.M); | 
| 94 | 0 |     FAISS_THROW_IF_NOT(skip_M >= 0); | 
| 95 |  |  | 
| 96 | 0 |     Search_type_t this_search_type = search_type; | 
| 97 | 0 |     int this_M = M; | 
| 98 |  |  | 
| 99 |  |     // a first good approximation: override everything | 
| 100 | 0 |     *this = other; | 
| 101 |  |  | 
| 102 |  |     // adjust derived values | 
| 103 | 0 |     M = this_M; | 
| 104 | 0 |     search_type = this_search_type; | 
| 105 | 0 |     nbits.resize(M); | 
| 106 | 0 |     memcpy(nbits.data(), | 
| 107 | 0 |            other.nbits.data() + skip_M, | 
| 108 | 0 |            nbits.size() * sizeof(nbits[0])); | 
| 109 |  | 
 | 
| 110 | 0 |     set_derived_values(); | 
| 111 |  |  | 
| 112 |  |     // resize codebooks if trained | 
| 113 | 0 |     if (codebooks.size() > 0) { | 
| 114 | 0 |         FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d); | 
| 115 | 0 |         codebooks.resize(total_codebook_size * d); | 
| 116 | 0 |         memcpy(codebooks.data(), | 
| 117 | 0 |                other.codebooks.data() + other.codebook_offsets[skip_M] * d, | 
| 118 | 0 |                codebooks.size() * sizeof(codebooks[0])); | 
| 119 |  |         // TODO: norm_tabs? | 
| 120 | 0 |     } | 
| 121 | 0 | } | 
| 122 |  |  | 
| 123 |  | /**************************************************************** | 
| 124 |  |  * Training | 
| 125 |  |  ****************************************************************/ | 
| 126 |  |  | 
| 127 | 0 | void ResidualQuantizer::train(size_t n, const float* x) { | 
| 128 | 0 |     codebooks.resize(d * codebook_offsets.back()); | 
| 129 |  | 
 | 
| 130 | 0 |     if (verbose) { | 
| 131 | 0 |         printf("Training ResidualQuantizer, with %zd steps on %zd %zdD vectors\n", | 
| 132 | 0 |                M, | 
| 133 | 0 |                n, | 
| 134 | 0 |                size_t(d)); | 
| 135 | 0 |     } | 
| 136 |  | 
 | 
| 137 | 0 |     int cur_beam_size = 1; | 
| 138 | 0 |     std::vector<float> residuals(x, x + n * d); | 
| 139 | 0 |     std::vector<int32_t> codes; | 
| 140 | 0 |     std::vector<float> distances; | 
| 141 | 0 |     double t0 = getmillisecs(); | 
| 142 | 0 |     double clustering_time = 0; | 
| 143 |  | 
 | 
| 144 | 0 |     for (int m = 0; m < M; m++) { | 
| 145 | 0 |         int K = 1 << nbits[m]; | 
| 146 |  |  | 
| 147 |  |         // on which residuals to train | 
| 148 | 0 |         std::vector<float>& train_residuals = residuals; | 
| 149 | 0 |         std::vector<float> residuals1; | 
| 150 | 0 |         if (train_type & Train_top_beam) { | 
| 151 | 0 |             residuals1.resize(n * d); | 
| 152 | 0 |             for (size_t j = 0; j < n; j++) { | 
| 153 | 0 |                 memcpy(residuals1.data() + j * d, | 
| 154 | 0 |                        residuals.data() + j * d * cur_beam_size, | 
| 155 | 0 |                        sizeof(residuals[0]) * d); | 
| 156 | 0 |             } | 
| 157 | 0 |             train_residuals = residuals1; | 
| 158 | 0 |         } | 
| 159 | 0 |         std::vector<float> codebooks; | 
| 160 | 0 |         float obj = 0; | 
| 161 |  | 
 | 
| 162 | 0 |         std::unique_ptr<Index> assign_index; | 
| 163 | 0 |         if (assign_index_factory) { | 
| 164 | 0 |             assign_index.reset((*assign_index_factory)(d)); | 
| 165 | 0 |         } else { | 
| 166 | 0 |             assign_index.reset(new IndexFlatL2(d)); | 
| 167 | 0 |         } | 
| 168 |  | 
 | 
| 169 | 0 |         double t1 = getmillisecs(); | 
| 170 |  | 
 | 
| 171 | 0 |         if (!(train_type & Train_progressive_dim)) { // regular kmeans | 
| 172 | 0 |             Clustering clus(d, K, cp); | 
| 173 | 0 |             clus.train( | 
| 174 | 0 |                     train_residuals.size() / d, | 
| 175 | 0 |                     train_residuals.data(), | 
| 176 | 0 |                     *assign_index.get()); | 
| 177 | 0 |             codebooks.swap(clus.centroids); | 
| 178 | 0 |             assign_index->reset(); | 
| 179 | 0 |             obj = clus.iteration_stats.back().obj; | 
| 180 | 0 |         } else { // progressive dim clustering | 
| 181 | 0 |             ProgressiveDimClustering clus(d, K, cp); | 
| 182 | 0 |             ProgressiveDimIndexFactory default_fac; | 
| 183 | 0 |             clus.train( | 
| 184 | 0 |                     train_residuals.size() / d, | 
| 185 | 0 |                     train_residuals.data(), | 
| 186 | 0 |                     assign_index_factory ? *assign_index_factory : default_fac); | 
| 187 | 0 |             codebooks.swap(clus.centroids); | 
| 188 | 0 |             obj = clus.iteration_stats.back().obj; | 
| 189 | 0 |         } | 
| 190 | 0 |         clustering_time += (getmillisecs() - t1) / 1000; | 
| 191 |  | 
 | 
| 192 | 0 |         memcpy(this->codebooks.data() + codebook_offsets[m] * d, | 
| 193 | 0 |                codebooks.data(), | 
| 194 | 0 |                codebooks.size() * sizeof(codebooks[0])); | 
| 195 |  |  | 
| 196 |  |         // quantize using the new codebooks | 
| 197 |  | 
 | 
| 198 | 0 |         int new_beam_size = std::min(cur_beam_size * K, max_beam_size); | 
| 199 | 0 |         std::vector<int32_t> new_codes(n * new_beam_size * (m + 1)); | 
| 200 | 0 |         std::vector<float> new_residuals(n * new_beam_size * d); | 
| 201 | 0 |         std::vector<float> new_distances(n * new_beam_size); | 
| 202 |  | 
 | 
| 203 | 0 |         size_t bs; | 
| 204 | 0 |         { // determine batch size | 
| 205 | 0 |             size_t mem = memory_per_point(); | 
| 206 | 0 |             if (n > 1 && mem * n > max_mem_distances) { | 
| 207 |  |                 // then split queries to reduce temp memory | 
| 208 | 0 |                 bs = std::max(max_mem_distances / mem, size_t(1)); | 
| 209 | 0 |             } else { | 
| 210 | 0 |                 bs = n; | 
| 211 | 0 |             } | 
| 212 | 0 |         } | 
| 213 |  | 
 | 
| 214 | 0 |         for (size_t i0 = 0; i0 < n; i0 += bs) { | 
| 215 | 0 |             size_t i1 = std::min(i0 + bs, n); | 
| 216 |  |  | 
| 217 |  |             /* printf("i0: %ld i1: %ld K %d ntotal assign index %ld\n", | 
| 218 |  |                 i0, i1, K, assign_index->ntotal); */ | 
| 219 |  | 
 | 
| 220 | 0 |             beam_search_encode_step( | 
| 221 | 0 |                     d, | 
| 222 | 0 |                     K, | 
| 223 | 0 |                     codebooks.data(), | 
| 224 | 0 |                     i1 - i0, | 
| 225 | 0 |                     cur_beam_size, | 
| 226 | 0 |                     residuals.data() + i0 * cur_beam_size * d, | 
| 227 | 0 |                     m, | 
| 228 | 0 |                     codes.data() + i0 * cur_beam_size * m, | 
| 229 | 0 |                     new_beam_size, | 
| 230 | 0 |                     new_codes.data() + i0 * new_beam_size * (m + 1), | 
| 231 | 0 |                     new_residuals.data() + i0 * new_beam_size * d, | 
| 232 | 0 |                     new_distances.data() + i0 * new_beam_size, | 
| 233 | 0 |                     assign_index.get(), | 
| 234 | 0 |                     approx_topk_mode); | 
| 235 | 0 |         } | 
| 236 | 0 |         codes.swap(new_codes); | 
| 237 | 0 |         residuals.swap(new_residuals); | 
| 238 | 0 |         distances.swap(new_distances); | 
| 239 |  | 
 | 
| 240 | 0 |         float sum_distances = 0; | 
| 241 | 0 |         for (int j = 0; j < distances.size(); j++) { | 
| 242 | 0 |             sum_distances += distances[j]; | 
| 243 | 0 |         } | 
| 244 |  | 
 | 
| 245 | 0 |         if (verbose) { | 
| 246 | 0 |             printf("[%.3f s, %.3f s clustering] train stage %d, %d bits, kmeans objective %g, " | 
| 247 | 0 |                    "total distance %g, beam_size %d->%d (batch size %zd)\n", | 
| 248 | 0 |                    (getmillisecs() - t0) / 1000, | 
| 249 | 0 |                    clustering_time, | 
| 250 | 0 |                    m, | 
| 251 | 0 |                    int(nbits[m]), | 
| 252 | 0 |                    obj, | 
| 253 | 0 |                    sum_distances, | 
| 254 | 0 |                    cur_beam_size, | 
| 255 | 0 |                    new_beam_size, | 
| 256 | 0 |                    bs); | 
| 257 | 0 |         } | 
| 258 | 0 |         cur_beam_size = new_beam_size; | 
| 259 | 0 |     } | 
| 260 |  | 
 | 
| 261 | 0 |     is_trained = true; | 
| 262 |  | 
 | 
| 263 | 0 |     if (train_type & Train_refine_codebook) { | 
| 264 | 0 |         for (int iter = 0; iter < niter_codebook_refine; iter++) { | 
| 265 | 0 |             if (verbose) { | 
| 266 | 0 |                 printf("re-estimating the codebooks to minimize " | 
| 267 | 0 |                        "quantization errors (iter %d).\n", | 
| 268 | 0 |                        iter); | 
| 269 | 0 |             } | 
| 270 | 0 |             retrain_AQ_codebook(n, x); | 
| 271 | 0 |         } | 
| 272 | 0 |     } | 
| 273 |  |  | 
| 274 |  |     // find min and max norms | 
| 275 | 0 |     std::vector<float> norms(n); | 
| 276 |  | 
 | 
| 277 | 0 |     for (size_t i = 0; i < n; i++) { | 
| 278 | 0 |         norms[i] = fvec_L2sqr( | 
| 279 | 0 |                 x + i * d, residuals.data() + i * cur_beam_size * d, d); | 
| 280 | 0 |     } | 
| 281 |  |  | 
| 282 |  |     // fvec_norms_L2sqr(norms.data(), x, d, n); | 
| 283 | 0 |     train_norm(n, norms.data()); | 
| 284 |  | 
 | 
| 285 | 0 |     if (!(train_type & Skip_codebook_tables)) { | 
| 286 | 0 |         compute_codebook_tables(); | 
| 287 | 0 |     } | 
| 288 | 0 | } | 
| 289 |  |  | 
| 290 | 0 | float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) { | 
| 291 | 0 |     FAISS_THROW_IF_NOT_MSG(n >= total_codebook_size, "too few training points"); | 
| 292 |  |  | 
| 293 | 0 |     if (verbose) { | 
| 294 | 0 |         printf("  encoding %zd training vectors\n", n); | 
| 295 | 0 |     } | 
| 296 | 0 |     std::vector<uint8_t> codes(n * code_size); | 
| 297 | 0 |     compute_codes(x, codes.data(), n); | 
| 298 |  |  | 
| 299 |  |     // compute reconstruction error | 
| 300 | 0 |     float input_recons_error; | 
| 301 | 0 |     { | 
| 302 | 0 |         std::vector<float> x_recons(n * d); | 
| 303 | 0 |         decode(codes.data(), x_recons.data(), n); | 
| 304 | 0 |         input_recons_error = fvec_L2sqr(x, x_recons.data(), n * d); | 
| 305 | 0 |         if (verbose) { | 
| 306 | 0 |             printf("  input quantization error %g\n", input_recons_error); | 
| 307 | 0 |         } | 
| 308 | 0 |     } | 
| 309 |  |  | 
| 310 |  |     // build matrix of the linear system | 
| 311 | 0 |     std::vector<float> C(n * total_codebook_size); | 
| 312 | 0 |     for (size_t i = 0; i < n; i++) { | 
| 313 | 0 |         BitstringReader bsr(codes.data() + i * code_size, code_size); | 
| 314 | 0 |         for (int m = 0; m < M; m++) { | 
| 315 | 0 |             int idx = bsr.read(nbits[m]); | 
| 316 | 0 |             C[i + (codebook_offsets[m] + idx) * n] = 1; | 
| 317 | 0 |         } | 
| 318 | 0 |     } | 
| 319 |  |  | 
| 320 |  |     // transpose training vectors | 
| 321 | 0 |     std::vector<float> xt(n * d); | 
| 322 |  | 
 | 
| 323 | 0 |     for (size_t i = 0; i < n; i++) { | 
| 324 | 0 |         for (size_t j = 0; j < d; j++) { | 
| 325 | 0 |             xt[j * n + i] = x[i * d + j]; | 
| 326 | 0 |         } | 
| 327 | 0 |     } | 
| 328 |  | 
 | 
| 329 | 0 |     { // solve least squares | 
| 330 | 0 |         FINTEGER lwork = -1; | 
| 331 | 0 |         FINTEGER di = d, ni = n, tcsi = total_codebook_size; | 
| 332 | 0 |         FINTEGER info = -1, rank = -1; | 
| 333 |  | 
 | 
| 334 | 0 |         float rcond = 1e-4; // this is an important parameter because the code | 
| 335 |  |                             // matrix can be rank deficient for small problems, | 
| 336 |  |                             // the default rcond=-1 does not work | 
| 337 | 0 |         float worksize; | 
| 338 | 0 |         std::vector<float> sing_vals(total_codebook_size); | 
| 339 | 0 |         FINTEGER nlvl = 1000; // formula is a bit convoluted so let's take an | 
| 340 |  |                               // upper bound | 
| 341 | 0 |         std::vector<FINTEGER> iwork( | 
| 342 | 0 |                 3 * total_codebook_size * nlvl + 11 * total_codebook_size); | 
| 343 |  |  | 
| 344 |  |         // worksize query | 
| 345 | 0 |         sgelsd_(&ni, | 
| 346 | 0 |                 &tcsi, | 
| 347 | 0 |                 &di, | 
| 348 | 0 |                 C.data(), | 
| 349 | 0 |                 &ni, | 
| 350 | 0 |                 xt.data(), | 
| 351 | 0 |                 &ni, | 
| 352 | 0 |                 sing_vals.data(), | 
| 353 | 0 |                 &rcond, | 
| 354 | 0 |                 &rank, | 
| 355 | 0 |                 &worksize, | 
| 356 | 0 |                 &lwork, | 
| 357 | 0 |                 iwork.data(), | 
| 358 | 0 |                 &info); | 
| 359 | 0 |         FAISS_THROW_IF_NOT(info == 0); | 
| 360 |  |  | 
| 361 | 0 |         lwork = worksize; | 
| 362 | 0 |         std::vector<float> work(lwork); | 
| 363 |  |         // actual call | 
| 364 | 0 |         sgelsd_(&ni, | 
| 365 | 0 |                 &tcsi, | 
| 366 | 0 |                 &di, | 
| 367 | 0 |                 C.data(), | 
| 368 | 0 |                 &ni, | 
| 369 | 0 |                 xt.data(), | 
| 370 | 0 |                 &ni, | 
| 371 | 0 |                 sing_vals.data(), | 
| 372 | 0 |                 &rcond, | 
| 373 | 0 |                 &rank, | 
| 374 | 0 |                 work.data(), | 
| 375 | 0 |                 &lwork, | 
| 376 | 0 |                 iwork.data(), | 
| 377 | 0 |                 &info); | 
| 378 | 0 |         FAISS_THROW_IF_NOT_FMT(info == 0, "SGELS returned info=%d", int(info)); | 
| 379 | 0 |         if (verbose) { | 
| 380 | 0 |             printf("   sgelsd rank=%d/%d\n", | 
| 381 | 0 |                    int(rank), | 
| 382 | 0 |                    int(total_codebook_size)); | 
| 383 | 0 |         } | 
| 384 | 0 |     } | 
| 385 |  |  | 
| 386 |  |     // result is in xt, re-transpose to codebook | 
| 387 |  |  | 
| 388 | 0 |     for (size_t i = 0; i < total_codebook_size; i++) { | 
| 389 | 0 |         for (size_t j = 0; j < d; j++) { | 
| 390 | 0 |             codebooks[i * d + j] = xt[j * n + i]; | 
| 391 | 0 |             FAISS_THROW_IF_NOT(std::isfinite(codebooks[i * d + j])); | 
| 392 | 0 |         } | 
| 393 | 0 |     } | 
| 394 |  |  | 
| 395 | 0 |     float output_recons_error = 0; | 
| 396 | 0 |     for (size_t j = 0; j < d; j++) { | 
| 397 | 0 |         output_recons_error += fvec_norm_L2sqr( | 
| 398 | 0 |                 xt.data() + total_codebook_size + n * j, | 
| 399 | 0 |                 n - total_codebook_size); | 
| 400 | 0 |     } | 
| 401 | 0 |     if (verbose) { | 
| 402 | 0 |         printf("  output quantization error %g\n", output_recons_error); | 
| 403 | 0 |     } | 
| 404 | 0 |     return output_recons_error; | 
| 405 | 0 | } | 
| 406 |  |  | 
| 407 | 0 | size_t ResidualQuantizer::memory_per_point(int beam_size) const { | 
| 408 | 0 |     if (beam_size < 0) { | 
| 409 | 0 |         beam_size = max_beam_size; | 
| 410 | 0 |     } | 
| 411 | 0 |     size_t mem; | 
| 412 | 0 |     mem = beam_size * d * 2 * sizeof(float); // size for 2 beams at a time | 
| 413 | 0 |     mem += beam_size * beam_size * | 
| 414 | 0 |             (sizeof(float) + sizeof(idx_t)); // size for 1 beam search result | 
| 415 | 0 |     return mem; | 
| 416 | 0 | } | 
| 417 |  |  | 
| 418 |  | /**************************************************************** | 
| 419 |  |  * Encoding | 
| 420 |  |  ****************************************************************/ | 
| 421 |  |  | 
| 422 |  | using namespace rq_encode_steps; | 
| 423 |  |  | 
| 424 |  | void ResidualQuantizer::compute_codes_add_centroids( | 
| 425 |  |         const float* x, | 
| 426 |  |         uint8_t* codes_out, | 
| 427 |  |         size_t n, | 
| 428 | 0 |         const float* centroids) const { | 
| 429 | 0 |     FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet."); | 
| 430 |  |  | 
| 431 |  |     // | 
| 432 | 0 |     size_t mem = memory_per_point(); | 
| 433 |  | 
 | 
| 434 | 0 |     size_t bs = max_mem_distances / mem; | 
| 435 | 0 |     if (bs == 0) { | 
| 436 | 0 |         bs = 1; // otherwise we can't do much | 
| 437 | 0 |     } | 
| 438 |  |  | 
| 439 |  |     // prepare memory pools | 
| 440 | 0 |     ComputeCodesAddCentroidsLUT0MemoryPool pool0; | 
| 441 | 0 |     ComputeCodesAddCentroidsLUT1MemoryPool pool1; | 
| 442 |  | 
 | 
| 443 | 0 |     for (size_t i0 = 0; i0 < n; i0 += bs) { | 
| 444 | 0 |         size_t i1 = std::min(n, i0 + bs); | 
| 445 | 0 |         const float* cent = nullptr; | 
| 446 | 0 |         if (centroids != nullptr) { | 
| 447 | 0 |             cent = centroids + i0 * d; | 
| 448 | 0 |         } | 
| 449 |  | 
 | 
| 450 | 0 |         if (use_beam_LUT == 0) { | 
| 451 | 0 |             compute_codes_add_centroids_mp_lut0( | 
| 452 | 0 |                     *this, | 
| 453 | 0 |                     x + i0 * d, | 
| 454 | 0 |                     codes_out + i0 * code_size, | 
| 455 | 0 |                     i1 - i0, | 
| 456 | 0 |                     cent, | 
| 457 | 0 |                     pool0); | 
| 458 | 0 |         } else if (use_beam_LUT == 1) { | 
| 459 | 0 |             compute_codes_add_centroids_mp_lut1( | 
| 460 | 0 |                     *this, | 
| 461 | 0 |                     x + i0 * d, | 
| 462 | 0 |                     codes_out + i0 * code_size, | 
| 463 | 0 |                     i1 - i0, | 
| 464 | 0 |                     cent, | 
| 465 | 0 |                     pool1); | 
| 466 | 0 |         } | 
| 467 | 0 |     } | 
| 468 | 0 | } | 
| 469 |  |  | 
| 470 |  | void ResidualQuantizer::refine_beam( | 
| 471 |  |         size_t n, | 
| 472 |  |         size_t beam_size, | 
| 473 |  |         const float* x, | 
| 474 |  |         int out_beam_size, | 
| 475 |  |         int32_t* out_codes, | 
| 476 |  |         float* out_residuals, | 
| 477 | 0 |         float* out_distances) const { | 
| 478 | 0 |     RefineBeamMemoryPool pool; | 
| 479 | 0 |     refine_beam_mp( | 
| 480 | 0 |             *this, | 
| 481 | 0 |             n, | 
| 482 | 0 |             beam_size, | 
| 483 | 0 |             x, | 
| 484 | 0 |             out_beam_size, | 
| 485 | 0 |             out_codes, | 
| 486 | 0 |             out_residuals, | 
| 487 | 0 |             out_distances, | 
| 488 | 0 |             pool); | 
| 489 | 0 | } | 
| 490 |  |  | 
| 491 |  | /******************************************************************* | 
| 492 |  |  * Functions using the dot products between codebook entries | 
| 493 |  |  *******************************************************************/ | 
| 494 |  |  | 
| 495 |  | void ResidualQuantizer::refine_beam_LUT( | 
| 496 |  |         size_t n, | 
| 497 |  |         const float* query_norms, // size n | 
| 498 |  |         const float* query_cp,    // | 
| 499 |  |         int out_beam_size, | 
| 500 |  |         int32_t* out_codes, | 
| 501 | 0 |         float* out_distances) const { | 
| 502 | 0 |     RefineBeamLUTMemoryPool pool; | 
| 503 | 0 |     refine_beam_LUT_mp( | 
| 504 | 0 |             *this, | 
| 505 | 0 |             n, | 
| 506 | 0 |             query_norms, | 
| 507 | 0 |             query_cp, | 
| 508 | 0 |             out_beam_size, | 
| 509 | 0 |             out_codes, | 
| 510 | 0 |             out_distances, | 
| 511 | 0 |             pool); | 
| 512 | 0 | } | 
| 513 |  |  | 
| 514 |  | } // namespace faiss |