Skip to content

Commit

Permalink
Introduced parametrized exception handler
Browse files Browse the repository at this point in the history
  • Loading branch information
sakno committed Sep 23, 2024
1 parent 71744f1 commit 9ab5679
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 22 deletions.
8 changes: 8 additions & 0 deletions src/DotNext.Tests/Threading/Tasks/ConversionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,12 @@ public static async Task SuspendException2()
var result = await t.AsDynamic().ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing | ConfigureAwaitOptions.ContinueOnCapturedContext);
Same(result, Missing.Value);
}

[Fact]
public static async Task SuspendExceptionParametrized()
{
await Task.FromException(new Exception()).SuspendException(42, (_, i) => i is 42);
await ValueTask.FromException(new Exception()).SuspendException(42, (_, i) => i is 42);
await ThrowsAsync<Exception>(async () => await Task.FromException(new Exception()).SuspendException(43, (_, i) => i is 42));
}
}
35 changes: 15 additions & 20 deletions src/DotNext.Threading/Threading/Tasks/TaskQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public async ValueTask EnqueueAsync(T task, CancellationToken token = default)
}
}

private T? TryPeekOrDequeue(out int head, out Task enqueueTask, out bool completed)
private T? TryPeekOrDequeue(out int head, out Task enqueueTask)
{
T? result;
lock (array)
Expand All @@ -187,7 +187,7 @@ public async ValueTask EnqueueAsync(T task, CancellationToken token = default)
{
result = this[head = this.head];
enqueueTask = Task.CompletedTask;
if (completed = result is { IsCompleted: true })
if (result is { IsCompleted: true })
{
MoveNext(ref head);
ChangeCount(increment: false);
Expand All @@ -197,7 +197,6 @@ public async ValueTask EnqueueAsync(T task, CancellationToken token = default)
{
head = default;
result = null;
completed = default;
signal ??= new();
enqueueTask = signal.Task;
}
Expand Down Expand Up @@ -258,17 +257,17 @@ public bool TryDequeue([NotNullWhen(true)] out T? task)
/// <exception cref="OperationCanceledException">The operation has been canceled.</exception>
public async ValueTask<T> DequeueAsync(CancellationToken token = default)
{
for (var filter = token.CanBeCanceled ? null : Predicate.Constant<Exception>(true);;)
for (;;)
{
if (TryPeekOrDequeue(out var expectedHead, out var enqueueTask, out var completed) is not { } task)
if (TryPeekOrDequeue(out var expectedHead, out var enqueueTask) is not { } task)
{
await enqueueTask.WaitAsync(token).ConfigureAwait(false);
continue;
}

if (!completed)
if (!task.IsCompleted)
{
await task.WaitAsync(token).SuspendException(filter ??= token.SuspendAllExceptCancellation).ConfigureAwait(false);
await task.WaitAsync(token).SuspendException(token, SuspendAllExceptCancellation).ConfigureAwait(false);

if (!TryDequeue(expectedHead, task))
continue;
Expand All @@ -286,12 +285,12 @@ public async ValueTask<T> DequeueAsync(CancellationToken token = default)
/// <exception cref="OperationCanceledException">The operation has been canceled.</exception>
public async ValueTask<T?> TryDequeueAsync(CancellationToken token = default)
{
for (var filter = token.CanBeCanceled ? null : Predicate.Constant<Exception>(true);;)
for (;;)
{
T? task;
if ((task = TryPeekOrDequeue(out var expectedHead, out _, out var completed)) is not null && !completed)
if ((task = TryPeekOrDequeue(out var expectedHead, out _)) is not null && !task.IsCompleted)
{
await task.WaitAsync(token).SuspendException(filter ??= token.SuspendAllExceptCancellation).ConfigureAwait(false);
await task.WaitAsync(token).SuspendException(token, SuspendAllExceptCancellation).ConfigureAwait(false);

if (!TryDequeue(expectedHead, task))
continue;
Expand All @@ -311,12 +310,11 @@ public async ValueTask<T> DequeueAsync(CancellationToken token = default)
/// <returns>The enumerator over completed tasks.</returns>
public async IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken token)
{
for (var filter = token.CanBeCanceled ? null : Predicate.Constant<Exception>(true);
TryPeekOrDequeue(out var expectedHead, out _, out var completed) is { } task;)
while (TryPeekOrDequeue(out var expectedHead, out _) is { } task)
{
if (!completed)
if (!task.IsCompleted)
{
await task.WaitAsync(token).SuspendException(filter ??= token.SuspendAllExceptCancellation).ConfigureAwait(false);
await task.WaitAsync(token).SuspendException(token, SuspendAllExceptCancellation).ConfigureAwait(false);
if (!TryDequeue(expectedHead, task))
continue;
}
Expand All @@ -325,6 +323,9 @@ public async IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken token)
}
}

private static bool SuspendAllExceptCancellation(Exception e, CancellationToken token)
=> e is not OperationCanceledException canceledEx || token != canceledEx.CancellationToken;

/// <summary>
/// Clears the queue.
/// </summary>
Expand All @@ -343,10 +344,4 @@ public void Clear()
void IResettable.Reset() => Clear();

private sealed class Signal() : TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
}

file static class CancellationTokenExtensions
{
internal static bool SuspendAllExceptCancellation(this object token, Exception e)
=> e is not OperationCanceledException canceledEx || !canceledEx.CancellationToken.Equals(token);
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

Expand All @@ -15,7 +16,9 @@ internal SuspendedExceptionTaskAwaitable(ValueTask task)
=> this.task = task;

internal SuspendedExceptionTaskAwaitable(Task task)
=> this.task = new(task);
: this(new ValueTask(task))
{
}

internal bool ContinueOnCapturedContext
{
Expand Down Expand Up @@ -91,4 +94,97 @@ public void GetResult()
}
}
}
}

/// <summary>
/// Represents awaitable object that can suspend exception raised by the underlying task.
/// </summary>
/// <typeparam name="TArg">The type of the argument to be passed to the exception filter.</typeparam>
[StructLayout(LayoutKind.Auto)]
public readonly struct SuspendedExceptionTaskAwaitable<TArg>
{
private readonly TArg arg;
private readonly ValueTask task;
private readonly Func<Exception, TArg, bool> filter;

internal SuspendedExceptionTaskAwaitable(ValueTask task, TArg arg, Func<Exception, TArg, bool> filter)
{
Debug.Assert(filter is not null);

this.task = task;
this.arg = arg;
this.filter = filter;
}

internal SuspendedExceptionTaskAwaitable(Task task, TArg arg, Func<Exception, TArg, bool> filter)
: this(new ValueTask(task), arg, filter)
{
}

internal bool ContinueOnCapturedContext
{
get;
init;
}

/// <summary>
/// Configures an awaiter for this value.
/// </summary>
/// <param name="continueOnCapturedContext">
/// <see langword="true"/> to attempt to marshal the continuation back to the captured context;
/// otherwise, <see langword="false"/>.
/// </param>
/// <returns>The configured object.</returns>
public SuspendedExceptionTaskAwaitable<TArg> ConfigureAwait(bool continueOnCapturedContext)
=> this with { ContinueOnCapturedContext = continueOnCapturedContext };

/// <summary>
/// Gets the awaiter for this object.
/// </summary>
/// <returns>The awaiter for this object.</returns>
public Awaiter GetAwaiter() => new(task, arg, filter, ContinueOnCapturedContext);

/// <summary>
/// Represents the awaiter that suspends exception.
/// </summary>
[StructLayout(LayoutKind.Auto)]
public readonly struct Awaiter : ICriticalNotifyCompletion
{
private readonly ConfiguredValueTaskAwaitable.ConfiguredValueTaskAwaiter awaiter;
private readonly TArg arg;
private readonly Func<Exception, TArg, bool> filter;

internal Awaiter(in ValueTask task, TArg arg, Func<Exception, TArg, bool> filter, bool continueOnCapturedContext)
{
awaiter = task.ConfigureAwait(continueOnCapturedContext).GetAwaiter();
this.arg = arg;
this.filter = filter;
}

/// <summary>
/// Gets a value indicating that <see cref="SuspendedExceptionTaskAwaitable"/> has completed.
/// </summary>
public bool IsCompleted => awaiter.IsCompleted;

/// <inheritdoc/>
public void OnCompleted(Action action) => awaiter.OnCompleted(action);

/// <inheritdoc/>
public void UnsafeOnCompleted(Action action) => awaiter.UnsafeOnCompleted(action);

/// <summary>
/// Obtains a result of asynchronous operation, and suspends exception if needed.
/// </summary>
public void GetResult()
{
try
{
awaiter.GetResult();
}
catch (Exception e) when (filter.Invoke(e, arg))
{
// suspend exception
}
}
}
}
30 changes: 29 additions & 1 deletion src/DotNext/Threading/Tasks/Conversion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace DotNext.Threading.Tasks;

using SuspendedExceptionTaskAwaitable = Runtime.CompilerServices.SuspendedExceptionTaskAwaitable;
using Runtime.CompilerServices;

/// <summary>
/// Provides task result conversion methods.
Expand Down Expand Up @@ -84,4 +84,32 @@ public static SuspendedExceptionTaskAwaitable SuspendException(this Task task, P
/// <returns>The awaitable object that suspends exceptions according to the filter.</returns>
public static SuspendedExceptionTaskAwaitable SuspendException(this ValueTask task, Predicate<Exception>? filter = null)
=> new(task) { Filter = filter };

/// <summary>
/// Suspends the exception that can be raised by the task.
/// </summary>
/// <param name="task">The task.</param>
/// <param name="arg">The argument to be passed to the filter.</param>
/// <param name="filter">The filter of the exception to be suspended.</param>
/// <returns>The awaitable object that suspends exceptions according to the filter.</returns>
public static SuspendedExceptionTaskAwaitable<TArg> SuspendException<TArg>(this Task task, TArg arg, Func<Exception, TArg, bool> filter)
{
ArgumentNullException.ThrowIfNull(filter);

return new(task, arg, filter);
}

/// <summary>
/// Suspends the exception that can be raised by the task.
/// </summary>
/// <param name="task">The task.</param>
/// <param name="arg">The argument to be passed to the filter.</param>
/// <param name="filter">The filter of the exception to be suspended.</param>
/// <returns>The awaitable object that suspends exceptions according to the filter.</returns>
public static SuspendedExceptionTaskAwaitable<TArg> SuspendException<TArg>(this ValueTask task, TArg arg, Func<Exception, TArg, bool> filter)
{
ArgumentNullException.ThrowIfNull(filter);

return new(task, arg, filter);
}
}

0 comments on commit 9ab5679

Please sign in to comment.