Skip to content

WIP: Implement COPY … FROM STDIN #566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,100 @@ extension PostgresConnection {
}
}

// MARK: Copy from

fileprivate extension EventLoop {
/// If we are on the given event loop, execute `task` immediately. Otherwise schedule it for execution.
func executeImmediatelyOrSchedule(_ task: @Sendable @escaping () -> Void) {
if inEventLoop {
return task()
}
return execute(task)
}
Comment on lines +701 to +706
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to allocate even in the fast path. we don't like that.

}

/// A handle to send
public struct PostgresCopyFromWriter: Sendable {
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
private let context: NIOLoopBound<ChannelHandlerContext>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the handlerContext: ChannelHandlerContext? with ! in PostgresChannelHandler instead.

private let eventLoop: any EventLoop

struct NotWritableError: Error, CustomStringConvertible {
var description = "No data must be written to `PostgresCopyFromWriter` after it has sent a CopyDone or CopyFail message, ie. after the closure producing the copy data has finished"
}
Comment on lines +715 to +717
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't used anywhere.


init(handler: PostgresChannelHandler, context: ChannelHandlerContext, eventLoop: any EventLoop) {
self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop)
self.context = NIOLoopBound(context, eventLoop: eventLoop)
self.eventLoop = eventLoop
}

/// Send data for a `COPY ... FROM STDIN` operation to the backend.
public func write(_ byteBuffer: ByteBuffer) async throws {
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
eventLoop.executeImmediatelyOrSchedule {
self.channelHandler.value.copyData(byteBuffer, context: self.context.value, readyForMoreWriteContinuation: continuation)
}
}
}

/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to
/// the backend.
func done() async throws {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
eventLoop.executeImmediatelyOrSchedule {
self.channelHandler.value.sendCopyDone(continuation: continuation, context: self.context.value)
}
}
}

/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to
/// the backend.
func failed(error: any Error) async throws {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
eventLoop.executeImmediatelyOrSchedule {
self.channelHandler.value.sendCopyFailed(message: "\(error)", continuation: continuation, context: self.context.value)
}
}
}
}

extension PostgresConnection {
// TODO: Instead of an arbitrary query, make this a structured data structure.
// TODO: Write doc comment
public func copyFrom(
_ query: PostgresQuery,
writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

closure should be last argument.

logger: Logger,
file: String = #fileID,
line: Int = #line
) async throws {
var logger = logger
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
guard query.binds.count <= Int(UInt16.max) else {
throw PSQLError(code: .tooManyParameters, query: query)
}

let writer = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<PostgresCopyFromWriter, any Error>) in
let context = ExtendedQueryContext(
copyFromQuery: query,
triggerCopy: continuation,
logger: logger
)
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
}

do {
try await writeData(writer)
} catch {
try await writer.failed(error: error)
throw error
}
try await writer.done()
}

}

// MARK: PostgresDatabase conformance

extension PostgresConnection: PostgresDatabase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,23 @@ struct ConnectionStateMachine {
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendBindExecuteSync(PSQLExecuteStatement)
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
/// Fail a query's execution by throwing an error on the given continuation.
case failQueryContinuation(any AnyErrorContinuation, with: PSQLError, cleanupContext: CleanUpContext?)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existential, where Existential isn't necessary.

case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
case succeedQueryContinuation(CheckedContinuation<Void, any Error>)

/// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation.
///
/// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state
/// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling
/// `PostgresChannelHandler.copyDone` or ``PostgresChannelHandler.copyFailed``.
case triggerCopyData(CheckedContinuation<PostgresCopyFromWriter, any Error>)

/// Send a `CopyDone` message to the backend, followed by a `Sync`.
case sendCopyDone

/// Send a `CopyFail` message to the backend with the given error message.
case sendCopyFailed(message: String)

// --- streaming actions
// actions if query has requested next row but we are waiting for backend
Expand Down Expand Up @@ -587,6 +603,8 @@ struct ConnectionStateMachine {
switch queryContext.query {
case .executeStatement(_, let promise), .unnamed(_, let promise):
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
case .copyFrom(_, let triggerCopy):
return .failQueryContinuation(triggerCopy, with: psqlErrror, cleanupContext: nil)
case .prepareStatement(_, _, _, let promise):
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
}
Expand Down Expand Up @@ -660,6 +678,15 @@ struct ConnectionStateMachine {
preconditionFailure("Invalid state: \(self.state)")
}
}

mutating func channelWritabilityChanged(isWritable: Bool) {
guard case .extendedQuery(var queryState, let connectionContext) = state else {
return
}
self.state = .modifying // avoid CoW
queryState.channelWritabilityChanged(isWritable: isWritable)
self.state = .extendedQuery(queryState, connectionContext)
}

// MARK: - Running Queries -

Expand Down Expand Up @@ -751,6 +778,55 @@ struct ConnectionStateMachine {
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

mutating func copyInResponseReceived(
_ copyInResponse: PostgresBackendMessage.CopyInResponseMessage
) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse))
}

self.state = .modifying // avoid CoW
let action = queryState.copyInResponseReceived(copyInResponse)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

/// Assuming that the channel to the backend is not writable, wait for the write buffer to become writable again and
/// then resume `continuation`.
mutating func waitForWritableBuffer(continuation: CheckedContinuation<Void, Never>) {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
queryState.waitForWritableBuffer(continuation: continuation)
self.state = .extendedQuery(queryState, connectionContext)
}

/// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
mutating func sendCopyDone(continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.sendCopyDone(continuation: continuation)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

/// Put the state machine out of the copying mode and send a `CopyFail` message to the backend.
mutating func sendCopyFail(message: String, continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
}

self.state = .modifying // avoid CoW
let action = queryState.sendCopyFailed(message: message, continuation: continuation)
self.state = .extendedQuery(queryState, connectionContext)
return self.modify(with: action)
}

mutating func emptyQueryResponseReceived() -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
Expand Down Expand Up @@ -860,14 +936,21 @@ struct ConnectionStateMachine {
.forwardRows,
.forwardStreamComplete,
.wait,
.read:
.read,
.triggerCopyData,
.sendCopyDone,
.sendCopyFailed,
.succeedQueryContinuation:
preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)")

case .evaluateErrorAtConnectionLevel:
return .closeConnectionAndCleanup(cleanupContext)

case .failQuery(let queryContext, with: let error):
return .failQuery(queryContext, with: error, cleanupContext: cleanupContext)
case .failQuery(let promise, with: let error):
return .failQuery(promise, with: error, cleanupContext: cleanupContext)

case .failQueryContinuation(let continuation, with: let error):
return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext)

case .forwardStreamError(let error, let read):
return .forwardStreamError(error, read: read, cleanupContext: cleanupContext)
Expand Down Expand Up @@ -1038,8 +1121,19 @@ extension ConnectionStateMachine {
case .failQuery(let requestContext, with: let error):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
case .failQueryContinuation(let continuation, with: let error):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQueryContinuation(continuation, with: error, cleanupContext: cleanupContext)
case .succeedQuery(let requestContext, with: let result):
return .succeedQuery(requestContext, with: result)
case .succeedQueryContinuation(let continuation):
return .succeedQueryContinuation(continuation)
case .triggerCopyData(let triggerCopy):
return .triggerCopyData(triggerCopy)
case .sendCopyDone:
return .sendCopyDone
case .sendCopyFailed(message: let message):
return .sendCopyFailed(message: message)
case .forwardRows(let buffer):
return .forwardRows(buffer)
case .forwardStreamComplete(let buffer, let commandTag):
Expand Down
Loading