diff --git a/docs/changelog.rst b/docs/changelog.rst index 16d40cb6..cec1df52 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -17,6 +17,10 @@ Compatibility * Added official support for Django 5.2 (`PR #1179 `__). * Dropped testing on MySQL’s MyISAM storage engine (`PR #1180 `__). +* Added fixtures :fixture:`django_assert_num_queries_all_connections` and + :fixture:`django_assert_max_num_queries_all_connections` to check all + your database connections at once. + Bugfixes ^^^^^^^^ diff --git a/docs/helpers.rst b/docs/helpers.rst index c9e189dd..7f08cb52 100644 --- a/docs/helpers.rst +++ b/docs/helpers.rst @@ -491,6 +491,75 @@ If you use type annotations, you can annotate the fixture like this:: ... +.. fixture:: django_assert_num_queries_all_connections + +``django_assert_num_queries_all_connections`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. py:function:: django_assert_num_queries_all_connections(num, info=None) + + :param num: expected number of queries + +This fixture allows to check for an expected number of DB queries on all +your database connections. + +If the assertion failed, the executed queries can be shown by using +the verbose command line option. + +It wraps ``django.test.utils.CaptureQueriesContext`` and yields the wrapped +``DjangoAssertNumAllConnectionsQueries`` instance. + +Example usage:: + + def test_queries(django_assert_num_queries_all_connections): + with django_assert_num_queries_all_connections(3) as captured: + Item.objects.using("default").create('foo') + Item.objects.using("logs").create('bar') + Item.objects.using("finance").create('baz') + + assert 'foo' in captured.captured_queries[0]['sql'] + +If you use type annotations, you can annotate the fixture like this:: + + from pytest_django import DjangoAssertNumAllConnectionsQueries + + def test_num_queries( + django_assert_num_queries: DjangoAssertNumAllConnectionsQueries, + ): + ... + + +.. fixture:: django_assert_max_num_queries_all_connections + +``django_assert_max_num_queries_all_connections`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. py:function:: django_assert_max_num_queries_all_connections(num, info=None) + + :param num: expected maximum number of queries + +This fixture allows to check for an expected maximum number of DB queries on all +your database connections. + +It is a specialized version of :fixture:`django_assert_num_queries_all_connections`. + +Example usage:: + + def test_max_queries(django_assert_max_num_queries_all_connections): + with django_assert_max_num_queries_all_connections(2): + Item.objects.using("logs").create('foo') + Item.objects.using("finance").create('bar') + +If you use type annotations, you can annotate the fixture like this:: + + from pytest_django import DjangoAssertNumAllConnectionsQueries + + def test_max_num_queries( + django_assert_max_num_queries_all_connections: DjangoAssertNumAllConnectionsQueries, + ): + ... + + .. fixture:: django_capture_on_commit_callbacks ``django_capture_on_commit_callbacks`` diff --git a/pyproject.toml b/pyproject.toml index 3fb403da..8d9d4585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,8 @@ skip_covered = true exclude_lines = [ "pragma: no cover", "if TYPE_CHECKING:", + "pass", + "...", ] [tool.ruff] diff --git a/pytest_django/__init__.py b/pytest_django/__init__.py index e4bb08f5..16008922 100644 --- a/pytest_django/__init__.py +++ b/pytest_django/__init__.py @@ -5,11 +5,16 @@ __version__ = "unknown" -from .fixtures import DjangoAssertNumQueries, DjangoCaptureOnCommitCallbacks +from .fixtures import ( + DjangoAssertNumAllConnectionsQueries, + DjangoAssertNumQueries, + DjangoCaptureOnCommitCallbacks, +) from .plugin import DjangoDbBlocker __all__ = [ + "DjangoAssertNumAllConnectionsQueries", "DjangoAssertNumQueries", "DjangoCaptureOnCommitCallbacks", "DjangoDbBlocker", diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index 6dc05fdb..ed648261 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from collections.abc import Sized from contextlib import contextmanager from functools import partial from typing import ( @@ -11,15 +12,19 @@ Any, Callable, ContextManager, + Dict, Generator, Iterable, + Iterator, List, Literal, Optional, Protocol, Sequence, Tuple, + TypeVar, Union, + runtime_checkable, ) import pytest @@ -51,7 +56,9 @@ "client", "db", "django_assert_max_num_queries", + "django_assert_max_num_queries_all_connections", "django_assert_num_queries", + "django_assert_num_queries_all_connections", "django_capture_on_commit_callbacks", "django_db_reset_sequences", "django_db_serialized_rollback", @@ -65,6 +72,19 @@ ] +@runtime_checkable +class QueryCaptureContextProtocol(Protocol, Sized): + @property + def captured_queries(self) -> List[Dict[str, Any]]: ... + + def __enter__(self) -> QueryCaptureContextProtocol: ... + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: ... + + +_QueriesContext = TypeVar("_QueriesContext", bound=QueryCaptureContextProtocol) + + @pytest.fixture(scope="session") def django_db_modify_db_settings_tox_suffix() -> None: skip_if_no_django() @@ -654,6 +674,43 @@ def _live_server_helper(request: pytest.FixtureRequest) -> Generator[None, None, live_server._live_server_modified_settings.disable() +class CaptureAllConnectionsQueriesContext: + """ + Context manager that captures all queries executed by Django ORM across all Databases in settings.DATABASES. + """ + + def __init__(self) -> None: + from django.db import connections + from django.test.utils import CaptureQueriesContext + + self.contexts = {alias: CaptureQueriesContext(connections[alias]) for alias in connections} + + def __iter__(self) -> Iterable[dict[str, Any]]: + return iter(self.captured_queries) + + def __getitem__(self, index: int) -> dict[str, Any]: + return self.captured_queries[index] + + def __len__(self) -> int: + return len(self.captured_queries) + + @property + def captured_queries(self) -> list[dict[str, Any]]: + queries = [] + for context in self.contexts.values(): + queries.extend(context.captured_queries) + return queries + + def __enter__(self): + for context in self.contexts.values(): + context.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + for context in self.contexts.values(): + context.__exit__(exc_type, exc_val, exc_tb) + + class DjangoAssertNumQueries(Protocol): """The type of the `django_assert_num_queries` and `django_assert_max_num_queries` fixtures.""" @@ -665,8 +722,18 @@ def __call__( info: str | None = ..., *, using: str | None = ..., - ) -> django.test.utils.CaptureQueriesContext: - pass # pragma: no cover + ) -> ContextManager[django.test.utils.CaptureQueriesContext]: ... + + +class DjangoAssertNumAllConnectionsQueries(Protocol): + """The type of the `django_assert_num_queries_all_connections` and + `django_assert_max_num_queries_all_connections` fixtures.""" + + def __call__( + self, + num: int, + info: str | None = ..., + ) -> ContextManager[CaptureAllConnectionsQueriesContext]: ... @contextmanager @@ -692,8 +759,37 @@ def _assert_num_queries( else: conn = default_conn - verbose = config.getoption("verbose") > 0 with CaptureQueriesContext(conn) as context: + yield from _assert_num_queries_context( + config=config, context=context, num=num, exact=exact, info=info + ) + + +@contextmanager +def _assert_num_queries_all_db( + config, + num: int, + exact: bool = True, + info: str | None = None, +) -> Generator[CaptureAllConnectionsQueriesContext, None, None]: + """A recreation of pytest-django's assert_num_queries that works with all databases in settings.Databases.""" + + with CaptureAllConnectionsQueriesContext() as context: + yield from _assert_num_queries_context( + config=config, context=context, num=num, exact=exact, info=info + ) + + +def _assert_num_queries_context( + *, + config: pytest.Config, + context: _QueriesContext, + num: int, + exact: bool = True, + info: str | None = None, +) -> Iterator[_QueriesContext]: + verbose = config.getoption("verbose") > 0 + with context: yield context num_performed = len(context) if exact: @@ -728,6 +824,22 @@ def django_assert_max_num_queries(pytestconfig: pytest.Config) -> DjangoAssertNu return partial(_assert_num_queries, pytestconfig, exact=False) +@pytest.fixture(scope="function") +def django_assert_num_queries_all_connections( + pytestconfig: pytest.Config, +) -> DjangoAssertNumAllConnectionsQueries: + """Asserts that the number of queries executed by Django ORM across all connections in settings.DATABASES is equal to the given number.""" + return partial(_assert_num_queries_all_db, pytestconfig) + + +@pytest.fixture(scope="function") +def django_assert_max_num_queries_all_connections( + pytestconfig: pytest.Config, +) -> DjangoAssertNumAllConnectionsQueries: + """Asserts that the number of queries executed by Django ORM across all connections in settings.DATABASES is less than or equal to the given number.""" + return partial(_assert_num_queries_all_db, pytestconfig, exact=False) + + class DjangoCaptureOnCommitCallbacks(Protocol): """The type of the `django_capture_on_commit_callbacks` fixture.""" diff --git a/pytest_django/plugin.py b/pytest_django/plugin.py index e8e629f4..1680773d 100644 --- a/pytest_django/plugin.py +++ b/pytest_django/plugin.py @@ -28,7 +28,9 @@ client, # noqa: F401 db, # noqa: F401 django_assert_max_num_queries, # noqa: F401 + django_assert_max_num_queries_all_connections, # noqa: F401 django_assert_num_queries, # noqa: F401 + django_assert_num_queries_all_connections, # noqa: F401 django_capture_on_commit_callbacks, # noqa: F401 django_db_createdb, # noqa: F401 django_db_keepdb, # noqa: F401 diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index f88ed802..df9e4d10 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -20,7 +20,12 @@ from .helpers import DjangoPytester -from pytest_django import DjangoAssertNumQueries, DjangoCaptureOnCommitCallbacks, DjangoDbBlocker +from pytest_django import ( + DjangoAssertNumAllConnectionsQueries, + DjangoAssertNumQueries, + DjangoCaptureOnCommitCallbacks, + DjangoDbBlocker, +) from pytest_django_test.app.models import Item @@ -259,6 +264,40 @@ def test_queries(django_assert_num_queries): assert result.ret == 1 +@pytest.mark.django_db(databases=["default", "replica", "second"]) +def test_django_assert_num_queries_all_connections( + django_assert_num_queries_all_connections: DjangoAssertNumAllConnectionsQueries, +) -> None: + with django_assert_num_queries_all_connections(3): + Item.objects.count() + Item.objects.using("replica").count() + Item.objects.using("second").count() + + +@pytest.mark.django_db(databases=["default", "replica", "second"]) +def test_django_assert_max_num_queries_all_connections( + request: pytest.FixtureRequest, + django_assert_max_num_queries_all_connections: DjangoAssertNumAllConnectionsQueries, +) -> None: + with nonverbose_config(request.config): + with django_assert_max_num_queries_all_connections(2): + Item.objects.create(name="1-foo") + Item.objects.using("second").create(name="2-bar") + + with pytest.raises(pytest.fail.Exception) as excinfo: # noqa: PT012 + with django_assert_max_num_queries_all_connections(2) as captured: + Item.objects.create(name="1-foo") + Item.objects.create(name="2-bar") + Item.objects.using("second").create(name="3-quux") + + assert excinfo.value.args == ( + "Expected to perform 2 queries or less but 3 were done " + "(add -v option to show queries)", + ) + assert len(captured.captured_queries) == 3 + assert "1-foo" in captured.captured_queries[0]["sql"] + + @pytest.mark.django_db def test_django_capture_on_commit_callbacks( django_capture_on_commit_callbacks: DjangoCaptureOnCommitCallbacks,