Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure we backup the database before doing migration #60

Merged
merged 2 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/test_db_migration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
import tscat.orm_sqlalchemy
import os
from glob import glob
import tscat
import datetime as dt

Expand All @@ -16,6 +17,7 @@ def setUp(self) -> None:
def test_existing_event_now_has_rating_field(self):
existing, = tscat.get_events()
self.assertEqual(existing.rating, None)
self.assertGreater(len(glob(f'{tscat.base.backend()._tmp_dir}/*.sqlite.backup')), 0)

def test_creating_event_with_rating(self):
tscat.create_event(dt.datetime.now(), dt.datetime.now() + dt.timedelta(days=1), "Patrick", rating=3)
Expand Down
19 changes: 14 additions & 5 deletions tscat/orm_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pickle
import datetime as dt
import os
from shutil import copyfile
from tempfile import mkdtemp
import orjson
from appdirs import user_data_dir

Expand Down Expand Up @@ -128,11 +130,12 @@ def visit_predicate(self, pred: Predicate):

class Backend:
def __init__(self, testing: Union[bool, str] = False):
in_memory = False
if testing is True:
sqlite_filename = ""
in_memory = True
elif isinstance(testing, str):
sqlite_filename = 'file:memdb1?mode=memory&cache=shared' # memory database

sqlite_filename = self._copy_to_tmp(testing)
else: # pragma: no cover
db_file_path = user_data_dir('tscat')
if not os.path.exists(db_file_path):
Expand All @@ -144,10 +147,9 @@ def __init__(self, testing: Union[bool, str] = False):
json_serializer=_serialize_json,
json_deserializer=_deserialize_json)

# copy testing database to memory
if isinstance(testing, str):
if in_memory:
import sqlite3
source = sqlite3.connect(testing)
source = sqlite3.connect("")
assert isinstance(self.engine.raw_connection(), _ConnectionFairy)
assert isinstance(self.engine.raw_connection().connection, sqlite3.Connection) # type: ignore
source.backup(self.engine.raw_connection().connection, pages=-1) # type: ignore
Expand All @@ -169,6 +171,13 @@ def do_begin(conn):

self.session = Session(bind=self.engine, autoflush=True)

def _copy_to_tmp(self, source_file) -> str:
# temp dir lives as long as the object
self._tmp_dir = mkdtemp()
destination_file = os.path.join(self._tmp_dir, os.path.basename(source_file))
copyfile(source_file, destination_file)
return destination_file

def close(self):
self.session.close()
self.engine.dispose()
Expand Down
22 changes: 20 additions & 2 deletions tscat/orm_sqlalchemy/migrations/env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import os.path
from logging.config import fileConfig
from logging import getLogger
from datetime import datetime
from shutil import copyfile
from typing import Optional

from sqlalchemy import engine_from_config
from sqlalchemy import pool
Expand All @@ -17,13 +22,24 @@
# add your model's MetaData object here
# for 'autogenerate' support
from tscat.orm_sqlalchemy.orm import Base

target_metadata = Base.metadata


# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.

def _backup_database(db_url: Optional[str] = None):
if db_url is not None:
path = db_url.replace('sqlite://', '')
if os.path.exists(path) and os.path.isfile(path):
now = datetime.now().strftime('%Y%m%dT%H%M%S')
backup_path = path.replace('.sqlite', f'-{now}.sqlite.backup')
getLogger('alembic').info(f'Backing up database to {backup_path}')
copyfile(path, backup_path)


def run_migrations_offline() -> None: # pragma: no cover
"""Run migrations in 'offline' mode.
Expand All @@ -44,7 +60,8 @@ def run_migrations_offline() -> None: # pragma: no cover
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)

if context.get_head_revision() != context.get_context().get_current_revision():
_backup_database(url)
with context.begin_transaction():
context.run_migrations()

Expand All @@ -66,7 +83,8 @@ def run_migrations_online() -> None:
context.configure(
connection=connection, target_metadata=target_metadata
)

if context.get_head_revision() != context.get_context().get_current_revision():
_backup_database(config.get_main_option("sqlalchemy.url"))
with context.begin_transaction():
context.run_migrations()

Expand Down
Loading