contrib/faiss/faiss/utils/quantize_lut.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/utils/quantize_lut.h> |
9 | | |
10 | | #include <algorithm> |
11 | | #include <cmath> |
12 | | #include <cstring> |
13 | | #include <vector> |
14 | | |
15 | | #include <faiss/impl/FaissAssert.h> |
16 | | |
17 | | namespace faiss { |
18 | | |
19 | | namespace quantize_lut { |
20 | | |
21 | | /****************************************************** |
22 | | * Quantize look-up tables |
23 | | ******************************************************/ |
24 | | |
25 | | namespace { |
26 | | |
27 | | // there can be NaNs in tables, they should be ignored |
28 | 0 | float tab_min(const float* tab, size_t n) { |
29 | 0 | float min = HUGE_VAL; |
30 | 0 | for (int i = 0; i < n; i++) { |
31 | 0 | if (tab[i] < min) |
32 | 0 | min = tab[i]; |
33 | 0 | } |
34 | 0 | return min; |
35 | 0 | } |
36 | | |
37 | 0 | float tab_max(const float* tab, size_t n) { |
38 | 0 | float max = -HUGE_VAL; |
39 | 0 | for (int i = 0; i < n; i++) { |
40 | 0 | if (tab[i] > max) |
41 | 0 | max = tab[i]; |
42 | 0 | } |
43 | 0 | return max; |
44 | 0 | } |
45 | | |
46 | 0 | void round_tab(float* tab, size_t n, float a, float bi) { |
47 | 0 | for (int i = 0; i < n; i++) { |
48 | 0 | tab[i] = floorf((tab[i] - bi) * a + 0.5); |
49 | 0 | } |
50 | 0 | } |
51 | | |
52 | | template <typename T> |
53 | 0 | void round_tab(const float* tab, size_t n, float a, float bi, T* tab_out) { |
54 | 0 | for (int i = 0; i < n; i++) { |
55 | 0 | tab_out[i] = (T)floorf((tab[i] - bi) * a + 0.5); |
56 | 0 | } |
57 | 0 | } Unexecuted instantiation: quantize_lut.cpp:_ZN5faiss12quantize_lut12_GLOBAL__N_19round_tabIhEEvPKfmffPT_ Unexecuted instantiation: quantize_lut.cpp:_ZN5faiss12quantize_lut12_GLOBAL__N_19round_tabItEEvPKfmffPT_ |
58 | | |
59 | | } // anonymous namespace |
60 | | |
61 | | void round_uint8_per_column( |
62 | | float* tab, |
63 | | size_t n, |
64 | | size_t d, |
65 | | float* a_out, |
66 | 0 | float* b_out) { |
67 | 0 | float max_span = 0; |
68 | 0 | std::vector<float> mins(n); |
69 | 0 | for (int i = 0; i < n; i++) { |
70 | 0 | mins[i] = tab_min(tab + i * d, d); |
71 | 0 | float span = tab_max(tab + i * d, d) - mins[i]; |
72 | 0 | if (span > max_span) { |
73 | 0 | max_span = span; |
74 | 0 | } |
75 | 0 | } |
76 | 0 | float a = 255 / max_span; |
77 | 0 | float b = 0; |
78 | 0 | for (int i = 0; i < n; i++) { |
79 | 0 | b += mins[i]; |
80 | 0 | round_tab(tab + i * d, d, a, mins[i]); |
81 | 0 | } |
82 | 0 | if (a_out) |
83 | 0 | *a_out = a; |
84 | 0 | if (b_out) |
85 | 0 | *b_out = b; |
86 | 0 | } |
87 | | |
88 | | void round_uint8_per_column_multi( |
89 | | float* tab, |
90 | | size_t m, |
91 | | size_t n, |
92 | | size_t d, |
93 | | float* a_out, |
94 | 0 | float* b_out) { |
95 | 0 | float max_span = 0; |
96 | 0 | std::vector<float> mins(n); |
97 | 0 | for (int i = 0; i < n; i++) { |
98 | 0 | float min_i = HUGE_VAL; |
99 | 0 | float max_i = -HUGE_VAL; |
100 | 0 | for (int j = 0; j < m; j++) { |
101 | 0 | min_i = std::min(min_i, tab_min(tab + (j * n + i) * d, d)); |
102 | 0 | max_i = std::max(max_i, tab_max(tab + (j * n + i) * d, d)); |
103 | 0 | } |
104 | 0 | mins[i] = min_i; |
105 | 0 | float span = max_i - min_i; |
106 | 0 | if (span > max_span) { |
107 | 0 | max_span = span; |
108 | 0 | } |
109 | 0 | } |
110 | 0 | float a = 255 / max_span; |
111 | 0 | float b = 0; |
112 | 0 | for (int i = 0; i < n; i++) { |
113 | 0 | b += mins[i]; |
114 | 0 | for (int j = 0; j < m; j++) { |
115 | 0 | round_tab(tab + (j * n + i) * d, d, a, mins[i]); |
116 | 0 | } |
117 | 0 | } |
118 | 0 | if (a_out) |
119 | 0 | *a_out = a; |
120 | 0 | if (b_out) |
121 | 0 | *b_out = b; |
122 | 0 | } |
123 | | |
124 | | // translation of |
125 | | // https://github.com/fairinternal/faiss_improvements/blob/7122c3cc6ddb0a371d8aa6f1309cd8bcf2335e61/LUT_quantization.ipynb |
126 | | void quantize_LUT_and_bias( |
127 | | size_t nprobe, |
128 | | size_t M, |
129 | | size_t ksub, |
130 | | bool lut_is_3d, |
131 | | const float* LUT, |
132 | | const float* bias, |
133 | | uint8_t* LUTq, |
134 | | size_t M2, |
135 | | uint16_t* biasq, |
136 | | float* a_out, |
137 | 0 | float* b_out) { |
138 | 0 | float a, b; |
139 | 0 | if (!bias) { |
140 | 0 | FAISS_THROW_IF_NOT(!lut_is_3d); |
141 | 0 | std::vector<float> mins(M); |
142 | 0 | float max_span_LUT = -HUGE_VAL, max_span_dis = 0; |
143 | 0 | b = 0; |
144 | 0 | for (int i = 0; i < M; i++) { |
145 | 0 | mins[i] = tab_min(LUT + i * ksub, ksub); |
146 | 0 | float span = tab_max(LUT + i * ksub, ksub) - mins[i]; |
147 | 0 | max_span_LUT = std::max(max_span_LUT, span); |
148 | 0 | max_span_dis += span; |
149 | 0 | b += mins[i]; |
150 | 0 | } |
151 | 0 | a = std::min(255 / max_span_LUT, 65535 / max_span_dis); |
152 | |
|
153 | 0 | for (int i = 0; i < M; i++) { |
154 | 0 | round_tab(LUT + i * ksub, ksub, a, mins[i], LUTq + i * ksub); |
155 | 0 | } |
156 | 0 | memset(LUTq + M * ksub, 0, ksub * (M2 - M)); |
157 | 0 | } else if (!lut_is_3d) { |
158 | 0 | std::vector<float> mins(M); |
159 | 0 | float max_span_LUT = -HUGE_VAL, max_span_dis; |
160 | 0 | float bias_min = tab_min(bias, nprobe); |
161 | 0 | float bias_max = tab_max(bias, nprobe); |
162 | 0 | max_span_dis = bias_max - bias_min; |
163 | 0 | b = 0; |
164 | 0 | for (int i = 0; i < M; i++) { |
165 | 0 | mins[i] = tab_min(LUT + i * ksub, ksub); |
166 | 0 | float span = tab_max(LUT + i * ksub, ksub) - mins[i]; |
167 | 0 | max_span_LUT = std::max(max_span_LUT, span); |
168 | 0 | max_span_dis += span; |
169 | 0 | b += mins[i]; |
170 | 0 | } |
171 | 0 | a = std::min(255 / max_span_LUT, 65535 / max_span_dis); |
172 | 0 | b += bias_min; |
173 | |
|
174 | 0 | for (int i = 0; i < M; i++) { |
175 | 0 | round_tab(LUT + i * ksub, ksub, a, mins[i], LUTq + i * ksub); |
176 | 0 | } |
177 | 0 | memset(LUTq + M * ksub, 0, ksub * (M2 - M)); |
178 | 0 | round_tab(bias, nprobe, a, bias_min, biasq); |
179 | |
|
180 | 0 | } else if (biasq) { |
181 | | // LUT is 3D |
182 | 0 | std::vector<float> mins(nprobe * M); |
183 | 0 | std::vector<float> bias2(nprobe); |
184 | 0 | float bias_min = tab_min(bias, nprobe); |
185 | 0 | float max_span_LUT = -HUGE_VAL, max_span_dis = -HUGE_VAL; |
186 | |
|
187 | 0 | b = HUGE_VAL; |
188 | 0 | size_t ij = 0; |
189 | 0 | for (int j = 0; j < nprobe; j++) { |
190 | 0 | float max_span_dis_j = bias[j] - bias_min; |
191 | 0 | float b2j = bias[j]; |
192 | 0 | for (int i = 0; i < M; i++) { |
193 | 0 | mins[ij] = tab_min(LUT + ij * ksub, ksub); |
194 | 0 | float span = tab_max(LUT + ij * ksub, ksub) - mins[ij]; |
195 | 0 | max_span_LUT = std::max(max_span_LUT, span); |
196 | 0 | max_span_dis_j += span; |
197 | 0 | b2j += mins[ij]; |
198 | 0 | ij++; |
199 | 0 | } |
200 | 0 | max_span_dis = std::max(max_span_dis, max_span_dis_j); |
201 | 0 | bias2[j] = b2j; |
202 | 0 | b = std::min(b, b2j); |
203 | 0 | } |
204 | |
|
205 | 0 | a = std::min(255 / max_span_LUT, 65535 / max_span_dis); |
206 | |
|
207 | 0 | ij = 0; |
208 | 0 | size_t ij_2 = 0; |
209 | 0 | for (int j = 0; j < nprobe; j++) { |
210 | 0 | for (int i = 0; i < M; i++) { |
211 | 0 | round_tab( |
212 | 0 | LUT + ij * ksub, ksub, a, mins[ij], LUTq + ij_2 * ksub); |
213 | 0 | ij++; |
214 | 0 | ij_2++; |
215 | 0 | } |
216 | 0 | memset(LUTq + ij_2 * ksub, 0, ksub * (M2 - M)); |
217 | 0 | ij_2 += M2 - M; |
218 | 0 | } |
219 | |
|
220 | 0 | round_tab(bias2.data(), nprobe, a, b, biasq); |
221 | |
|
222 | 0 | } else { // !biasq |
223 | | // then we integrate the bias into the LUTs |
224 | 0 | std::vector<float> LUT2_storage(nprobe * M * ksub); |
225 | 0 | float* LUT2 = LUT2_storage.data(); |
226 | 0 | size_t ijc = 0; |
227 | 0 | for (int j = 0; j < nprobe; j++) { |
228 | 0 | float bias_j = bias[j] / M; |
229 | 0 | for (int i = 0; i < M; i++) { |
230 | 0 | for (int c = 0; c < ksub; c++) { |
231 | 0 | LUT2[ijc] = LUT[ijc] + bias_j; |
232 | 0 | ijc++; |
233 | 0 | } |
234 | 0 | } |
235 | 0 | } |
236 | 0 | std::vector<float> mins(M, HUGE_VAL), maxs(M, -HUGE_VAL); |
237 | 0 | size_t ij = 0; |
238 | 0 | for (int j = 0; j < nprobe; j++) { |
239 | 0 | for (int i = 0; i < M; i++) { |
240 | 0 | mins[i] = std::min(mins[i], tab_min(LUT2 + ij * ksub, ksub)); |
241 | 0 | maxs[i] = std::max(maxs[i], tab_max(LUT2 + ij * ksub, ksub)); |
242 | 0 | ij++; |
243 | 0 | } |
244 | 0 | } |
245 | |
|
246 | 0 | float max_span = -HUGE_VAL; |
247 | 0 | b = 0; |
248 | 0 | for (int i = 0; i < M; i++) { |
249 | 0 | float span = maxs[i] - mins[i]; |
250 | 0 | max_span = std::max(max_span, span); |
251 | 0 | b += mins[i]; |
252 | 0 | } |
253 | 0 | a = 255 / max_span; |
254 | 0 | ij = 0; |
255 | 0 | size_t ij_2 = 0; |
256 | 0 | for (int j = 0; j < nprobe; j++) { |
257 | 0 | for (int i = 0; i < M; i++) { |
258 | 0 | round_tab( |
259 | 0 | LUT2 + ij * ksub, ksub, a, mins[i], LUTq + ij_2 * ksub); |
260 | 0 | ij++; |
261 | 0 | ij_2++; |
262 | 0 | } |
263 | 0 | memset(LUTq + ij_2 * ksub, 0, ksub * (M2 - M)); |
264 | 0 | ij_2 += M2 - M; |
265 | 0 | } |
266 | 0 | } |
267 | 0 | if (a_out) |
268 | 0 | *a_out = a; |
269 | 0 | if (b_out) |
270 | 0 | *b_out = b; |
271 | 0 | } |
272 | | |
273 | | void aq_quantize_LUT_and_bias( |
274 | | size_t nprobe, |
275 | | size_t M, |
276 | | size_t ksub, |
277 | | const float* LUT, |
278 | | const float* bias, |
279 | | size_t M_norm, |
280 | | int norm_scale, |
281 | | uint8_t* LUTq, |
282 | | size_t M2, |
283 | | uint16_t* biasq, |
284 | | float* a_out, |
285 | 0 | float* b_out) { |
286 | 0 | float a, b; |
287 | 0 | std::vector<float> mins(M); |
288 | 0 | float max_span_LUT = -HUGE_VAL, max_span_dis; |
289 | 0 | float bias_min = tab_min(bias, nprobe); |
290 | 0 | float bias_max = tab_max(bias, nprobe); |
291 | 0 | max_span_dis = bias_max - bias_min; |
292 | 0 | b = 0; |
293 | 0 | for (int i = 0; i < M; i++) { |
294 | 0 | mins[i] = tab_min(LUT + i * ksub, ksub); |
295 | 0 | float span = tab_max(LUT + i * ksub, ksub) - mins[i]; |
296 | 0 | max_span_LUT = std::max(max_span_LUT, span); |
297 | 0 | max_span_dis += (i >= M - M_norm ? span * norm_scale : span); |
298 | 0 | b += mins[i]; |
299 | 0 | } |
300 | 0 | a = std::min(255 / max_span_LUT, 65535 / max_span_dis); |
301 | 0 | b += bias_min; |
302 | |
|
303 | 0 | for (int i = 0; i < M; i++) { |
304 | 0 | round_tab(LUT + i * ksub, ksub, a, mins[i], LUTq + i * ksub); |
305 | 0 | } |
306 | 0 | memset(LUTq + M * ksub, 0, ksub * (M2 - M)); |
307 | 0 | round_tab(bias, nprobe, a, bias_min, biasq); |
308 | |
|
309 | 0 | *a_out = a; |
310 | 0 | *b_out = b; |
311 | 0 | } |
312 | | |
313 | | float aq_estimate_norm_scale( |
314 | | size_t M, |
315 | | size_t ksub, |
316 | | size_t M_norm, |
317 | 0 | const float* LUT) { |
318 | 0 | float max_span_LUT = -HUGE_VAL; |
319 | 0 | for (int i = 0; i < M - M_norm; i++) { |
320 | 0 | float min = tab_min(LUT + i * ksub, ksub); |
321 | 0 | float span = tab_max(LUT + i * ksub, ksub) - min; |
322 | 0 | max_span_LUT = std::max(max_span_LUT, span); |
323 | 0 | } |
324 | |
|
325 | 0 | float max_span_LUT_norm = -HUGE_VAL; |
326 | 0 | for (int i = M - M_norm; i < M; i++) { |
327 | 0 | float min = tab_min(LUT + i * ksub, ksub); |
328 | 0 | float span = tab_max(LUT + i * ksub, ksub) - min; |
329 | 0 | max_span_LUT_norm = std::max(max_span_LUT_norm, span); |
330 | 0 | } |
331 | |
|
332 | 0 | return max_span_LUT_norm / max_span_LUT; |
333 | 0 | } |
334 | | |
335 | | } // namespace quantize_lut |
336 | | |
337 | | } // namespace faiss |