Skip to content

Commit

Permalink
NIOThrowingAsyncSequenceProducer throws when cancelled (#2415)
Browse files Browse the repository at this point in the history
* NIOThrowingAsyncSequenceProducer throws when cancelled

* PR review
  • Loading branch information
fabianfett authored Apr 28, 2023
1 parent 5f8b064 commit d1690f8
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,21 @@ extension NIOThrowingAsyncSequenceProducer {
return nil
}

case .returnCancellationError:
self._lock.unlock()
// We have deprecated the generic Failure type in the public API and Failure should
// now be `Swift.Error`. However, if users have not migrated to the new API they could
// still use a custom generic Error type and this cast might fail.
// In addition, we use `NIOThrowingAsyncSequenceProducer` in the implementation of the
// non-throwing variant `NIOAsyncSequenceProducer` where `Failure` will be `Never` and
// this cast will fail as well.
// Everything is marked @inlinable and the Failure type is known at compile time,
// therefore this cast should be optimised away in release build.
if let error = CancellationError() as? Failure {
throw error
}
return nil

case .returnNil:
self._lock.unlock()
return nil
Expand Down Expand Up @@ -603,6 +618,9 @@ extension NIOThrowingAsyncSequenceProducer {
failure: Failure?
)

/// The state once a call to next has been cancelled. Cancel the source when entering this state.
case cancelled(iteratorInitialized: Bool)

/// The state once there can be no outstanding demand. This can happen if:
/// 1. The ``NIOThrowingAsyncSequenceProducer/AsyncIterator`` was deinited
/// 2. The underlying source finished and all buffered elements have been consumed
Expand Down Expand Up @@ -644,15 +662,17 @@ extension NIOThrowingAsyncSequenceProducer {
switch self._state {
case .initial(_, iteratorInitialized: false),
.streaming(_, _, _, _, iteratorInitialized: false),
.sourceFinished(_, iteratorInitialized: false, _):
.sourceFinished(_, iteratorInitialized: false, _),
.cancelled(iteratorInitialized: false):
// No iterator was created so we can transition to finished right away.
self._state = .finished(iteratorInitialized: false)

return .callDidTerminate

case .initial(_, iteratorInitialized: true),
.streaming(_, _, _, _, iteratorInitialized: true),
.sourceFinished(_, iteratorInitialized: true, _):
.sourceFinished(_, iteratorInitialized: true, _),
.cancelled(iteratorInitialized: true):
// An iterator was created and we deinited the sequence.
// This is an expected pattern and we just continue on normal.
return .none
Expand All @@ -673,6 +693,7 @@ extension NIOThrowingAsyncSequenceProducer {
case .initial(_, iteratorInitialized: true),
.streaming(_, _, _, _, iteratorInitialized: true),
.sourceFinished(_, iteratorInitialized: true, _),
.cancelled(iteratorInitialized: true),
.finished(iteratorInitialized: true):
// Our sequence is a unicast sequence and does not support multiple AsyncIterator's
fatalError("NIOThrowingAsyncSequenceProducer allows only a single AsyncIterator to be created")
Expand All @@ -694,6 +715,10 @@ extension NIOThrowingAsyncSequenceProducer {
iteratorInitialized: true
)

case .cancelled(iteratorInitialized: false):
// An iterator needs to be initialized before we can be cancelled.
preconditionFailure("Internal inconsistency")

case .sourceFinished(let buffer, false, let failure):
// The first and only iterator was initialized.
self._state = .sourceFinished(
Expand Down Expand Up @@ -727,13 +752,15 @@ extension NIOThrowingAsyncSequenceProducer {
switch self._state {
case .initial(_, iteratorInitialized: false),
.streaming(_, _, _, _, iteratorInitialized: false),
.sourceFinished(_, iteratorInitialized: false, _):
.sourceFinished(_, iteratorInitialized: false, _),
.cancelled(iteratorInitialized: false):
// An iterator needs to be initialized before it can be deinitialized.
preconditionFailure("Internal inconsistency")

case .initial(_, iteratorInitialized: true),
.streaming(_, _, _, _, iteratorInitialized: true),
.sourceFinished(_, iteratorInitialized: true, _):
.sourceFinished(_, iteratorInitialized: true, _),
.cancelled(iteratorInitialized: true):
// An iterator was created and deinited. Since we only support
// a single iterator we can now transition to finish and inform the delegate.
self._state = .finished(iteratorInitialized: true)
Expand Down Expand Up @@ -861,7 +888,7 @@ extension NIOThrowingAsyncSequenceProducer {

return .init(shouldProduceMore: shouldProduceMore)

case .sourceFinished, .finished:
case .cancelled, .sourceFinished, .finished:
// If the source has finished we are dropping the elements.
return .returnDropped

Expand Down Expand Up @@ -913,7 +940,7 @@ extension NIOThrowingAsyncSequenceProducer {

return .none

case .sourceFinished, .finished:
case .cancelled, .sourceFinished, .finished:
// If the source has finished, finishing again has no effect.
return .none

Expand Down Expand Up @@ -968,11 +995,14 @@ extension NIOThrowingAsyncSequenceProducer {
return .resumeContinuationWithCancellationErrorAndCallDidTerminate(continuation)

case .streaming(_, _, continuation: .none, _, let iteratorInitialized):
self._state = .finished(iteratorInitialized: iteratorInitialized)
// We may have elements in the buffer, which is why we have no continuation
// waiting. We must store the cancellation error to hand it out on the next
// next() call.
self._state = .cancelled(iteratorInitialized: iteratorInitialized)

return .callDidTerminate

case .sourceFinished, .finished:
case .cancelled, .sourceFinished, .finished:
// If the source has finished, finishing again has no effect.
return .none

Expand All @@ -992,6 +1022,8 @@ extension NIOThrowingAsyncSequenceProducer {
/// Indicates that the `Failure` should be returned to the caller and
/// that ``NIOAsyncSequenceProducerDelegate/didTerminate()`` should be called.
case returnFailureAndCallDidTerminate(Failure?)
/// Indicates that the next call to AsyncSequence got cancelled
case returnCancellationError
/// Indicates that the `nil` should be returned to the caller.
case returnNil
/// Indicates that the `Task` of the caller should be suspended.
Expand Down Expand Up @@ -1075,6 +1107,10 @@ extension NIOThrowingAsyncSequenceProducer {
return .returnFailureAndCallDidTerminate(failure)
}

case .cancelled(let iteratorInitialized):
self._state = .finished(iteratorInitialized: iteratorInitialized)
return .returnCancellationError

case .finished:
return .returnNil

Expand Down Expand Up @@ -1119,7 +1155,7 @@ extension NIOThrowingAsyncSequenceProducer {
return .none
}

case .streaming(_, _, .some(_), _, _), .sourceFinished, .finished:
case .streaming(_, _, .some(_), _, _), .sourceFinished, .finished, .cancelled:
preconditionFailure("This should have already been handled by `next()`")

case .modifying:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,36 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase {

XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate])
}

func testIteratorThrows_whenCancelled() async {
_ = self.source.yield(contentsOf: Array(0..<100))
await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
var counter = 0
guard let sequence = self.sequence else {
return XCTFail("Expected to have an AsyncSequence")
}

do {
for try await next in sequence {
XCTAssertEqual(next, counter)
counter += 1
}
XCTFail("Expected that this throws")
} catch is CancellationError {
// expected
} catch {
XCTFail("Unexpected error: \(error)")
}

XCTAssertLessThan(counter, 100)
}

group.cancelAll()
}

XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate])
}
}

// This is needed until async let is supported to be used in autoclosures
Expand Down

0 comments on commit d1690f8

Please sign in to comment.