diff --git a/Sources/SQLiteNIO/SQLiteConnection.swift b/Sources/SQLiteNIO/SQLiteConnection.swift index 7ed6ec2..6e4dce4 100644 --- a/Sources/SQLiteNIO/SQLiteConnection.swift +++ b/Sources/SQLiteNIO/SQLiteConnection.swift @@ -3,126 +3,6 @@ import NIOPosix import CSQLite import Logging -public protocol SQLiteDatabase { - var logger: Logger { get } - - var eventLoop: any EventLoop { get } - - @preconcurrency - func query( - _ query: String, - _ binds: [SQLiteData], - logger: Logger, - _ onRow: @escaping @Sendable (SQLiteRow) -> Void - ) -> EventLoopFuture - - @preconcurrency - func withConnection( - _: @escaping @Sendable (SQLiteConnection) -> EventLoopFuture - ) -> EventLoopFuture -} - -extension SQLiteDatabase { - /// Logger-less version of ``query(_:_:logger:_:)``. - @preconcurrency - public func query( - _ query: String, - _ binds: [SQLiteData] = [], - _ onRow: @escaping @Sendable (SQLiteRow) -> Void - ) -> EventLoopFuture { - self.query(query, binds, logger: self.logger, onRow) - } - - /// Logger-less async version of ``query(_:_:logger:_:)``. - public func query( - _ query: String, - _ binds: [SQLiteData], - _ onRow: @escaping @Sendable (SQLiteRow) -> Void - ) async throws { - try await self.query(query, binds, logger: self.logger, onRow).get() - } - - /// Data-returning version of ``query(_:_:_:)-2zmfi``. - public func query( - _ query: String, - _ binds: [SQLiteData] = [] - ) -> EventLoopFuture<[SQLiteRow]> { - #if swift(<5.10) - let rows: UnsafeMutableTransferBox<[SQLiteRow]> = .init([]) - - return self.query(query, binds, logger: self.logger) { rows.wrappedValue.append($0) }.map { rows.wrappedValue } - #else - nonisolated(unsafe) var rows: [SQLiteRow] = [] - - return self.query(query, binds, logger: self.logger) { rows.append($0) }.map { rows } - #endif - } - - /// Data-returning version of ``query(_:_:_:)-3s65n``. - public func query(_ query: String, _ binds: [SQLiteData] = []) async throws -> [SQLiteRow] { - try await self.query(query, binds).get() - } - - /// Async version of ``withConnection(_:)-48y34``. - public func withConnection( - _ closure: @escaping @Sendable (SQLiteConnection) async throws -> T - ) async throws -> T { - try await self.withConnection { conn in - conn.eventLoop.makeFutureWithTask { - try await closure(conn) - } - }.get() - } -} - -#if swift(<5.10) -fileprivate final class UnsafeMutableTransferBox: @unchecked Sendable { - var wrappedValue: Wrapped - init(_ wrappedValue: Wrapped) { self.wrappedValue = wrappedValue } -} -#endif - -extension SQLiteDatabase { - public func logging(to logger: Logger) -> any SQLiteDatabase { - SQLiteDatabaseCustomLogger(database: self, logger: logger) - } -} - -private struct SQLiteDatabaseCustomLogger: SQLiteDatabase { - let database: any SQLiteDatabase - var eventLoop: any EventLoop { self.database.eventLoop } - let logger: Logger - - func withConnection(_ closure: @escaping @Sendable (SQLiteConnection) -> EventLoopFuture) -> EventLoopFuture { - self.database.withConnection(closure) - } - func withConnection(_ closure: @escaping @Sendable (SQLiteConnection) async throws -> T) async throws -> T { - try await self.database.withConnection(closure) - } - - func query(_ query: String, _ binds: [SQLiteData], logger: Logger, _ onRow: @escaping @Sendable (SQLiteRow) -> Void) -> EventLoopFuture { - self.database.query(query, binds, logger: logger, onRow) - } - - func query(_ query: String, _ binds: [SQLiteData] = [], _ onRow: @escaping @Sendable (SQLiteRow) -> Void) -> EventLoopFuture { - self.database.query(query, binds, onRow) - } - func query(_ query: String, _ binds: [SQLiteData], _ onRow: @escaping @Sendable (SQLiteRow) -> Void) async throws { - try await self.database.query(query, binds, onRow) - } - - func query(_ query: String, _ binds: [SQLiteData] = []) -> EventLoopFuture<[SQLiteRow]> { - self.database.query(query, binds) - } - func query(_ query: String, _ binds: [SQLiteData] = []) async throws -> [SQLiteRow] { - try await self.database.query(query, binds) - } - - func logging(to logger: Logger) -> any SQLiteDatabase { - Self(database: self.database, logger: logger) - } -} - final class SQLiteConnectionHandle: @unchecked Sendable { var raw: OpaquePointer? @@ -149,7 +29,7 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable { public let logger: Logger let handle: SQLiteConnectionHandle - let threadPool: NIOThreadPool + private let threadPool: NIOThreadPool public var isClosed: Bool { self.handle.raw == nil @@ -167,18 +47,6 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable { ) } - public static func open( - storage: Storage = .memory, - logger: Logger = .init(label: "codes.vapor.sqlite") - ) async throws -> SQLiteConnection { - try await Self.open( - storage: storage, - threadPool: NIOThreadPool.singleton, - logger: logger, - on: MultiThreadedEventLoopGroup.singleton.any() - ) - } - public static func open( storage: Storage = .memory, threadPool: NIOThreadPool, @@ -190,17 +58,6 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable { } } - public static func open( - storage: Storage = .memory, - threadPool: NIOThreadPool, - logger: Logger = .init(label: "codes.vapor.sqlite"), - on eventLoop: any EventLoop - ) async throws -> SQLiteConnection { - try await threadPool.runIfActive { - try self.openInternal(storage: storage, threadPool: threadPool, logger: logger, eventLoop: eventLoop) - } - } - private static func openInternal(storage: Storage, threadPool: NIOThreadPool, logger: Logger, eventLoop: any EventLoop) throws -> SQLiteConnection { let path: String switch storage { @@ -255,12 +112,6 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable { } } - public func lastAutoincrementID() async throws -> Int { - try await self.threadPool.runIfActive { - numericCast(sqlite_nio_sqlite3_last_insert_rowid(self.handle.raw)) - } - } - @preconcurrency public func withConnection( _ closure: @escaping @Sendable (SQLiteConnection) -> EventLoopFuture @@ -268,12 +119,6 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable { closure(self) } - public func withConnection( - _ closure: @escaping @Sendable (SQLiteConnection) async throws -> T - ) async throws -> T { - try await closure(self) - } - @preconcurrency public func query( _ query: String, @@ -305,6 +150,68 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable { return promise.futureResult } + public func close() -> EventLoopFuture { + self.threadPool.runIfActive(eventLoop: self.eventLoop) { + sqlite_nio_sqlite3_close(self.handle.raw) + self.handle.raw = nil + } + } + + public func install(customFunction: SQLiteCustomFunction) -> EventLoopFuture { + self.logger.trace("Adding custom function \(customFunction.name)") + return self.threadPool.runIfActive(eventLoop: self.eventLoop) { + try customFunction.install(in: self) + } + } + + public func uninstall(customFunction: SQLiteCustomFunction) -> EventLoopFuture { + self.logger.trace("Removing custom function \(customFunction.name)") + return self.threadPool.runIfActive(eventLoop: self.eventLoop) { + try customFunction.uninstall(in: self) + } + } + + deinit { + assert(self.handle.raw == nil, "SQLiteConnection was not closed before deinitializing") + } +} + +extension SQLiteConnection { + public static func open( + storage: Storage = .memory, + logger: Logger = .init(label: "codes.vapor.sqlite") + ) async throws -> SQLiteConnection { + try await Self.open( + storage: storage, + threadPool: NIOThreadPool.singleton, + logger: logger, + on: MultiThreadedEventLoopGroup.singleton.any() + ) + } + + public static func open( + storage: Storage = .memory, + threadPool: NIOThreadPool, + logger: Logger = .init(label: "codes.vapor.sqlite"), + on eventLoop: any EventLoop + ) async throws -> SQLiteConnection { + try await threadPool.runIfActive { + try self.openInternal(storage: storage, threadPool: threadPool, logger: logger, eventLoop: eventLoop) + } + } + + public func lastAutoincrementID() async throws -> Int { + try await self.threadPool.runIfActive { + numericCast(sqlite_nio_sqlite3_last_insert_rowid(self.handle.raw)) + } + } + + public func withConnection( + _ closure: @escaping @Sendable (SQLiteConnection) async throws -> T + ) async throws -> T { + try await closure(self) + } + public func query( _ query: String, _ binds: [SQLiteData], @@ -313,13 +220,6 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable { try await self.query(query, binds, onRow).get() } - public func close() -> EventLoopFuture { - self.threadPool.runIfActive(eventLoop: self.eventLoop) { - sqlite_nio_sqlite3_close(self.handle.raw) - self.handle.raw = nil - } - } - public func close() async throws { try await self.threadPool.runIfActive { sqlite_nio_sqlite3_close(self.handle.raw) @@ -327,13 +227,6 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable { } } - public func install(customFunction: SQLiteCustomFunction) -> EventLoopFuture { - self.logger.trace("Adding custom function \(customFunction.name)") - return self.threadPool.runIfActive(eventLoop: self.eventLoop) { - try customFunction.install(in: self) - } - } - public func install(customFunction: SQLiteCustomFunction) async throws { self.logger.trace("Adding custom function \(customFunction.name)") return try await self.threadPool.runIfActive { @@ -341,21 +234,10 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable { } } - public func uninstall(customFunction: SQLiteCustomFunction) -> EventLoopFuture { - self.logger.trace("Removing custom function \(customFunction.name)") - return self.threadPool.runIfActive(eventLoop: self.eventLoop) { - try customFunction.uninstall(in: self) - } - } - public func uninstall(customFunction: SQLiteCustomFunction) async throws { self.logger.trace("Removing custom function \(customFunction.name)") return try await self.threadPool.runIfActive { try customFunction.uninstall(in: self) } } - - deinit { - assert(self.handle.raw == nil, "SQLiteConnection was not closed before deinitializing") - } } diff --git a/Sources/SQLiteNIO/SQLiteDatabase.swift b/Sources/SQLiteNIO/SQLiteDatabase.swift new file mode 100644 index 0000000..fdb6a6b --- /dev/null +++ b/Sources/SQLiteNIO/SQLiteDatabase.swift @@ -0,0 +1,124 @@ +import NIOCore +import NIOPosix +import CSQLite +import Logging + +public protocol SQLiteDatabase { + var logger: Logger { get } + + var eventLoop: any EventLoop { get } + + @preconcurrency + func query( + _ query: String, + _ binds: [SQLiteData], + logger: Logger, + _ onRow: @escaping @Sendable (SQLiteRow) -> Void + ) -> EventLoopFuture + + @preconcurrency + func withConnection( + _: @escaping @Sendable (SQLiteConnection) -> EventLoopFuture + ) -> EventLoopFuture +} + +extension SQLiteDatabase { + /// Logger-less version of ``query(_:_:logger:_:)``. + @preconcurrency + public func query( + _ query: String, + _ binds: [SQLiteData] = [], + _ onRow: @escaping @Sendable (SQLiteRow) -> Void + ) -> EventLoopFuture { + self.query(query, binds, logger: self.logger, onRow) + } + + /// Logger-less async version of ``query(_:_:logger:_:)``. + public func query( + _ query: String, + _ binds: [SQLiteData], + _ onRow: @escaping @Sendable (SQLiteRow) -> Void + ) async throws { + try await self.query(query, binds, logger: self.logger, onRow).get() + } + + /// Data-returning version of ``query(_:_:_:)-2zmfi``. + public func query( + _ query: String, + _ binds: [SQLiteData] = [] + ) -> EventLoopFuture<[SQLiteRow]> { + #if swift(<5.10) + let rows: UnsafeMutableTransferBox<[SQLiteRow]> = .init([]) + + return self.query(query, binds, logger: self.logger) { rows.wrappedValue.append($0) }.map { rows.wrappedValue } + #else + nonisolated(unsafe) var rows: [SQLiteRow] = [] + + return self.query(query, binds, logger: self.logger) { rows.append($0) }.map { rows } + #endif + } + + /// Data-returning version of ``query(_:_:_:)-3s65n``. + public func query(_ query: String, _ binds: [SQLiteData] = []) async throws -> [SQLiteRow] { + try await self.query(query, binds).get() + } + + /// Async version of ``withConnection(_:)-48y34``. + public func withConnection( + _ closure: @escaping @Sendable (SQLiteConnection) async throws -> T + ) async throws -> T { + try await self.withConnection { conn in + conn.eventLoop.makeFutureWithTask { + try await closure(conn) + } + }.get() + } +} + +#if swift(<5.10) +fileprivate final class UnsafeMutableTransferBox: @unchecked Sendable { + var wrappedValue: Wrapped + init(_ wrappedValue: Wrapped) { self.wrappedValue = wrappedValue } +} +#endif + +extension SQLiteDatabase { + public func logging(to logger: Logger) -> any SQLiteDatabase { + SQLiteDatabaseCustomLogger(database: self, logger: logger) + } +} + +private struct SQLiteDatabaseCustomLogger: SQLiteDatabase { + let database: any SQLiteDatabase + var eventLoop: any EventLoop { self.database.eventLoop } + let logger: Logger + + func withConnection(_ closure: @escaping @Sendable (SQLiteConnection) -> EventLoopFuture) -> EventLoopFuture { + self.database.withConnection(closure) + } + func withConnection(_ closure: @escaping @Sendable (SQLiteConnection) async throws -> T) async throws -> T { + try await self.database.withConnection(closure) + } + + func query(_ query: String, _ binds: [SQLiteData], logger: Logger, _ onRow: @escaping @Sendable (SQLiteRow) -> Void) -> EventLoopFuture { + self.database.query(query, binds, logger: logger, onRow) + } + + func query(_ query: String, _ binds: [SQLiteData] = [], _ onRow: @escaping @Sendable (SQLiteRow) -> Void) -> EventLoopFuture { + self.database.query(query, binds, onRow) + } + func query(_ query: String, _ binds: [SQLiteData], _ onRow: @escaping @Sendable (SQLiteRow) -> Void) async throws { + try await self.database.query(query, binds, onRow) + } + + func query(_ query: String, _ binds: [SQLiteData] = []) -> EventLoopFuture<[SQLiteRow]> { + self.database.query(query, binds) + } + func query(_ query: String, _ binds: [SQLiteData] = []) async throws -> [SQLiteRow] { + try await self.database.query(query, binds) + } + + func logging(to logger: Logger) -> any SQLiteDatabase { + Self(database: self.database, logger: logger) + } +}