Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose query metadata #504

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.DS_Store
/.build
/.index-build
/Packages
/*.xcodeproj
DerivedData
Expand Down
81 changes: 81 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,48 @@ extension PostgresConnection {
}
}

// use this for queries where you want to consume the rows.
// we can use the `consume` scope to better ensure structured concurrency when consuming the rows.
public func query<Result>(
_ query: PostgresQuery,
logger: Logger,
file: String = #fileID,
line: Int = #line,
_ consume: (PostgresRowSequence) async throws -> Result
) async throws -> (Result, PostgresQueryMetadata) {
var logger = logger
logger[postgresMetadataKey: .connectionID] = "\(self.id)"

guard query.binds.count <= Int(UInt16.max) else {
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
}
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let context = ExtendedQueryContext(
query: query,
logger: logger,
promise: promise
)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)

do {
let (rowStream, rowSequence) = try await promise.futureResult.map { rowStream in
(rowStream, rowStream.asyncSequence())
}.get()
let result = try await consume(rowSequence)
try await rowStream.drain().get()
guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else {
throw PSQLError.invalidCommandTag(rowStream.commandTag)
}
return (result, metadata)
} catch var error as PSQLError {
error.file = file
error.line = line
error.query = query
throw error // rethrow with more metadata
}
}

/// Start listening for a channel
public func listen(_ channel: String) async throws -> PostgresNotificationSequence {
let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed)
Expand Down Expand Up @@ -531,6 +573,45 @@ extension PostgresConnection {
}
}

// use this for queries where you want to consume the rows.
// we can use the `consume` scope to better ensure structured concurrency when consuming the rows.
@discardableResult
public func execute(
_ query: PostgresQuery,
logger: Logger,
file: String = #fileID,
line: Int = #line
) async throws -> PostgresQueryMetadata {
var logger = logger
logger[postgresMetadataKey: .connectionID] = "\(self.id)"

guard query.binds.count <= Int(UInt16.max) else {
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
}
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let context = ExtendedQueryContext(
query: query,
logger: logger,
promise: promise
)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)

do {
let rowStream = try await promise.futureResult.get()
try await rowStream.drain().get()
guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else {
throw PSQLError.invalidCommandTag(rowStream.commandTag)
}
return metadata
} catch var error as PSQLError {
error.file = file
error.line = line
error.query = query
throw error // rethrow with more metadata
}
}

#if compiler(>=6.0)
/// Puts the connection into an open transaction state, for the provided `closure`'s lifetime.
///
Expand Down
58 changes: 57 additions & 1 deletion Sources/PostgresNIO/New/PSQLRowStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,63 @@ final class PSQLRowStream: @unchecked Sendable {
return self.eventLoop.makeFailedFuture(error)
}
}


// MARK: Drain on EventLoop

func drain() -> EventLoopFuture<Void> {
if self.eventLoop.inEventLoop {
return self.drain0()
} else {
return self.eventLoop.flatSubmit {
self.drain0()
}
}
}

private func drain0() -> EventLoopFuture<Void> {
self.eventLoop.preconditionInEventLoop()

switch self.downstreamState {
case .waitingForConsumer(let bufferState):
switch bufferState {
case .streaming(var buffer, let dataSource):
let promise = self.eventLoop.makePromise(of: Void.self)

buffer.removeAll()
self.downstreamState = .iteratingRows(onRow: { _ in }, promise, dataSource)
// immediately request more
dataSource.request(for: self)

return promise.futureResult

case .finished(_, let summary):
self.downstreamState = .consumed(.success(summary))
return self.eventLoop.makeSucceededVoidFuture()

case .failure(let error):
self.downstreamState = .consumed(.failure(error))
return self.eventLoop.makeFailedFuture(error)
}
case .asyncSequence(let consumer, let dataSource, _):
consumer.finish()

let promise = self.eventLoop.makePromise(of: Void.self)

self.downstreamState = .iteratingRows(onRow: { _ in }, promise, dataSource)
// immediately request more
dataSource.request(for: self)

return promise.futureResult
case .consumed(.success):
// already drained
return self.eventLoop.makeSucceededVoidFuture()
case .consumed(let .failure(error)):
return self.eventLoop.makeFailedFuture(error)
default:
preconditionFailure("Invalid state: \(self.downstreamState)")
}
}

internal func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) {
self.logger.debug("Notice Received", metadata: [
.notice: "\(notice)"
Expand Down
2 changes: 2 additions & 0 deletions Sources/PostgresNIO/New/PostgresRowSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ extension PostgresRowSequence {
extension PostgresRowSequence.AsyncIterator: Sendable {}

extension PostgresRowSequence {
/// Collects all rows into an array.
/// - Returns: The rows.
public func collect() async throws -> [PostgresRow] {
var result = [PostgresRow]()
for try await row in self {
Expand Down
103 changes: 95 additions & 8 deletions Tests/IntegrationTests/AsyncTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,98 @@ final class AsyncPostgresConnectionTests: XCTestCase {
}
}

func testSelect10kRowsWithMetadata() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

let start = 1
let end = 10000

try await withTestConnection(on: eventLoop) { connection in
let (result, metadata) = try await connection.query(
"SELECT generate_series(\(start), \(end));",
logger: .psqlTest
) { rows in
var counter = 0
for try await row in rows {
let element = try row.decode(Int.self)
XCTAssertEqual(element, counter + 1)
counter += 1
}
return counter
}

XCTAssertEqual(metadata.command, "SELECT")
XCTAssertEqual(metadata.oid, nil)
XCTAssertEqual(metadata.rows, end)

XCTAssertEqual(result, end)
}
}

func testSelectRowsWithMetadataNotConsumedAtAll() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

let start = 1
let end = 10000

try await withTestConnection(on: eventLoop) { connection in
let (_, metadata) = try await connection.query(
"SELECT generate_series(\(start), \(end));",
logger: .psqlTest
) { _ in }

XCTAssertEqual(metadata.command, "SELECT")
XCTAssertEqual(metadata.oid, nil)
XCTAssertEqual(metadata.rows, end)
}
}

func testSelectRowsWithMetadataNotFullyConsumed() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

try await withTestConnection(on: eventLoop) { connection in
do {
_ = try await connection.query(
"SELECT generate_series(1, 10000);",
logger: .psqlTest
) { rows in
for try await _ in rows { break }
}
// This path is also fine
} catch is CancellationError {
// Expected
} catch {
XCTFail("Expected 'CancellationError', got: \(String(reflecting: error))")
}
}
}

func testExecuteRowsWithMetadata() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

let start = 1
let end = 10000

try await withTestConnection(on: eventLoop) { connection in
let metadata = try await connection.execute(
"SELECT generate_series(\(start), \(end));",
logger: .psqlTest
)

XCTAssertEqual(metadata.command, "SELECT")
XCTAssertEqual(metadata.oid, nil)
XCTAssertEqual(metadata.rows, end)
}
}

func testSelectActiveConnection() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
Expand Down Expand Up @@ -207,7 +299,7 @@ final class AsyncPostgresConnectionTests: XCTestCase {

try await withTestConnection(on: eventLoop) { connection in
// Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257
// Max columns limit is 1664, so we will only make 5 * 257 columns which is less
// Max columns limit appears to be ~1600, so we will only make 5 * 257 columns which is less
// Then we will insert 3 * 17 rows
// In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings
// If the test is successful, it means Postgres supports UInt16.max bindings
Expand Down Expand Up @@ -241,13 +333,8 @@ final class AsyncPostgresConnectionTests: XCTestCase {
unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)",
binds: binds
)
try await connection.query(insertionQuery, logger: .psqlTest)

let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1")
let countRows = try await connection.query(countQuery, logger: .psqlTest)
var countIterator = countRows.makeAsyncIterator()
let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default)
XCTAssertEqual(rowsCount, insertedRowsCount)
let metadata = try await connection.execute(insertionQuery, logger: .psqlTest)
XCTAssertEqual(metadata.rows, rowsCount)

let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1")
try await connection.query(dropQuery, logger: .psqlTest)
Expand Down
3 changes: 3 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ x-shared-config: &shared_config
- 5432:5432

services:
psql-17:
image: postgres:17
<<: *shared_config
psql-16:
image: postgres:16
<<: *shared_config
Expand Down
Loading