diff --git a/cads_broker/database.py b/cads_broker/database.py index a05d584b..38724d47 100644 --- a/cads_broker/database.py +++ b/cads_broker/database.py @@ -592,9 +592,9 @@ def init_database(connection_string: str, force: bool = False) -> sa.engine.Engi if not sqlalchemy_utils.database_exists(engine.url): sqlalchemy_utils.create_database(engine.url) # cleanup and create the schema + BaseModel.metadata.drop_all(engine) cacholote.database.Base.metadata.drop_all(engine) cacholote.database.Base.metadata.create_all(engine) - BaseModel.metadata.drop_all(engine) BaseModel.metadata.create_all(engine) alembic.command.stamp(alembic_cfg, "head") else: @@ -608,6 +608,8 @@ def init_database(connection_string: str, force: bool = False) -> sa.engine.Engi if force: # cleanup and create the schema BaseModel.metadata.drop_all(engine) + cacholote.database.Base.metadata.drop_all(engine) + cacholote.database.Base.metadata.create_all(engine) BaseModel.metadata.create_all(engine) alembic.command.stamp(alembic_cfg, "head") else: diff --git a/tests/test_02_database.py b/tests/test_02_database.py index e9bc54e0..424fedd0 100644 --- a/tests/test_02_database.py +++ b/tests/test_02_database.py @@ -731,8 +731,10 @@ def test_init_database(postgresql: Connection[str]) -> None: # verify create structure db.init_database(connection_string, force=True) - expected_tables_complete = set(db.BaseModel.metadata.tables).union( - {"alembic_version"} + expected_tables_complete = ( + set(db.BaseModel.metadata.tables) + .union({"alembic_version"}) + .union(set(cacholote.database.Base.metadata.tables)) ) assert set(conn.execute(query).scalars()) == expected_tables_complete # type: ignore @@ -779,8 +781,10 @@ def test_init_database_with_password(postgresql2: Connection[str]) -> None: # verify create structure db.init_database(connection_string, force=True) - expected_tables_complete = set(db.BaseModel.metadata.tables).union( - {"alembic_version"} + expected_tables_complete = ( + set(db.BaseModel.metadata.tables) + .union({"alembic_version"}) + .union(set(cacholote.database.Base.metadata.tables)) ) assert set(conn.execute(query).scalars()) == expected_tables_complete # type: ignore diff --git a/tests/test_90_entry_points.py b/tests/test_90_entry_points.py index fa0a52c0..5579fd88 100644 --- a/tests/test_90_entry_points.py +++ b/tests/test_90_entry_points.py @@ -1,5 +1,6 @@ from typing import Any +import cacholote import sqlalchemy as sa from psycopg import Connection from typer.testing import CliRunner @@ -40,5 +41,5 @@ def test_init_db(postgresql: Connection[str], mocker) -> None: ) assert set(conn.execute(query).scalars()) == set( database.BaseModel.metadata.tables - ).union({"alembic_version"}) + ).union({"alembic_version"}).union(set(cacholote.database.Base.metadata.tables)) conn.close()