1/*
2 * LegacyClonk
3 *
4 * Copyright (c) 2023, The LegacyClonk Team and contributors
5 *
6 * Distributed under the terms of the ISC license; see accompanying file
7 * "COPYING" for details.
8 *
9 * "Clonk" is a registered trademark of Matthes Bender, used with permission.
10 * See accompanying file "TRADEMARK" for details.
11 *
12 * To redistribute this file separately, substitute the full license texts
13 * for the above references.
14 */
15
16#pragma once
17
18#ifdef _WIN32
19#include "C4Coroutine.h"
20#include "C4WinRT.h"
21#endif
22
23#include <bit>
24#include <coroutine>
25#include <cstdint>
26#include <functional>
27#include <memory>
28#include <mutex>
29#include <utility>
30#include <unordered_map>
31
32#ifndef _WIN32
33#include <mutex>
34#include <queue>
35#include <semaphore>
36#include <thread>
37#include <vector>
38#endif
39
40class C4ThreadPool
41{
42#ifdef _WIN32
43private:
44 class CallbackEnvironment
45 {
46 public:
47 CallbackEnvironment() noexcept;
48 ~CallbackEnvironment() noexcept;
49
50 CallbackEnvironment(const CallbackEnvironment &) = delete;
51 CallbackEnvironment &operator=(const CallbackEnvironment &) = delete;
52
53 CallbackEnvironment(CallbackEnvironment &&) = delete;
54 CallbackEnvironment &operator=(CallbackEnvironment &&) = delete;
55
56 public:
57 bool HasPool() const noexcept { return callbackEnvironment.Pool; }
58
59 operator TP_CALLBACK_ENVIRON &() noexcept { return callbackEnvironment; }
60 operator const TP_CALLBACK_ENVIRON &() const noexcept { return callbackEnvironment; }
61
62 TP_CALLBACK_ENVIRON *operator&() noexcept { return &callbackEnvironment; }
63 const TP_CALLBACK_ENVIRON *operator&() const noexcept { return &callbackEnvironment; }
64
65 private:
66 TP_CALLBACK_ENVIRON callbackEnvironment;
67 };
68
69 template<typename T>
70 struct ThreadPoolTraitsBase
71 {
72 using type = T;
73 static constexpr type invalid()
74 {
75 return nullptr;
76 }
77 };
78
79 struct ThreadPoolTraits : ThreadPoolTraitsBase<PTP_POOL>
80 {
81 static void close(const type handle);
82 };
83
84 struct ThreadPoolCleanupTraits : ThreadPoolTraitsBase<PTP_CLEANUP_GROUP>
85 {
86 static void close(const type handle);
87 };
88
89 struct ThreadPoolIoTraits : ThreadPoolTraitsBase<PTP_IO>
90 {
91 static void close(const type handle);
92 };
93
94public:
95 class Io
96 {
97 public:
98 class Awaiter : public C4Task::CancellableAwaiter<Awaiter>
99 {
100 public:
101 using IoFunction = std::function<bool(HANDLE, OVERLAPPED *)>;
102
103 enum class State
104 {
105 NotStarted,
106 Started,
107 Result,
108 Error,
109 Cancelled
110 };
111
112 public:
113 Awaiter(Io &io, IoFunction &&function, std::uint64_t offset) noexcept;
114 ~Awaiter() noexcept;
115
116 public:
117 constexpr bool await_ready() const noexcept { return false; }
118
119 template<typename T>
120 bool await_suspend(const std::coroutine_handle<T> handle)
121 {
122 SetCancellablePromise(handle);
123 return DoSuspend(handle);
124 }
125
126 std::uint64_t await_resume() const;
127 void SetupCancellation(C4Task::CancellablePromise *promise);
128
129 void Callback(PTP_CALLBACK_INSTANCE instance, ULONG result, ULONG_PTR numberOfBytesTransferred);
130
131 private:
132 bool DoSuspend(std::coroutine_handle<> handle);
133
134 private:
135 Io &io;
136 IoFunction ioFunction;
137 OVERLAPPED overlapped{};
138 std::atomic<State> state{State::NotStarted};
139
140 union
141 {
142 std::uint64_t numberOfBytesTransferred;
143 DWORD error;
144 } result;
145
146 std::atomic<std::coroutine_handle<>> coroutineHandle;
147 };
148
149 public:
150 Io() noexcept = default;
151 Io(HANDLE fileHandle, PTP_CALLBACK_ENVIRON environment = nullptr);
152 ~Io() noexcept;
153
154 Io(const Io &) = delete;
155 Io &operator=(const Io &) = delete;
156
157 Io(Io &&other) : Io{}
158 {
159 swap(*this, other);
160 }
161
162 Io &operator=(Io &&other)
163 {
164 Io temp{std::move(other)};
165 swap(*this, temp);
166 return *this;
167 }
168
169 public:
170 Awaiter ExecuteAsync(Awaiter::IoFunction &&ioFunction, const std::uint64_t offset = 0)
171 {
172 return {*this, std::move(ioFunction), offset};
173 }
174
175 explicit operator bool() const noexcept
176 {
177 return io.get();
178 }
179
180 private:
181 void Start();
182 void Cancel();
183
184 void SetAwaiter(OVERLAPPED *overlapped, Awaiter *handle);
185 Awaiter *PopAwaiter(OVERLAPPED *overlapped);
186
187 static void CALLBACK Callback(PTP_CALLBACK_INSTANCE instance, void *context, void *overlapped, ULONG result, ULONG_PTR numberOfBytesTransferred, PTP_IO poolIo);
188
189 private:
190 HANDLE fileHandle{nullptr};
191 winrt::handle_type<ThreadPoolIoTraits> io;
192 std::unordered_map<OVERLAPPED *, Awaiter *> awaiterMap;
193 std::mutex awaiterMapMutex;
194
195 friend void swap(Io &first, Io &second)
196 {
197 using std::swap;
198
199 swap(first.fileHandle, second.fileHandle);
200 swap(first.io, second.io);
201
202 {
203 const std::scoped_lock lock{first.awaiterMapMutex, second.awaiterMapMutex};
204 swap(first.awaiterMap, second.awaiterMap);
205 }
206 }
207 };
208#else
209private:
210 using Callback = std::function<void()>;
211#endif
212
213public:
214 C4ThreadPool() = default;
215 C4ThreadPool(std::uint32_t minimum, std::uint32_t maximum);
216 ~C4ThreadPool();
217
218 C4ThreadPool(const C4ThreadPool &) = delete;
219 C4ThreadPool &operator=(const C4ThreadPool &) = delete;
220
221 C4ThreadPool(C4ThreadPool &&) = delete;
222 C4ThreadPool &operator=(C4ThreadPool &&) = delete;
223
224public:
225#ifdef _WIN32
226 template<typename T>
227 void SubmitCallback(T &&callback)
228 {
229 if constexpr (sizeof(void *) == sizeof(void(*)()) && std::is_convertible_v<T, void(*)()>)
230 {
231 SubmitCallback([](const PTP_CALLBACK_INSTANCE, void *const data)
232 {
233 std::bit_cast<void(*)()>(data)();
234 }, std::bit_cast<void *>(static_cast<void(*)()>(callback)));
235 }
236 else
237 {
238 using Type = std::remove_reference_t<T>;
239 auto callbackPtr = std::make_unique<Type>(std::move(callback));
240 SubmitCallback([](const PTP_CALLBACK_INSTANCE, void *const data)
241 {
242 const std::unique_ptr<Type> callback{reinterpret_cast<Type *>(data)};
243 (*callback)();
244 }, const_cast<void *>(static_cast<const void *>(callbackPtr.get())));
245
246 callbackPtr.release();
247 }
248 }
249
250 void SubmitCallback(const std::coroutine_handle<> handle)
251 {
252 SubmitCallback([](const PTP_CALLBACK_INSTANCE, void *const data)
253 {
254 std::coroutine_handle<>::from_address(data).resume();
255 }, handle.address());
256 }
257
258 void SubmitCallback(PTP_SIMPLE_CALLBACK callback, void *data);
259
260 template<typename Func>
261 decltype(auto) NativeThreadPoolOperation(Func &&operation)
262 {
263 return operation(pool.get(), GetCallbackEnvironment());
264 }
265#else
266 template<typename T>
267 void SubmitCallback(T &&callback)
268 {
269 {
270 const std::lock_guard lock{callbackMutex};
271 callbacks.push(std::move(callback));
272 }
273
274 availableCallbacks.release();
275 }
276#endif
277
278 auto operator co_await() & noexcept
279 {
280 struct Awaiter
281 {
282 C4ThreadPool &ThreadPool;
283
284 constexpr bool await_ready() const noexcept
285 {
286 return false;
287 }
288
289 void await_suspend(const std::coroutine_handle<> handle) const noexcept
290 {
291 ThreadPool.SubmitCallback(callback: handle);
292 }
293
294 constexpr void await_resume() const noexcept
295 {
296 }
297 };
298
299 return Awaiter{.ThreadPool: *this};
300 }
301
302private:
303#ifdef _WIN32
304 PTP_CALLBACK_ENVIRON GetCallbackEnvironment() noexcept
305 {
306 return callbackEnvironment.HasPool() ? &callbackEnvironment : nullptr;
307 }
308#else
309 void ThreadProc();
310#endif
311
312public:
313 static inline std::shared_ptr<C4ThreadPool> Global{};
314
315private:
316#ifdef _WIN32
317 CallbackEnvironment callbackEnvironment;
318 winrt::handle_type<ThreadPoolTraits> pool;
319 winrt::handle_type<ThreadPoolCleanupTraits> cleanupGroup;
320#else
321 std::vector<std::thread> threads;
322 std::atomic_bool quit{false};
323 std::queue<Callback> callbacks;
324 std::counting_semaphore<> availableCallbacks{0};
325 std::mutex callbackMutex;
326#endif
327};
328