diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index b695dcfe..6449ab29 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -172,6 +172,16 @@ public struct PostgresBindings: Sendable, Hashable { try self.append(value, context: .default) } + @inlinable + public mutating func append(_ value: Optional) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value) + } + } + @inlinable public mutating func append( _ value: Value, @@ -181,11 +191,34 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value, context: context) + } + } + @inlinable public mutating func append(_ value: Value) { self.append(value, context: .default) } + @inlinable + public mutating func append(_ value: Optional) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value) + } + } + @inlinable public mutating func append( _ value: Value, @@ -195,6 +228,19 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value, context: context) + } + } + @inlinable mutating func appendUnprotected( _ value: Value, diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 513157fd..b4c8e93f 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -476,6 +476,87 @@ final class AsyncPostgresConnectionTests: XCTestCase { XCTFail("Unexpected error: \(String(describing: error))") } } + + static let preparedStatementWithOptionalTestTable = "AsyncTestPreparedStatementWithOptionalTestTable" + func testPreparedStatementWithOptionalBinding() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct InsertPreparedStatement: PostgresPreparedStatement { + static let name = "INSERT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" (uuid) VALUES ($1);"# + typealias Row = () + + var uuid: UUID? + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.uuid) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + struct SelectPreparedStatement: PostgresPreparedStatement { + static let name = "SELECT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" WHERE id <= $1;"# + typealias Row = (Int, UUID?) + + var id: Int + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode((Int, UUID?).self) + } + } + + do { + try await withTestConnection(on: eventLoop) { connection in + try await connection.query(""" + CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementWithOptionalTestTable)" ( + id SERIAL PRIMARY KEY, + uuid UUID + ) + """, + logger: .psqlTest + ) + + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + + let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest) + var counter = 0 + for try await (id, uuid) in rows { + Logger.psqlTest.info("Received row", metadata: [ + "id": "\(id)", "uuid": "\(String(describing: uuid))" + ]) + counter += 1 + } + + try await connection.query(""" + DROP TABLE "\(unescaped: Self.preparedStatementWithOptionalTestTable)"; + """, + logger: .psqlTest + ) + } + } catch { + XCTFail("Unexpected error: \(String(describing: error))") + } + } } extension XCTestCase {