Skip to content

Commit

Permalink
implement PostgresConnection.query and .execute with metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM committed Feb 20, 2025
1 parent 5d817be commit 61da70d
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.DS_Store
/.build
/.index-build
/Packages
/*.xcodeproj
DerivedData
Expand Down
81 changes: 81 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,48 @@ extension PostgresConnection {
}
}

// use this for queries where you want to consume the rows.
// we can use the `consume` scope to better ensure structured concurrency when consuming the rows.
public func query<Result>(
_ query: PostgresQuery,
logger: Logger,
file: String = #fileID,
line: Int = #line,
_ consume: (PostgresRowSequence) async throws -> Result
) async throws -> (Result, PostgresQueryMetadata) {
var logger = logger
logger[postgresMetadataKey: .connectionID] = "\(self.id)"

guard query.binds.count <= Int(UInt16.max) else {
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
}
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let context = ExtendedQueryContext(
query: query,
logger: logger,
promise: promise
)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)

do {
let (rowStream, rowSequence) = try await promise.futureResult.map { rowStream in
(rowStream, rowStream.asyncSequence())
}.get()
let result = try await consume(rowSequence)
try await rowStream.drain().get()
guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else {
throw PSQLError.invalidCommandTag(rowStream.commandTag)
}
return (result, metadata)
} catch var error as PSQLError {
error.file = file
error.line = line
error.query = query
throw error // rethrow with more metadata
}
}

/// Start listening for a channel
public func listen(_ channel: String) async throws -> PostgresNotificationSequence {
let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed)
Expand Down Expand Up @@ -531,6 +573,45 @@ extension PostgresConnection {
}
}

// use this for queries where you want to consume the rows.
// we can use the `consume` scope to better ensure structured concurrency when consuming the rows.
@discardableResult
public func execute(
_ query: PostgresQuery,
logger: Logger,
file: String = #fileID,
line: Int = #line
) async throws -> PostgresQueryMetadata {
var logger = logger
logger[postgresMetadataKey: .connectionID] = "\(self.id)"

guard query.binds.count <= Int(UInt16.max) else {
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
}
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let context = ExtendedQueryContext(
query: query,
logger: logger,
promise: promise
)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)

do {
let rowStream = try await promise.futureResult.get()
try await rowStream.drain().get()
guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else {
throw PSQLError.invalidCommandTag(rowStream.commandTag)
}
return metadata
} catch var error as PSQLError {
error.file = file
error.line = line
error.query = query
throw error // rethrow with more metadata
}
}

#if compiler(>=6.0)
/// Puts the connection into an open transaction state, for the provided `closure`'s lifetime.
///
Expand Down
58 changes: 57 additions & 1 deletion Sources/PostgresNIO/New/PSQLRowStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,63 @@ final class PSQLRowStream: @unchecked Sendable {
return self.eventLoop.makeFailedFuture(error)
}
}


// MARK: Drain on EventLoop

func drain() -> EventLoopFuture<Void> {
if self.eventLoop.inEventLoop {
return self.drain0()
} else {
return self.eventLoop.flatSubmit {
self.drain0()
}
}
}

private func drain0() -> EventLoopFuture<Void> {
self.eventLoop.preconditionInEventLoop()

switch self.downstreamState {
case .waitingForConsumer(let bufferState):
switch bufferState {
case .streaming(var buffer, let dataSource):
let promise = self.eventLoop.makePromise(of: Void.self)

buffer.removeAll()
self.downstreamState = .iteratingRows(onRow: { _ in }, promise, dataSource)
// immediately request more
dataSource.request(for: self)

return promise.futureResult

case .finished(_, let summary):
self.downstreamState = .consumed(.success(summary))
return self.eventLoop.makeSucceededVoidFuture()

case .failure(let error):
self.downstreamState = .consumed(.failure(error))
return self.eventLoop.makeFailedFuture(error)
}
case .asyncSequence(let consumer, let dataSource, _):
consumer.finish()

let promise = self.eventLoop.makePromise(of: Void.self)

self.downstreamState = .iteratingRows(onRow: { _ in }, promise, dataSource)
// immediately request more
dataSource.request(for: self)

return promise.futureResult
case .consumed(.success):
// already drained
return self.eventLoop.makeSucceededVoidFuture()
case .consumed(let .failure(error)):
return self.eventLoop.makeFailedFuture(error)
default:
preconditionFailure("Invalid state: \(self.downstreamState)")
}
}

internal func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) {
self.logger.debug("Notice Received", metadata: [
.notice: "\(notice)"
Expand Down
2 changes: 2 additions & 0 deletions Sources/PostgresNIO/New/PostgresRowSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ extension PostgresRowSequence {
extension PostgresRowSequence.AsyncIterator: Sendable {}

extension PostgresRowSequence {
/// Collects all rows into an array.
/// - Returns: The rows.
public func collect() async throws -> [PostgresRow] {
var result = [PostgresRow]()
for try await row in self {
Expand Down
103 changes: 95 additions & 8 deletions Tests/IntegrationTests/AsyncTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,98 @@ final class AsyncPostgresConnectionTests: XCTestCase {
}
}

func testSelect10kRowsWithMetadata() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

let start = 1
let end = 10000

try await withTestConnection(on: eventLoop) { connection in
let (result, metadata) = try await connection.query(
"SELECT generate_series(\(start), \(end));",
logger: .psqlTest
) { rows in
var counter = 0
for try await row in rows {
let element = try row.decode(Int.self)
XCTAssertEqual(element, counter + 1)
counter += 1
}
return counter
}

XCTAssertEqual(metadata.command, "SELECT")
XCTAssertEqual(metadata.oid, nil)
XCTAssertEqual(metadata.rows, end)

XCTAssertEqual(result, end)
}
}

func testSelectRowsWithMetadataNotConsumedAtAll() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

let start = 1
let end = 10000

try await withTestConnection(on: eventLoop) { connection in
let (_, metadata) = try await connection.query(
"SELECT generate_series(\(start), \(end));",
logger: .psqlTest
) { _ in }

XCTAssertEqual(metadata.command, "SELECT")
XCTAssertEqual(metadata.oid, nil)
XCTAssertEqual(metadata.rows, end)
}
}

func testSelectRowsWithMetadataNotFullyConsumed() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

try await withTestConnection(on: eventLoop) { connection in
do {
_ = try await connection.query(
"SELECT generate_series(1, 10000);",
logger: .psqlTest
) { rows in
for try await _ in rows { break }
}
// This path is also fine
} catch is CancellationError {
// Expected
} catch {
XCTFail("Expected 'CancellationError', got: \(String(reflecting: error))")
}
}
}

func testExecuteRowsWithMetadata() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

let start = 1
let end = 10000

try await withTestConnection(on: eventLoop) { connection in
let metadata = try await connection.execute(
"SELECT generate_series(\(start), \(end));",
logger: .psqlTest
)

XCTAssertEqual(metadata.command, "SELECT")
XCTAssertEqual(metadata.oid, nil)
XCTAssertEqual(metadata.rows, end)
}
}

func testSelectActiveConnection() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
Expand Down Expand Up @@ -207,7 +299,7 @@ final class AsyncPostgresConnectionTests: XCTestCase {

try await withTestConnection(on: eventLoop) { connection in
// Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257
// Max columns limit is 1664, so we will only make 5 * 257 columns which is less
// Max columns limit appears to be ~1600, so we will only make 5 * 257 columns which is less
// Then we will insert 3 * 17 rows
// In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings
// If the test is successful, it means Postgres supports UInt16.max bindings
Expand Down Expand Up @@ -241,13 +333,8 @@ final class AsyncPostgresConnectionTests: XCTestCase {
unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)",
binds: binds
)
try await connection.query(insertionQuery, logger: .psqlTest)

let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1")
let countRows = try await connection.query(countQuery, logger: .psqlTest)
var countIterator = countRows.makeAsyncIterator()
let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default)
XCTAssertEqual(rowsCount, insertedRowsCount)
let metadata = try await connection.execute(insertionQuery, logger: .psqlTest)
XCTAssertEqual(metadata.rows, rowsCount)

let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1")
try await connection.query(dropQuery, logger: .psqlTest)
Expand Down
3 changes: 3 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ x-shared-config: &shared_config
- 5432:5432

services:
psql-17:
image: postgres:17
<<: *shared_config
psql-16:
image: postgres:16
<<: *shared_config
Expand Down

0 comments on commit 61da70d

Please sign in to comment.