Skip to content

Commit

Permalink
Split the SQLDatabase protocol into its own file. Group SQLiteConnnec…
Browse files Browse the repository at this point in the history
…tion's async methods together.
  • Loading branch information
gwynne committed May 2, 2024
1 parent 3028277 commit c51dcfe
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 181 deletions.
244 changes: 63 additions & 181 deletions Sources/SQLiteNIO/SQLiteConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,126 +3,6 @@ import NIOPosix
import CSQLite
import Logging

public protocol SQLiteDatabase {
var logger: Logger { get }

var eventLoop: any EventLoop { get }

@preconcurrency
func query(
_ query: String,
_ binds: [SQLiteData],
logger: Logger,
_ onRow: @escaping @Sendable (SQLiteRow) -> Void
) -> EventLoopFuture<Void>

@preconcurrency
func withConnection<T>(
_: @escaping @Sendable (SQLiteConnection) -> EventLoopFuture<T>
) -> EventLoopFuture<T>
}

extension SQLiteDatabase {
/// Logger-less version of ``query(_:_:logger:_:)``.
@preconcurrency
public func query(
_ query: String,
_ binds: [SQLiteData] = [],
_ onRow: @escaping @Sendable (SQLiteRow) -> Void
) -> EventLoopFuture<Void> {
self.query(query, binds, logger: self.logger, onRow)
}

/// Logger-less async version of ``query(_:_:logger:_:)``.
public func query(
_ query: String,
_ binds: [SQLiteData],
_ onRow: @escaping @Sendable (SQLiteRow) -> Void
) async throws {
try await self.query(query, binds, logger: self.logger, onRow).get()
}

/// Data-returning version of ``query(_:_:_:)-2zmfi``.
public func query(
_ query: String,
_ binds: [SQLiteData] = []
) -> EventLoopFuture<[SQLiteRow]> {
#if swift(<5.10)
let rows: UnsafeMutableTransferBox<[SQLiteRow]> = .init([])

return self.query(query, binds, logger: self.logger) { rows.wrappedValue.append($0) }.map { rows.wrappedValue }
#else
nonisolated(unsafe) var rows: [SQLiteRow] = []

return self.query(query, binds, logger: self.logger) { rows.append($0) }.map { rows }
#endif
}

/// Data-returning version of ``query(_:_:_:)-3s65n``.
public func query(_ query: String, _ binds: [SQLiteData] = []) async throws -> [SQLiteRow] {
try await self.query(query, binds).get()
}

/// Async version of ``withConnection(_:)-48y34``.
public func withConnection<T>(
_ closure: @escaping @Sendable (SQLiteConnection) async throws -> T
) async throws -> T {
try await self.withConnection { conn in
conn.eventLoop.makeFutureWithTask {
try await closure(conn)
}
}.get()
}
}

#if swift(<5.10)
fileprivate final class UnsafeMutableTransferBox<Wrapped: Sendable>: @unchecked Sendable {
var wrappedValue: Wrapped
init(_ wrappedValue: Wrapped) { self.wrappedValue = wrappedValue }
}
#endif

extension SQLiteDatabase {
public func logging(to logger: Logger) -> any SQLiteDatabase {
SQLiteDatabaseCustomLogger(database: self, logger: logger)
}
}

private struct SQLiteDatabaseCustomLogger: SQLiteDatabase {
let database: any SQLiteDatabase
var eventLoop: any EventLoop { self.database.eventLoop }
let logger: Logger

func withConnection<T>(_ closure: @escaping @Sendable (SQLiteConnection) -> EventLoopFuture<T>) -> EventLoopFuture<T> {
self.database.withConnection(closure)
}
func withConnection<T>(_ closure: @escaping @Sendable (SQLiteConnection) async throws -> T) async throws -> T {
try await self.database.withConnection(closure)
}

func query(_ query: String, _ binds: [SQLiteData], logger: Logger, _ onRow: @escaping @Sendable (SQLiteRow) -> Void) -> EventLoopFuture<Void> {
self.database.query(query, binds, logger: logger, onRow)
}

func query(_ query: String, _ binds: [SQLiteData] = [], _ onRow: @escaping @Sendable (SQLiteRow) -> Void) -> EventLoopFuture<Void> {
self.database.query(query, binds, onRow)
}
func query(_ query: String, _ binds: [SQLiteData], _ onRow: @escaping @Sendable (SQLiteRow) -> Void) async throws {
try await self.database.query(query, binds, onRow)
}

func query(_ query: String, _ binds: [SQLiteData] = []) -> EventLoopFuture<[SQLiteRow]> {
self.database.query(query, binds)
}
func query(_ query: String, _ binds: [SQLiteData] = []) async throws -> [SQLiteRow] {
try await self.database.query(query, binds)
}

func logging(to logger: Logger) -> any SQLiteDatabase {
Self(database: self.database, logger: logger)
}
}

