Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensuring NIOAsyncChannel flushes all writes before closing #3049

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
93 changes: 72 additions & 21 deletions Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

import DequeModule

@usableFromInline
enum OutboundAction<OutboundOut>: Sendable where OutboundOut: Sendable {
/// Write value
case write(OutboundOut)
/// Write value and flush pipeline
case writeAndFlush(OutboundOut, EventLoopPromise<Void>)
/// flush writes to writer
case flush(EventLoopPromise<Void>)
}

/// A ``ChannelHandler`` that is used to transform the inbound portion of a NIO
/// ``Channel`` into an asynchronous sequence that supports back-pressure. It's also used
/// to write the outbound portion of a NIO ``Channel`` from Swift Concurrency with back-pressure
Expand Down Expand Up @@ -77,7 +87,7 @@ internal final class NIOAsyncChannelHandler<InboundIn: Sendable, ProducerElement

@usableFromInline
typealias Writer = NIOAsyncWriter<
OutboundOut,
OutboundAction<OutboundOut>,
NIOAsyncChannelHandlerWriterDelegate<OutboundOut>
>

Expand Down Expand Up @@ -372,7 +382,10 @@ struct NIOAsyncChannelHandlerProducerDelegate: @unchecked Sendable, NIOAsyncSequ

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@usableFromInline
struct NIOAsyncChannelHandlerWriterDelegate<Element: Sendable>: NIOAsyncWriterSinkDelegate, @unchecked Sendable {
struct NIOAsyncChannelHandlerWriterDelegate<OutboundOut: Sendable>: NIOAsyncWriterSinkDelegate, @unchecked Sendable {
@usableFromInline
typealias Element = OutboundAction<OutboundOut>

@usableFromInline
let eventLoop: EventLoop

Expand All @@ -386,7 +399,7 @@ struct NIOAsyncChannelHandlerWriterDelegate<Element: Sendable>: NIOAsyncWriterSi
let _didTerminate: ((any Error)?) -> Void

@inlinable
init<InboundIn, ProducerElement>(handler: NIOAsyncChannelHandler<InboundIn, ProducerElement, Element>) {
init<InboundIn, ProducerElement>(handler: NIOAsyncChannelHandler<InboundIn, ProducerElement, OutboundOut>) {
self.eventLoop = handler.eventLoop
self._didYieldContentsOf = handler._didYield(sequence:)
self._didYield = handler._didYield(element:)
Expand Down Expand Up @@ -430,35 +443,27 @@ struct NIOAsyncChannelHandlerWriterDelegate<Element: Sendable>: NIOAsyncWriterSi
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension NIOAsyncChannelHandler {
@inlinable
func _didYield(sequence: Deque<OutboundOut>) {
func _didYield(sequence: Deque<OutboundAction<OutboundOut>>) {
// This is always called from an async context, so we must loop-hop.
// Because we always loop-hop, we're always at the top of a stack frame. As this
// is the only source of writes for us, and as this channel handler doesn't implement
// func write(), we cannot possibly re-entrantly write. That means we can skip many of the
// awkward re-entrancy protections NIO usually requires, and can safely just do an iterative
// write.
self.eventLoop.preconditionInEventLoop()
guard let context = self.context else {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been moved into _doOutboundWrites as we need to complete promises even if the channel handler is no longer there.

// Already removed from the channel by now, we can stop.
return
}

self._doOutboundWrites(context: context, writes: sequence)
}

@inlinable
func _didYield(element: OutboundOut) {
func _didYield(element: OutboundAction<OutboundOut>) {
// This is always called from an async context, so we must loop-hop.
// Because we always loop-hop, we're always at the top of a stack frame. As this
// is the only source of writes for us, and as this channel handler doesn't implement
// func write(), we cannot possibly re-entrantly write. That means we can skip many of the
// awkward re-entrancy protections NIO usually requires, and can safely just do an iterative
// write.
self.eventLoop.preconditionInEventLoop()
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been moved into _doOutboundWrites as we need to complete promises even if the channel handler is no longer there.

return
}

self._doOutboundWrite(context: context, write: element)
}
Expand All @@ -475,18 +480,64 @@ extension NIOAsyncChannelHandler {
}

@inlinable
func _doOutboundWrites(context: ChannelHandlerContext, writes: Deque<OutboundOut>) {
for write in writes {
context.write(Self.wrapOutboundOut(write), promise: nil)
func _doOutboundWrites(context: ChannelHandlerContext?, writes: Deque<OutboundAction<OutboundOut>>) {
// write everything but the last item
for write in writes.dropLast() {
switch write {
case .write(let value), .writeAndFlush(let value, _):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
return
}
context.write(Self.wrapOutboundOut(value), promise: nil)
context.flush()
case .flush(let promise):
promise.succeed()
}
}
// write last item
switch writes.last {
case .write(let value):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
return
}
context.write(Self.wrapOutboundOut(value), promise: nil)
context.flush()
case .flush(let promise):
promise.succeed()
case .writeAndFlush(let value, let promise):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
promise.succeed()
return
}
context.writeAndFlush(Self.wrapOutboundOut(value), promise: promise)
case .none:
break
}

context.flush()
}

@inlinable
func _doOutboundWrite(context: ChannelHandlerContext, write: OutboundOut) {
context.write(Self.wrapOutboundOut(write), promise: nil)
context.flush()
func _doOutboundWrite(context: ChannelHandlerContext?, write: OutboundAction<OutboundOut>) {
switch write {
case .write(let value):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
return
}
context.write(Self.wrapOutboundOut(value), promise: nil)
context.flush()
case .flush(let promise):
promise.succeed()
case .writeAndFlush(let value, let promise):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
promise.succeed()
return
}
context.writeAndFlush(Self.wrapOutboundOut(value), promise: promise)
}
}
}

Expand Down
75 changes: 66 additions & 9 deletions Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
@usableFromInline
typealias _Writer = NIOAsyncWriter<
OutboundOut,
OutboundAction<OutboundOut>,
NIOAsyncChannelHandlerWriterDelegate<OutboundOut>
>

Expand Down Expand Up @@ -66,7 +66,7 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
@usableFromInline
enum Backing: Sendable {
case asyncStream(AsyncStream<OutboundOut>.Continuation)
case writer(_Writer)
case writer(_Writer, EventLoop)
}

@usableFromInline
Expand All @@ -93,7 +93,7 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
) throws {
eventLoop.preconditionInEventLoop()
let writer = _Writer.makeWriter(
elementType: OutboundOut.self,
elementType: OutboundAction<OutboundOut>.self,
isWritable: true,
finishOnDeinit: closeOnDeinit,
delegate: .init(handler: handler)
Expand All @@ -102,7 +102,7 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
handler.sink = writer.sink
handler.writer = writer.writer

self._backing = .writer(writer.writer)
self._backing = .writer(writer.writer, eventLoop)
}

@inlinable
Expand All @@ -118,8 +118,23 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
switch self._backing {
case .asyncStream(let continuation):
continuation.yield(data)
case .writer(let writer):
try await writer.yield(data)
case .writer(let writer, _):
try await writer.yield(.write(data))
}
}

/// Send a write into the ``ChannelPipeline`` and flush it right away.
///
/// This method suspends until the write has been written and flushed.
@inlinable
public func writeAndFlush(_ data: OutboundOut) async throws {
switch self._backing {
case .asyncStream(let continuation):
continuation.yield(data)
case .writer(let writer, let eventLoop):
try await self.withPromise(eventLoop: eventLoop) { promise in
try await writer.yield(.writeAndFlush(data, promise))
}
}
}

Expand All @@ -133,8 +148,26 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
for data in sequence {
continuation.yield(data)
}
case .writer(let writer):
try await writer.yield(contentsOf: sequence)
case .writer(let writer, _):
try await writer.yield(contentsOf: sequence.map { .write($0) })
}
}

/// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away.
///
/// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again.
@inlinable
public func writeAndFlush<Writes: Sequence>(contentsOf sequence: Writes) async throws
where Writes.Element == OutboundOut {
switch self._backing {
case .asyncStream(let continuation):
for data in sequence {
continuation.yield(data)
}
case .writer(let writer, let eventLoop):
try await withPromise(eventLoop: eventLoop) { promise in
try await writer.yield(contentsOf: sequence.map { .writeAndFlush($0, promise) })
}
}
}

Expand All @@ -151,17 +184,41 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
}
}

/// Ensure all writes to the writer have been read
@inlinable
public func flush() async throws {
if case .writer(let writer, let eventLoop) = self._backing {
try await self.withPromise(eventLoop: eventLoop) { promise in
try await writer.yield(.flush(promise))
}
}
}

/// Finishes the writer.
///
/// This might trigger a half closure if the ``NIOAsyncChannel`` was configured to support it.
public func finish() {
switch self._backing {
case .asyncStream(let continuation):
continuation.finish()
case .writer(let writer):
case .writer(let writer, _):
writer.finish()
}
}

@usableFromInline
func withPromise(
eventLoop: EventLoop,
_ process: (EventLoopPromise<Void>) async throws -> Void
) async throws {
let promise = eventLoop.makePromise(of: Void.self)
do {
try await process(promise)
try await promise.futureResult.get()
} catch {
promise.fail(error)
}
}
}

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
Expand Down
84 changes: 84 additions & 0 deletions Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,63 @@ final class AsyncChannelTests: XCTestCase {
}
}

