/root/doris/contrib/faiss/faiss/impl/AuxIndexStructures.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // Auxiliary index structures, that are used in indexes but that can |
9 | | // be forward-declared |
10 | | |
11 | | #ifndef FAISS_AUX_INDEX_STRUCTURES_H |
12 | | #define FAISS_AUX_INDEX_STRUCTURES_H |
13 | | |
14 | | #include <stdint.h> |
15 | | |
16 | | #include <cstring> |
17 | | #include <memory> |
18 | | #include <mutex> |
19 | | #include <vector> |
20 | | |
21 | | #include <faiss/MetricType.h> |
22 | | #include <faiss/impl/platform_macros.h> |
23 | | |
24 | | namespace faiss { |
25 | | |
26 | | /** The objective is to have a simple result structure while |
27 | | * minimizing the number of mem copies in the result. The method |
28 | | * do_allocation can be overloaded to allocate the result tables in |
29 | | * the matrix type of a scripting language like Lua or Python. */ |
30 | | struct RangeSearchResult { |
31 | | size_t nq; ///< nb of queries |
32 | | size_t* lims; ///< size (nq + 1) |
33 | | |
34 | | idx_t* labels; ///< result for query i is labels[lims[i]:lims[i+1]] |
35 | | float* distances; ///< corresponding distances (not sorted) |
36 | | |
37 | | size_t buffer_size; ///< size of the result buffers used |
38 | | |
39 | | /// lims must be allocated on input to range_search. |
40 | | explicit RangeSearchResult(size_t nq, bool alloc_lims = true); |
41 | | |
42 | | /// called when lims contains the nb of elements result entries |
43 | | /// for each query |
44 | | virtual void do_allocation(); |
45 | | |
46 | | virtual ~RangeSearchResult(); |
47 | | }; |
48 | | |
49 | | /**************************************************************** |
50 | | * Result structures for range search. |
51 | | * |
52 | | * The main constraint here is that we want to support parallel |
53 | | * queries from different threads in various ways: 1 thread per query, |
54 | | * several threads per query. We store the actual results in blocks of |
55 | | * fixed size rather than exponentially increasing memory. At the end, |
56 | | * we copy the block content to a linear result array. |
57 | | *****************************************************************/ |
58 | | |
59 | | /** List of temporary buffers used to store results before they are |
60 | | * copied to the RangeSearchResult object. */ |
61 | | struct BufferList { |
62 | | // buffer sizes in # entries |
63 | | size_t buffer_size; |
64 | | |
65 | | struct Buffer { |
66 | | idx_t* ids; |
67 | | float* dis; |
68 | | }; |
69 | | |
70 | | std::vector<Buffer> buffers; |
71 | | size_t wp; ///< write pointer in the last buffer. |
72 | | |
73 | | explicit BufferList(size_t buffer_size); |
74 | | |
75 | | ~BufferList(); |
76 | | |
77 | | /// create a new buffer |
78 | | void append_buffer(); |
79 | | |
80 | | /// add one result, possibly appending a new buffer if needed |
81 | | void add(idx_t id, float dis); |
82 | | |
83 | | /// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to |
84 | | /// tables dest_ids, dest_dis |
85 | | void copy_range(size_t ofs, size_t n, idx_t* dest_ids, float* dest_dis); |
86 | | }; |
87 | | |
88 | | struct RangeSearchPartialResult; |
89 | | |
90 | | /// result structure for a single query |
91 | | struct RangeQueryResult { |
92 | | idx_t qno; //< id of the query |
93 | | size_t nres; //< nb of results for this query |
94 | | RangeSearchPartialResult* pres; |
95 | | |
96 | | /// called by search function to report a new result |
97 | | void add(float dis, idx_t id); |
98 | | }; |
99 | | |
100 | | /// the entries in the buffers are split per query |
101 | | struct RangeSearchPartialResult : BufferList { |
102 | | RangeSearchResult* res; |
103 | | |
104 | | /// eventually the result will be stored in res_in |
105 | | explicit RangeSearchPartialResult(RangeSearchResult* res_in); |
106 | | |
107 | | /// query ids + nb of results per query. |
108 | | std::vector<RangeQueryResult> queries; |
109 | | |
110 | | /// begin a new result |
111 | | RangeQueryResult& new_result(idx_t qno); |
112 | | |
113 | | /***************************************** |
114 | | * functions used at the end of the search to merge the result |
115 | | * lists */ |
116 | | void finalize(); |
117 | | |
118 | | /// called by range_search before do_allocation |
119 | | void set_lims(); |
120 | | |
121 | | /// called by range_search after do_allocation |
122 | | void copy_result(bool incremental = false); |
123 | | |
124 | | /// merge a set of PartialResult's into one RangeSearchResult |
125 | | /// on output the partialresults are empty! |
126 | | static void merge( |
127 | | std::vector<RangeSearchPartialResult*>& partial_results, |
128 | | bool do_delete = true); |
129 | | }; |
130 | | |
131 | | /*********************************************************** |
132 | | * Interrupt callback |
133 | | ***********************************************************/ |
134 | | |
135 | | struct FAISS_API InterruptCallback { |
136 | | virtual bool want_interrupt() = 0; |
137 | 0 | virtual ~InterruptCallback() {} |
138 | | |
139 | | // lock that protects concurrent calls to is_interrupted |
140 | | static std::mutex lock; |
141 | | |
142 | | static std::unique_ptr<InterruptCallback> instance; |
143 | | |
144 | | static void clear_instance(); |
145 | | |
146 | | /** check if: |
147 | | * - an interrupt callback is set |
148 | | * - the callback returns true |
149 | | * if this is the case, then throw an exception. Should not be called |
150 | | * from multiple threads. |
151 | | */ |
152 | | static void check(); |
153 | | |
154 | | /// same as check() but return true if is interrupted instead of |
155 | | /// throwing. Can be called from multiple threads. |
156 | | static bool is_interrupted(); |
157 | | |
158 | | /** assuming each iteration takes a certain number of flops, what |
159 | | * is a reasonable interval to check for interrupts? |
160 | | */ |
161 | | static size_t get_period_hint(size_t flops); |
162 | | }; |
163 | | |
164 | | struct TimeoutCallback : InterruptCallback { |
165 | | std::chrono::time_point<std::chrono::steady_clock> start; |
166 | | double timeout; |
167 | | bool want_interrupt() override; |
168 | | void set_timeout(double timeout_in_seconds); |
169 | | static void reset(double timeout_in_seconds); |
170 | | }; |
171 | | |
172 | | /// set implementation optimized for fast access. |
173 | | struct VisitedTable { |
174 | | std::vector<uint8_t> visited; |
175 | | uint8_t visno; |
176 | | |
177 | 16.8k | explicit VisitedTable(int size) : visited(size), visno(1) {} |
178 | | |
179 | | /// set flag #no to true |
180 | 5.96M | void set(int no) { |
181 | 5.96M | visited[no] = visno; |
182 | 5.96M | } |
183 | | |
184 | | /// get flag #no |
185 | 26.1M | bool get(int no) const { |
186 | 26.1M | return visited[no] == visno; |
187 | 26.1M | } |
188 | | |
189 | | /// reset all flags to false |
190 | 25.0k | void advance() { |
191 | 25.0k | visno++; |
192 | 25.0k | if (visno == 250) { |
193 | | // 250 rather than 255 because sometimes we use visno and visno+1 |
194 | 0 | memset(visited.data(), 0, sizeof(visited[0]) * visited.size()); |
195 | 0 | visno = 1; |
196 | 0 | } |
197 | 25.0k | } |
198 | | }; |
199 | | |
200 | | } // namespace faiss |
201 | | |
202 | | #endif |