/root/doris/contrib/faiss/faiss/IndexBinaryFromFloat.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/IndexBinaryFromFloat.h> |
11 | | |
12 | | #include <faiss/impl/FaissAssert.h> |
13 | | #include <faiss/utils/utils.h> |
14 | | #include <algorithm> |
15 | | #include <memory> |
16 | | |
17 | | namespace faiss { |
18 | | |
19 | 0 | IndexBinaryFromFloat::IndexBinaryFromFloat() = default; |
20 | | |
21 | | IndexBinaryFromFloat::IndexBinaryFromFloat(Index* index) |
22 | 0 | : IndexBinary(index->d), index(index), own_fields(false) { |
23 | 0 | is_trained = index->is_trained; |
24 | 0 | ntotal = index->ntotal; |
25 | 0 | } |
26 | | |
27 | 0 | IndexBinaryFromFloat::~IndexBinaryFromFloat() { |
28 | 0 | if (own_fields) { |
29 | 0 | delete index; |
30 | 0 | } |
31 | 0 | } |
32 | | |
33 | 0 | void IndexBinaryFromFloat::add(idx_t n, const uint8_t* x) { |
34 | 0 | constexpr idx_t bs = 32768; |
35 | 0 | std::unique_ptr<float[]> xf(new float[bs * d]); |
36 | |
|
37 | 0 | for (idx_t b = 0; b < n; b += bs) { |
38 | 0 | idx_t bn = std::min(bs, n - b); |
39 | 0 | binary_to_real(bn * d, x + b * code_size, xf.get()); |
40 | |
|
41 | 0 | index->add(bn, xf.get()); |
42 | 0 | } |
43 | 0 | ntotal = index->ntotal; |
44 | 0 | } |
45 | | |
46 | 0 | void IndexBinaryFromFloat::reset() { |
47 | 0 | index->reset(); |
48 | 0 | ntotal = index->ntotal; |
49 | 0 | } |
50 | | |
51 | | void IndexBinaryFromFloat::search( |
52 | | idx_t n, |
53 | | const uint8_t* x, |
54 | | idx_t k, |
55 | | int32_t* distances, |
56 | | idx_t* labels, |
57 | 0 | const SearchParameters* params) const { |
58 | 0 | FAISS_THROW_IF_NOT_MSG( |
59 | 0 | !params, "search params not supported for this index"); |
60 | 0 | FAISS_THROW_IF_NOT(k > 0); |
61 | | |
62 | 0 | constexpr idx_t bs = 32768; |
63 | 0 | std::unique_ptr<float[]> xf(new float[bs * d]); |
64 | 0 | std::unique_ptr<float[]> df(new float[bs * k]); |
65 | |
|
66 | 0 | for (idx_t b = 0; b < n; b += bs) { |
67 | 0 | idx_t bn = std::min(bs, n - b); |
68 | 0 | binary_to_real(bn * d, x + b * code_size, xf.get()); |
69 | |
|
70 | 0 | index->search(bn, xf.get(), k, df.get(), labels + b * k); |
71 | 0 | for (int i = 0; i < bn * k; ++i) { |
72 | 0 | distances[b * k + i] = int32_t(std::round(df[i] / 4.0)); |
73 | 0 | } |
74 | 0 | } |
75 | 0 | } |
76 | | |
77 | 0 | void IndexBinaryFromFloat::train(idx_t n, const uint8_t* x) { |
78 | 0 | std::unique_ptr<float[]> xf(new float[n * d]); |
79 | 0 | binary_to_real(n * d, x, xf.get()); |
80 | |
|
81 | 0 | index->train(n, xf.get()); |
82 | 0 | is_trained = true; |
83 | 0 | ntotal = index->ntotal; |
84 | 0 | } |
85 | | |
86 | | } // namespace faiss |