/root/doris/contrib/faiss/faiss/IndexLSH.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexLSH.h> |
9 | | |
10 | | #include <cstdio> |
11 | | #include <cstring> |
12 | | |
13 | | #include <algorithm> |
14 | | #include <memory> |
15 | | |
16 | | #include <faiss/impl/FaissAssert.h> |
17 | | #include <faiss/utils/hamming.h> |
18 | | |
19 | | namespace faiss { |
20 | | |
21 | | /*************************************************************** |
22 | | * IndexLSH |
23 | | ***************************************************************/ |
24 | | |
25 | | IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds) |
26 | 0 | : IndexFlatCodes((nbits + 7) / 8, d), |
27 | 0 | nbits(nbits), |
28 | 0 | rotate_data(rotate_data), |
29 | 0 | train_thresholds(train_thresholds), |
30 | 0 | rrot(d, nbits) { |
31 | 0 | is_trained = !train_thresholds; |
32 | |
|
33 | 0 | if (rotate_data) { |
34 | 0 | rrot.init(5); |
35 | 0 | } else { |
36 | 0 | FAISS_THROW_IF_NOT(d >= nbits); |
37 | 0 | } |
38 | 0 | } |
39 | | |
40 | 0 | IndexLSH::IndexLSH() : nbits(0), rotate_data(false), train_thresholds(false) {} |
41 | | |
42 | 0 | const float* IndexLSH::apply_preprocess(idx_t n, const float* x) const { |
43 | 0 | float* xt = nullptr; |
44 | 0 | if (rotate_data) { |
45 | | // also applies bias if exists |
46 | 0 | xt = rrot.apply(n, x); |
47 | 0 | } else if (d != nbits) { |
48 | 0 | assert(nbits < d); |
49 | 0 | xt = new float[nbits * n]; |
50 | 0 | float* xp = xt; |
51 | 0 | for (idx_t i = 0; i < n; i++) { |
52 | 0 | const float* xl = x + i * d; |
53 | 0 | for (int j = 0; j < nbits; j++) |
54 | 0 | *xp++ = xl[j]; |
55 | 0 | } |
56 | 0 | } |
57 | | |
58 | 0 | if (train_thresholds) { |
59 | 0 | if (xt == nullptr) { |
60 | 0 | xt = new float[nbits * n]; |
61 | 0 | memcpy(xt, x, sizeof(*x) * n * nbits); |
62 | 0 | } |
63 | |
|
64 | 0 | float* xp = xt; |
65 | 0 | for (idx_t i = 0; i < n; i++) |
66 | 0 | for (int j = 0; j < nbits; j++) |
67 | 0 | *xp++ -= thresholds[j]; |
68 | 0 | } |
69 | |
|
70 | 0 | return xt ? xt : x; |
71 | 0 | } |
72 | | |
73 | 0 | void IndexLSH::train(idx_t n, const float* x) { |
74 | 0 | if (train_thresholds) { |
75 | 0 | thresholds.resize(nbits); |
76 | 0 | train_thresholds = false; |
77 | 0 | const float* xt = apply_preprocess(n, x); |
78 | 0 | std::unique_ptr<const float[]> del(xt == x ? nullptr : xt); |
79 | 0 | train_thresholds = true; |
80 | |
|
81 | 0 | std::unique_ptr<float[]> transposed_x(new float[n * nbits]); |
82 | |
|
83 | 0 | for (idx_t i = 0; i < n; i++) |
84 | 0 | for (idx_t j = 0; j < nbits; j++) |
85 | 0 | transposed_x[j * n + i] = xt[i * nbits + j]; |
86 | |
|
87 | 0 | for (idx_t i = 0; i < nbits; i++) { |
88 | 0 | float* xi = transposed_x.get() + i * n; |
89 | | // std::nth_element |
90 | 0 | std::sort(xi, xi + n); |
91 | 0 | if (n % 2 == 1) |
92 | 0 | thresholds[i] = xi[n / 2]; |
93 | 0 | else |
94 | 0 | thresholds[i] = (xi[n / 2 - 1] + xi[n / 2]) / 2; |
95 | 0 | } |
96 | 0 | } |
97 | 0 | is_trained = true; |
98 | 0 | } |
99 | | |
100 | | void IndexLSH::search( |
101 | | idx_t n, |
102 | | const float* x, |
103 | | idx_t k, |
104 | | float* distances, |
105 | | idx_t* labels, |
106 | 0 | const SearchParameters* params) const { |
107 | 0 | FAISS_THROW_IF_NOT_MSG( |
108 | 0 | !params, "search params not supported for this index"); |
109 | 0 | FAISS_THROW_IF_NOT(k > 0); |
110 | 0 | FAISS_THROW_IF_NOT(is_trained); |
111 | 0 | const float* xt = apply_preprocess(n, x); |
112 | 0 | std::unique_ptr<const float[]> del(xt == x ? nullptr : xt); |
113 | |
|
114 | 0 | std::unique_ptr<uint8_t[]> qcodes(new uint8_t[n * code_size]); |
115 | |
|
116 | 0 | fvecs2bitvecs(xt, qcodes.get(), nbits, n); |
117 | |
|
118 | 0 | std::unique_ptr<int[]> idistances(new int[n * k]); |
119 | |
|
120 | 0 | int_maxheap_array_t res = {size_t(n), size_t(k), labels, idistances.get()}; |
121 | |
|
122 | 0 | hammings_knn_hc(&res, qcodes.get(), codes.data(), ntotal, code_size, true); |
123 | | |
124 | | // convert distances to floats |
125 | 0 | for (int i = 0; i < k * n; i++) |
126 | 0 | distances[i] = idistances[i]; |
127 | 0 | } |
128 | | |
129 | 0 | void IndexLSH::transfer_thresholds(LinearTransform* vt) { |
130 | 0 | if (!train_thresholds) |
131 | 0 | return; |
132 | 0 | FAISS_THROW_IF_NOT(nbits == vt->d_out); |
133 | 0 | if (!vt->have_bias) { |
134 | 0 | vt->b.resize(nbits, 0); |
135 | 0 | vt->have_bias = true; |
136 | 0 | } |
137 | 0 | for (int i = 0; i < nbits; i++) |
138 | 0 | vt->b[i] -= thresholds[i]; |
139 | 0 | train_thresholds = false; |
140 | 0 | thresholds.clear(); |
141 | 0 | } |
142 | | |
143 | 0 | void IndexLSH::sa_encode(idx_t n, const float* x, uint8_t* bytes) const { |
144 | 0 | FAISS_THROW_IF_NOT(is_trained); |
145 | 0 | const float* xt = apply_preprocess(n, x); |
146 | 0 | std::unique_ptr<const float[]> del(xt == x ? nullptr : xt); |
147 | 0 | fvecs2bitvecs(xt, bytes, nbits, n); |
148 | 0 | } |
149 | | |
150 | 0 | void IndexLSH::sa_decode(idx_t n, const uint8_t* bytes, float* x) const { |
151 | 0 | float* xt = x; |
152 | 0 | std::unique_ptr<float[]> del; |
153 | 0 | if (rotate_data || nbits != d) { |
154 | 0 | xt = new float[n * nbits]; |
155 | 0 | del.reset(xt); |
156 | 0 | } |
157 | 0 | bitvecs2fvecs(bytes, xt, nbits, n); |
158 | |
|
159 | 0 | if (train_thresholds) { |
160 | 0 | float* xp = xt; |
161 | 0 | for (idx_t i = 0; i < n; i++) { |
162 | 0 | for (int j = 0; j < nbits; j++) { |
163 | 0 | *xp++ += thresholds[j]; |
164 | 0 | } |
165 | 0 | } |
166 | 0 | } |
167 | |
|
168 | 0 | if (rotate_data) { |
169 | 0 | rrot.reverse_transform(n, xt, x); |
170 | 0 | } else if (nbits != d) { |
171 | 0 | for (idx_t i = 0; i < n; i++) { |
172 | 0 | memcpy(x + i * d, xt + i * nbits, nbits * sizeof(xt[0])); |
173 | 0 | } |
174 | 0 | } |
175 | 0 | } |
176 | | |
177 | | } // namespace faiss |