func testAllWritesAreWritten() async throws {
let channel = NIOAsyncTestingChannel()
let promise = channel.testingEventLoop.makePromise(of: Void.self)
let wrapped = try await channel.testingEventLoop.executeInContext {
try channel.pipeline.syncOperations.addHandler(DelayingChannelHandler(promise: promise))
return try NIOAsyncChannel<Never, String>(wrappingChannelSynchronously: channel)
}
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await wrapped.executeThenClose { inbound, outbound in
try await outbound.write("hello")
try await outbound.writeAndFlush("world")
}
}
group.addTask {
let firstRead = try await channel.waitForOutboundWrite(as: String.self)
let secondRead = try await channel.waitForOutboundWrite(as: String.self)

XCTAssertEqual(firstRead, "hello")
XCTAssertEqual(secondRead, "world")
}

// wait 50 milliseconds to ensure we are inside write and flush then
// trigger pipeline flush by succeeding promise in DelayingChannelHandler
try await Task.sleep(for: .milliseconds(50))
promise.succeed()
}
}

func testAllWritesInSequenceAreWritten() async throws {
let channel = NIOAsyncTestingChannel()
let promise = channel.testingEventLoop.makePromise(of: Void.self)
let wrapped = try await channel.testingEventLoop.executeInContext {
try channel.pipeline.syncOperations.addHandler(DelayingChannelHandler(promise: promise))
return try NIOAsyncChannel<Never, String>(wrappingChannelSynchronously: channel)
}
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await wrapped.executeThenClose { inbound, outbound in
try await outbound.writeAndFlush(contentsOf: ["hello", "world"])
}
}
group.addTask {
let firstRead = try await channel.waitForOutboundWrite(as: String.self)
let secondRead = try await channel.waitForOutboundWrite(as: String.self)

XCTAssertEqual(firstRead, "hello")
XCTAssertEqual(secondRead, "world")
}

// wait 50 milliseconds to ensure we are inside write and flush then
// trigger pipeline flush by succeeding promise in DelayingChannelHandler
try await Task.sleep(for: .milliseconds(50))
promise.succeed()
}
}

func testErrorsArePropagatedButAfterReads() async throws {
let channel = NIOAsyncTestingChannel()
let wrapped = try await channel.testingEventLoop.executeInContext {
Expand Down Expand Up @@ -429,6 +486,17 @@ private final class CloseRecorder: ChannelOutboundHandler, @unchecked Sendable {
}
}

struct UnsafeContext: @unchecked Sendable {
private let _context: ChannelHandlerContext
var context: ChannelHandlerContext {
self._context.eventLoop.preconditionInEventLoop()
return _context
}
init(_ context: ChannelHandlerContext) {
self._context = context
}
}

private final class CloseSuppressor: ChannelOutboundHandler, RemovableChannelHandler, Sendable {
typealias OutboundIn = Any

Expand All @@ -438,6 +506,22 @@ private final class CloseSuppressor: ChannelOutboundHandler, RemovableChannelHan
}
}

private final class DelayingChannelHandler: ChannelOutboundHandler, RemovableChannelHandler, Sendable {
typealias OutboundIn = Any
typealias OutboundOut = Any
let waitPromise: EventLoopPromise<Void>

init(promise: EventLoopPromise<Void>) {
self.waitPromise = promise
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let unsafeTransfer = UnsafeTransfer((context: context, data: data))
self.waitPromise.futureResult.whenComplete { _ in
unsafeTransfer.wrappedValue.context.writeAndFlush(unsafeTransfer.wrappedValue.data, promise: promise)
}
}
}

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension NIOAsyncTestingChannel {
fileprivate func closeIgnoringSuppression() async throws {
Expand Down