/root/doris/contrib/faiss/faiss/impl/RaBitQuantizer.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/impl/RaBitQuantizer.h> |
9 | | |
10 | | #include <algorithm> |
11 | | #include <cmath> |
12 | | #include <cstring> |
13 | | #include <limits> |
14 | | #include <memory> |
15 | | #include <vector> |
16 | | |
17 | | #include <faiss/impl/FaissAssert.h> |
18 | | #include <faiss/utils/distances.h> |
19 | | |
20 | | namespace faiss { |
21 | | |
22 | | struct FactorsData { |
23 | | // ||or - c||^2 - ((metric==IP) ? ||or||^2 : 0) |
24 | | float or_minus_c_l2sqr = 0; |
25 | | float dp_multiplier = 0; |
26 | | }; |
27 | | |
28 | | struct QueryFactorsData { |
29 | | float c1 = 0; |
30 | | float c2 = 0; |
31 | | float c34 = 0; |
32 | | |
33 | | float qr_to_c_L2sqr = 0; |
34 | | float qr_norm_L2sqr = 0; |
35 | | }; |
36 | | |
37 | 0 | static size_t get_code_size(const size_t d) { |
38 | 0 | return (d + 7) / 8 + sizeof(FactorsData); |
39 | 0 | } |
40 | | |
41 | | RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric) |
42 | 0 | : Quantizer(d, get_code_size(d)), metric_type{metric} {} |
43 | | |
44 | 0 | void RaBitQuantizer::train(size_t n, const float* x) { |
45 | | // does nothing |
46 | 0 | } |
47 | | |
48 | | void RaBitQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n) |
49 | 0 | const { |
50 | 0 | compute_codes_core(x, codes, n, centroid); |
51 | 0 | } |
52 | | |
53 | | void RaBitQuantizer::compute_codes_core( |
54 | | const float* x, |
55 | | uint8_t* codes, |
56 | | size_t n, |
57 | 0 | const float* centroid_in) const { |
58 | 0 | FAISS_ASSERT(codes != nullptr); |
59 | 0 | FAISS_ASSERT(x != nullptr); |
60 | 0 | FAISS_ASSERT( |
61 | 0 | (metric_type == MetricType::METRIC_L2 || |
62 | 0 | metric_type == MetricType::METRIC_INNER_PRODUCT)); |
63 | | |
64 | 0 | if (n == 0) { |
65 | 0 | return; |
66 | 0 | } |
67 | | |
68 | | // compute some helper constants |
69 | 0 | const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d)); |
70 | | |
71 | | // compute codes |
72 | 0 | #pragma omp parallel for if (n > 1000) |
73 | 0 | for (int64_t i = 0; i < n; i++) { |
74 | | // ||or - c||^2 |
75 | 0 | float norm_L2sqr = 0; |
76 | | // ||or||^2, which is equal to ||P(or)||^2 and ||P^(-1)(or)||^2 |
77 | 0 | float or_L2sqr = 0; |
78 | | // dot product |
79 | 0 | float dp_oO = 0; |
80 | | |
81 | | // the code |
82 | 0 | uint8_t* code = codes + i * code_size; |
83 | 0 | FactorsData* fac = reinterpret_cast<FactorsData*>(code + (d + 7) / 8); |
84 | | |
85 | | // cleanup it |
86 | 0 | if (code != nullptr) { |
87 | 0 | memset(code, 0, code_size); |
88 | 0 | } |
89 | |
|
90 | 0 | for (size_t j = 0; j < d; j++) { |
91 | 0 | const float or_minus_c = x[i * d + j] - |
92 | 0 | ((centroid_in == nullptr) ? 0 : centroid_in[j]); |
93 | 0 | norm_L2sqr += or_minus_c * or_minus_c; |
94 | 0 | or_L2sqr += x[i * d + j] * x[i * d + j]; |
95 | |
|
96 | 0 | const bool xb = (or_minus_c > 0); |
97 | |
|
98 | 0 | dp_oO += xb ? or_minus_c : (-or_minus_c); |
99 | | |
100 | | // store the output data |
101 | 0 | if (code != nullptr) { |
102 | 0 | if (xb) { |
103 | | // enable a particular bit |
104 | 0 | code[j / 8] |= (1 << (j % 8)); |
105 | 0 | } |
106 | 0 | } |
107 | 0 | } |
108 | | |
109 | | // compute factors |
110 | | |
111 | | // compute the inverse norm |
112 | 0 | const float inv_norm_L2 = |
113 | 0 | (std::abs(norm_L2sqr) < std::numeric_limits<float>::epsilon()) |
114 | 0 | ? 1.0f |
115 | 0 | : (1.0f / std::sqrt(norm_L2sqr)); |
116 | 0 | dp_oO *= inv_norm_L2; |
117 | 0 | dp_oO *= inv_d_sqrt; |
118 | |
|
119 | 0 | const float inv_dp_oO = |
120 | 0 | (std::abs(dp_oO) < std::numeric_limits<float>::epsilon()) |
121 | 0 | ? 1.0f |
122 | 0 | : (1.0f / dp_oO); |
123 | |
|
124 | 0 | fac->or_minus_c_l2sqr = norm_L2sqr; |
125 | 0 | if (metric_type == MetricType::METRIC_INNER_PRODUCT) { |
126 | 0 | fac->or_minus_c_l2sqr -= or_L2sqr; |
127 | 0 | } |
128 | |
|
129 | 0 | fac->dp_multiplier = inv_dp_oO * std::sqrt(norm_L2sqr); |
130 | 0 | } |
131 | 0 | } |
132 | | |
133 | 0 | void RaBitQuantizer::decode(const uint8_t* codes, float* x, size_t n) const { |
134 | 0 | decode_core(codes, x, n, centroid); |
135 | 0 | } |
136 | | |
137 | | void RaBitQuantizer::decode_core( |
138 | | const uint8_t* codes, |
139 | | float* x, |
140 | | size_t n, |
141 | 0 | const float* centroid_in) const { |
142 | 0 | FAISS_ASSERT(codes != nullptr); |
143 | 0 | FAISS_ASSERT(x != nullptr); |
144 | | |
145 | 0 | const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d)); |
146 | |
|
147 | 0 | #pragma omp parallel for if (n > 1000) |
148 | 0 | for (int64_t i = 0; i < n; i++) { |
149 | 0 | const uint8_t* code = codes + i * code_size; |
150 | | |
151 | | // split the code into parts |
152 | 0 | const uint8_t* binary_data = code; |
153 | 0 | const FactorsData* fac = |
154 | 0 | reinterpret_cast<const FactorsData*>(code + (d + 7) / 8); |
155 | | |
156 | | // |
157 | 0 | for (size_t j = 0; j < d; j++) { |
158 | | // extract i-th bit |
159 | 0 | const uint8_t masker = (1 << (j % 8)); |
160 | 0 | const float bit = ((binary_data[j / 8] & masker) == masker) ? 1 : 0; |
161 | | |
162 | | // compute the output code |
163 | 0 | x[i * d + j] = (bit - 0.5f) * fac->dp_multiplier * 2 * inv_d_sqrt + |
164 | 0 | ((centroid_in == nullptr) ? 0 : centroid_in[j]); |
165 | 0 | } |
166 | 0 | } |
167 | 0 | } |
168 | | |
169 | | struct RaBitDistanceComputer : FlatCodesDistanceComputer { |
170 | | // dimensionality |
171 | | size_t d = 0; |
172 | | // a centroid to use |
173 | | const float* centroid = nullptr; |
174 | | |
175 | | // the metric |
176 | | MetricType metric_type = MetricType::METRIC_L2; |
177 | | |
178 | | RaBitDistanceComputer(); |
179 | | |
180 | | float symmetric_dis(idx_t i, idx_t j) override; |
181 | | }; |
182 | | |
183 | 0 | RaBitDistanceComputer::RaBitDistanceComputer() = default; |
184 | | |
185 | 0 | float RaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) { |
186 | 0 | FAISS_THROW_MSG("Not implemented"); |
187 | 0 | } |
188 | | |
189 | | struct RaBitDistanceComputerNotQ : RaBitDistanceComputer { |
190 | | // the rotated query (qr - c) |
191 | | std::vector<float> rotated_q; |
192 | | // some additional numbers for the query |
193 | | QueryFactorsData query_fac; |
194 | | |
195 | | RaBitDistanceComputerNotQ(); |
196 | | |
197 | | float distance_to_code(const uint8_t* code) override; |
198 | | |
199 | | void set_query(const float* x) override; |
200 | | }; |
201 | | |
202 | 0 | RaBitDistanceComputerNotQ::RaBitDistanceComputerNotQ() = default; |
203 | | |
204 | 0 | float RaBitDistanceComputerNotQ::distance_to_code(const uint8_t* code) { |
205 | 0 | FAISS_ASSERT(code != nullptr); |
206 | 0 | FAISS_ASSERT( |
207 | 0 | (metric_type == MetricType::METRIC_L2 || |
208 | 0 | metric_type == MetricType::METRIC_INNER_PRODUCT)); |
209 | | |
210 | | // split the code into parts |
211 | 0 | const uint8_t* binary_data = code; |
212 | 0 | const FactorsData* fac = |
213 | 0 | reinterpret_cast<const FactorsData*>(code + (d + 7) / 8); |
214 | | |
215 | | // this is the baseline code |
216 | | // |
217 | | // compute <q,o> using floats |
218 | 0 | float dot_qo = 0; |
219 | | // It was a willful decision (after the discussion) to not to pre-cache |
220 | | // the sum of all bits, just in order to reduce the overhead per vector. |
221 | 0 | uint64_t sum_q = 0; |
222 | 0 | for (size_t i = 0; i < d; i++) { |
223 | | // extract i-th bit |
224 | 0 | const uint8_t masker = (1 << (i % 8)); |
225 | 0 | const bool b_bit = ((binary_data[i / 8] & masker) == masker); |
226 | | |
227 | | // accumulate dp |
228 | 0 | dot_qo += (b_bit) ? rotated_q[i] : 0; |
229 | | // accumulate sum-of-bits |
230 | 0 | sum_q += (b_bit) ? 1 : 0; |
231 | 0 | } |
232 | |
|
233 | 0 | float final_dot = 0; |
234 | | // dot-product itself |
235 | 0 | final_dot += query_fac.c1 * dot_qo; |
236 | | // normalizer coefficients |
237 | 0 | final_dot += query_fac.c2 * sum_q; |
238 | | // normalizer coefficients |
239 | 0 | final_dot -= query_fac.c34; |
240 | | |
241 | | // this is ||or - c||^2 - (IP ? ||or||^2 : 0) |
242 | 0 | const float or_c_l2sqr = fac->or_minus_c_l2sqr; |
243 | | |
244 | | // pre_dist = ||or - c||^2 + ||qr - c||^2 - |
245 | | // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0) |
246 | 0 | const float pre_dist = or_c_l2sqr + query_fac.qr_to_c_L2sqr - |
247 | 0 | 2 * fac->dp_multiplier * final_dot; |
248 | |
|
249 | 0 | if (metric_type == MetricType::METRIC_L2) { |
250 | | // ||or - q||^ 2 |
251 | 0 | return pre_dist; |
252 | 0 | } else { |
253 | | // metric == MetricType::METRIC_INNER_PRODUCT |
254 | | |
255 | | // this is ||q||^2 |
256 | 0 | const float query_norm_sqr = query_fac.qr_norm_L2sqr; |
257 | | |
258 | | // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2) |
259 | 0 | return -0.5f * (pre_dist - query_norm_sqr); |
260 | 0 | } |
261 | 0 | } |
262 | | |
263 | 0 | void RaBitDistanceComputerNotQ::set_query(const float* x) { |
264 | 0 | FAISS_ASSERT(x != nullptr); |
265 | 0 | FAISS_ASSERT( |
266 | 0 | (metric_type == MetricType::METRIC_L2 || |
267 | 0 | metric_type == MetricType::METRIC_INNER_PRODUCT)); |
268 | | |
269 | | // compute the distance from the query to the centroid |
270 | 0 | if (centroid != nullptr) { |
271 | 0 | query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d); |
272 | 0 | } else { |
273 | 0 | query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d); |
274 | 0 | } |
275 | | |
276 | | // subtract c, obtain P^(-1)(qr - c) |
277 | 0 | rotated_q.resize(d); |
278 | 0 | for (size_t i = 0; i < d; i++) { |
279 | 0 | rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]); |
280 | 0 | } |
281 | | |
282 | | // compute some numbers |
283 | 0 | const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d)); |
284 | | |
285 | | // do not quantize the query |
286 | 0 | float sum_q = 0; |
287 | 0 | for (size_t i = 0; i < d; i++) { |
288 | 0 | sum_q += rotated_q[i]; |
289 | 0 | } |
290 | |
|
291 | 0 | query_fac.c1 = 2 * inv_d; |
292 | 0 | query_fac.c2 = 0; |
293 | 0 | query_fac.c34 = sum_q * inv_d; |
294 | |
|
295 | 0 | if (metric_type == MetricType::METRIC_INNER_PRODUCT) { |
296 | | // precompute if needed |
297 | 0 | query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d); |
298 | 0 | } |
299 | 0 | } |
300 | | |
301 | | // |
302 | | struct RaBitDistanceComputerQ : RaBitDistanceComputer { |
303 | | // the rotated and quantized query (qr - c) |
304 | | std::vector<uint8_t> rotated_qq; |
305 | | // we're using the proposed relayout-ed scheme from 3.3 that allows |
306 | | // using popcounts for computing the distance. |
307 | | std::vector<uint8_t> rearranged_rotated_qq; |
308 | | // some additional numbers for the query |
309 | | QueryFactorsData query_fac; |
310 | | |
311 | | // the number of bits for SQ quantization of the query (qb > 0) |
312 | | uint8_t qb = 8; |
313 | | // the smallest value divisible by 8 that is not smaller than dim |
314 | | size_t popcount_aligned_dim = 0; |
315 | | |
316 | | RaBitDistanceComputerQ(); |
317 | | |
318 | | float distance_to_code(const uint8_t* code) override; |
319 | | |
320 | | void set_query(const float* x) override; |
321 | | }; |
322 | | |
323 | 0 | RaBitDistanceComputerQ::RaBitDistanceComputerQ() = default; |
324 | | |
325 | 0 | float RaBitDistanceComputerQ::distance_to_code(const uint8_t* code) { |
326 | 0 | FAISS_ASSERT(code != nullptr); |
327 | 0 | FAISS_ASSERT( |
328 | 0 | (metric_type == MetricType::METRIC_L2 || |
329 | 0 | metric_type == MetricType::METRIC_INNER_PRODUCT)); |
330 | | |
331 | | // split the code into parts |
332 | 0 | const uint8_t* binary_data = code; |
333 | 0 | const FactorsData* fac = |
334 | 0 | reinterpret_cast<const FactorsData*>(code + (d + 7) / 8); |
335 | | |
336 | | // // this is the baseline code |
337 | | // // |
338 | | // // compute <q,o> using integers |
339 | | // size_t dot_qo = 0; |
340 | | // for (size_t i = 0; i < d; i++) { |
341 | | // // extract i-th bit |
342 | | // const uint8_t masker = (1 << (i % 8)); |
343 | | // const uint8_t bit = ((binary_data[i / 8] & masker) == masker) ? 1 : |
344 | | // 0; |
345 | | // |
346 | | // // accumulate dp |
347 | | // dot_qo += bit * rotated_qq[i]; |
348 | | // } |
349 | | |
350 | | // this is the scheme for popcount |
351 | 0 | const size_t di_8b = (d + 7) / 8; |
352 | 0 | const size_t di_64b = (di_8b / 8) * 8; |
353 | |
|
354 | 0 | uint64_t dot_qo = 0; |
355 | 0 | for (size_t j = 0; j < qb; j++) { |
356 | 0 | const uint8_t* query_j = rearranged_rotated_qq.data() + j * di_8b; |
357 | | |
358 | | // process 64-bit popcounts |
359 | 0 | uint64_t count_dot = 0; |
360 | 0 | for (size_t i = 0; i < di_64b; i += 8) { |
361 | 0 | const auto qv = *(const uint64_t*)(query_j + i); |
362 | 0 | const auto yv = *(const uint64_t*)(binary_data + i); |
363 | 0 | count_dot += __builtin_popcountll(qv & yv); |
364 | 0 | } |
365 | | |
366 | | // process leftovers |
367 | 0 | for (size_t i = di_64b; i < di_8b; i++) { |
368 | 0 | const auto qv = *(query_j + i); |
369 | 0 | const auto yv = *(binary_data + i); |
370 | 0 | count_dot += __builtin_popcount(qv & yv); |
371 | 0 | } |
372 | |
|
373 | 0 | dot_qo += (count_dot << j); |
374 | 0 | } |
375 | | |
376 | | // It was a willful decision (after the discussion) to not to pre-cache |
377 | | // the sum of all bits, just in order to reduce the overhead per vector. |
378 | 0 | uint64_t sum_q = 0; |
379 | 0 | { |
380 | | // process 64-bit popcounts |
381 | 0 | for (size_t i = 0; i < di_64b; i += 8) { |
382 | 0 | const auto yv = *(const uint64_t*)(binary_data + i); |
383 | 0 | sum_q += __builtin_popcountll(yv); |
384 | 0 | } |
385 | | |
386 | | // process leftovers |
387 | 0 | for (size_t i = di_64b; i < di_8b; i++) { |
388 | 0 | const auto yv = *(binary_data + i); |
389 | 0 | sum_q += __builtin_popcount(yv); |
390 | 0 | } |
391 | 0 | } |
392 | |
|
393 | 0 | float final_dot = 0; |
394 | | // dot-product itself |
395 | 0 | final_dot += query_fac.c1 * dot_qo; |
396 | | // normalizer coefficients |
397 | 0 | final_dot += query_fac.c2 * sum_q; |
398 | | // normalizer coefficients |
399 | 0 | final_dot -= query_fac.c34; |
400 | | |
401 | | // this is ||or - c||^2 - (IP ? ||or||^2 : 0) |
402 | 0 | const float or_c_l2sqr = fac->or_minus_c_l2sqr; |
403 | | |
404 | | // pre_dist = ||or - c||^2 + ||qr - c||^2 - |
405 | | // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0) |
406 | 0 | const float pre_dist = or_c_l2sqr + query_fac.qr_to_c_L2sqr - |
407 | 0 | 2 * fac->dp_multiplier * final_dot; |
408 | |
|
409 | 0 | if (metric_type == MetricType::METRIC_L2) { |
410 | | // ||or - q||^ 2 |
411 | 0 | return pre_dist; |
412 | 0 | } else { |
413 | | // metric == MetricType::METRIC_INNER_PRODUCT |
414 | | |
415 | | // this is ||q||^2 |
416 | 0 | const float query_norm_sqr = query_fac.qr_norm_L2sqr; |
417 | | |
418 | | // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2) |
419 | 0 | return -0.5f * (pre_dist - query_norm_sqr); |
420 | 0 | } |
421 | 0 | } |
422 | | |
423 | 0 | void RaBitDistanceComputerQ::set_query(const float* x) { |
424 | 0 | FAISS_ASSERT(x != nullptr); |
425 | 0 | FAISS_ASSERT( |
426 | 0 | (metric_type == MetricType::METRIC_L2 || |
427 | 0 | metric_type == MetricType::METRIC_INNER_PRODUCT)); |
428 | | |
429 | | // compute the distance from the query to the centroid |
430 | 0 | if (centroid != nullptr) { |
431 | 0 | query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d); |
432 | 0 | } else { |
433 | 0 | query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d); |
434 | 0 | } |
435 | | |
436 | | // allocate space |
437 | 0 | rotated_qq.resize(d); |
438 | | |
439 | | // rotate the query |
440 | 0 | std::vector<float> rotated_q(d); |
441 | 0 | for (size_t i = 0; i < d; i++) { |
442 | 0 | rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]); |
443 | 0 | } |
444 | | |
445 | | // compute some numbers |
446 | 0 | const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d)); |
447 | | |
448 | | // quantize the query. compute min and max |
449 | 0 | float v_min = std::numeric_limits<float>::max(); |
450 | 0 | float v_max = std::numeric_limits<float>::lowest(); |
451 | 0 | for (size_t i = 0; i < d; i++) { |
452 | 0 | const float v_q = rotated_q[i]; |
453 | 0 | v_min = std::min(v_min, v_q); |
454 | 0 | v_max = std::max(v_max, v_q); |
455 | 0 | } |
456 | |
|
457 | 0 | const float pow_2_qb = 1 << qb; |
458 | |
|
459 | 0 | const float delta = (v_max - v_min) / (pow_2_qb - 1); |
460 | 0 | const float inv_delta = 1.0f / delta; |
461 | |
|
462 | 0 | size_t sum_qq = 0; |
463 | 0 | for (int32_t i = 0; i < d; i++) { |
464 | 0 | const float v_q = rotated_q[i]; |
465 | | |
466 | | // a default non-randomized SQ |
467 | 0 | const int v_qq = std::round((v_q - v_min) * inv_delta); |
468 | |
|
469 | 0 | rotated_qq[i] = std::min(255, std::max(0, v_qq)); |
470 | 0 | sum_qq += v_qq; |
471 | 0 | } |
472 | | |
473 | | // rearrange the query vector |
474 | 0 | popcount_aligned_dim = ((d + 7) / 8) * 8; |
475 | 0 | size_t offset = (d + 7) / 8; |
476 | |
|
477 | 0 | rearranged_rotated_qq.resize(offset * qb); |
478 | 0 | std::fill(rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0); |
479 | |
|
480 | 0 | for (size_t idim = 0; idim < d; idim++) { |
481 | 0 | for (size_t iv = 0; iv < qb; iv++) { |
482 | 0 | const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0); |
483 | 0 | rearranged_rotated_qq[iv * offset + idim / 8] |= |
484 | 0 | bit ? (1 << (idim % 8)) : 0; |
485 | 0 | } |
486 | 0 | } |
487 | |
|
488 | 0 | query_fac.c1 = 2 * delta * inv_d; |
489 | 0 | query_fac.c2 = 2 * v_min * inv_d; |
490 | 0 | query_fac.c34 = inv_d * (delta * sum_qq + d * v_min); |
491 | |
|
492 | 0 | if (metric_type == MetricType::METRIC_INNER_PRODUCT) { |
493 | | // precompute if needed |
494 | 0 | query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d); |
495 | 0 | } |
496 | 0 | } |
497 | | |
498 | | FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer( |
499 | | uint8_t qb, |
500 | 0 | const float* centroid_in) const { |
501 | 0 | if (qb == 0) { |
502 | 0 | auto dc = std::make_unique<RaBitDistanceComputerNotQ>(); |
503 | 0 | dc->metric_type = metric_type; |
504 | 0 | dc->d = d; |
505 | 0 | dc->centroid = centroid_in; |
506 | |
|
507 | 0 | return dc.release(); |
508 | 0 | } else { |
509 | 0 | auto dc = std::make_unique<RaBitDistanceComputerQ>(); |
510 | 0 | dc->metric_type = metric_type; |
511 | 0 | dc->d = d; |
512 | 0 | dc->centroid = centroid_in; |
513 | 0 | dc->qb = qb; |
514 | |
|
515 | 0 | return dc.release(); |
516 | 0 | } |
517 | 0 | } |
518 | | |
519 | | } // namespace faiss |