/root/doris/contrib/faiss/faiss/IndexBinaryFromFloat.cpp
| Line | Count | Source | 
| 1 |  | /* | 
| 2 |  |  * Copyright (c) Meta Platforms, Inc. and affiliates. | 
| 3 |  |  * | 
| 4 |  |  * This source code is licensed under the MIT license found in the | 
| 5 |  |  * LICENSE file in the root directory of this source tree. | 
| 6 |  |  */ | 
| 7 |  |  | 
| 8 |  | // -*- c++ -*- | 
| 9 |  |  | 
| 10 |  | #include <faiss/IndexBinaryFromFloat.h> | 
| 11 |  |  | 
| 12 |  | #include <faiss/impl/FaissAssert.h> | 
| 13 |  | #include <faiss/utils/utils.h> | 
| 14 |  | #include <algorithm> | 
| 15 |  | #include <memory> | 
| 16 |  |  | 
| 17 |  | namespace faiss { | 
| 18 |  |  | 
| 19 | 0 | IndexBinaryFromFloat::IndexBinaryFromFloat() = default; | 
| 20 |  |  | 
| 21 |  | IndexBinaryFromFloat::IndexBinaryFromFloat(Index* index) | 
| 22 | 0 |         : IndexBinary(index->d), index(index), own_fields(false) { | 
| 23 | 0 |     is_trained = index->is_trained; | 
| 24 | 0 |     ntotal = index->ntotal; | 
| 25 | 0 | } | 
| 26 |  |  | 
| 27 | 0 | IndexBinaryFromFloat::~IndexBinaryFromFloat() { | 
| 28 | 0 |     if (own_fields) { | 
| 29 | 0 |         delete index; | 
| 30 | 0 |     } | 
| 31 | 0 | } | 
| 32 |  |  | 
| 33 | 0 | void IndexBinaryFromFloat::add(idx_t n, const uint8_t* x) { | 
| 34 | 0 |     constexpr idx_t bs = 32768; | 
| 35 | 0 |     std::unique_ptr<float[]> xf(new float[bs * d]); | 
| 36 |  | 
 | 
| 37 | 0 |     for (idx_t b = 0; b < n; b += bs) { | 
| 38 | 0 |         idx_t bn = std::min(bs, n - b); | 
| 39 | 0 |         binary_to_real(bn * d, x + b * code_size, xf.get()); | 
| 40 |  | 
 | 
| 41 | 0 |         index->add(bn, xf.get()); | 
| 42 | 0 |     } | 
| 43 | 0 |     ntotal = index->ntotal; | 
| 44 | 0 | } | 
| 45 |  |  | 
| 46 | 0 | void IndexBinaryFromFloat::reset() { | 
| 47 | 0 |     index->reset(); | 
| 48 | 0 |     ntotal = index->ntotal; | 
| 49 | 0 | } | 
| 50 |  |  | 
| 51 |  | void IndexBinaryFromFloat::search( | 
| 52 |  |         idx_t n, | 
| 53 |  |         const uint8_t* x, | 
| 54 |  |         idx_t k, | 
| 55 |  |         int32_t* distances, | 
| 56 |  |         idx_t* labels, | 
| 57 | 0 |         const SearchParameters* params) const { | 
| 58 | 0 |     FAISS_THROW_IF_NOT_MSG( | 
| 59 | 0 |             !params, "search params not supported for this index"); | 
| 60 | 0 |     FAISS_THROW_IF_NOT(k > 0); | 
| 61 |  |  | 
| 62 | 0 |     constexpr idx_t bs = 32768; | 
| 63 | 0 |     std::unique_ptr<float[]> xf(new float[bs * d]); | 
| 64 | 0 |     std::unique_ptr<float[]> df(new float[bs * k]); | 
| 65 |  | 
 | 
| 66 | 0 |     for (idx_t b = 0; b < n; b += bs) { | 
| 67 | 0 |         idx_t bn = std::min(bs, n - b); | 
| 68 | 0 |         binary_to_real(bn * d, x + b * code_size, xf.get()); | 
| 69 |  | 
 | 
| 70 | 0 |         index->search(bn, xf.get(), k, df.get(), labels + b * k); | 
| 71 | 0 |         for (int i = 0; i < bn * k; ++i) { | 
| 72 | 0 |             distances[b * k + i] = int32_t(std::round(df[i] / 4.0)); | 
| 73 | 0 |         } | 
| 74 | 0 |     } | 
| 75 | 0 | } | 
| 76 |  |  | 
| 77 | 0 | void IndexBinaryFromFloat::train(idx_t n, const uint8_t* x) { | 
| 78 | 0 |     std::unique_ptr<float[]> xf(new float[n * d]); | 
| 79 | 0 |     binary_to_real(n * d, x, xf.get()); | 
| 80 |  | 
 | 
| 81 | 0 |     index->train(n, xf.get()); | 
| 82 | 0 |     is_trained = true; | 
| 83 | 0 |     ntotal = index->ntotal; | 
| 84 | 0 | } | 
| 85 |  |  | 
| 86 |  | } // namespace faiss |