/root/doris/contrib/faiss/faiss/IndexPQFastScan.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexPQFastScan.h> |
9 | | |
10 | | #include <cassert> |
11 | | #include <memory> |
12 | | |
13 | | #include <faiss/impl/FaissAssert.h> |
14 | | #include <faiss/impl/pq4_fast_scan.h> |
15 | | #include <faiss/utils/utils.h> |
16 | | |
17 | | namespace faiss { |
18 | | |
19 | 0 | inline size_t roundup(size_t a, size_t b) { |
20 | 0 | return (a + b - 1) / b * b; |
21 | 0 | } |
22 | | |
23 | | IndexPQFastScan::IndexPQFastScan( |
24 | | int d, |
25 | | size_t M, |
26 | | size_t nbits, |
27 | | MetricType metric, |
28 | | int bbs) |
29 | 0 | : pq(d, M, nbits) { |
30 | 0 | init_fastscan(d, M, nbits, metric, bbs); |
31 | 0 | } |
32 | | |
33 | 0 | IndexPQFastScan::IndexPQFastScan(const IndexPQ& orig, int bbs) : pq(orig.pq) { |
34 | 0 | init_fastscan(orig.d, pq.M, pq.nbits, orig.metric_type, bbs); |
35 | 0 | ntotal = orig.ntotal; |
36 | 0 | ntotal2 = roundup(ntotal, bbs); |
37 | 0 | is_trained = orig.is_trained; |
38 | 0 | orig_codes = orig.codes.data(); |
39 | | |
40 | | // pack the codes |
41 | 0 | codes.resize(ntotal2 * M2 / 2); |
42 | 0 | pq4_pack_codes(orig.codes.data(), ntotal, M, ntotal2, bbs, M2, codes.get()); |
43 | 0 | } |
44 | | |
45 | 0 | void IndexPQFastScan::train(idx_t n, const float* x) { |
46 | 0 | if (is_trained) { |
47 | 0 | return; |
48 | 0 | } |
49 | 0 | pq.train(n, x); |
50 | 0 | is_trained = true; |
51 | 0 | } |
52 | | |
53 | | void IndexPQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x) |
54 | 0 | const { |
55 | 0 | pq.compute_codes(x, codes, n); |
56 | 0 | } |
57 | | |
58 | | void IndexPQFastScan::compute_float_LUT(float* lut, idx_t n, const float* x) |
59 | 0 | const { |
60 | 0 | if (metric_type == METRIC_L2) { |
61 | 0 | pq.compute_distance_tables(n, x, lut); |
62 | 0 | } else { |
63 | 0 | pq.compute_inner_prod_tables(n, x, lut); |
64 | 0 | } |
65 | 0 | } |
66 | | |
67 | 0 | void IndexPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x) const { |
68 | 0 | pq.decode(bytes, x, n); |
69 | 0 | } |
70 | | |
71 | | } // namespace faiss |