From 08fcd81a0d6d3f3cbe61f003835ffad7f000f2f9 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 25 Sep 2024 09:03:29 +0100 Subject: [PATCH] Don't update expires in `set` if it is nil (#33) --- .../PostgresPersistDriver.swift | 30 +++++++++++++------ .../PersistTests.swift | 4 +-- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/Sources/HummingbirdPostgres/PostgresPersistDriver.swift b/Sources/HummingbirdPostgres/PostgresPersistDriver.swift index 7f613b4..27b0fa0 100644 --- a/Sources/HummingbirdPostgres/PostgresPersistDriver.swift +++ b/Sources/HummingbirdPostgres/PostgresPersistDriver.swift @@ -98,15 +98,27 @@ public final class PostgresPersistDriver: PersistDriver { /// Set value for key. public func set(key: String, value: some Codable, expires: Duration?) async throws { - let expires = expires.map { Date.now + Double($0.components.seconds) } ?? Date.distantFuture - try await self.client.query( - """ - INSERT INTO _hb_pg_persist (id, data, expires) VALUES (\(key), \(WrapperObject(value)), \(expires)) - ON CONFLICT (id) - DO UPDATE SET data = \(WrapperObject(value)), expires = \(expires) - """, - logger: self.logger - ) + if let expires { + let expires = Date.now + Double(expires.components.seconds) + try await self.client.query( + """ + INSERT INTO _hb_pg_persist (id, data, expires) VALUES (\(key), \(WrapperObject(value)), \(expires)) + ON CONFLICT (id) + DO UPDATE SET data = \(WrapperObject(value)), expires = \(expires) + """, + logger: self.logger + ) + + } else { + try await self.client.query( + """ + INSERT INTO _hb_pg_persist (id, data, expires) VALUES (\(key), \(WrapperObject(value)), \(Date.distantFuture)) + ON CONFLICT (id) + DO UPDATE SET data = \(WrapperObject(value)) + """, + logger: self.logger + ) + } } /// Get value for key diff --git a/Tests/HummingbirdPostgresTests/PersistTests.swift b/Tests/HummingbirdPostgresTests/PersistTests.swift index 3c47539..24120ae 100644 --- a/Tests/HummingbirdPostgresTests/PersistTests.swift +++ b/Tests/HummingbirdPostgresTests/PersistTests.swift @@ -228,12 +228,12 @@ final class PersistTests: XCTestCase { try await app.test(.router) { client in let tag = UUID().uuidString - try await client.execute(uri: "/persist/\(tag)/0", method: .put, body: ByteBufferAllocator().buffer(string: "ThisIsTest1")) { _ in } + try await client.execute(uri: "/persist/\(tag)/0", method: .put, body: ByteBuffer(string: "ThisIsTest1")) { _ in } try await Task.sleep(nanoseconds: 1_000_000_000) try await client.execute(uri: "/persist/\(tag)", method: .get) { response in XCTAssertEqual(response.status, .noContent) } - try await client.execute(uri: "/persist/\(tag)/10", method: .put, body: ByteBufferAllocator().buffer(string: "ThisIsTest1")) { response in + try await client.execute(uri: "/persist/\(tag)/10", method: .put, body: ByteBuffer(string: "ThisIsTest1")) { response in XCTAssertEqual(response.status, .ok) } try await client.execute(uri: "/persist/\(tag)", method: .get) { response in