/root/doris/contrib/faiss/faiss/impl/DistanceComputer.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #pragma once |
9 | | |
10 | | #include <faiss/Index.h> |
11 | | |
12 | | namespace faiss { |
13 | | |
14 | | /*********************************************************** |
15 | | * The distance computer maintains a current query and computes |
16 | | * distances to elements in an index that supports random access. |
17 | | * |
18 | | * The DistanceComputer is not intended to be thread-safe (eg. because |
19 | | * it maintains counters) so the distance functions are not const, |
20 | | * instantiate one from each thread if needed. |
21 | | * |
22 | | * Note that the equivalent for IVF indexes is the InvertedListScanner, |
23 | | * that has additional methods to handle the inverted list context. |
24 | | ***********************************************************/ |
25 | | struct DistanceComputer { |
26 | | /// called before computing distances. Pointer x should remain valid |
27 | | /// while operator () is called |
28 | | virtual void set_query(const float* x) = 0; |
29 | | |
30 | | /// compute distance of vector i to current query |
31 | | virtual float operator()(idx_t i) = 0; |
32 | | |
33 | | /// compute distances of current query to 4 stored vectors. |
34 | | /// certain DistanceComputer implementations may benefit |
35 | | /// heavily from this. |
36 | | virtual void distances_batch_4( |
37 | | const idx_t idx0, |
38 | | const idx_t idx1, |
39 | | const idx_t idx2, |
40 | | const idx_t idx3, |
41 | | float& dis0, |
42 | | float& dis1, |
43 | | float& dis2, |
44 | 0 | float& dis3) { |
45 | | // compute first, assign next |
46 | 0 | const float d0 = this->operator()(idx0); |
47 | 0 | const float d1 = this->operator()(idx1); |
48 | 0 | const float d2 = this->operator()(idx2); |
49 | 0 | const float d3 = this->operator()(idx3); |
50 | 0 | dis0 = d0; |
51 | 0 | dis1 = d1; |
52 | 0 | dis2 = d2; |
53 | 0 | dis3 = d3; |
54 | 0 | } |
55 | | |
56 | | /// compute distance between two stored vectors |
57 | | virtual float symmetric_dis(idx_t i, idx_t j) = 0; |
58 | | |
59 | 21.8k | virtual ~DistanceComputer() {} |
60 | | }; |
61 | | |
62 | | /* Wrap the distance computer into one that negates the |
63 | | distances. This makes supporting INNER_PRODUCE search easier */ |
64 | | |
65 | | struct NegativeDistanceComputer : DistanceComputer { |
66 | | /// owned by this |
67 | | DistanceComputer* basedis; |
68 | | |
69 | | explicit NegativeDistanceComputer(DistanceComputer* basedis) |
70 | 1.23k | : basedis(basedis) {} |
71 | | |
72 | 6.47k | void set_query(const float* x) override { |
73 | 6.47k | basedis->set_query(x); |
74 | 6.47k | } |
75 | | |
76 | | /// compute distance of vector i to current query |
77 | 202k | float operator()(idx_t i) override { |
78 | 202k | return -(*basedis)(i); |
79 | 202k | } |
80 | | |
81 | | void distances_batch_4( |
82 | | const idx_t idx0, |
83 | | const idx_t idx1, |
84 | | const idx_t idx2, |
85 | | const idx_t idx3, |
86 | | float& dis0, |
87 | | float& dis1, |
88 | | float& dis2, |
89 | 270k | float& dis3) override { |
90 | 270k | basedis->distances_batch_4( |
91 | 270k | idx0, idx1, idx2, idx3, dis0, dis1, dis2, dis3); |
92 | 270k | dis0 = -dis0; |
93 | 270k | dis1 = -dis1; |
94 | 270k | dis2 = -dis2; |
95 | 270k | dis3 = -dis3; |
96 | 270k | } |
97 | | |
98 | | /// compute distance between two stored vectors |
99 | 1.14M | float symmetric_dis(idx_t i, idx_t j) override { |
100 | 1.14M | return -basedis->symmetric_dis(i, j); |
101 | 1.14M | } |
102 | | |
103 | 1.23k | virtual ~NegativeDistanceComputer() { |
104 | 1.23k | delete basedis; |
105 | 1.23k | } |
106 | | }; |
107 | | |
108 | | /************************************************************* |
109 | | * Specialized version of the DistanceComputer when we know that codes are |
110 | | * laid out in a flat index. |
111 | | */ |
112 | | struct FlatCodesDistanceComputer : DistanceComputer { |
113 | | const uint8_t* codes; |
114 | | size_t code_size; |
115 | | |
116 | | FlatCodesDistanceComputer(const uint8_t* codes, size_t code_size) |
117 | 20.6k | : codes(codes), code_size(code_size) {} |
118 | | |
119 | 0 | FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {} |
120 | | |
121 | 1.21M | float operator()(idx_t i) override { |
122 | 1.21M | return distance_to_code(codes + i * code_size); |
123 | 1.21M | } |
124 | | |
125 | | /// compute distance of current query to an encoded vector |
126 | | virtual float distance_to_code(const uint8_t* code) = 0; |
127 | | |
128 | 20.6k | virtual ~FlatCodesDistanceComputer() {} |
129 | | }; |
130 | | |
131 | | } // namespace faiss |