1/*
2 * LegacyClonk
3 *
4 * Copyright (c) 2023, The LegacyClonk Team and contributors
5 *
6 * Distributed under the terms of the ISC license; see accompanying file
7 * "COPYING" for details.
8 *
9 * "Clonk" is a registered trademark of Matthes Bender, used with permission.
10 * See accompanying file "TRADEMARK" for details.
11 *
12 * To redistribute this file separately, substitute the full license texts
13 * for the above references.
14 */
15
16#pragma once
17
18#include "C4Attributes.h"
19
20#ifdef _WIN32
21#include "C4WinRT.h"
22#endif
23
24#include "StdSync.h"
25
26#include <atomic>
27#include <bit>
28#include <concepts>
29#include <coroutine>
30#include <exception>
31#include <functional>
32#include <memory>
33#include <mutex>
34#include <utility>
35
36namespace C4Task
37{
38
39class CancelledException : public std::exception
40#ifdef _WIN32
41 , public winrt::hresult_canceled
42#endif
43{
44public:
45 CancelledException() = default;
46
47#ifdef _WIN32
48 CancelledException(const winrt::hresult_canceled &e) : hresult_canceled{e} {}
49 CancelledException(winrt::hresult_canceled &&e) : hresult_canceled{std::move(e)} {}
50#endif
51
52public:
53 const char *what() const noexcept override { return "Cancelled"; }
54};
55
56template<typename U>
57struct ValueWrapper
58{
59 [[NO_UNIQUE_ADDRESS]] U Value;
60 U &&GetValue() { return static_cast<U &&>(Value); }
61};
62
63template<>
64struct ValueWrapper<void>
65{
66 void GetValue() const {}
67};
68
69template<typename T, bool IsNoExcept>
70union Optional
71{
72public:
73 Optional() noexcept {}
74 ~Optional() noexcept {}
75
76 ValueWrapper<T> Wrapper;
77 std::exception_ptr Exception;
78};
79
80template<typename T>
81union Optional<T, true>
82{
83public:
84 Optional() noexcept {}
85 ~Optional() noexcept {}
86
87 ValueWrapper<T> Wrapper;
88};
89
90template<typename T, bool IsNoExcept = false>
91class ResultBase
92{
93protected:
94 decltype(auto) GetValue()
95 {
96 return result.Wrapper.GetValue();
97 }
98
99 [[noreturn]] void ReThrowException() const
100 {
101 std::rethrow_exception(result.Exception);
102 }
103
104 template<typename U = std::type_identity_t<T>>
105 void CreateValue(U &&value) noexcept(noexcept(std::construct_at(std::addressof(result.Wrapper), std::forward<U>(value)))) requires (!std::same_as<U, void>)
106 {
107 std::construct_at(std::addressof(result.Wrapper), std::forward<U>(value));
108 }
109
110 void CreateValue() noexcept requires std::same_as<T, void>
111 {
112 std::construct_at(std::addressof(result.Wrapper));
113 }
114
115 void CreateException(std::exception_ptr &&ptr) requires (!IsNoExcept)
116 {
117 std::construct_at(std::addressof(this->result.Exception), std::move(ptr));
118 }
119
120 void DestroyWrapper() noexcept
121 {
122 std::destroy_at(std::addressof(result.Wrapper));
123 }
124
125 void DestroyException() noexcept requires (!IsNoExcept)
126 {
127 std::destroy_at(std::addressof(result.Exception));
128 }
129
130protected:
131 Optional<T, IsNoExcept> result;
132};
133
134template<>
135class ResultBase<void, true>
136{
137protected:
138 constexpr void GetValue() const noexcept {}
139 constexpr void CreateValue() const noexcept {}
140 constexpr void DestroyWrapper() const noexcept {}
141};
142
143template<typename U, bool IsNoExcept_>
144struct EnumBaseType
145{
146 using Type = int;
147};
148
149template<typename U>
150struct EnumBaseType<U, true>
151{
152 using Type = std::conditional_t<
153 sizeof(U) == 1, std::uint8_t, std::conditional_t<
154 sizeof(U) == 2, std::uint16_t, int
155 >
156 >;
157};
158
159template<>
160struct EnumBaseType<void, true>
161{
162 using Type = std::uint8_t;
163};
164
165template<typename T, bool IsNoExcept = false>
166class Result : protected ResultBase<T, IsNoExcept>
167{
168private:
169 enum ResultState : typename EnumBaseType<T, IsNoExcept>::Type
170 {
171 NotPresent,
172 Present,
173 Exception
174 };
175
176public:
177 Result() = default;
178 ~Result()
179 {
180 switch (resultState)
181 {
182 case ResultState::Present:
183 ResultBase<T, IsNoExcept>::DestroyWrapper();
184 break;
185
186 case ResultState::Exception:
187 if constexpr (!IsNoExcept)
188 {
189 ResultBase<T, IsNoExcept>::DestroyException();
190 }
191 else
192 {
193 std::unreachable();
194 }
195 break;
196
197 default:
198 break;
199 }
200 }
201
202 Result(const Result &) = delete;
203 Result &operator=(const Result &) = delete;
204
205public:
206 decltype(auto) GetResult()
207 {
208 switch (resultState)
209 {
210 case ResultState::Present:
211 return ResultBase<T, IsNoExcept>::GetValue();
212
213 case ResultState::Exception:
214 if constexpr (!IsNoExcept)
215 {
216 ResultBase<T, IsNoExcept>::ReThrowException();
217 }
218 else
219 {
220 std::unreachable();
221 }
222
223 default:
224 std::terminate();
225 }
226 }
227
228 template<typename U = std::type_identity_t<T>>
229 void SetResult(U &&result) noexcept(noexcept(ResultBase<T, IsNoExcept>::CreateValue(std::forward<U>(result)))) requires (!std::same_as<U, void>)
230 {
231 if (resultState == ResultState::NotPresent)
232 {
233 ResultBase<T, IsNoExcept>::CreateValue(std::forward<U>(result));
234 resultState = ResultState::Present;
235 }
236 }
237
238 void SetResult() noexcept requires std::same_as<T, void>
239 {
240 if (resultState == ResultState::NotPresent)
241 {
242 ResultBase<T, IsNoExcept>::CreateValue();
243 resultState = ResultState::Present;
244 }
245 }
246
247 void SetException(std::exception_ptr &&exceptionPtr) noexcept requires (!IsNoExcept)
248 {
249 if (resultState == ResultState::NotPresent)
250 {
251 ResultBase<T, IsNoExcept>::CreateException(std::move(exceptionPtr));
252 resultState = ResultState::Exception;
253 }
254 }
255
256private:
257 ResultState resultState{ResultState::NotPresent};
258};
259
260struct CancellationTokenMarker {};
261
262inline CancellationTokenMarker GetCancellationToken()
263{
264 return {};
265}
266
267struct PromiseAwaiterMarker {};
268
269inline PromiseAwaiterMarker GetPromise()
270{
271 return {};
272}
273
274class CancellablePromise
275#ifdef _WIN32
276 : public winrt::cancellable_promise
277#endif
278{
279public:
280 class CancellationToken
281 {
282 public:
283 constexpr CancellationToken(const CancellablePromise *const promise) : promise{promise} {}
284
285 public:
286 constexpr bool await_ready() const noexcept { return true; }
287 constexpr void await_suspend(const std::coroutine_handle<>) const noexcept {}
288 constexpr CancellationToken await_resume() const noexcept { return *this; }
289
290 bool operator()() const noexcept
291 {
292 return promise->IsCancelled();
293 }
294
295 private:
296 const CancellablePromise *promise;
297 };
298
299private:
300 struct PromiseAwaiter
301 {
302 CancellablePromise &promise;
303
304 constexpr bool await_ready() const noexcept
305 {
306 return true;
307 }
308
309 constexpr void await_suspend(const std::coroutine_handle<>) const noexcept {}
310
311 CancellablePromise &await_resume() const noexcept
312 {
313 return promise;
314 }
315 };
316
317public:
318#ifdef _WIN32
319 CancellablePromise()
320 {
321 enable_cancellation_propagation(true);
322 }
323#endif
324
325 void Cancel()
326 {
327#ifdef _WIN32
328 cancelled.store(true, std::memory_order_release);
329 cancel();
330#else
331
332 const auto callback = cancellationCallback.exchange(p: CancelledSentinel, m: std::memory_order_acq_rel);
333
334 if (callback && callback != CancelledSentinel)
335 {
336 const struct Cleanup
337 {
338 CancellablePromise &Promise;
339
340 ~Cleanup()
341 {
342 Promise.cancellationCallback.store(p: CancelledSentinel, m: std::memory_order_release);
343 }
344 } cleanup{.Promise: *this};
345
346 callback(cancellationArgument);
347 }
348#endif
349 }
350
351 void SetCancellationCallback(void(*const callback)(void *), void *const argument) noexcept
352 {
353#ifdef _WIN32
354 set_canceller(callback, argument);
355#else
356 cancellationArgument = argument;
357 cancellationCallback.store(p: callback, m: std::memory_order_release);
358#endif
359 }
360
361 void ResetCancellationCallback() noexcept
362 {
363#ifdef _WIN32
364 revoke_canceller();
365#else
366 auto callback = cancellationCallback.load(m: std::memory_order_acquire);
367
368 do
369 {
370 if (!callback || callback == CancelledSentinel)
371 {
372 return;
373 }
374 }
375 while (!cancellationCallback.compare_exchange_weak(p1&: callback, p2: nullptr, m: std::memory_order_release));
376#endif
377 }
378
379 bool IsCancelled() const noexcept
380 {
381#ifdef _WIN32
382 return cancelled.load(std::memory_order_acquire);
383#else
384 const auto callback = cancellationCallback.load(m: std::memory_order_acquire);
385 return callback == CancelledSentinel;
386#endif
387 }
388
389 CancellationToken await_transform(const CancellationTokenMarker) noexcept
390 {
391 return {this};
392 }
393
394 PromiseAwaiter await_transform(const PromiseAwaiterMarker) noexcept
395 {
396 return {.promise: *this};
397 }
398
399#ifdef _WIN32
400private:
401 std::atomic_bool cancelled{false};
402#else
403public:
404 static inline const auto CancelledSentinel = std::bit_cast<void(*)(void *)>(from: reinterpret_cast<void *>(0x02));
405
406protected:
407 std::atomic<void(*)(void *)> cancellationCallback{nullptr};
408 void *cancellationArgument{nullptr};
409#endif
410};
411
412template<typename Class>
413class CancellableAwaiter
414#ifdef _WIN32
415 : public winrt::cancellable_awaiter<Class>
416#endif
417{
418public:
419 CancellableAwaiter() noexcept = default;
420
421#ifndef _WIN32
422 ~CancellableAwaiter() noexcept
423 {
424 if (promise)
425 {
426 promise->ResetCancellationCallback();
427 }
428 }
429#endif
430
431 CancellableAwaiter(const CancellableAwaiter &) noexcept = delete;
432 CancellableAwaiter &operator=(const CancellableAwaiter &) noexcept = delete;
433
434#ifndef _WIN32
435 CancellableAwaiter(CancellableAwaiter &&other) noexcept : promise{std::exchange(obj&: other.promise, new_val: nullptr)} {}
436
437 CancellableAwaiter &operator=(CancellableAwaiter &&other) noexcept
438 {
439 promise = std::exchange(obj&: other.promise, new_val: nullptr);
440 return *this;
441 }
442#endif
443
444public:
445 template<typename T>
446 void SetCancellablePromise(const std::coroutine_handle<T> handle)
447 {
448#ifdef _WIN32
449 winrt::cancellable_awaiter<Class>::set_cancellable_promise_from_handle(handle);
450#else
451 if constexpr (std::derived_from<T, CancellablePromise>)
452 {
453 promise = &handle.promise();
454 static_cast<Class *>(this)->SetupCancellation(promise);
455 }
456#endif
457 }
458
459#ifdef _WIN32
460 void enable_cancellation(winrt::cancellable_promise *const promise)
461 {
462 static_cast<Class *>(this)->SetupCancellation(static_cast<CancellablePromise *>(promise));
463 }
464#endif
465
466private:
467#ifndef _WIN32
468 CancellablePromise *promise{nullptr};
469#endif
470};
471
472struct PromiseTraitsDefault
473{
474 static constexpr bool TerminateOnException{false};
475};
476
477struct PromiseTraitsTerminateOnException : PromiseTraitsDefault
478{
479 static constexpr bool TerminateOnException{true};
480};
481
482template<typename T, typename PromiseTraits>
483struct Promise : CancellablePromise
484{
485 struct Deleter
486 {
487 void operator()(Promise *const promise) noexcept;
488 };
489
490 struct FinalSuspendAwaiter : std::suspend_always
491 {
492 Promise &promise;
493
494 std::coroutine_handle<> await_suspend(const std::coroutine_handle<>) const noexcept
495 {
496 const auto waiter = promise.waiting.exchange(p: CompletedSentinel, m: std::memory_order_acq_rel);
497 if (waiter == AbandonedSentinel)
498 {
499 promise.GetHandle().destroy();
500 }
501
502 else if (waiter != StartedSentinel)
503 {
504 if (promise.resumer)
505 {
506 promise.resumer(waiter);
507 }
508 else
509 {
510 return std::coroutine_handle<>::from_address(a: waiter);
511 }
512 }
513
514 return std::noop_coroutine();
515 }
516 };
517
518 using PromisePtr = std::unique_ptr<Promise, Deleter>;
519 using ResumeFunction = void(*)(void *);
520
521 auto get_return_object() noexcept { return this; }
522
523 constexpr std::suspend_always initial_suspend() const noexcept
524 {
525 return {};
526 }
527
528 FinalSuspendAwaiter final_suspend() noexcept
529 {
530 return {.promise = *this};
531 }
532
533 void unhandled_exception() const noexcept
534 {
535 std::exception_ptr ptr{std::current_exception()};
536
537#ifndef _WIN32
538 if constexpr (PromiseTraits::TerminateOnException)
539 {
540#endif
541 try
542 {
543 std::rethrow_exception(ptr);
544 }
545 catch (const C4Task::CancelledException &)
546 {
547 SetException(ptr);
548 }
549#ifdef _WIN32
550 catch (const winrt::hresult_canceled &e)
551 {
552 SetException(std::make_exception_ptr(CancelledException{e}));
553 }
554#endif
555 catch (...)
556 {
557 if constexpr (PromiseTraits::TerminateOnException)
558 {
559 std::terminate();
560 }
561 else
562 {
563 SetException(ptr);
564 }
565 }
566#ifndef _WIN32
567 }
568 else
569 {
570 SetException(ptr);
571 }
572#endif
573 }
574
575 template<typename Awaiter>
576 Awaiter &&await_transform(Awaiter &&awaiter)
577 {
578 if (IsCancelled())
579 {
580 throw CancelledException{};
581 }
582
583 return std::forward<Awaiter>(awaiter);
584 }
585
586 using CancellablePromise::await_transform;
587
588 template<typename U = T> requires(!std::same_as<T, void>)
589 void SetResult(U &&result) const noexcept(noexcept(this->result.SetResult(std::forward<U>(result))))
590 {
591 this->result.SetResult(std::forward<U>(result));
592 }
593
594 void SetResult() const noexcept
595 {
596 this->result.SetResult();
597 }
598
599 void SetException(std::exception_ptr e) const noexcept
600 {
601 this->result.SetException(std::move(e));
602 }
603
604 void Start() noexcept
605 {
606 void *expected{ColdSentinel};
607 if (waiting.compare_exchange_strong(p1&: expected, p2: StartedSentinel, m: std::memory_order_acq_rel))
608 {
609 GetHandle().resume();
610 }
611 }
612
613 void Abandon() noexcept
614 {
615 const auto handle = waiting.exchange(p: AbandonedSentinel, m: std::memory_order_acq_rel);
616 if (handle != StartedSentinel)
617 {
618 GetHandle().destroy();
619 }
620 }
621
622 bool AwaitReady() const noexcept
623 {
624 return waiting.load(m: std::memory_order_acquire) != StartedSentinel;
625 }
626
627 bool AwaitSuspend(void *const address, const ResumeFunction resumer = {}) noexcept
628 {
629 this->resumer = resumer;
630 return waiting.exchange(p: address, m: std::memory_order_acq_rel) == StartedSentinel;
631 }
632
633 bool ColdAwaitSuspend(void *const address, const ResumeFunction resumer = {}) noexcept
634 {
635 Start();
636 return AwaitSuspend(address, resumer);
637 }
638
639 decltype(auto) AwaitResume()
640 {
641 return result.GetResult();
642 }
643
644private:
645 auto GetHandle() noexcept
646 {
647 return std::coroutine_handle<Promise>::from_promise(*this);
648 }
649
650public:
651 static inline const auto StartedSentinel = reinterpret_cast<void *>(0x00);
652 static inline const auto CompletedSentinel = reinterpret_cast<void *>(0x01);
653 static inline const auto AbandonedSentinel = reinterpret_cast<void *>(0x02);
654 static inline const auto ColdSentinel = reinterpret_cast<void *>(0x03);
655
656private:
657 std::atomic<void *> waiting{ColdSentinel};
658 ResumeFunction resumer;
659 mutable Result<T> result;
660};
661
662template<typename T, typename PromiseTraits>
663void Promise<T, PromiseTraits>::Deleter::operator()(Promise *const promise) noexcept
664{
665 if (promise)
666 {
667 promise->Abandon();
668 }
669}
670
671struct TaskTraitsDefault
672{
673 static constexpr bool IsColdStart{false};
674 static constexpr bool CancelAndWaitOnDestruction{false};
675};
676
677struct TaskTraitsCold : public TaskTraitsDefault
678{
679 static constexpr bool IsColdStart{true};
680};
681
682struct TaskTraitsHotWaitOnDestruction : TaskTraitsDefault
683{
684 static constexpr bool CancelAndWaitOnDestruction{true};
685};
686
687struct TaskTraitsColdWaitOnDestruction : TaskTraitsCold
688{
689 static constexpr bool CancelAndWaitOnDestruction{true};
690};
691
692template<typename T, typename TaskTraits, typename PromiseTraits = PromiseTraitsDefault>
693class Task
694{
695private:
696 struct Awaiter : public CancellableAwaiter<Awaiter>
697 {
698 typename Promise<T, PromiseTraits>::PromisePtr promisePtr;
699
700 constexpr bool await_ready() const
701 {
702 return false;
703 }
704
705 template<typename U>
706 bool await_suspend(const std::coroutine_handle<U> handle)
707 {
708 CancellableAwaiter<Awaiter>::SetCancellablePromise(handle);
709
710 if constexpr (TaskTraits::IsColdStart)
711 {
712 return promisePtr->ColdAwaitSuspend(handle.address());
713 }
714 else
715 {
716 return promisePtr->AwaitSuspend(handle.address());
717 }
718 }
719
720 decltype(auto) await_resume()
721 {
722 return promisePtr->AwaitResume();
723 }
724
725 void SetupCancellation(CancellablePromise *const promise)
726 {
727 promise->SetCancellationCallback(callback: [](void *const context)
728 {
729 reinterpret_cast<typename Promise<T, PromiseTraits>::PromisePtr::pointer>(context)->Cancel();
730 }, argument: this->promisePtr.get());
731 }
732 };
733
734public:
735 Task() = default;
736 Task(Promise<T, PromiseTraits> *const promise) noexcept : promise{promise}
737 {
738 if constexpr (!TaskTraits::IsColdStart)
739 {
740 promise->Start();
741 }
742 }
743
744 ~Task() noexcept
745 {
746 if constexpr (TaskTraits::CancelAndWaitOnDestruction)
747 {
748 if (promise)
749 {
750 std::move(*this).CancelAndWait();
751 }
752 }
753 }
754
755 Task(Task &&) = default;
756 Task &operator=(Task &&) = default;
757
758public:
759 T Get() &&
760 {
761 if constexpr (TaskTraits::IsColdStart)
762 {
763 promise->Start();
764 }
765
766 if (!promise->AwaitReady())
767 {
768 constexpr auto signalEvent = [](void *const event)
769 {
770 reinterpret_cast<CStdEvent *>(event)->Set();
771 };
772
773 CStdEvent event;
774 if (promise->AwaitSuspend(&event, signalEvent))
775 {
776 event.WaitFor(milliseconds: StdSync::Infinite);
777 }
778 }
779
780 return promise->AwaitResume();
781 }
782
783 void Cancel() const
784 {
785 promise->Cancel();
786 }
787
788 bool IsCancelled() const noexcept
789 {
790 return promise->IsCancelled();
791 }
792
793 void CancelAndWait() && noexcept(PromiseTraitsTerminateOnException::TerminateOnException)
794 {
795 Cancel();
796
797 try
798 {
799 (void) std::move(*this).Get();
800 }
801 catch (const CancelledException &)
802 {
803 }
804 }
805
806 void Start() const requires TaskTraits::IsColdStart
807 {
808 promise->Start();
809 }
810
811 explicit operator bool() const noexcept
812 {
813 return promise.get();
814 }
815
816 Awaiter operator co_await() && noexcept
817 {
818 return {.promisePtr = std::move(promise)};
819 }
820
821public:
822 static constexpr inline bool IsColdStart{false};
823
824protected:
825 typename Promise<T, PromiseTraits>::PromisePtr promise;
826};
827
828template<typename T, typename PromiseTraits = PromiseTraitsDefault>
829using Hot = Task<T, TaskTraitsDefault, PromiseTraits>;
830
831template<typename T, typename PromiseTraits = PromiseTraitsDefault>
832using Cold = Task<T, TaskTraitsCold, PromiseTraits>;
833
834class OneShot
835{
836public:
837 struct promise_type
838 {
839 constexpr OneShot get_return_object() const noexcept { return {}; }
840 constexpr std::suspend_never initial_suspend() const noexcept { return {}; }
841 constexpr std::suspend_never final_suspend() const noexcept { return {}; }
842 constexpr void return_void() const noexcept {}
843 void unhandled_exception() const noexcept { std::terminate(); }
844 };
845};
846
847}
848
849template<template<typename, typename, typename> typename Task, typename T, typename TaskTraits, typename PromiseTraits, typename... Args> requires std::derived_from<Task<T, TaskTraits, PromiseTraits>, C4Task::Task<T, TaskTraits, PromiseTraits>> && (!std::same_as<T, void>)
850struct std::coroutine_traits<Task<T, TaskTraits, PromiseTraits>, Args...>
851{
852 struct promise_type : C4Task::Promise<T, PromiseTraits>
853 {
854 void return_value(T &&value) const noexcept(noexcept(C4Task::Promise<T, PromiseTraits>::SetResult(std::forward<T>(value))))
855 {
856 C4Task::Promise<T, PromiseTraits>::SetResult(std::forward<T>(value));
857 }
858 };
859};
860
861template<template<typename, typename, typename> typename Task, typename T, typename TaskTraits, typename PromiseTraits, typename... Args> requires std::derived_from<Task<T, TaskTraits, PromiseTraits>, C4Task::Task<T, TaskTraits, PromiseTraits>> && std::same_as<T, void>
862struct std::coroutine_traits<Task<T, TaskTraits, PromiseTraits>, Args...>
863{
864 struct promise_type : C4Task::Promise<T, PromiseTraits>
865 {
866 void return_void() const noexcept(noexcept(C4Task::Promise<T, PromiseTraits>::SetResult()))
867 {
868 C4Task::Promise<T, PromiseTraits>::SetResult();
869 }
870 };
871};
872