diff --git a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift index b4afa657cf..3ffefc6881 100644 --- a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift +++ b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift @@ -511,6 +511,21 @@ extension NIOThrowingAsyncSequenceProducer { return nil } + case .returnCancellationError: + self._lock.unlock() + // We have deprecated the generic Failure type in the public API and Failure should + // now be `Swift.Error`. However, if users have not migrated to the new API they could + // still use a custom generic Error type and this cast might fail. + // In addition, we use `NIOThrowingAsyncSequenceProducer` in the implementation of the + // non-throwing variant `NIOAsyncSequenceProducer` where `Failure` will be `Never` and + // this cast will fail as well. + // Everything is marked @inlinable and the Failure type is known at compile time, + // therefore this cast should be optimised away in release build. + if let error = CancellationError() as? Failure { + throw error + } + return nil + case .returnNil: self._lock.unlock() return nil @@ -603,6 +618,9 @@ extension NIOThrowingAsyncSequenceProducer { failure: Failure? ) + /// The state once a call to next has been cancelled. Cancel the source when entering this state. + case cancelled(iteratorInitialized: Bool) + /// The state once there can be no outstanding demand. This can happen if: /// 1. The ``NIOThrowingAsyncSequenceProducer/AsyncIterator`` was deinited /// 2. The underlying source finished and all buffered elements have been consumed @@ -644,7 +662,8 @@ extension NIOThrowingAsyncSequenceProducer { switch self._state { case .initial(_, iteratorInitialized: false), .streaming(_, _, _, _, iteratorInitialized: false), - .sourceFinished(_, iteratorInitialized: false, _): + .sourceFinished(_, iteratorInitialized: false, _), + .cancelled(iteratorInitialized: false): // No iterator was created so we can transition to finished right away. self._state = .finished(iteratorInitialized: false) @@ -652,7 +671,8 @@ extension NIOThrowingAsyncSequenceProducer { case .initial(_, iteratorInitialized: true), .streaming(_, _, _, _, iteratorInitialized: true), - .sourceFinished(_, iteratorInitialized: true, _): + .sourceFinished(_, iteratorInitialized: true, _), + .cancelled(iteratorInitialized: true): // An iterator was created and we deinited the sequence. // This is an expected pattern and we just continue on normal. return .none @@ -673,6 +693,7 @@ extension NIOThrowingAsyncSequenceProducer { case .initial(_, iteratorInitialized: true), .streaming(_, _, _, _, iteratorInitialized: true), .sourceFinished(_, iteratorInitialized: true, _), + .cancelled(iteratorInitialized: true), .finished(iteratorInitialized: true): // Our sequence is a unicast sequence and does not support multiple AsyncIterator's fatalError("NIOThrowingAsyncSequenceProducer allows only a single AsyncIterator to be created") @@ -694,6 +715,10 @@ extension NIOThrowingAsyncSequenceProducer { iteratorInitialized: true ) + case .cancelled(iteratorInitialized: false): + // An iterator needs to be initialized before we can be cancelled. + preconditionFailure("Internal inconsistency") + case .sourceFinished(let buffer, false, let failure): // The first and only iterator was initialized. self._state = .sourceFinished( @@ -727,13 +752,15 @@ extension NIOThrowingAsyncSequenceProducer { switch self._state { case .initial(_, iteratorInitialized: false), .streaming(_, _, _, _, iteratorInitialized: false), - .sourceFinished(_, iteratorInitialized: false, _): + .sourceFinished(_, iteratorInitialized: false, _), + .cancelled(iteratorInitialized: false): // An iterator needs to be initialized before it can be deinitialized. preconditionFailure("Internal inconsistency") case .initial(_, iteratorInitialized: true), .streaming(_, _, _, _, iteratorInitialized: true), - .sourceFinished(_, iteratorInitialized: true, _): + .sourceFinished(_, iteratorInitialized: true, _), + .cancelled(iteratorInitialized: true): // An iterator was created and deinited. Since we only support // a single iterator we can now transition to finish and inform the delegate. self._state = .finished(iteratorInitialized: true) @@ -861,7 +888,7 @@ extension NIOThrowingAsyncSequenceProducer { return .init(shouldProduceMore: shouldProduceMore) - case .sourceFinished, .finished: + case .cancelled, .sourceFinished, .finished: // If the source has finished we are dropping the elements. return .returnDropped @@ -913,7 +940,7 @@ extension NIOThrowingAsyncSequenceProducer { return .none - case .sourceFinished, .finished: + case .cancelled, .sourceFinished, .finished: // If the source has finished, finishing again has no effect. return .none @@ -968,11 +995,14 @@ extension NIOThrowingAsyncSequenceProducer { return .resumeContinuationWithCancellationErrorAndCallDidTerminate(continuation) case .streaming(_, _, continuation: .none, _, let iteratorInitialized): - self._state = .finished(iteratorInitialized: iteratorInitialized) + // We may have elements in the buffer, which is why we have no continuation + // waiting. We must store the cancellation error to hand it out on the next + // next() call. + self._state = .cancelled(iteratorInitialized: iteratorInitialized) return .callDidTerminate - case .sourceFinished, .finished: + case .cancelled, .sourceFinished, .finished: // If the source has finished, finishing again has no effect. return .none @@ -992,6 +1022,8 @@ extension NIOThrowingAsyncSequenceProducer { /// Indicates that the `Failure` should be returned to the caller and /// that ``NIOAsyncSequenceProducerDelegate/didTerminate()`` should be called. case returnFailureAndCallDidTerminate(Failure?) + /// Indicates that the next call to AsyncSequence got cancelled + case returnCancellationError /// Indicates that the `nil` should be returned to the caller. case returnNil /// Indicates that the `Task` of the caller should be suspended. @@ -1075,6 +1107,10 @@ extension NIOThrowingAsyncSequenceProducer { return .returnFailureAndCallDidTerminate(failure) } + case .cancelled(let iteratorInitialized): + self._state = .finished(iteratorInitialized: iteratorInitialized) + return .returnCancellationError + case .finished: return .returnNil @@ -1119,7 +1155,7 @@ extension NIOThrowingAsyncSequenceProducer { return .none } - case .streaming(_, _, .some(_), _, _), .sourceFinished, .finished: + case .streaming(_, _, .some(_), _, _), .sourceFinished, .finished, .cancelled: preconditionFailure("This should have already been handled by `next()`") case .modifying: diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift index 49f5d78445..ebd38f87db 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift @@ -743,6 +743,36 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) } + + func testIteratorThrows_whenCancelled() async { + _ = self.source.yield(contentsOf: Array(0..<100)) + await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + var counter = 0 + guard let sequence = self.sequence else { + return XCTFail("Expected to have an AsyncSequence") + } + + do { + for try await next in sequence { + XCTAssertEqual(next, counter) + counter += 1 + } + XCTFail("Expected that this throws") + } catch is CancellationError { + // expected + } catch { + XCTFail("Unexpected error: \(error)") + } + + XCTAssertLessThan(counter, 100) + } + + group.cancelAll() + } + + XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) + } } // This is needed until async let is supported to be used in autoclosures