Skip to content

Commit

Permalink
Allow SQLite query parameters and support cached databases (#561)
Browse files Browse the repository at this point in the history
* add support for sqlite connection string query parameters, cached memory databases

* add additional comments #196

* tweaked comments #196

* Lint

---------

Co-authored-by: Nathan Joshi <[email protected]>
  • Loading branch information
zanieb and Nathan Joshi authored Aug 2, 2023
1 parent 25fa295 commit 9795187
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 5 deletions.
18 changes: 15 additions & 3 deletions databases/backends/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import sqlite3
import typing
import uuid
from urllib.parse import urlencode

import aiosqlite
from sqlalchemy.dialects.sqlite import pysqlite
Expand Down Expand Up @@ -45,7 +47,9 @@ async def connect(self) -> None:
# )

async def disconnect(self) -> None:
pass
# if it extsis, remove reference to connection to cached in-memory database on disconnect
if self._pool._memref:
self._pool._memref = None
# assert self._pool is not None, "DatabaseBackend is not running"
# self._pool.close()
# await self._pool.wait_closed()
Expand All @@ -57,12 +61,20 @@ def connection(self) -> "SQLiteConnection":

class SQLitePool:
def __init__(self, url: DatabaseURL, **options: typing.Any) -> None:
self._url = url
self._database = url.database
self._memref = None
# add query params to database connection string
if url.options:
self._database += "?" + urlencode(url.options)
self._options = options

if url.options and "cache" in url.options:
# reference to a connection to the cached in-memory database must be held to keep it from being deleted
self._memref = sqlite3.connect(self._database, **self._options)

async def acquire(self) -> aiosqlite.Connection:
connection = aiosqlite.connect(
database=self._url.database, isolation_level=None, **self._options
database=self._database, isolation_level=None, **self._options
)
await connection.__aenter__()
return connection
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ aiosqlite==0.17.0
asyncpg==0.26.0

# Sync database drivers for standard tooling around setup/teardown/migrations.
psycopg2-binary==2.9.3
pymysql==1.0.2
# psycopg2-binary==2.9.3
# pymysql==1.0.2

# Testing
autoflake==1.4
Expand Down
62 changes: 62 additions & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import itertools
import os
import re
import sqlite3
from typing import MutableMapping
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -1529,6 +1530,67 @@ async def test_result_named_access(database_url):
assert result.completed is True


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_mapping_property_interface(database_url):
"""
Test that all connections implement interface with `_mapping` property
"""
async with Database(database_url) as database:
query = notes.select()
single_result = await database.fetch_one(query=query)
assert single_result._mapping["text"] == "example1"
assert single_result._mapping["completed"] is True

list_result = await database.fetch_all(query=query)
assert list_result[0]._mapping["text"] == "example1"
assert list_result[0]._mapping["completed"] is True


@async_adapter
async def test_should_not_maintain_ref_when_no_cache_param():
async with Database("sqlite:///file::memory:", uri=True) as database:
query = sqlalchemy.schema.CreateTable(notes)
await database.execute(query)

query = notes.insert()
values = {"text": "example1", "completed": True}
with pytest.raises(sqlite3.OperationalError):
await database.execute(query, values)


@async_adapter
async def test_should_maintain_ref_when_cache_param():
async with Database("sqlite:///file::memory:?cache=shared", uri=True) as database:
query = sqlalchemy.schema.CreateTable(notes)
await database.execute(query)

query = notes.insert()
values = {"text": "example1", "completed": True}
await database.execute(query, values)

query = notes.select().where(notes.c.text == "example1")
result = await database.fetch_one(query=query)
assert result.text == "example1"
assert result.completed is True


@async_adapter
async def test_should_remove_ref_on_disconnect():
async with Database("sqlite:///file::memory:?cache=shared", uri=True) as database:
query = sqlalchemy.schema.CreateTable(notes)
await database.execute(query)

query = notes.insert()
values = {"text": "example1", "completed": True}
await database.execute(query, values)

async with Database("sqlite:///file::memory:?cache=shared", uri=True) as database:
query = notes.select()
with pytest.raises(sqlite3.OperationalError):
await database.fetch_all(query=query)


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_mapping_property_interface(database_url):
Expand Down

0 comments on commit 9795187

Please sign in to comment.