1/*
2 * LegacyClonk
3 *
4 * Copyright (c) 2023, The LegacyClonk Team and contributors
5 *
6 * Distributed under the terms of the ISC license; see accompanying file
7 * "COPYING" for details.
8 *
9 * "Clonk" is a registered trademark of Matthes Bender, used with permission.
10 * See accompanying file "TRADEMARK" for details.
11 *
12 * To redistribute this file separately, substitute the full license texts
13 * for the above references.
14 */
15
16#include "C4ThreadPool.h"
17
18#ifdef _WIN32
19#include <format>
20#include <limits>
21
22#include <threadpoolapiset.h>
23#endif
24
25#ifdef _WIN32
26
27C4ThreadPool::CallbackEnvironment::CallbackEnvironment() noexcept
28{
29 TpInitializeCallbackEnviron(&callbackEnvironment);
30}
31
32C4ThreadPool::CallbackEnvironment::~CallbackEnvironment() noexcept
33{
34 TpDestroyCallbackEnviron(&callbackEnvironment);
35}
36
37void C4ThreadPool::ThreadPoolTraits::close(const type handle)
38{
39 CloseThreadpool(handle);
40}
41
42void C4ThreadPool::ThreadPoolCleanupTraits::close(const type handle)
43{
44 CloseThreadpoolCleanupGroupMembers(handle, false, nullptr);
45 CloseThreadpoolCleanupGroup(handle);
46}
47
48void C4ThreadPool::ThreadPoolIoTraits::close(const type handle)
49{
50 CloseThreadpoolIo(handle);
51}
52
53C4ThreadPool::Io::Awaiter::Awaiter(Io &io, IoFunction &&function, const std::uint64_t offset) noexcept
54 : io{io}, ioFunction{std::move(function)}
55{
56 overlapped.Offset = static_cast<std::uint32_t>(offset);
57 overlapped.OffsetHigh = static_cast<std::uint32_t>(offset >> 32);
58}
59
60C4ThreadPool::Io::Awaiter::~Awaiter() noexcept
61{
62 if (state.exchange(State::Cancelled, std::memory_order_acq_rel) == State::Started)
63 {
64 io.Cancel();
65 io.PopAwaiter(&overlapped);
66 CancelIoEx(io.fileHandle, &overlapped);
67 }
68}
69
70bool C4ThreadPool::Io::Awaiter::DoSuspend(const std::coroutine_handle<> handle)
71{
72 coroutineHandle.store(handle, std::memory_order_relaxed);
73 io.SetAwaiter(&overlapped, this);
74
75 io.Start();
76 if (ioFunction(io.fileHandle, &overlapped))
77 {
78 state.store(State::Result, std::memory_order_relaxed);
79 io.Cancel();
80 io.PopAwaiter(&overlapped);
81
82 DWORD transferred;
83 winrt::check_bool(GetOverlappedResult(io.fileHandle, &overlapped, &transferred, false));
84
85 result.numberOfBytesTransferred = transferred;
86 return false;
87 }
88 else if (const DWORD error{GetLastError()}; error != ERROR_IO_PENDING)
89 {
90 state.store(State::Error, std::memory_order_relaxed);
91 io.Cancel();
92 io.PopAwaiter(&overlapped);
93
94 result.error = error;
95 return false;
96 }
97
98 State expected{State::NotStarted};
99 if (!state.compare_exchange_strong(expected, State::Started, std::memory_order_release))
100 {
101 // The callback has already been called.
102 return false;
103 }
104
105 return true;
106}
107
108std::uint64_t C4ThreadPool::Io::Awaiter::await_resume() const
109{
110 switch (state.load(std::memory_order_acquire))
111 {
112 case State::Result:
113 return result.numberOfBytesTransferred;
114
115 case State::Error:
116 winrt::throw_hresult(HRESULT_FROM_WIN32(result.error));
117
118 case State::Cancelled:
119 throw C4Task::CancelledException{};
120
121 default:
122 std::terminate();
123 }
124}
125
126void C4ThreadPool::Io::Awaiter::SetupCancellation(C4Task::CancellablePromise *const promise)
127{
128 promise->SetCancellationCallback(
129 [](void *const argument)
130 {
131 auto *const awaiter = reinterpret_cast<Awaiter *>(argument);
132 CancelIoEx(awaiter->io.fileHandle, &awaiter->overlapped);
133 },
134 this
135 );
136}
137
138void C4ThreadPool::Io::Awaiter::Callback(const PTP_CALLBACK_INSTANCE instance, const ULONG result, const ULONG_PTR numberOfBytesTransferred)
139{
140 State expected{State::Started};
141 State desired;
142
143 switch (result)
144 {
145 case ERROR_SUCCESS:
146 this->result.numberOfBytesTransferred = numberOfBytesTransferred;
147 desired = State::Result;
148 break;
149
150 case ERROR_OPERATION_ABORTED:
151 this->result.error = result;
152 desired = State::Cancelled;
153 break;
154
155 default:
156 this->result.error = result;
157 desired = State::Error;
158 break;
159 }
160
161 if (state.compare_exchange_strong(expected, desired, std::memory_order_acq_rel, std::memory_order_acquire))
162 {
163 TrySubmitThreadpoolCallback([](const PTP_CALLBACK_INSTANCE, void *const context)
164 {
165 std::coroutine_handle<>::from_address(context).resume();
166 }, coroutineHandle.load(std::memory_order_relaxed).address(), nullptr);
167 }
168}
169
170C4ThreadPool::Io::Io(const HANDLE fileHandle, const PTP_CALLBACK_ENVIRON environment)
171 : fileHandle{fileHandle}, io{winrt::check_pointer(CreateThreadpoolIo(fileHandle, &Io::Callback, this, environment))}
172{
173}
174
175C4ThreadPool::Io::~Io() noexcept
176{
177 WaitForThreadpoolIoCallbacks(io.get(), true);
178}
179
180void C4ThreadPool::Io::Start()
181{
182 StartThreadpoolIo(io.get());
183}
184
185void C4ThreadPool::Io::Cancel()
186{
187 CancelThreadpoolIo(io.get());
188}
189
190void C4ThreadPool::Io::SetAwaiter(OVERLAPPED *const overlapped, Awaiter *const handle)
191{
192 const std::lock_guard lock{awaiterMapMutex};
193
194 awaiterMap.insert_or_assign(overlapped, handle);
195}
196
197C4ThreadPool::Io::Awaiter *C4ThreadPool::Io::PopAwaiter(OVERLAPPED *const overlapped)
198{
199 const std::lock_guard lock{awaiterMapMutex};
200
201 const auto node = awaiterMap.extract(overlapped);
202 return node.empty() ? nullptr : node.mapped();
203}
204
205void C4ThreadPool::Io::Callback(const PTP_CALLBACK_INSTANCE instance, void *const context, void *const overlapped, const ULONG result, const ULONG_PTR numberOfBytesTransferred, const PTP_IO poolIo)
206{
207 auto &io = *reinterpret_cast<Io *>(context);
208 Awaiter *const awaiter{io.PopAwaiter(reinterpret_cast<OVERLAPPED *>(overlapped))};
209
210 if (awaiter)
211 {
212 awaiter->Callback(instance, result, numberOfBytesTransferred);
213 }
214}
215
216C4ThreadPool::C4ThreadPool(const std::uint32_t minimum, const std::uint32_t maximum)
217{
218 MapHResultError([minimum, maximum, this]
219 {
220 pool.attach(winrt::check_pointer(CreateThreadpool(nullptr)));
221
222 SetThreadpoolCallbackPool(&callbackEnvironment, pool.get());
223
224 SetThreadpoolThreadMaximum(pool.get(), static_cast<DWORD>(maximum));
225 winrt::check_bool(SetThreadpoolThreadMinimum(pool.get(), static_cast<DWORD>(minimum)));
226
227 cleanupGroup.attach(winrt::check_pointer(CreateThreadpoolCleanupGroup()));
228
229 SetThreadpoolCallbackCleanupGroup(&callbackEnvironment, cleanupGroup.get(), nullptr);
230 });
231}
232
233C4ThreadPool::~C4ThreadPool()
234{
235}
236
237void C4ThreadPool::SubmitCallback(const PTP_SIMPLE_CALLBACK callback, void *const data)
238{
239 if (!TrySubmitThreadpoolCallback(callback, data, GetCallbackEnvironment()))
240 {
241 MapHResultError([] { winrt::throw_last_error(); });
242 }
243}
244
245#else
246
247C4ThreadPool::C4ThreadPool(const std::uint32_t minimum, const std::uint32_t maximum)
248{
249 threads.reserve(n: maximum);
250
251 for (std::size_t i{0}; i < maximum; ++i)
252 {
253 threads.emplace_back(args: std::thread{&C4ThreadPool::ThreadProc, this});
254 }
255}
256
257C4ThreadPool::~C4ThreadPool()
258{
259 quit.store(i: true, m: std::memory_order_release);
260 availableCallbacks.release(update: threads.size());
261
262 for (auto &thread : threads)
263 {
264 if (thread.joinable())
265 {
266 thread.join();
267 }
268 }
269}
270
271void C4ThreadPool::ThreadProc()
272{
273 for (;;)
274 {
275 availableCallbacks.acquire();
276 if (quit.load(m: std::memory_order_acquire))
277 {
278 return;
279 }
280
281 Callback callback;
282 {
283 std::lock_guard lock{callbackMutex};
284 callback = std::move(callbacks.front());
285 callbacks.pop();
286 }
287
288 callback();
289 }
290}
291
292#endif
293