diff --git a/conftest.py b/conftest.py index 00fcc072e..f6c4636af 100644 --- a/conftest.py +++ b/conftest.py @@ -39,7 +39,7 @@ def pytest_itemcollected(item): @pytest.fixture(scope="session") -def engine(request, sqlalchemy_connect_url, app_config): +def engine(request, sqlalchemy_db, sqlalchemy_connect_url, app_config): """Engine configuration. See http://docs.sqlalchemy.org/en/latest/core/engines.html for more details. @@ -66,44 +66,11 @@ def engine(request, sqlalchemy_connect_url, app_config): engine.url.database = "{}_{}".format(engine.url.database, xdist_suffix) engine = create_engine(engine.url) # override engine - def fin(): - print("Disposing engine") - engine.dispose() - - request.addfinalizer(fin) - return engine - - -@pytest.fixture(scope="session") -@pytest.mark.django_db -def db( - request: pytest.FixtureRequest, - engine, - sqlalchemy_connect_url, - django_db_blocker, -): - # Bootstrap the DB by running the Django bootstrap version. - from django.conf import settings - from django.test.utils import setup_databases, teardown_databases - - with django_db_blocker.unblock(): - # Temporarily reset the database to the SQLAlchemy DBs - original_name = settings.DATABASES["default"]["NAME"] - original_test_name = settings.DATABASES["default"]["TEST"]["NAME"] - settings.DATABASES["default"]["NAME"] = "test_postgres_sqlalchemy" - settings.DATABASES["default"]["TEST"]["NAME"] = "test_postgres_sqlalchemy" - db_cfg = setup_databases( - verbosity=request.config.option.verbose, - interactive=False, - ) - settings.DATABASES["default"]["NAME"] = original_name - settings.DATABASES["default"]["TEST"]["NAME"] = original_test_name - + # Check that the DB exist and migrate the unmigrated SQLALchemy models as a stop-gap database_url = sqlalchemy_connect_url if not database_exists(database_url): raise RuntimeError(f"SQLAlchemy cannot connect to DB at {database_url}") - # Create the unmigrated models as a stopgap Base.metadata.tables["profiling_profilingcommit"].create(bind=engine) Base.metadata.tables["profiling_profilingupload"].create(bind=engine) Base.metadata.tables["timeseries_measurement"].create(bind=engine) @@ -126,10 +93,44 @@ def db( bind=engine ) + yield engine + + print("Disposing engine") + engine.dispose() + + +@pytest.fixture(scope="session") +@pytest.mark.django_db +def sqlalchemy_db(request: pytest.FixtureRequest, django_db_blocker): + # Bootstrap the DB by running the Django bootstrap version. + from django.conf import settings + from django.test.utils import setup_databases, teardown_databases + + with django_db_blocker.unblock(): + # Temporarily reset the database to the SQLAlchemy DBs + original_name, original_test_name = ( + settings.DATABASES["default"]["NAME"], + settings.DATABASES["default"]["TEST"]["NAME"], + ) + settings.DATABASES["default"]["NAME"] = "test_postgres_sqlalchemy" + settings.DATABASES["default"]["TEST"]["NAME"] = "test_postgres_sqlalchemy" + db_cfg = setup_databases( + verbosity=request.config.option.verbose, + interactive=False, + ) + settings.DATABASES["default"]["NAME"] = original_name + settings.DATABASES["default"]["TEST"]["NAME"] = original_test_name + yield + # Cleanup with Django version as well try: - teardown_databases(db_cfg, verbosity=request.config.option.verbose) + with django_db_blocker.unblock(): + settings.DATABASES["default"]["NAME"] = "test_postgres_sqlalchemy" + settings.DATABASES["default"]["TEST"]["NAME"] = "test_postgres_sqlalchemy" + teardown_databases(db_cfg, verbosity=request.config.option.verbose) + settings.DATABASES["default"]["NAME"] = original_name + settings.DATABASES["default"]["TEST"]["NAME"] = original_test_name except Exception as exc: # noqa: BLE001 request.node.warn( pytest.PytestWarning( @@ -139,7 +140,7 @@ def db( @pytest.fixture -def dbsession(db, engine): +def dbsession(sqlalchemy_db, engine): """Sets up the SQLAlchemy dbsession.""" connection = engine.connect() diff --git a/database/tests/unit/test_model_utils.py b/database/tests/unit/test_model_utils.py index f4fb67a5a..9105ef10b 100644 --- a/database/tests/unit/test_model_utils.py +++ b/database/tests/unit/test_model_utils.py @@ -56,14 +56,14 @@ def test_subclass_validation(self, mocker): self.ClassWithArchiveFieldMissingMethods, ArchiveFieldInterface ) - def test_archive_getter_db_field_set(self, db): + def test_archive_getter_db_field_set(self, sqlalchemy_db): commit = CommitFactory() test_class = self.ClassWithArchiveField(commit, "db_value", "gcs_path") assert test_class._archive_field == "db_value" assert test_class._archive_field_storage_path == "gcs_path" assert test_class.archive_field == "db_value" - def test_archive_getter_archive_field_set(self, db, mocker): + def test_archive_getter_archive_field_set(self, sqlalchemy_db, mocker): some_json = {"some": "data"} mock_read_file = mocker.MagicMock(return_value=json.dumps(some_json)) mock_archive_service = mocker.patch("database.utils.ArchiveService") @@ -81,7 +81,7 @@ def test_archive_getter_archive_field_set(self, db, mocker): assert test_class.archive_field == some_json assert mock_read_file.call_count == 1 - def test_archive_getter_file_not_in_storage(self, db, mocker): + def test_archive_getter_file_not_in_storage(self, sqlalchemy_db, mocker): mocker.patch( "database.utils.ArchiveField.read_timeout", new_callable=PropertyMock, @@ -99,7 +99,7 @@ def test_archive_getter_file_not_in_storage(self, db, mocker): mock_read_file.assert_called_with("gcs_path") mock_archive_service.assert_called_with(repository=commit.repository) - def test_archive_setter_db_field(self, db, mocker): + def test_archive_setter_db_field(self, sqlalchemy_db, mocker): commit = CommitFactory() test_class = self.ClassWithArchiveField(commit, "db_value", "gcs_path", False) assert test_class._archive_field == "db_value" @@ -111,7 +111,7 @@ def test_archive_setter_db_field(self, db, mocker): assert test_class._archive_field == "batata frita" assert test_class.archive_field == "batata frita" - def test_archive_setter_archive_field(self, db, mocker): + def test_archive_setter_archive_field(self, sqlalchemy_db, mocker): commit = CommitFactory() test_class = self.ClassWithArchiveField(commit, "db_value", None, True) some_json = {"some": "data"} diff --git a/django_scaffold/tests_settings.py b/django_scaffold/tests_settings.py index ba21eccb8..c97e47811 100644 --- a/django_scaffold/tests_settings.py +++ b/django_scaffold/tests_settings.py @@ -22,3 +22,5 @@ }, } ] + +DATABASES["default"]["TEST"] = {"NAME": "test_postgres_django"}