diff --git a/src/DotNext.Tests/Threading/Tasks/ConversionTests.cs b/src/DotNext.Tests/Threading/Tasks/ConversionTests.cs index b66ca9599..037c18904 100644 --- a/src/DotNext.Tests/Threading/Tasks/ConversionTests.cs +++ b/src/DotNext.Tests/Threading/Tasks/ConversionTests.cs @@ -55,4 +55,12 @@ public static async Task SuspendException2() var result = await t.AsDynamic().ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing | ConfigureAwaitOptions.ContinueOnCapturedContext); Same(result, Missing.Value); } + + [Fact] + public static async Task SuspendExceptionParametrized() + { + await Task.FromException(new Exception()).SuspendException(42, (_, i) => i is 42); + await ValueTask.FromException(new Exception()).SuspendException(42, (_, i) => i is 42); + await ThrowsAsync(async () => await Task.FromException(new Exception()).SuspendException(43, (_, i) => i is 42)); + } } \ No newline at end of file diff --git a/src/DotNext.Threading/Threading/Tasks/TaskQueue.cs b/src/DotNext.Threading/Threading/Tasks/TaskQueue.cs index 18e92f1ff..3c1aa4f7e 100644 --- a/src/DotNext.Threading/Threading/Tasks/TaskQueue.cs +++ b/src/DotNext.Threading/Threading/Tasks/TaskQueue.cs @@ -178,7 +178,7 @@ public async ValueTask EnqueueAsync(T task, CancellationToken token = default) } } - private T? TryPeekOrDequeue(out int head, out Task enqueueTask, out bool completed) + private T? TryPeekOrDequeue(out int head, out Task enqueueTask) { T? result; lock (array) @@ -187,7 +187,7 @@ public async ValueTask EnqueueAsync(T task, CancellationToken token = default) { result = this[head = this.head]; enqueueTask = Task.CompletedTask; - if (completed = result is { IsCompleted: true }) + if (result is { IsCompleted: true }) { MoveNext(ref head); ChangeCount(increment: false); @@ -197,7 +197,6 @@ public async ValueTask EnqueueAsync(T task, CancellationToken token = default) { head = default; result = null; - completed = default; signal ??= new(); enqueueTask = signal.Task; } @@ -258,17 +257,17 @@ public bool TryDequeue([NotNullWhen(true)] out T? task) /// The operation has been canceled. public async ValueTask DequeueAsync(CancellationToken token = default) { - for (var filter = token.CanBeCanceled ? null : Predicate.Constant(true);;) + for (;;) { - if (TryPeekOrDequeue(out var expectedHead, out var enqueueTask, out var completed) is not { } task) + if (TryPeekOrDequeue(out var expectedHead, out var enqueueTask) is not { } task) { await enqueueTask.WaitAsync(token).ConfigureAwait(false); continue; } - if (!completed) + if (!task.IsCompleted) { - await task.WaitAsync(token).SuspendException(filter ??= token.SuspendAllExceptCancellation).ConfigureAwait(false); + await task.WaitAsync(token).SuspendException(token, SuspendAllExceptCancellation).ConfigureAwait(false); if (!TryDequeue(expectedHead, task)) continue; @@ -286,12 +285,12 @@ public async ValueTask DequeueAsync(CancellationToken token = default) /// The operation has been canceled. public async ValueTask TryDequeueAsync(CancellationToken token = default) { - for (var filter = token.CanBeCanceled ? null : Predicate.Constant(true);;) + for (;;) { T? task; - if ((task = TryPeekOrDequeue(out var expectedHead, out _, out var completed)) is not null && !completed) + if ((task = TryPeekOrDequeue(out var expectedHead, out _)) is not null && !task.IsCompleted) { - await task.WaitAsync(token).SuspendException(filter ??= token.SuspendAllExceptCancellation).ConfigureAwait(false); + await task.WaitAsync(token).SuspendException(token, SuspendAllExceptCancellation).ConfigureAwait(false); if (!TryDequeue(expectedHead, task)) continue; @@ -311,12 +310,11 @@ public async ValueTask DequeueAsync(CancellationToken token = default) /// The enumerator over completed tasks. public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken token) { - for (var filter = token.CanBeCanceled ? null : Predicate.Constant(true); - TryPeekOrDequeue(out var expectedHead, out _, out var completed) is { } task;) + while (TryPeekOrDequeue(out var expectedHead, out _) is { } task) { - if (!completed) + if (!task.IsCompleted) { - await task.WaitAsync(token).SuspendException(filter ??= token.SuspendAllExceptCancellation).ConfigureAwait(false); + await task.WaitAsync(token).SuspendException(token, SuspendAllExceptCancellation).ConfigureAwait(false); if (!TryDequeue(expectedHead, task)) continue; } @@ -325,6 +323,9 @@ public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken token) } } + private static bool SuspendAllExceptCancellation(Exception e, CancellationToken token) + => e is not OperationCanceledException canceledEx || token != canceledEx.CancellationToken; + /// /// Clears the queue. /// @@ -343,10 +344,4 @@ public void Clear() void IResettable.Reset() => Clear(); private sealed class Signal() : TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -} - -file static class CancellationTokenExtensions -{ - internal static bool SuspendAllExceptCancellation(this object token, Exception e) - => e is not OperationCanceledException canceledEx || !canceledEx.CancellationToken.Equals(token); } \ No newline at end of file diff --git a/src/DotNext/Runtime/CompilerServices/SuspendedExceptionTaskAwaitable.cs b/src/DotNext/Runtime/CompilerServices/SuspendedExceptionTaskAwaitable.cs index 6ec23a437..de636f575 100644 --- a/src/DotNext/Runtime/CompilerServices/SuspendedExceptionTaskAwaitable.cs +++ b/src/DotNext/Runtime/CompilerServices/SuspendedExceptionTaskAwaitable.cs @@ -1,3 +1,4 @@ +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -15,7 +16,9 @@ internal SuspendedExceptionTaskAwaitable(ValueTask task) => this.task = task; internal SuspendedExceptionTaskAwaitable(Task task) - => this.task = new(task); + : this(new ValueTask(task)) + { + } internal bool ContinueOnCapturedContext { @@ -91,4 +94,97 @@ public void GetResult() } } } +} + +/// +/// Represents awaitable object that can suspend exception raised by the underlying task. +/// +/// The type of the argument to be passed to the exception filter. +[StructLayout(LayoutKind.Auto)] +public readonly struct SuspendedExceptionTaskAwaitable +{ + private readonly TArg arg; + private readonly ValueTask task; + private readonly Func filter; + + internal SuspendedExceptionTaskAwaitable(ValueTask task, TArg arg, Func filter) + { + Debug.Assert(filter is not null); + + this.task = task; + this.arg = arg; + this.filter = filter; + } + + internal SuspendedExceptionTaskAwaitable(Task task, TArg arg, Func filter) + : this(new ValueTask(task), arg, filter) + { + } + + internal bool ContinueOnCapturedContext + { + get; + init; + } + + /// + /// Configures an awaiter for this value. + /// + /// + /// to attempt to marshal the continuation back to the captured context; + /// otherwise, . + /// + /// The configured object. + public SuspendedExceptionTaskAwaitable ConfigureAwait(bool continueOnCapturedContext) + => this with { ContinueOnCapturedContext = continueOnCapturedContext }; + + /// + /// Gets the awaiter for this object. + /// + /// The awaiter for this object. + public Awaiter GetAwaiter() => new(task, arg, filter, ContinueOnCapturedContext); + + /// + /// Represents the awaiter that suspends exception. + /// + [StructLayout(LayoutKind.Auto)] + public readonly struct Awaiter : ICriticalNotifyCompletion + { + private readonly ConfiguredValueTaskAwaitable.ConfiguredValueTaskAwaiter awaiter; + private readonly TArg arg; + private readonly Func filter; + + internal Awaiter(in ValueTask task, TArg arg, Func filter, bool continueOnCapturedContext) + { + awaiter = task.ConfigureAwait(continueOnCapturedContext).GetAwaiter(); + this.arg = arg; + this.filter = filter; + } + + /// + /// Gets a value indicating that has completed. + /// + public bool IsCompleted => awaiter.IsCompleted; + + /// + public void OnCompleted(Action action) => awaiter.OnCompleted(action); + + /// + public void UnsafeOnCompleted(Action action) => awaiter.UnsafeOnCompleted(action); + + /// + /// Obtains a result of asynchronous operation, and suspends exception if needed. + /// + public void GetResult() + { + try + { + awaiter.GetResult(); + } + catch (Exception e) when (filter.Invoke(e, arg)) + { + // suspend exception + } + } + } } \ No newline at end of file diff --git a/src/DotNext/Threading/Tasks/Conversion.cs b/src/DotNext/Threading/Tasks/Conversion.cs index 0283f4fcd..daa2c5498 100644 --- a/src/DotNext/Threading/Tasks/Conversion.cs +++ b/src/DotNext/Threading/Tasks/Conversion.cs @@ -3,7 +3,7 @@ namespace DotNext.Threading.Tasks; -using SuspendedExceptionTaskAwaitable = Runtime.CompilerServices.SuspendedExceptionTaskAwaitable; +using Runtime.CompilerServices; /// /// Provides task result conversion methods. @@ -84,4 +84,32 @@ public static SuspendedExceptionTaskAwaitable SuspendException(this Task task, P /// The awaitable object that suspends exceptions according to the filter. public static SuspendedExceptionTaskAwaitable SuspendException(this ValueTask task, Predicate? filter = null) => new(task) { Filter = filter }; + + /// + /// Suspends the exception that can be raised by the task. + /// + /// The task. + /// The argument to be passed to the filter. + /// The filter of the exception to be suspended. + /// The awaitable object that suspends exceptions according to the filter. + public static SuspendedExceptionTaskAwaitable SuspendException(this Task task, TArg arg, Func filter) + { + ArgumentNullException.ThrowIfNull(filter); + + return new(task, arg, filter); + } + + /// + /// Suspends the exception that can be raised by the task. + /// + /// The task. + /// The argument to be passed to the filter. + /// The filter of the exception to be suspended. + /// The awaitable object that suspends exceptions according to the filter. + public static SuspendedExceptionTaskAwaitable SuspendException(this ValueTask task, TArg arg, Func filter) + { + ArgumentNullException.ThrowIfNull(filter); + + return new(task, arg, filter); + } } \ No newline at end of file