/root/doris/contrib/faiss/faiss/Clustering.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/Clustering.h> |
11 | | #include <faiss/VectorTransform.h> |
12 | | #include <faiss/impl/AuxIndexStructures.h> |
13 | | |
14 | | #include <chrono> |
15 | | #include <cinttypes> |
16 | | #include <cmath> |
17 | | #include <cstdio> |
18 | | #include <cstring> |
19 | | |
20 | | #include <omp.h> |
21 | | |
22 | | #include <faiss/IndexFlat.h> |
23 | | #include <faiss/impl/FaissAssert.h> |
24 | | #include <faiss/impl/kmeans1d.h> |
25 | | #include <faiss/utils/distances.h> |
26 | | #include <faiss/utils/random.h> |
27 | | #include <faiss/utils/utils.h> |
28 | | |
29 | | namespace faiss { |
30 | | |
31 | 0 | Clustering::Clustering(int d, int k) : d(d), k(k) {} |
32 | | |
33 | | Clustering::Clustering(int d, int k, const ClusteringParameters& cp) |
34 | 0 | : ClusteringParameters(cp), d(d), k(k) {} |
35 | | |
36 | 0 | void Clustering::post_process_centroids() { |
37 | 0 | if (spherical) { |
38 | 0 | fvec_renorm_L2(d, k, centroids.data()); |
39 | 0 | } |
40 | |
|
41 | 0 | if (int_centroids) { |
42 | 0 | for (size_t i = 0; i < centroids.size(); i++) |
43 | 0 | centroids[i] = roundf(centroids[i]); |
44 | 0 | } |
45 | 0 | } |
46 | | |
47 | | void Clustering::train( |
48 | | idx_t nx, |
49 | | const float* x_in, |
50 | | Index& index, |
51 | 0 | const float* weights) { |
52 | 0 | train_encoded( |
53 | 0 | nx, |
54 | 0 | reinterpret_cast<const uint8_t*>(x_in), |
55 | 0 | nullptr, |
56 | 0 | index, |
57 | 0 | weights); |
58 | 0 | } |
59 | | |
60 | | namespace { |
61 | | |
62 | 0 | uint64_t get_actual_rng_seed(const int seed) { |
63 | 0 | return (seed >= 0) |
64 | 0 | ? seed |
65 | 0 | : static_cast<uint64_t>(std::chrono::high_resolution_clock::now() |
66 | 0 | .time_since_epoch() |
67 | 0 | .count()); |
68 | 0 | } |
69 | | |
70 | | idx_t subsample_training_set( |
71 | | const Clustering& clus, |
72 | | idx_t nx, |
73 | | const uint8_t* x, |
74 | | size_t line_size, |
75 | | const float* weights, |
76 | | uint8_t** x_out, |
77 | 0 | float** weights_out) { |
78 | 0 | if (clus.verbose) { |
79 | 0 | printf("Sampling a subset of %zd / %" PRId64 " for training\n", |
80 | 0 | clus.k * clus.max_points_per_centroid, |
81 | 0 | nx); |
82 | 0 | } |
83 | |
|
84 | 0 | const uint64_t actual_seed = get_actual_rng_seed(clus.seed); |
85 | |
|
86 | 0 | std::vector<int> perm; |
87 | 0 | if (clus.use_faster_subsampling) { |
88 | | // use subsampling with splitmix64 rng |
89 | 0 | SplitMix64RandomGenerator rng(actual_seed); |
90 | |
|
91 | 0 | const idx_t new_nx = clus.k * clus.max_points_per_centroid; |
92 | 0 | perm.resize(new_nx); |
93 | 0 | for (idx_t i = 0; i < new_nx; i++) { |
94 | 0 | perm[i] = rng.rand_int(nx); |
95 | 0 | } |
96 | 0 | } else { |
97 | | // use subsampling with a default std rng |
98 | 0 | perm.resize(nx); |
99 | 0 | rand_perm(perm.data(), nx, actual_seed); |
100 | 0 | } |
101 | |
|
102 | 0 | nx = clus.k * clus.max_points_per_centroid; |
103 | 0 | uint8_t* x_new = new uint8_t[nx * line_size]; |
104 | 0 | *x_out = x_new; |
105 | | |
106 | | // might be worth omp-ing as well |
107 | 0 | for (idx_t i = 0; i < nx; i++) { |
108 | 0 | memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size); |
109 | 0 | } |
110 | 0 | if (weights) { |
111 | 0 | float* weights_new = new float[nx]; |
112 | 0 | for (idx_t i = 0; i < nx; i++) { |
113 | 0 | weights_new[i] = weights[perm[i]]; |
114 | 0 | } |
115 | 0 | *weights_out = weights_new; |
116 | 0 | } else { |
117 | 0 | *weights_out = nullptr; |
118 | 0 | } |
119 | 0 | return nx; |
120 | 0 | } |
121 | | |
122 | | /** compute centroids as (weighted) sum of training points |
123 | | * |
124 | | * @param x training vectors, size n * code_size (from codec) |
125 | | * @param codec how to decode the vectors (if NULL then cast to float*) |
126 | | * @param weights per-training vector weight, size n (or NULL) |
127 | | * @param assign nearest centroid for each training vector, size n |
128 | | * @param k_frozen do not update the k_frozen first centroids |
129 | | * @param centroids centroid vectors (output only), size k * d |
130 | | * @param hassign histogram of assignments per centroid (size k), |
131 | | * should be 0 on input |
132 | | * |
133 | | */ |
134 | | |
135 | | void compute_centroids( |
136 | | size_t d, |
137 | | size_t k, |
138 | | size_t n, |
139 | | size_t k_frozen, |
140 | | const uint8_t* x, |
141 | | const Index* codec, |
142 | | const int64_t* assign, |
143 | | const float* weights, |
144 | | float* hassign, |
145 | 0 | float* centroids) { |
146 | 0 | k -= k_frozen; |
147 | 0 | centroids += k_frozen * d; |
148 | |
|
149 | 0 | memset(centroids, 0, sizeof(*centroids) * d * k); |
150 | |
|
151 | 0 | size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float); |
152 | |
|
153 | 0 | #pragma omp parallel |
154 | 0 | { |
155 | 0 | int nt = omp_get_num_threads(); |
156 | 0 | int rank = omp_get_thread_num(); |
157 | | |
158 | | // this thread is taking care of centroids c0:c1 |
159 | 0 | size_t c0 = (k * rank) / nt; |
160 | 0 | size_t c1 = (k * (rank + 1)) / nt; |
161 | 0 | std::vector<float> decode_buffer(d); |
162 | |
|
163 | 0 | for (size_t i = 0; i < n; i++) { |
164 | 0 | int64_t ci = assign[i]; |
165 | 0 | assert(ci >= 0 && ci < k + k_frozen); |
166 | 0 | ci -= k_frozen; |
167 | 0 | if (ci >= c0 && ci < c1) { |
168 | 0 | float* c = centroids + ci * d; |
169 | 0 | const float* xi; |
170 | 0 | if (!codec) { |
171 | 0 | xi = reinterpret_cast<const float*>(x + i * line_size); |
172 | 0 | } else { |
173 | 0 | float* xif = decode_buffer.data(); |
174 | 0 | codec->sa_decode(1, x + i * line_size, xif); |
175 | 0 | xi = xif; |
176 | 0 | } |
177 | 0 | if (weights) { |
178 | 0 | float w = weights[i]; |
179 | 0 | hassign[ci] += w; |
180 | 0 | for (size_t j = 0; j < d; j++) { |
181 | 0 | c[j] += xi[j] * w; |
182 | 0 | } |
183 | 0 | } else { |
184 | 0 | hassign[ci] += 1.0; |
185 | 0 | for (size_t j = 0; j < d; j++) { |
186 | 0 | c[j] += xi[j]; |
187 | 0 | } |
188 | 0 | } |
189 | 0 | } |
190 | 0 | } |
191 | 0 | } |
192 | |
|
193 | 0 | #pragma omp parallel for |
194 | 0 | for (idx_t ci = 0; ci < k; ci++) { |
195 | 0 | if (hassign[ci] == 0) { |
196 | 0 | continue; |
197 | 0 | } |
198 | 0 | float norm = 1 / hassign[ci]; |
199 | 0 | float* c = centroids + ci * d; |
200 | 0 | for (size_t j = 0; j < d; j++) { |
201 | 0 | c[j] *= norm; |
202 | 0 | } |
203 | 0 | } |
204 | 0 | } |
205 | | |
206 | | // a bit above machine epsilon for float16 |
207 | 0 | #define EPS (1 / 1024.) |
208 | | |
209 | | /** Handle empty clusters by splitting larger ones. |
210 | | * |
211 | | * It works by slightly changing the centroids to make 2 clusters from |
212 | | * a single one. Takes the same arguments as compute_centroids. |
213 | | * |
214 | | * @return nb of spliting operations (larger is worse) |
215 | | */ |
216 | | int split_clusters( |
217 | | size_t d, |
218 | | size_t k, |
219 | | size_t n, |
220 | | size_t k_frozen, |
221 | | float* hassign, |
222 | 0 | float* centroids) { |
223 | 0 | k -= k_frozen; |
224 | 0 | centroids += k_frozen * d; |
225 | | |
226 | | /* Take care of void clusters */ |
227 | 0 | size_t nsplit = 0; |
228 | 0 | RandomGenerator rng(1234); |
229 | 0 | for (size_t ci = 0; ci < k; ci++) { |
230 | 0 | if (hassign[ci] == 0) { /* need to redefine a centroid */ |
231 | 0 | size_t cj; |
232 | 0 | for (cj = 0; true; cj = (cj + 1) % k) { |
233 | | /* probability to pick this cluster for split */ |
234 | 0 | float p = (hassign[cj] - 1.0) / (float)(n - k); |
235 | 0 | float r = rng.rand_float(); |
236 | 0 | if (r < p) { |
237 | 0 | break; /* found our cluster to be split */ |
238 | 0 | } |
239 | 0 | } |
240 | 0 | memcpy(centroids + ci * d, |
241 | 0 | centroids + cj * d, |
242 | 0 | sizeof(*centroids) * d); |
243 | | |
244 | | /* small symmetric pertubation */ |
245 | 0 | for (size_t j = 0; j < d; j++) { |
246 | 0 | if (j % 2 == 0) { |
247 | 0 | centroids[ci * d + j] *= 1 + EPS; |
248 | 0 | centroids[cj * d + j] *= 1 - EPS; |
249 | 0 | } else { |
250 | 0 | centroids[ci * d + j] *= 1 - EPS; |
251 | 0 | centroids[cj * d + j] *= 1 + EPS; |
252 | 0 | } |
253 | 0 | } |
254 | | |
255 | | /* assume even split of the cluster */ |
256 | 0 | hassign[ci] = hassign[cj] / 2; |
257 | 0 | hassign[cj] -= hassign[ci]; |
258 | 0 | nsplit++; |
259 | 0 | } |
260 | 0 | } |
261 | |
|
262 | 0 | return nsplit; |
263 | 0 | } |
264 | | |
265 | | } // namespace |
266 | | |
267 | | void Clustering::train_encoded( |
268 | | idx_t nx, |
269 | | const uint8_t* x_in, |
270 | | const Index* codec, |
271 | | Index& index, |
272 | 0 | const float* weights) { |
273 | 0 | FAISS_THROW_IF_NOT_FMT( |
274 | 0 | nx >= k, |
275 | 0 | "Number of training points (%" PRId64 |
276 | 0 | ") should be at least " |
277 | 0 | "as large as number of clusters (%zd)", |
278 | 0 | nx, |
279 | 0 | k); |
280 | | |
281 | 0 | FAISS_THROW_IF_NOT_FMT( |
282 | 0 | (!codec || codec->d == d), |
283 | 0 | "Codec dimension %d not the same as data dimension %d", |
284 | 0 | int(codec->d), |
285 | 0 | int(d)); |
286 | | |
287 | 0 | FAISS_THROW_IF_NOT_FMT( |
288 | 0 | index.d == d, |
289 | 0 | "Index dimension %d not the same as data dimension %d", |
290 | 0 | int(index.d), |
291 | 0 | int(d)); |
292 | | |
293 | 0 | double t0 = getmillisecs(); |
294 | |
|
295 | 0 | if (!codec && check_input_data_for_NaNs) { |
296 | | // Check for NaNs in input data. Normally it is the user's |
297 | | // responsibility, but it may spare us some hard-to-debug |
298 | | // reports. |
299 | 0 | const float* x = reinterpret_cast<const float*>(x_in); |
300 | 0 | for (size_t i = 0; i < nx * d; i++) { |
301 | 0 | FAISS_THROW_IF_NOT_MSG( |
302 | 0 | std::isfinite(x[i]), "input contains NaN's or Inf's"); |
303 | 0 | } |
304 | 0 | } |
305 | | |
306 | 0 | const uint8_t* x = x_in; |
307 | 0 | std::unique_ptr<uint8_t[]> del1; |
308 | 0 | std::unique_ptr<float[]> del3; |
309 | 0 | size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d; |
310 | |
|
311 | 0 | if (nx > k * max_points_per_centroid) { |
312 | 0 | uint8_t* x_new; |
313 | 0 | float* weights_new; |
314 | 0 | nx = subsample_training_set( |
315 | 0 | *this, nx, x, line_size, weights, &x_new, &weights_new); |
316 | 0 | del1.reset(x_new); |
317 | 0 | x = x_new; |
318 | 0 | del3.reset(weights_new); |
319 | 0 | weights = weights_new; |
320 | 0 | } else if (nx < k * min_points_per_centroid) { |
321 | 0 | fprintf(stderr, |
322 | 0 | "WARNING clustering %" PRId64 |
323 | 0 | " points to %zd centroids: " |
324 | 0 | "please provide at least %" PRId64 " training points\n", |
325 | 0 | nx, |
326 | 0 | k, |
327 | 0 | idx_t(k) * min_points_per_centroid); |
328 | 0 | } |
329 | |
|
330 | 0 | if (nx == k) { |
331 | | // this is a corner case, just copy training set to clusters |
332 | 0 | if (verbose) { |
333 | 0 | printf("Number of training points (%" PRId64 |
334 | 0 | ") same as number of " |
335 | 0 | "clusters, just copying\n", |
336 | 0 | nx); |
337 | 0 | } |
338 | 0 | centroids.resize(d * k); |
339 | 0 | if (!codec) { |
340 | 0 | memcpy(centroids.data(), x_in, sizeof(float) * d * k); |
341 | 0 | } else { |
342 | 0 | codec->sa_decode(nx, x_in, centroids.data()); |
343 | 0 | } |
344 | | |
345 | | // one fake iteration... |
346 | 0 | ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0}; |
347 | 0 | iteration_stats.push_back(stats); |
348 | |
|
349 | 0 | index.reset(); |
350 | 0 | index.add(k, centroids.data()); |
351 | 0 | return; |
352 | 0 | } |
353 | | |
354 | 0 | if (verbose) { |
355 | 0 | printf("Clustering %" PRId64 |
356 | 0 | " points in %zdD to %zd clusters, " |
357 | 0 | "redo %d times, %d iterations\n", |
358 | 0 | nx, |
359 | 0 | d, |
360 | 0 | k, |
361 | 0 | nredo, |
362 | 0 | niter); |
363 | 0 | if (codec) { |
364 | 0 | printf("Input data encoded in %zd bytes per vector\n", |
365 | 0 | codec->sa_code_size()); |
366 | 0 | } |
367 | 0 | } |
368 | |
|
369 | 0 | std::unique_ptr<idx_t[]> assign(new idx_t[nx]); |
370 | 0 | std::unique_ptr<float[]> dis(new float[nx]); |
371 | | |
372 | | // remember best iteration for redo |
373 | 0 | bool lower_is_better = !is_similarity_metric(index.metric_type); |
374 | 0 | float best_obj = lower_is_better ? HUGE_VALF : -HUGE_VALF; |
375 | 0 | std::vector<ClusteringIterationStats> best_iteration_stats; |
376 | 0 | std::vector<float> best_centroids; |
377 | | |
378 | | // support input centroids |
379 | |
|
380 | 0 | FAISS_THROW_IF_NOT_MSG( |
381 | 0 | centroids.size() % d == 0, |
382 | 0 | "size of provided input centroids not a multiple of dimension"); |
383 | | |
384 | 0 | size_t n_input_centroids = centroids.size() / d; |
385 | |
|
386 | 0 | if (verbose && n_input_centroids > 0) { |
387 | 0 | printf(" Using %zd centroids provided as input (%sfrozen)\n", |
388 | 0 | n_input_centroids, |
389 | 0 | frozen_centroids ? "" : "not "); |
390 | 0 | } |
391 | |
|
392 | 0 | double t_search_tot = 0; |
393 | 0 | if (verbose) { |
394 | 0 | printf(" Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.); |
395 | 0 | } |
396 | 0 | t0 = getmillisecs(); |
397 | | |
398 | | // initialize seed |
399 | 0 | const uint64_t actual_seed = get_actual_rng_seed(seed); |
400 | | |
401 | | // temporary buffer to decode vectors during the optimization |
402 | 0 | std::vector<float> decode_buffer(codec ? d * decode_block_size : 0); |
403 | |
|
404 | 0 | for (int redo = 0; redo < nredo; redo++) { |
405 | 0 | if (verbose && nredo > 1) { |
406 | 0 | printf("Outer iteration %d / %d\n", redo, nredo); |
407 | 0 | } |
408 | | |
409 | | // initialize (remaining) centroids with random points from the dataset |
410 | 0 | centroids.resize(d * k); |
411 | 0 | std::vector<int> perm(nx); |
412 | |
|
413 | 0 | rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L); |
414 | |
|
415 | 0 | if (!codec) { |
416 | 0 | for (int i = n_input_centroids; i < k; i++) { |
417 | 0 | memcpy(¢roids[i * d], x + perm[i] * line_size, line_size); |
418 | 0 | } |
419 | 0 | } else { |
420 | 0 | for (int i = n_input_centroids; i < k; i++) { |
421 | 0 | codec->sa_decode(1, x + perm[i] * line_size, ¢roids[i * d]); |
422 | 0 | } |
423 | 0 | } |
424 | |
|
425 | 0 | post_process_centroids(); |
426 | | |
427 | | // prepare the index |
428 | |
|
429 | 0 | if (index.ntotal != 0) { |
430 | 0 | index.reset(); |
431 | 0 | } |
432 | |
|
433 | 0 | if (!index.is_trained) { |
434 | 0 | index.train(k, centroids.data()); |
435 | 0 | } |
436 | |
|
437 | 0 | index.add(k, centroids.data()); |
438 | | |
439 | | // k-means iterations |
440 | |
|
441 | 0 | float obj = 0; |
442 | 0 | for (int i = 0; i < niter; i++) { |
443 | 0 | double t0s = getmillisecs(); |
444 | |
|
445 | 0 | if (!codec) { |
446 | 0 | index.search( |
447 | 0 | nx, |
448 | 0 | reinterpret_cast<const float*>(x), |
449 | 0 | 1, |
450 | 0 | dis.get(), |
451 | 0 | assign.get()); |
452 | 0 | } else { |
453 | | // search by blocks of decode_block_size vectors |
454 | 0 | size_t code_size = codec->sa_code_size(); |
455 | 0 | for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) { |
456 | 0 | size_t i1 = i0 + decode_block_size; |
457 | 0 | if (i1 > nx) { |
458 | 0 | i1 = nx; |
459 | 0 | } |
460 | 0 | codec->sa_decode( |
461 | 0 | i1 - i0, x + code_size * i0, decode_buffer.data()); |
462 | 0 | index.search( |
463 | 0 | i1 - i0, |
464 | 0 | decode_buffer.data(), |
465 | 0 | 1, |
466 | 0 | dis.get() + i0, |
467 | 0 | assign.get() + i0); |
468 | 0 | } |
469 | 0 | } |
470 | |
|
471 | 0 | InterruptCallback::check(); |
472 | 0 | t_search_tot += getmillisecs() - t0s; |
473 | | |
474 | | // accumulate objective |
475 | 0 | obj = 0; |
476 | 0 | for (int j = 0; j < nx; j++) { |
477 | 0 | obj += dis[j]; |
478 | 0 | } |
479 | | |
480 | | // update the centroids |
481 | 0 | std::vector<float> hassign(k); |
482 | |
|
483 | 0 | size_t k_frozen = frozen_centroids ? n_input_centroids : 0; |
484 | 0 | compute_centroids( |
485 | 0 | d, |
486 | 0 | k, |
487 | 0 | nx, |
488 | 0 | k_frozen, |
489 | 0 | x, |
490 | 0 | codec, |
491 | 0 | assign.get(), |
492 | 0 | weights, |
493 | 0 | hassign.data(), |
494 | 0 | centroids.data()); |
495 | |
|
496 | 0 | int nsplit = split_clusters( |
497 | 0 | d, k, nx, k_frozen, hassign.data(), centroids.data()); |
498 | | |
499 | | // collect statistics |
500 | 0 | ClusteringIterationStats stats = { |
501 | 0 | obj, |
502 | 0 | (getmillisecs() - t0) / 1000.0, |
503 | 0 | t_search_tot / 1000, |
504 | 0 | imbalance_factor(nx, k, assign.get()), |
505 | 0 | nsplit}; |
506 | 0 | iteration_stats.push_back(stats); |
507 | |
|
508 | 0 | if (verbose) { |
509 | 0 | printf(" Iteration %d (%.2f s, search %.2f s): " |
510 | 0 | "objective=%g imbalance=%.3f nsplit=%d \r", |
511 | 0 | i, |
512 | 0 | stats.time, |
513 | 0 | stats.time_search, |
514 | 0 | stats.obj, |
515 | 0 | stats.imbalance_factor, |
516 | 0 | nsplit); |
517 | 0 | fflush(stdout); |
518 | 0 | } |
519 | |
|
520 | 0 | post_process_centroids(); |
521 | | |
522 | | // add centroids to index for the next iteration (or for output) |
523 | |
|
524 | 0 | index.reset(); |
525 | 0 | if (update_index) { |
526 | 0 | index.train(k, centroids.data()); |
527 | 0 | } |
528 | |
|
529 | 0 | index.add(k, centroids.data()); |
530 | 0 | InterruptCallback::check(); |
531 | 0 | } |
532 | |
|
533 | 0 | if (verbose) |
534 | 0 | printf("\n"); |
535 | 0 | if (nredo > 1) { |
536 | 0 | if ((lower_is_better && obj < best_obj) || |
537 | 0 | (!lower_is_better && obj > best_obj)) { |
538 | 0 | if (verbose) { |
539 | 0 | printf("Objective improved: keep new clusters\n"); |
540 | 0 | } |
541 | 0 | best_centroids = centroids; |
542 | 0 | best_iteration_stats = iteration_stats; |
543 | 0 | best_obj = obj; |
544 | 0 | } |
545 | 0 | index.reset(); |
546 | 0 | } |
547 | 0 | } |
548 | 0 | if (nredo > 1) { |
549 | 0 | centroids = best_centroids; |
550 | 0 | iteration_stats = best_iteration_stats; |
551 | 0 | index.reset(); |
552 | 0 | index.add(k, best_centroids.data()); |
553 | 0 | } |
554 | 0 | } |
555 | | |
556 | 0 | Clustering1D::Clustering1D(int k) : Clustering(1, k) {} |
557 | | |
558 | | Clustering1D::Clustering1D(int k, const ClusteringParameters& cp) |
559 | 0 | : Clustering(1, k, cp) {} |
560 | | |
561 | 0 | void Clustering1D::train_exact(idx_t n, const float* x) { |
562 | 0 | const float* xt = x; |
563 | |
|
564 | 0 | std::unique_ptr<uint8_t[]> del; |
565 | 0 | if (n > k * max_points_per_centroid) { |
566 | 0 | uint8_t* x_new; |
567 | 0 | float* weights_new; |
568 | 0 | n = subsample_training_set( |
569 | 0 | *this, |
570 | 0 | n, |
571 | 0 | (uint8_t*)x, |
572 | 0 | sizeof(float) * d, |
573 | 0 | nullptr, |
574 | 0 | &x_new, |
575 | 0 | &weights_new); |
576 | 0 | del.reset(x_new); |
577 | 0 | xt = (float*)x_new; |
578 | 0 | } |
579 | |
|
580 | 0 | centroids.resize(k); |
581 | 0 | double uf = kmeans1d(xt, n, k, centroids.data()); |
582 | |
|
583 | 0 | ClusteringIterationStats stats = {0.0, 0.0, 0.0, uf, 0}; |
584 | 0 | iteration_stats.push_back(stats); |
585 | 0 | } |
586 | | |
587 | | float kmeans_clustering( |
588 | | size_t d, |
589 | | size_t n, |
590 | | size_t k, |
591 | | const float* x, |
592 | 0 | float* centroids) { |
593 | 0 | Clustering clus(d, k); |
594 | 0 | clus.verbose = d * n * k > (size_t(1) << 30); |
595 | | // display logs if > 1Gflop per iteration |
596 | 0 | IndexFlatL2 index(d); |
597 | 0 | clus.train(n, x, index); |
598 | 0 | memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k); |
599 | 0 | return clus.iteration_stats.back().obj; |
600 | 0 | } |
601 | | |
602 | | /****************************************************************************** |
603 | | * ProgressiveDimClustering implementation |
604 | | ******************************************************************************/ |
605 | | |
606 | 0 | ProgressiveDimClusteringParameters::ProgressiveDimClusteringParameters() { |
607 | 0 | progressive_dim_steps = 10; |
608 | 0 | apply_pca = true; // seems a good idea to do this by default |
609 | 0 | niter = 10; // reduce nb of iterations per step |
610 | 0 | } |
611 | | |
612 | 0 | Index* ProgressiveDimIndexFactory::operator()(int dim) { |
613 | 0 | return new IndexFlatL2(dim); |
614 | 0 | } |
615 | | |
616 | 0 | ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {} |
617 | | |
618 | | ProgressiveDimClustering::ProgressiveDimClustering( |
619 | | int d, |
620 | | int k, |
621 | | const ProgressiveDimClusteringParameters& cp) |
622 | 0 | : ProgressiveDimClusteringParameters(cp), d(d), k(k) {} |
623 | | |
624 | | namespace { |
625 | | |
626 | 0 | void copy_columns(idx_t n, idx_t d1, const float* src, idx_t d2, float* dest) { |
627 | 0 | idx_t d = std::min(d1, d2); |
628 | 0 | for (idx_t i = 0; i < n; i++) { |
629 | 0 | memcpy(dest, src, sizeof(float) * d); |
630 | 0 | src += d1; |
631 | 0 | dest += d2; |
632 | 0 | } |
633 | 0 | } |
634 | | |
635 | | } // namespace |
636 | | |
637 | | void ProgressiveDimClustering::train( |
638 | | idx_t n, |
639 | | const float* x, |
640 | 0 | ProgressiveDimIndexFactory& factory) { |
641 | 0 | int d_prev = 0; |
642 | |
|
643 | 0 | PCAMatrix pca(d, d); |
644 | |
|
645 | 0 | std::vector<float> xbuf; |
646 | 0 | if (apply_pca) { |
647 | 0 | if (verbose) { |
648 | 0 | printf("Training PCA transform\n"); |
649 | 0 | } |
650 | 0 | pca.train(n, x); |
651 | 0 | if (verbose) { |
652 | 0 | printf("Apply PCA\n"); |
653 | 0 | } |
654 | 0 | xbuf.resize(n * d); |
655 | 0 | pca.apply_noalloc(n, x, xbuf.data()); |
656 | 0 | x = xbuf.data(); |
657 | 0 | } |
658 | |
|
659 | 0 | for (int iter = 0; iter < progressive_dim_steps; iter++) { |
660 | 0 | int di = int(pow(d, (1. + iter) / progressive_dim_steps)); |
661 | 0 | if (verbose) { |
662 | 0 | printf("Progressive dim step %d: cluster in dimension %d\n", |
663 | 0 | iter, |
664 | 0 | di); |
665 | 0 | } |
666 | 0 | std::unique_ptr<Index> clustering_index(factory(di)); |
667 | |
|
668 | 0 | Clustering clus(di, k, *this); |
669 | 0 | if (d_prev > 0) { |
670 | | // copy warm-start centroids (padded with 0s) |
671 | 0 | clus.centroids.resize(k * di); |
672 | 0 | copy_columns( |
673 | 0 | k, d_prev, centroids.data(), di, clus.centroids.data()); |
674 | 0 | } |
675 | 0 | std::vector<float> xsub(n * di); |
676 | 0 | copy_columns(n, d, x, di, xsub.data()); |
677 | |
|
678 | 0 | clus.train(n, xsub.data(), *clustering_index.get()); |
679 | |
|
680 | 0 | centroids = clus.centroids; |
681 | 0 | iteration_stats.insert( |
682 | 0 | iteration_stats.end(), |
683 | 0 | clus.iteration_stats.begin(), |
684 | 0 | clus.iteration_stats.end()); |
685 | |
|
686 | 0 | d_prev = di; |
687 | 0 | } |
688 | |
|
689 | 0 | if (apply_pca) { |
690 | 0 | if (verbose) { |
691 | 0 | printf("Revert PCA transform on centroids\n"); |
692 | 0 | } |
693 | 0 | std::vector<float> cent_transformed(d * k); |
694 | 0 | pca.reverse_transform(k, centroids.data(), cent_transformed.data()); |
695 | 0 | cent_transformed.swap(centroids); |
696 | 0 | } |
697 | 0 | } |
698 | | |
699 | | } // namespace faiss |