final class SQLiteConnectionHandle: @unchecked Sendable {
var raw: OpaquePointer?

Expand All @@ -149,7 +29,7 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable {
public let logger: Logger

let handle: SQLiteConnectionHandle
let threadPool: NIOThreadPool
private let threadPool: NIOThreadPool

public var isClosed: Bool {
self.handle.raw == nil
Expand All @@ -167,18 +47,6 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable {
)
}

public static func open(
storage: Storage = .memory,
logger: Logger = .init(label: "codes.vapor.sqlite")
) async throws -> SQLiteConnection {
try await Self.open(
storage: storage,
threadPool: NIOThreadPool.singleton,
logger: logger,
on: MultiThreadedEventLoopGroup.singleton.any()
)
}

public static func open(
storage: Storage = .memory,
threadPool: NIOThreadPool,
Expand All @@ -190,17 +58,6 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable {
}
}

public static func open(
storage: Storage = .memory,
threadPool: NIOThreadPool,
logger: Logger = .init(label: "codes.vapor.sqlite"),
on eventLoop: any EventLoop
) async throws -> SQLiteConnection {
try await threadPool.runIfActive {
try self.openInternal(storage: storage, threadPool: threadPool, logger: logger, eventLoop: eventLoop)
}
}

private static func openInternal(storage: Storage, threadPool: NIOThreadPool, logger: Logger, eventLoop: any EventLoop) throws -> SQLiteConnection {
let path: String
switch storage {
Expand Down Expand Up @@ -255,25 +112,13 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable {
}
}

public func lastAutoincrementID() async throws -> Int {
try await self.threadPool.runIfActive {
numericCast(sqlite_nio_sqlite3_last_insert_rowid(self.handle.raw))
}
}

@preconcurrency
public func withConnection<T>(
_ closure: @escaping @Sendable (SQLiteConnection) -> EventLoopFuture<T>
) -> EventLoopFuture<T> {
closure(self)
}

public func withConnection<T>(
_ closure: @escaping @Sendable (SQLiteConnection) async throws -> T
) async throws -> T {
try await closure(self)
}

@preconcurrency
public func query(
_ query: String,
Expand Down Expand Up @@ -305,6 +150,68 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable {
return promise.futureResult
}

public func close() -> EventLoopFuture<Void> {
self.threadPool.runIfActive(eventLoop: self.eventLoop) {
sqlite_nio_sqlite3_close(self.handle.raw)
self.handle.raw = nil
}
}

public func install(customFunction: SQLiteCustomFunction) -> EventLoopFuture<Void> {
self.logger.trace("Adding custom function \(customFunction.name)")
return self.threadPool.runIfActive(eventLoop: self.eventLoop) {
try customFunction.install(in: self)
}
}

public func uninstall(customFunction: SQLiteCustomFunction) -> EventLoopFuture<Void> {
self.logger.trace("Removing custom function \(customFunction.name)")
return self.threadPool.runIfActive(eventLoop: self.eventLoop) {
try customFunction.uninstall(in: self)
}
}

deinit {
assert(self.handle.raw == nil, "SQLiteConnection was not closed before deinitializing")
}
}

extension SQLiteConnection {
public static func open(
storage: Storage = .memory,
logger: Logger = .init(label: "codes.vapor.sqlite")
) async throws -> SQLiteConnection {
try await Self.open(
storage: storage,
threadPool: NIOThreadPool.singleton,
logger: logger,
on: MultiThreadedEventLoopGroup.singleton.any()
)
}

public static func open(
storage: Storage = .memory,
threadPool: NIOThreadPool,
logger: Logger = .init(label: "codes.vapor.sqlite"),
on eventLoop: any EventLoop
) async throws -> SQLiteConnection {
try await threadPool.runIfActive {
try self.openInternal(storage: storage, threadPool: threadPool, logger: logger, eventLoop: eventLoop)
}
}

public func lastAutoincrementID() async throws -> Int {
try await self.threadPool.runIfActive {
numericCast(sqlite_nio_sqlite3_last_insert_rowid(self.handle.raw))
}
}

public func withConnection<T>(
_ closure: @escaping @Sendable (SQLiteConnection) async throws -> T
) async throws -> T {
try await closure(self)
}

public func query(
_ query: String,
_ binds: [SQLiteData],
Expand All @@ -313,49 +220,24 @@ public final class SQLiteConnection: SQLiteDatabase, Sendable {
try await self.query(query, binds, onRow).get()
}

public func close() -> EventLoopFuture<Void> {
self.threadPool.runIfActive(eventLoop: self.eventLoop) {
sqlite_nio_sqlite3_close(self.handle.raw)
self.handle.raw = nil
}
}

public func close() async throws {
try await self.threadPool.runIfActive {
sqlite_nio_sqlite3_close(self.handle.raw)
self.handle.raw = nil
}
}

public func install(customFunction: SQLiteCustomFunction) -> EventLoopFuture<Void> {
self.logger.trace("Adding custom function \(customFunction.name)")
return self.threadPool.runIfActive(eventLoop: self.eventLoop) {
try customFunction.install(in: self)
}
}

public func install(customFunction: SQLiteCustomFunction) async throws {
self.logger.trace("Adding custom function \(customFunction.name)")
return try await self.threadPool.runIfActive {
try customFunction.install(in: self)
}
}

public func uninstall(customFunction: SQLiteCustomFunction) -> EventLoopFuture<Void> {
self.logger.trace("Removing custom function \(customFunction.name)")
return self.threadPool.runIfActive(eventLoop: self.eventLoop) {
try customFunction.uninstall(in: self)
}
}

public func uninstall(customFunction: SQLiteCustomFunction) async throws {
self.logger.trace("Removing custom function \(customFunction.name)")
return try await self.threadPool.runIfActive {
try customFunction.uninstall(in: self)
}
}

deinit {
assert(self.handle.raw == nil, "SQLiteConnection was not closed before deinitializing")
}
}
Loading

0 comments on commit c51dcfe

Please sign in to comment.