contrib/faiss/faiss/utils/WorkerThread.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/impl/FaissAssert.h> |
9 | | #include <faiss/utils/WorkerThread.h> |
10 | | #include <exception> |
11 | | |
12 | | namespace faiss { |
13 | | |
14 | | namespace { |
15 | | |
16 | | // Captures any exceptions thrown by the lambda and returns them via the promise |
17 | 0 | void runCallback(std::function<void()>& fn, std::promise<bool>& promise) { |
18 | 0 | try { |
19 | 0 | fn(); |
20 | 0 | promise.set_value(true); |
21 | 0 | } catch (...) { |
22 | 0 | promise.set_exception(std::current_exception()); |
23 | 0 | } |
24 | 0 | } |
25 | | |
26 | | } // namespace |
27 | | |
28 | 0 | WorkerThread::WorkerThread() : wantStop_(false) { |
29 | 0 | startThread(); |
30 | | |
31 | | // Make sure that the thread has started before continuing |
32 | 0 | add([]() {}).get(); |
33 | 0 | } |
34 | | |
35 | 0 | WorkerThread::~WorkerThread() { |
36 | 0 | stop(); |
37 | 0 | waitForThreadExit(); |
38 | 0 | } |
39 | | |
40 | 0 | void WorkerThread::startThread() { |
41 | 0 | thread_ = std::thread([this]() { threadMain(); }); |
42 | 0 | } |
43 | | |
44 | 0 | void WorkerThread::stop() { |
45 | 0 | std::lock_guard<std::mutex> guard(mutex_); |
46 | |
|
47 | 0 | wantStop_ = true; |
48 | 0 | monitor_.notify_one(); |
49 | 0 | } |
50 | | |
51 | 0 | std::future<bool> WorkerThread::add(std::function<void()> f) { |
52 | 0 | std::lock_guard<std::mutex> guard(mutex_); |
53 | |
|
54 | 0 | if (wantStop_) { |
55 | | // The timer thread has been stopped, or we want to stop; we can't |
56 | | // schedule anything else |
57 | 0 | std::promise<bool> p; |
58 | 0 | auto fut = p.get_future(); |
59 | | |
60 | | // did not execute |
61 | 0 | p.set_value(false); |
62 | 0 | return fut; |
63 | 0 | } |
64 | | |
65 | 0 | auto pr = std::promise<bool>(); |
66 | 0 | auto fut = pr.get_future(); |
67 | |
|
68 | 0 | queue_.emplace_back(std::make_pair(std::move(f), std::move(pr))); |
69 | | |
70 | | // Wake up our thread |
71 | 0 | monitor_.notify_one(); |
72 | 0 | return fut; |
73 | 0 | } |
74 | | |
75 | 0 | void WorkerThread::threadMain() { |
76 | 0 | threadLoop(); |
77 | | |
78 | | // Call all pending tasks |
79 | 0 | FAISS_ASSERT(wantStop_); |
80 | | |
81 | | // flush all pending operations |
82 | 0 | for (auto& f : queue_) { |
83 | 0 | runCallback(f.first, f.second); |
84 | 0 | } |
85 | 0 | } |
86 | | |
87 | 0 | void WorkerThread::threadLoop() { |
88 | 0 | while (true) { |
89 | 0 | std::pair<std::function<void()>, std::promise<bool>> data; |
90 | |
|
91 | 0 | { |
92 | 0 | std::unique_lock<std::mutex> lock(mutex_); |
93 | |
|
94 | 0 | while (!wantStop_ && queue_.empty()) { |
95 | 0 | monitor_.wait(lock); |
96 | 0 | } |
97 | |
|
98 | 0 | if (wantStop_) { |
99 | 0 | return; |
100 | 0 | } |
101 | | |
102 | 0 | data = std::move(queue_.front()); |
103 | 0 | queue_.pop_front(); |
104 | 0 | } |
105 | | |
106 | 0 | runCallback(data.first, data.second); |
107 | 0 | } |
108 | 0 | } |
109 | | |
110 | 0 | void WorkerThread::waitForThreadExit() { |
111 | 0 | try { |
112 | 0 | thread_.join(); |
113 | 0 | } catch (...) { |
114 | 0 | } |
115 | 0 | } |
116 | | |
117 | | } // namespace faiss |