diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index ba61cd49a..670a65403 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -151,11 +151,12 @@ namespace winrt::impl #ifdef WINRT_IMPL_COROUTINES template - struct await_adapter : cancellable_awaiter> + struct await_adapter : cancellable_awaiter> { - await_adapter(Async const& async) : async(async) { } + template + await_adapter(T&& async) : async(std::forward(async)) { } - Async const& async; + std::conditional_t async; Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started; int32_t failure = 0; std::atomic suspending = true; @@ -190,6 +191,11 @@ namespace winrt::impl private: bool register_completed_callback(coroutine_handle<> handle) { + if constexpr (!preserve_context) + { + // Ensure that the illegal delegate assignment propagates properly. + suspending.store(true, std::memory_order_relaxed); + } async.Completed(disconnect_aware_handler(this, handle)); return suspending.exchange(false, std::memory_order_acquire); } @@ -257,9 +263,9 @@ namespace winrt::impl WINRT_EXPORT namespace winrt { template>> - inline impl::await_adapter resume_agile(Async const& async) + inline impl::await_adapter, false> resume_agile(Async&& async) { - return { async }; + return { std::forward(async) }; }; } diff --git a/test/test/await_adapter.cpp b/test/test/await_adapter.cpp index 701bc6b15..dccb9ce3d 100644 --- a/test/test/await_adapter.cpp +++ b/test/test/await_adapter.cpp @@ -136,3 +136,36 @@ TEST_CASE("await_adapter_agile") AgileAsync(dispatcher).get(); controller.ShutdownQueueAsync().get(); } + +namespace +{ + IAsyncAction AgileAsyncVariable(DispatcherQueue dispatcher) + { + // Switch to the STA. + co_await resume_foreground(dispatcher); + REQUIRE(is_sta()); + + // Ask for agile resumption of a coroutine that finishes on a background thread. + // Add a 100ms delay to ensure we suspend. Store the resume_agile in a variable + // and await the variable. + auto op = resume_agile(OtherBackgroundDelayAsync()); + co_await op; + // We should be on the background thread now. + REQUIRE(!is_sta()); + + // Second attempt to await the op should fail cleanly. + REQUIRE_THROWS_AS(co_await op, hresult_illegal_delegate_assignment); + // We should still be on the background thread. + REQUIRE(!is_sta()); + } +} + + +TEST_CASE("await_adapter_agile_variable") +{ + auto controller = DispatcherQueueController::CreateOnDedicatedThread(); + auto dispatcher = controller.DispatcherQueue(); + + AgileAsyncVariable(dispatcher).get(); + controller.ShutdownQueueAsync().get(); +}