diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift index 149c0ff5ea..83e48cd229 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift @@ -14,6 +14,16 @@ import DequeModule +@usableFromInline +enum OutboundAction: Sendable where OutboundOut: Sendable { + /// Write value + case write(OutboundOut) + /// Write value and flush pipeline + case writeAndFlush(OutboundOut, EventLoopPromise) + /// flush writes to writer + case flush(EventLoopPromise) +} + /// A ``ChannelHandler`` that is used to transform the inbound portion of a NIO /// ``Channel`` into an asynchronous sequence that supports back-pressure. It's also used /// to write the outbound portion of a NIO ``Channel`` from Swift Concurrency with back-pressure @@ -77,7 +87,7 @@ internal final class NIOAsyncChannelHandler, NIOAsyncChannelHandlerWriterDelegate > @@ -372,7 +382,10 @@ struct NIOAsyncChannelHandlerProducerDelegate: @unchecked Sendable, NIOAsyncSequ @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @usableFromInline -struct NIOAsyncChannelHandlerWriterDelegate: NIOAsyncWriterSinkDelegate, @unchecked Sendable { +struct NIOAsyncChannelHandlerWriterDelegate: NIOAsyncWriterSinkDelegate, @unchecked Sendable { + @usableFromInline + typealias Element = OutboundAction + @usableFromInline let eventLoop: EventLoop @@ -386,7 +399,7 @@ struct NIOAsyncChannelHandlerWriterDelegate: NIOAsyncWriterSi let _didTerminate: ((any Error)?) -> Void @inlinable - init(handler: NIOAsyncChannelHandler) { + init(handler: NIOAsyncChannelHandler) { self.eventLoop = handler.eventLoop self._didYieldContentsOf = handler._didYield(sequence:) self._didYield = handler._didYield(element:) @@ -430,7 +443,7 @@ struct NIOAsyncChannelHandlerWriterDelegate: NIOAsyncWriterSi @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOAsyncChannelHandler { @inlinable - func _didYield(sequence: Deque) { + func _didYield(sequence: Deque>) { // This is always called from an async context, so we must loop-hop. // Because we always loop-hop, we're always at the top of a stack frame. As this // is the only source of writes for us, and as this channel handler doesn't implement @@ -438,16 +451,12 @@ extension NIOAsyncChannelHandler { // awkward re-entrancy protections NIO usually requires, and can safely just do an iterative // write. self.eventLoop.preconditionInEventLoop() - guard let context = self.context else { - // Already removed from the channel by now, we can stop. - return - } self._doOutboundWrites(context: context, writes: sequence) } @inlinable - func _didYield(element: OutboundOut) { + func _didYield(element: OutboundAction) { // This is always called from an async context, so we must loop-hop. // Because we always loop-hop, we're always at the top of a stack frame. As this // is the only source of writes for us, and as this channel handler doesn't implement @@ -455,10 +464,6 @@ extension NIOAsyncChannelHandler { // awkward re-entrancy protections NIO usually requires, and can safely just do an iterative // write. self.eventLoop.preconditionInEventLoop() - guard let context = self.context else { - // Already removed from the channel by now, we can stop. - return - } self._doOutboundWrite(context: context, write: element) } @@ -475,18 +480,64 @@ extension NIOAsyncChannelHandler { } @inlinable - func _doOutboundWrites(context: ChannelHandlerContext, writes: Deque) { - for write in writes { - context.write(Self.wrapOutboundOut(write), promise: nil) + func _doOutboundWrites(context: ChannelHandlerContext?, writes: Deque>) { + // write everything but the last item + for write in writes.dropLast() { + switch write { + case .write(let value), .writeAndFlush(let value, _): + guard let context = self.context else { + // Already removed from the channel by now, we can stop. + return + } + context.write(Self.wrapOutboundOut(value), promise: nil) + context.flush() + case .flush(let promise): + promise.succeed() + } + } + // write last item + switch writes.last { + case .write(let value): + guard let context = self.context else { + // Already removed from the channel by now, we can stop. + return + } + context.write(Self.wrapOutboundOut(value), promise: nil) + context.flush() + case .flush(let promise): + promise.succeed() + case .writeAndFlush(let value, let promise): + guard let context = self.context else { + // Already removed from the channel by now, we can stop. + promise.succeed() + return + } + context.writeAndFlush(Self.wrapOutboundOut(value), promise: promise) + case .none: + break } - - context.flush() } @inlinable - func _doOutboundWrite(context: ChannelHandlerContext, write: OutboundOut) { - context.write(Self.wrapOutboundOut(write), promise: nil) - context.flush() + func _doOutboundWrite(context: ChannelHandlerContext?, write: OutboundAction) { + switch write { + case .write(let value): + guard let context = self.context else { + // Already removed from the channel by now, we can stop. + return + } + context.write(Self.wrapOutboundOut(value), promise: nil) + context.flush() + case .flush(let promise): + promise.succeed() + case .writeAndFlush(let value, let promise): + guard let context = self.context else { + // Already removed from the channel by now, we can stop. + promise.succeed() + return + } + context.writeAndFlush(Self.wrapOutboundOut(value), promise: promise) + } } } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift index dfdeeb0fda..604f5eff6f 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift @@ -21,7 +21,7 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { @usableFromInline typealias _Writer = NIOAsyncWriter< - OutboundOut, + OutboundAction, NIOAsyncChannelHandlerWriterDelegate > @@ -66,7 +66,7 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { @usableFromInline enum Backing: Sendable { case asyncStream(AsyncStream.Continuation) - case writer(_Writer) + case writer(_Writer, EventLoop) } @usableFromInline @@ -93,7 +93,7 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { ) throws { eventLoop.preconditionInEventLoop() let writer = _Writer.makeWriter( - elementType: OutboundOut.self, + elementType: OutboundAction.self, isWritable: true, finishOnDeinit: closeOnDeinit, delegate: .init(handler: handler) @@ -102,7 +102,7 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { handler.sink = writer.sink handler.writer = writer.writer - self._backing = .writer(writer.writer) + self._backing = .writer(writer.writer, eventLoop) } @inlinable @@ -118,8 +118,23 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { switch self._backing { case .asyncStream(let continuation): continuation.yield(data) - case .writer(let writer): - try await writer.yield(data) + case .writer(let writer, _): + try await writer.yield(.write(data)) + } + } + + /// Send a write into the ``ChannelPipeline`` and flush it right away. + /// + /// This method suspends until the write has been written and flushed. + @inlinable + public func writeAndFlush(_ data: OutboundOut) async throws { + switch self._backing { + case .asyncStream(let continuation): + continuation.yield(data) + case .writer(let writer, let eventLoop): + try await self.withPromise(eventLoop: eventLoop) { promise in + try await writer.yield(.writeAndFlush(data, promise)) + } } } @@ -133,8 +148,26 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { for data in sequence { continuation.yield(data) } - case .writer(let writer): - try await writer.yield(contentsOf: sequence) + case .writer(let writer, _): + try await writer.yield(contentsOf: sequence.map { .write($0) }) + } + } + + /// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away. + /// + /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. + @inlinable + public func writeAndFlush(contentsOf sequence: Writes) async throws + where Writes.Element == OutboundOut { + switch self._backing { + case .asyncStream(let continuation): + for data in sequence { + continuation.yield(data) + } + case .writer(let writer, let eventLoop): + try await withPromise(eventLoop: eventLoop) { promise in + try await writer.yield(contentsOf: sequence.map { .writeAndFlush($0, promise) }) + } } } @@ -151,6 +184,16 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { } } + /// Ensure all writes to the writer have been read + @inlinable + public func flush() async throws { + if case .writer(let writer, let eventLoop) = self._backing { + try await self.withPromise(eventLoop: eventLoop) { promise in + try await writer.yield(.flush(promise)) + } + } + } + /// Finishes the writer. /// /// This might trigger a half closure if the ``NIOAsyncChannel`` was configured to support it. @@ -158,10 +201,24 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { switch self._backing { case .asyncStream(let continuation): continuation.finish() - case .writer(let writer): + case .writer(let writer, _): writer.finish() } } + + @usableFromInline + func withPromise( + eventLoop: EventLoop, + _ process: (EventLoopPromise) async throws -> Void + ) async throws { + let promise = eventLoop.makePromise(of: Void.self) + do { + try await process(promise) + try await promise.futureResult.get() + } catch { + promise.fail(error) + } + } } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index 3d21c2c7a9..581d5a4548 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -180,6 +180,63 @@ final class AsyncChannelTests: XCTestCase { } } + func testAllWritesAreWritten() async throws { + let channel = NIOAsyncTestingChannel() + let promise = channel.testingEventLoop.makePromise(of: Void.self) + let wrapped = try await channel.testingEventLoop.executeInContext { + try channel.pipeline.syncOperations.addHandler(DelayingChannelHandler(promise: promise)) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await wrapped.executeThenClose { inbound, outbound in + try await outbound.write("hello") + try await outbound.writeAndFlush("world") + } + } + group.addTask { + let firstRead = try await channel.waitForOutboundWrite(as: String.self) + let secondRead = try await channel.waitForOutboundWrite(as: String.self) + + XCTAssertEqual(firstRead, "hello") + XCTAssertEqual(secondRead, "world") + } + + // wait 50 milliseconds to ensure we are inside write and flush then + // trigger pipeline flush by succeeding promise in DelayingChannelHandler + try await Task.sleep(for: .milliseconds(50)) + promise.succeed() + } + } + + func testAllWritesInSequenceAreWritten() async throws { + let channel = NIOAsyncTestingChannel() + let promise = channel.testingEventLoop.makePromise(of: Void.self) + let wrapped = try await channel.testingEventLoop.executeInContext { + try channel.pipeline.syncOperations.addHandler(DelayingChannelHandler(promise: promise)) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await wrapped.executeThenClose { inbound, outbound in + try await outbound.writeAndFlush(contentsOf: ["hello", "world"]) + } + } + group.addTask { + let firstRead = try await channel.waitForOutboundWrite(as: String.self) + let secondRead = try await channel.waitForOutboundWrite(as: String.self) + + XCTAssertEqual(firstRead, "hello") + XCTAssertEqual(secondRead, "world") + } + + // wait 50 milliseconds to ensure we are inside write and flush then + // trigger pipeline flush by succeeding promise in DelayingChannelHandler + try await Task.sleep(for: .milliseconds(50)) + promise.succeed() + } + } + func testErrorsArePropagatedButAfterReads() async throws { let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { @@ -429,6 +486,17 @@ private final class CloseRecorder: ChannelOutboundHandler, @unchecked Sendable { } } +struct UnsafeContext: @unchecked Sendable { + private let _context: ChannelHandlerContext + var context: ChannelHandlerContext { + self._context.eventLoop.preconditionInEventLoop() + return _context + } + init(_ context: ChannelHandlerContext) { + self._context = context + } +} + private final class CloseSuppressor: ChannelOutboundHandler, RemovableChannelHandler, Sendable { typealias OutboundIn = Any @@ -438,6 +506,22 @@ private final class CloseSuppressor: ChannelOutboundHandler, RemovableChannelHan } } +private final class DelayingChannelHandler: ChannelOutboundHandler, RemovableChannelHandler, Sendable { + typealias OutboundIn = Any + typealias OutboundOut = Any + let waitPromise: EventLoopPromise + + init(promise: EventLoopPromise) { + self.waitPromise = promise + } + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let unsafeTransfer = UnsafeTransfer((context: context, data: data)) + self.waitPromise.futureResult.whenComplete { _ in + unsafeTransfer.wrappedValue.context.writeAndFlush(unsafeTransfer.wrappedValue.data, promise: promise) + } + } +} + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOAsyncTestingChannel { fileprivate func closeIgnoringSuppression() async throws {