diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 32dea4a5..53dbd8c9 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -594,7 +594,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func makeStartListeningQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) let query = ExtendedQueryContext( - query: PostgresQuery(unsafeSQL: "LISTEN \(channel);"), + query: PostgresQuery(unsafeSQL: #"LISTEN "\#(channel)";"#), logger: self.logger, promise: promise ) @@ -642,7 +642,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func makeUnlistenQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) let query = ExtendedQueryContext( - query: PostgresQuery(unsafeSQL: "UNLISTEN \(channel);"), + query: PostgresQuery(unsafeSQL: #"UNLISTEN "\#(channel)";"#), logger: self.logger, promise: promise ) diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 75e5b6ba..ce6fe027 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -225,25 +225,32 @@ final class AsyncPostgresConnectionTests: XCTestCase { } func testListenAndNotify() async throws { + let channelNames = [ + "foo", + "default" + ] + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - try await self.withTestConnection(on: eventLoop) { connection in - let stream = try await connection.listen("foo") - var iterator = stream.makeAsyncIterator() + for channelName in channelNames { + try await self.withTestConnection(on: eventLoop) { connection in + let stream = try await connection.listen(channelName) + var iterator = stream.makeAsyncIterator() - try await self.withTestConnection(on: eventLoop) { other in - try await other.query(#"NOTIFY foo, 'bar';"#, logger: .psqlTest) + try await self.withTestConnection(on: eventLoop) { other in + try await other.query(#"NOTIFY "\#(unescaped: channelName)", 'bar';"#, logger: .psqlTest) - try await other.query(#"NOTIFY foo, 'foo';"#, logger: .psqlTest) - } + try await other.query(#"NOTIFY "\#(unescaped: channelName)", 'foo';"#, logger: .psqlTest) + } - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "bar") + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "bar") - let second = try await iterator.next() - XCTAssertEqual(second?.payload, "foo") + let second = try await iterator.next() + XCTAssertEqual(second?.payload, "foo") + } } } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index f2cd96f8..fe94633a 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -51,7 +51,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -63,7 +63,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, "UNLISTEN foo;") + XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -111,7 +111,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -124,7 +124,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, "UNLISTEN foo;") + XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -160,7 +160,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: [])))