Skip to content

Use apsw instead of sqlite #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,21 @@ except: print("Delete succeeded!")
```

Delete succeeded!

## Differences from sqlite-utils and sqlite-minutils

- WAL is the default
- Setting `Database(recursive_triggers=False)` works as expected
- Primary keys must be set on a table for it to be a target of a foreign
key
- Errors have been changed minimally, future PRs will change them
incrementally

## Differences in error handling

| Old/sqlite3/dbapi | New/APSW | Reason |
|----|----|----|
| IntegrityError | apsw.ConstraintError | Caused due to SQL transformation blocked on database constraints |
| sqlite3.dbapi2.OperationalError | apsw.Error | General error, OperationalError is now proxied to apsw.Error |
| sqlite3.dbapi2.OperationalError | apsw.SQLError | When an error is due to flawed SQL statements |
| sqlite3.ProgrammingError | apsw.ConnectionClosedError | Caused by an improperly closed database file |
122 changes: 56 additions & 66 deletions apswutils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@
from typing import ( cast, Any, Callable, Dict, Generator, Iterable, Union, Optional, List, Tuple,Iterator)
from functools import cache
import uuid
import apsw.ext
import apsw.bestpractice

# We don't use apsw.bestpractice.connection_dqs because sqlite-utils
# allowed doublequotes
apsw.bestpractice.apply((
apsw.bestpractice.connection_busy_timeout,
apsw.bestpractice.connection_enable_foreign_keys,
apsw.bestpractice.connection_optimize,
apsw.bestpractice.connection_recursive_triggers,
apsw.bestpractice.connection_wal,
apsw.bestpractice.library_logging
))

try: from sqlite_dump import iterdump
except ImportError: iterdump = None
Expand Down Expand Up @@ -238,33 +251,31 @@ def __init__(
), "Either specify a filename_or_conn or pass memory=True"
if memory_name:
uri = "file:{}?mode=memory&cache=shared".format(memory_name)
self.conn = sqlite3.connect(
uri,
uri=True,
check_same_thread=False,
isolation_level=None
# The flags being set allow apswutils to maintain the same behavior here
# as sqlite-minutils
self.conn = sqlite3.Connection(
uri, flags=apsw.SQLITE_OPEN_URI|apsw.SQLITE_OPEN_READWRITE
)
elif memory or filename_or_conn == ":memory:":
self.conn = sqlite3.connect(":memory:", isolation_level=None)
self.conn = sqlite3.Connection(":memory:")
elif isinstance(filename_or_conn, (str, pathlib.Path)):
if recreate and os.path.exists(filename_or_conn):
try:
os.remove(filename_or_conn)
except OSError:
# Avoid mypy and __repr__ errors, see:
# https://github.com/simonw/sqlite-utils/issues/503
self.conn = sqlite3.connect(":memory:", isolation_level=None)
self.conn = sqlite3.Connection(":memory:")
raise
self.conn = sqlite3.connect(str(filename_or_conn), check_same_thread=False, isolation_level=None)
self.conn = sqlite3.Connection(str(filename_or_conn))
else:
assert not recreate, "recreate cannot be used with connections, only paths"
self.conn = filename_or_conn
if not hasattr(self.conn, '__enter__'):
self.conn.__enter__ = __conn_enter__
self.conn.__exit__ = __conn_exit__
self._tracer = tracer
if recursive_triggers:
self.execute("PRAGMA recursive_triggers=on;")
self.conn.pragma('recursive_triggers', recursive_triggers)
self._registered_functions: set = set()
self.use_counts_table = use_counts_table
self.strict = strict
Expand All @@ -278,25 +289,6 @@ def get_last_rowid(self):
if res is None: return None
return int(res[0])

@contextlib.contextmanager
def ensure_autocommit_off(self):
"""
Ensure autocommit is off for this database connection.

Example usage::

with db.ensure_autocommit_off():
# do stuff here

This will reset to the previous autocommit state at the end of the block.
"""
old_isolation_level = self.conn.isolation_level
try:
self.conn.isolation_level = None
yield
finally:
self.conn.isolation_level = old_isolation_level

@contextlib.contextmanager
def tracer(self, tracer: Optional[Callable] = None):
"""
Expand Down Expand Up @@ -370,16 +362,18 @@ def register(fn):
kwargs = {}
registered = False
if deterministic:
# Try this, but fall back if sqlite3.NotSupportedError
# Try this, but fall back if apsw.Error
try:
self.conn.create_function(
fn_name, arity, fn, **dict(kwargs, deterministic=True)
self.conn.create_scalar_function(
fn_name, fn, arity, **dict(kwargs, deterministic=True)
)
registered = True
except sqlite3.NotSupportedError:
except sqlite3.Error: # Remember, sqlite3 here is actually apsw
# TODO Find the precise error, sqlite-minutils used sqlite3.NotSupportedError
# but as this isn't defined in APSW we fall back to apsw.Error
pass
if not registered:
self.conn.create_function(fn_name, arity, fn, **kwargs)
self.conn.create_scalar_function(fn_name, fn, arity, **kwargs)
self._registered_functions.add((fn_name, arity))
return fn

Expand Down Expand Up @@ -421,10 +415,10 @@ def query(
parameters, or a dictionary for ``where id = :id``
"""
cursor = self.execute(sql, tuple(params or tuple()))
if cursor.description is None: return []
keys = [d[0] for d in cursor.description]
try: columns = [c[0] for c in cursor.description]
except apsw.ExecutionCompleteError: return []
for row in cursor:
yield dict(zip(keys, row))
yield dict(zip(columns, row))

def execute(
self, sql: str, parameters: Optional[Union[Iterable, dict]] = None
Expand All @@ -449,7 +443,7 @@ def executescript(self, sql: str) -> sqlite3.Cursor:
"""
if self._tracer:
self._tracer(sql, None)
return self.conn.executescript(sql)
return self.conn.execute(sql)

def __hash__(self): return hash(self.conn)

Expand Down Expand Up @@ -643,14 +637,12 @@ def enable_wal(self):
Sets ``journal_mode`` to ``'wal'`` to enable Write-Ahead Log mode.
"""
if self.journal_mode != "wal":
with self.ensure_autocommit_off():
self.execute("PRAGMA journal_mode=wal;")
self.execute("PRAGMA journal_mode=wal;")

def disable_wal(self):
"Sets ``journal_mode`` back to ``'delete'`` to disable Write-Ahead Log mode."
if self.journal_mode != "delete":
with self.ensure_autocommit_off():
self.execute("PRAGMA journal_mode=delete;")
self.execute("PRAGMA journal_mode=delete;")

def _ensure_counts_table(self):
self.execute(_COUNTS_TABLE_CREATE_SQL.format(self._counts_table_name))
Expand Down Expand Up @@ -1292,7 +1284,9 @@ def rows_where(
if offset is not None:
sql += f" offset {offset}"
cursor = self.db.execute(sql, where_args or [])
columns = [c[0] for c in cursor.description]
# If no records found, return empty list
try: columns = [c[0] for c in cursor.description]
except apsw.ExecutionCompleteError: return []
for row in cursor:
yield dict(zip(columns, row))

Expand Down Expand Up @@ -1735,20 +1729,18 @@ def transform(
column_order=column_order,
keep_table=keep_table,
)
pragma_foreign_keys_was_on = self.db.execute("PRAGMA foreign_keys").fetchone()[
0
]
pragma_foreign_keys_was_on = self.db.conn.pragma('foreign_keys')
try:
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_keys=0;")
self.db.conn.pragma('foreign_keys', 0)
for sql in sqls:
self.db.execute(sql)
# Run the foreign_key_check before we commit
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_key_check;")
self.db.conn.pragma('foreign_key_check')
finally:
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_keys=1;")
self.db.conn.pragma('foreign_keys', 1)
return self

def transform_sql(
Expand Down Expand Up @@ -2211,7 +2203,7 @@ def drop(self, ignore: bool = False):
"""
try:
self.db.execute("DROP TABLE [{}]".format(self.name))
except sqlite3.OperationalError:
except apsw.SQLError:
if not ignore:
raise

Expand Down Expand Up @@ -2667,7 +2659,8 @@ def search(
),
args,
)
columns = [c[0] for c in cursor.description]
try: columns = [c[0] for c in cursor.description]
except apsw.ExecutionCompleteError: return []
for row in cursor:
yield dict(zip(columns, row))

Expand Down Expand Up @@ -2757,26 +2750,23 @@ def update(
table=self.name, sets=", ".join(sets), wheres=" and ".join(wheres)
)
sql += ' RETURNING *'
records = []
self.result = []
try:
cursor = self.db.execute(sql, args)
rowcount = cursor.rowcount
if cursor.description is not None:
columns = [d[0] for d in cursor.description]
for row in cursor:
records.append(dict(zip(columns, row)))
try: columns = [c[0] for c in cursor.description]
except apsw.ExecutionCompleteError: return self

for row in cursor:
self.result.append(dict(zip(columns, row)))
except OperationalError as e:
if alter and (" column" in e.args[0]):
# Attempt to add any missing columns, then try again
self.add_missing_columns([updates])
rowcount = self.db.execute(sql, args).rowcount
self.db.execute(sql, args)
else:
raise

# TODO: Test this works (rolls back) - use better exception:
# assert rowcount == 1
self.last_pk = pk_values[0] if len(pks) == 1 else pk_values
self.result = records
return self

def build_insert_queries_and_params(
Expand Down Expand Up @@ -2929,17 +2919,17 @@ def insert_chunk(
for query, params in queries_and_params:
try:
cursor = self.db.execute(query, tuple(params))
if cursor.description is None: continue
columns = [d[0] for d in cursor.description]
try: columns = [c[0] for c in cursor.description]
except apsw.ExecutionCompleteError: continue
for row in cursor:
records.append(dict(zip(columns, row)))
except OperationalError as e:
if alter and (" column" in e.args[0]):
# Attempt to add any missing columns, then try again
self.add_missing_columns(chunk)
cursor = self.db.execute(query, params)
if cursor.description is None: continue
columns = [d[0] for d in cursor.description]
try: columns = [c[0] for c in cursor.description]
except apsw.ExecutionCompleteError: continue
for row in cursor:
records.append(dict(zip(columns, row)))
elif e.args[0] == "too many SQL variables":
Expand Down Expand Up @@ -3671,7 +3661,7 @@ def drop(self, ignore=False):

try:
self.db.execute("DROP VIEW [{}]".format(self.name))
except sqlite3.OperationalError:
except apsw.SQLError:
if not ignore:
raise

Expand Down
23 changes: 6 additions & 17 deletions apswutils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,12 @@
import json
from typing import Dict, cast, BinaryIO, Iterable, Optional, Tuple, Type

try:
import pysqlite3 as sqlite3 # noqa: F401
from pysqlite3 import dbapi2 # noqa: F401

OperationalError = dbapi2.OperationalError
except ImportError:
try:
import sqlean as sqlite3 # noqa: F401
from sqlean import dbapi2 # noqa: F401

OperationalError = dbapi2.OperationalError
except ImportError:
import sqlite3 # noqa: F401
from sqlite3 import dbapi2 # noqa: F401

OperationalError = dbapi2.OperationalError

# TODO: Change use of apsw as a shim for sqlite3 more explicit
# In order to keep this PR minimal, we use sqlite3 as a shim for APSW
import apsw as sqlite3
# TODO: Replace use of OperationalError with more explicit apsw errors
# In order to keep this PR minimal, we use OperationalError as a shim for apsw.Error
OperationalError = sqlite3.Error

SPATIALITE_PATHS = (
"/usr/lib/x86_64-linux-gnu/mod_spatialite.so",
Expand Down
32 changes: 28 additions & 4 deletions nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,35 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": []
"source": [
"## Differences from sqlite-utils and sqlite-minutils\n",
"\n",
"- WAL is the default\n",
"- Setting `Database(recursive_triggers=False)` works as expected \n",
"- Primary keys must be set on a table for it to be a target of a foreign key\n",
"- Errors have been changed minimally, future PRs will change them incrementally"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Differences in error handling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"|Old/sqlite3/dbapi|New/APSW|Reason|\n",
"|---|---|---|\n",
"|IntegrityError|apsw.ConstraintError|Caused due to SQL transformation blocked on database constraints|\n",
"|sqlite3.dbapi2.OperationalError|apsw.Error|General error, OperationalError is now proxied to apsw.Error|\n",
"|sqlite3.dbapi2.OperationalError|apsw.SQLError|When an error is due to flawed SQL statements|\n",
"|sqlite3.ProgrammingError|apsw.ConnectionClosedError|Caused by an improperly closed database file|\n"
]
}
],
"metadata": {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_constructor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from apswutils import Database
from apswutils.utils import sqlite3
import apsw
import pytest


Expand Down Expand Up @@ -37,5 +37,5 @@ def test_database_close(tmpdir, memory):
db = Database(str(tmpdir / "test.db"))
assert db.execute("select 1 + 1").fetchone()[0] == 2
db.close()
with pytest.raises(sqlite3.ProgrammingError):
with pytest.raises(apsw.ConnectionClosedError):
db.execute("select 1 + 1")
Loading