diff --git a/docs/contributing.rst b/docs/contributing.rst index 58282ae8..b7fd3108 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -13,7 +13,7 @@ We require features to be backed by a unit test. This way, we can test *django-polymorphic* against new Django versions. To run the included test suite, execute:: - ./runtests.py + py.test To test support for multiple Python and Django versions, run tox from the repository root:: diff --git a/manage.py b/manage.py new file mode 100644 index 00000000..f7785401 --- /dev/null +++ b/manage.py @@ -0,0 +1,9 @@ +# This helps pytest-django locate the project. +import os +import sys + +if __name__ == "__main__": + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "polymorphic_test_settings") + from django.core.management import execute_from_command_line + + execute_from_command_line(sys.argv) diff --git a/polymorphic/tests/__init__.py b/polymorphic/tests/__init__.py index e69de29b..3d112022 100644 --- a/polymorphic/tests/__init__.py +++ b/polymorphic/tests/__init__.py @@ -0,0 +1,6 @@ +import django + +if django.VERSION < (4, 2): # TODO: remove when dropping support for Django < 4.2 + from django.test.testcases import TransactionTestCase + + TransactionTestCase.assertQuerySetEqual = TransactionTestCase.assertQuerysetEqual diff --git a/polymorphic/tests/admintestcase.py b/polymorphic/tests/admintestcase.py index 116b756c..ac872598 100644 --- a/polymorphic/tests/admintestcase.py +++ b/polymorphic/tests/admintestcase.py @@ -94,7 +94,7 @@ def admin_get_add(self, model, qs=""): admin_instance = self.get_admin_instance(model) request = self.create_admin_request("get", self.get_add_url(model) + qs) response = admin_instance.add_view(request) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 return response def admin_post_add(self, model, formdata, qs=""): @@ -114,7 +114,7 @@ def admin_get_changelist(self, model): admin_instance = self.get_admin_instance(model) request = self.create_admin_request("get", self.get_changelist_url(model)) response = admin_instance.changelist_view(request) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 return response def admin_get_change(self, model, object_id, query=None, **extra): @@ -126,7 +126,7 @@ def admin_get_change(self, model, object_id, query=None, **extra): "get", self.get_change_url(model, object_id), data=query, **extra ) response = admin_instance.change_view(request, str(object_id)) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 return response def admin_post_change(self, model, object_id, formdata, **extra): @@ -150,7 +150,7 @@ def admin_get_history(self, model, object_id, query=None, **extra): "get", self.get_history_url(model, object_id), data=query, **extra ) response = admin_instance.history_view(request, str(object_id)) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 return response def admin_get_delete(self, model, object_id, query=None, **extra): @@ -162,7 +162,7 @@ def admin_get_delete(self, model, object_id, query=None, **extra): "get", self.get_delete_url(model, object_id), data=query, **extra ) response = admin_instance.delete_view(request, str(object_id)) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 return response def admin_post_delete(self, model, object_id, **extra): @@ -175,7 +175,7 @@ def admin_post_delete(self, model, object_id, **extra): admin_instance = self.get_admin_instance(model) request = self.create_admin_request("post", self.get_delete_url(model, object_id), **extra) response = admin_instance.delete_view(request, str(object_id)) - self.assertEqual(response.status_code, 302, f"Form errors in calling {request.path}") + assert response.status_code == 302, f"Form errors in calling {request.path}" return response def create_admin_request(self, method, url, data=None, **extra): @@ -209,7 +209,7 @@ def assertFormSuccess(self, request_url, response): """ Assert that the response was a redirect, not a form error. """ - self.assertIn(response.status_code, [200, 302]) + assert response.status_code in [200, 302] if response.status_code != 302: context_data = response.context_data if "errors" in context_data: @@ -219,12 +219,9 @@ def assertFormSuccess(self, request_url, response): else: raise KeyError("Unknown field for errors in the TemplateResponse!") - self.assertEqual( - response.status_code, - 302, - "Form errors in calling {}:\n{}".format(request_url, errors.as_text()), + assert response.status_code == 302, "Form errors in calling {}:\n{}".format( + request_url, errors.as_text() ) - self.assertTrue( - "/login/?next=" not in response["Location"], - f"Received login response for {request_url}", - ) + assert ( + "/login/?next=" not in response["Location"] + ), f"Received login response for {request_url}" diff --git a/polymorphic/tests/test_admin.py b/polymorphic/tests/test_admin.py index 222ee63b..4b1c3e56 100644 --- a/polymorphic/tests/test_admin.py +++ b/polymorphic/tests/test_admin.py @@ -1,3 +1,4 @@ +import pytest from django.contrib import admin from django.contrib.contenttypes.models import ContentType from django.utils.html import escape @@ -53,9 +54,9 @@ class Model2ChildAdmin(PolymorphicChildModelAdmin): ) d_obj = Model2A.objects.all()[0] - self.assertEqual(d_obj.__class__, Model2D) - self.assertEqual(d_obj.field1, "A") - self.assertEqual(d_obj.field2, "B") + assert d_obj.__class__ == Model2D + assert d_obj.field1 == "A" + assert d_obj.field2 == "B" # -- list page self.admin_get_changelist(Model2A) # asserts 200 @@ -70,10 +71,10 @@ class Model2ChildAdmin(PolymorphicChildModelAdmin): ) d_obj.refresh_from_db() - self.assertEqual(d_obj.field1, "A2") - self.assertEqual(d_obj.field2, "B2") - self.assertEqual(d_obj.field3, "C2") - self.assertEqual(d_obj.field4, "D2") + assert d_obj.field1 == "A2" + assert d_obj.field2 == "B2" + assert d_obj.field3 == "C2" + assert d_obj.field4 == "D2" # -- history self.admin_get_history(Model2A, d_obj.pk) @@ -81,7 +82,7 @@ class Model2ChildAdmin(PolymorphicChildModelAdmin): # -- delete self.admin_get_delete(Model2A, d_obj.pk) self.admin_post_delete(Model2A, d_obj.pk) - self.assertRaises(Model2A.DoesNotExist, lambda: d_obj.refresh_from_db()) + pytest.raises(Model2A.DoesNotExist, (lambda: d_obj.refresh_from_db())) def test_admin_inlines(self): """ @@ -103,7 +104,7 @@ class InlineParentAdmin(PolymorphicInlineSupportMixin, admin.ModelAdmin): inlines = (Inline,) parent = InlineParent.objects.create(title="FOO") - self.assertEqual(parent.inline_children.count(), 0) + assert parent.inline_children.count() == 0 # -- get edit page response = self.admin_get_change(InlineParent, parent.pk) @@ -133,9 +134,9 @@ class InlineParentAdmin(PolymorphicInlineSupportMixin, admin.ModelAdmin): ) parent.refresh_from_db() - self.assertEqual(parent.title, "FOO2") - self.assertEqual(parent.inline_children.count(), 1) + assert parent.title == "FOO2" + assert parent.inline_children.count() == 1 child = parent.inline_children.all()[0] - self.assertEqual(child.__class__, InlineModelB) - self.assertEqual(child.field1, "A2") - self.assertEqual(child.field2, "B2") + assert child.__class__ == InlineModelB + assert child.field1 == "A2" + assert child.field2 == "B2" diff --git a/polymorphic/tests/test_contrib.py b/polymorphic/tests/test_contrib.py index 6cb1f520..1e153df3 100644 --- a/polymorphic/tests/test_contrib.py +++ b/polymorphic/tests/test_contrib.py @@ -1,4 +1,4 @@ -from unittest import TestCase +from django.test import TestCase from polymorphic.contrib.guardian import get_polymorphic_base_content_type from polymorphic.tests.models import Model2D, PlainC @@ -13,15 +13,15 @@ def test_contrib_guardian(self): # Regular Django inheritance should return the child model content type. obj = PlainC() ctype = get_polymorphic_base_content_type(obj) - self.assertEqual(ctype.name, "plain c") + assert ctype.name == "plain c" ctype = get_polymorphic_base_content_type(PlainC) - self.assertEqual(ctype.name, "plain c") + assert ctype.name == "plain c" # Polymorphic inheritance should return the parent model content type. obj = Model2D() ctype = get_polymorphic_base_content_type(obj) - self.assertEqual(ctype.name, "model2a") + assert ctype.name == "model2a" ctype = get_polymorphic_base_content_type(Model2D) - self.assertEqual(ctype.name, "model2a") + assert ctype.name == "model2a" diff --git a/polymorphic/tests/test_multidb.py b/polymorphic/tests/test_multidb.py index 4f835f84..192bb778 100644 --- a/polymorphic/tests/test_multidb.py +++ b/polymorphic/tests/test_multidb.py @@ -26,13 +26,13 @@ def test_save_to_non_default_database(self): Model2B.objects.create(field1="B1", field2="B2") Model2D(field1="D1", field2="D2", field3="D3", field4="D4").save() - self.assertQuerysetEqual( + self.assertQuerySetEqual( Model2A.objects.order_by("id"), [Model2B, Model2D], transform=lambda o: o.__class__, ) - self.assertQuerysetEqual( + self.assertQuerySetEqual( Model2A.objects.db_manager("secondary").order_by("id"), [Model2A, Model2C], transform=lambda o: o.__class__, @@ -44,26 +44,26 @@ def test_instance_of_filter_on_non_default_database(self): ModelY.objects.db_manager("secondary").create(field_b="Y", field_y="Y") objects = Base.objects.db_manager("secondary").filter(instance_of=Base) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Base, ModelX, ModelY], transform=lambda o: o.__class__, ordered=False, ) - self.assertQuerysetEqual( + self.assertQuerySetEqual( Base.objects.db_manager("secondary").filter(instance_of=ModelX), [ModelX], transform=lambda o: o.__class__, ) - self.assertQuerysetEqual( + self.assertQuerySetEqual( Base.objects.db_manager("secondary").filter(instance_of=ModelY), [ModelY], transform=lambda o: o.__class__, ) - self.assertQuerysetEqual( + self.assertQuerySetEqual( Base.objects.db_manager("secondary").filter( Q(instance_of=ModelX) | Q(instance_of=ModelY) ), @@ -78,7 +78,7 @@ def func(): entry = BlogEntry.objects.db_manager("secondary").create(blog=blog, text="Text") ContentType.objects.clear_cache() entry = BlogEntry.objects.db_manager("secondary").get(pk=entry.id) - self.assertEqual(blog, entry.blog) + assert blog == entry.blog # Ensure no queries are made using the default database. self.assertNumQueries(0, func) @@ -89,7 +89,7 @@ def func(): entry = BlogEntry.objects.db_manager("secondary").create(blog=blog, text="Text") ContentType.objects.clear_cache() blog = BlogA.objects.db_manager("secondary").get(pk=blog.id) - self.assertEqual(entry, blog.blogentry_set.using("secondary").get()) + assert entry == blog.blogentry_set.using("secondary").get() # Ensure no queries are made using the default database. self.assertNumQueries(0, func) @@ -102,7 +102,7 @@ def func(): ) ContentType.objects.clear_cache() m2a = Model2A.objects.db_manager("secondary").get(pk=m2a.id) - self.assertEqual(one2one, m2a.one2onerelatingmodel) + assert one2one == m2a.one2onerelatingmodel # Ensure no queries are made using the default database. self.assertNumQueries(0, func) @@ -114,7 +114,7 @@ def func(): rm.many2many.add(m2a) ContentType.objects.clear_cache() m2a = Model2A.objects.db_manager("secondary").get(pk=m2a.id) - self.assertEqual(rm, m2a.relatingmodel_set.using("secondary").get()) + assert rm == m2a.relatingmodel_set.using("secondary").get() # Ensure no queries are made using the default database. self.assertNumQueries(0, func) diff --git a/polymorphic/tests/test_orm.py b/polymorphic/tests/test_orm.py index 0bec30a4..61f7d242 100644 --- a/polymorphic/tests/test_orm.py +++ b/polymorphic/tests/test_orm.py @@ -1,3 +1,4 @@ +import pytest import re import uuid @@ -107,16 +108,16 @@ def test_annotate_aggregate_order(self): BlogB.objects.create(name="Bb3") qs = BlogBase.objects.annotate(entrycount=Count("BlogA___blogentry")) - self.assertEqual(len(qs), 4) + assert len(qs) == 4 for o in qs: if o.name == "B1": - self.assertEqual(o.entrycount, 2) + assert o.entrycount == 2 else: - self.assertEqual(o.entrycount, 0) + assert o.entrycount == 0 x = BlogBase.objects.aggregate(entrycount=Count("BlogA___blogentry")) - self.assertEqual(x["entrycount"], 2) + assert x["entrycount"] == 2 # create some more blogs for next test BlogA.objects.create(name="B2", info="i2") @@ -135,7 +136,7 @@ def test_annotate_aggregate_order(self): , ]""" x = "\n" + repr(BlogBase.objects.order_by("-name")) - self.assertEqual(x, expected) + assert x == expected # test ordering for field in one subclass only # MySQL and SQLite return this order @@ -161,7 +162,7 @@ def test_annotate_aggregate_order(self): ]""" x = "\n" + repr(BlogBase.objects.order_by("-BlogA___info")) - self.assertTrue(x == expected1 or x == expected2) + assert (x == expected1) or (x == expected2) def test_limit_choices_to(self): """ @@ -187,23 +188,23 @@ def test_primary_key_custom_field_problem(self): a = qs[0] b = qs[1] c = qs[2] - self.assertEqual(len(qs), 3) - self.assertIsInstance(a.uuid_primary_key, uuid.UUID) - self.assertIsInstance(a.pk, uuid.UUID) + assert len(qs) == 3 + assert isinstance(a.uuid_primary_key, uuid.UUID) + assert isinstance(a.pk, uuid.UUID) res = re.sub(' "(.*?)..", topic', ", topic", repr(qs)) res_exp = """[ , , ]""" - self.assertEqual(res, res_exp) + assert res == res_exp a = UUIDPlainA.objects.create(field1="A1") b = UUIDPlainB.objects.create(field1="B1", field2="B2") c = UUIDPlainC.objects.create(field1="C1", field2="C2", field3="C3") qs = UUIDPlainA.objects.all() # Test that primary key values are valid UUIDs - self.assertEqual(uuid.UUID("urn:uuid:%s" % a.pk, version=1), a.pk) - self.assertEqual(uuid.UUID("urn:uuid:%s" % c.pk, version=1), c.pk) + assert uuid.UUID(("urn:uuid:%s" % a.pk), version=1) == a.pk + assert uuid.UUID(("urn:uuid:%s" % c.pk), version=1) == c.pk def create_model2abcd(self): """ @@ -221,7 +222,7 @@ def test_simple_inheritance(self): self.create_model2abcd() objects = Model2A.objects.all() - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2A, Model2B, Model2C, Model2D], transform=lambda o: o.__class__, @@ -233,78 +234,68 @@ def test_defer_fields(self): objects_deferred = Model2A.objects.defer("field1").order_by("id") - self.assertNotIn( - "field1", - objects_deferred[0].__dict__, - "field1 was not deferred (using defer())", - ) + assert ( + "field1" not in objects_deferred[0].__dict__ + ), "field1 was not deferred (using defer())" # Check that we have exactly one deferred field ('field1') per resulting object. for obj in objects_deferred: deferred_fields = obj.get_deferred_fields() - self.assertEqual(1, len(deferred_fields)) - self.assertIn("field1", deferred_fields) + assert len(deferred_fields) == 1 + assert "field1" in deferred_fields objects_only = Model2A.objects.only("pk", "polymorphic_ctype", "field1") - self.assertIn( - "field1", - objects_only[0].__dict__, - 'qs.only("field1") was used, but field1 was incorrectly deferred', - ) - self.assertIn( - "field1", - objects_only[3].__dict__, - 'qs.only("field1") was used, but field1 was incorrectly deferred' " on a child model", - ) - self.assertNotIn( - "field4", objects_only[3].__dict__, "field4 was not deferred (using only())" - ) - self.assertNotIn("field1", objects_only[0].get_deferred_fields()) + assert ( + "field1" in objects_only[0].__dict__ + ), 'qs.only("field1") was used, but field1 was incorrectly deferred' + assert ( + "field1" in objects_only[3].__dict__ + ), 'qs.only("field1") was used, but field1 was incorrectly deferred on a child model' + assert "field4" not in objects_only[3].__dict__, "field4 was not deferred (using only())" + assert "field1" not in objects_only[0].get_deferred_fields() - self.assertIn("field2", objects_only[1].get_deferred_fields()) + assert "field2" in objects_only[1].get_deferred_fields() # objects_only[2] has several deferred fields, ensure they are all set as such. model2c_deferred = objects_only[2].get_deferred_fields() - self.assertIn("field2", model2c_deferred) - self.assertIn("field3", model2c_deferred) - self.assertIn("model2a_ptr_id", model2c_deferred) + assert "field2" in model2c_deferred + assert "field3" in model2c_deferred + assert "model2a_ptr_id" in model2c_deferred # objects_only[3] has a few more fields that should be set as deferred. model2d_deferred = objects_only[3].get_deferred_fields() - self.assertIn("field2", model2d_deferred) - self.assertIn("field3", model2d_deferred) - self.assertIn("field4", model2d_deferred) - self.assertIn("model2a_ptr_id", model2d_deferred) - self.assertIn("model2b_ptr_id", model2d_deferred) + assert "field2" in model2d_deferred + assert "field3" in model2d_deferred + assert "field4" in model2d_deferred + assert "model2a_ptr_id" in model2d_deferred + assert "model2b_ptr_id" in model2d_deferred ModelX.objects.create(field_b="A1", field_x="A2") ModelY.objects.create(field_b="B1", field_y="B2") # If we defer a field on a descendent, the parent's field is not deferred. objects_deferred = Base.objects.defer("ModelY___field_y") - self.assertNotIn("field_y", objects_deferred[0].get_deferred_fields()) - self.assertIn("field_y", objects_deferred[1].get_deferred_fields()) + assert "field_y" not in objects_deferred[0].get_deferred_fields() + assert "field_y" in objects_deferred[1].get_deferred_fields() objects_only = Base.objects.only( "polymorphic_ctype", "ModelY___field_y", "ModelX___field_x" ) - self.assertIn("field_b", objects_only[0].get_deferred_fields()) - self.assertIn("field_b", objects_only[1].get_deferred_fields()) + assert "field_b" in objects_only[0].get_deferred_fields() + assert "field_b" in objects_only[1].get_deferred_fields() def test_defer_related_fields(self): self.create_model2abcd() objects_deferred_field4 = Model2A.objects.defer("Model2D___field4") - self.assertNotIn( - "field4", - objects_deferred_field4[3].__dict__, - "field4 was not deferred (using defer(), traversing inheritance)", - ) - self.assertEqual(objects_deferred_field4[0].__class__, Model2A) - self.assertEqual(objects_deferred_field4[1].__class__, Model2B) - self.assertEqual(objects_deferred_field4[2].__class__, Model2C) - self.assertEqual(objects_deferred_field4[3].__class__, Model2D) + assert ( + "field4" not in objects_deferred_field4[3].__dict__ + ), "field4 was not deferred (using defer(), traversing inheritance)" + assert objects_deferred_field4[0].__class__ == Model2A + assert objects_deferred_field4[1].__class__ == Model2B + assert objects_deferred_field4[2].__class__ == Model2C + assert objects_deferred_field4[3].__class__ == Model2D objects_only_field4 = Model2A.objects.only( "polymorphic_ctype", @@ -318,22 +309,22 @@ def test_defer_related_fields(self): "Model2D___id", "Model2D___model2c_ptr", ) - self.assertEqual(objects_only_field4[0].__class__, Model2A) - self.assertEqual(objects_only_field4[1].__class__, Model2B) - self.assertEqual(objects_only_field4[2].__class__, Model2C) - self.assertEqual(objects_only_field4[3].__class__, Model2D) + assert objects_only_field4[0].__class__ == Model2A + assert objects_only_field4[1].__class__ == Model2B + assert objects_only_field4[2].__class__ == Model2C + assert objects_only_field4[3].__class__ == Model2D def test_manual_get_real_instance(self): self.create_model2abcd() o = Model2A.objects.non_polymorphic().get(field1="C1") - self.assertEqual(o.get_real_instance().__class__, Model2C) + assert o.get_real_instance().__class__ == Model2C def test_non_polymorphic(self): self.create_model2abcd() objects = list(Model2A.objects.all().non_polymorphic()) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2A, Model2A, Model2A, Model2A], transform=lambda o: o.__class__, @@ -345,7 +336,7 @@ def test_get_real_instances(self): # from queryset objects = qs.get_real_instances() - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2A, Model2B, Model2C, Model2D], transform=lambda o: o.__class__, @@ -353,7 +344,7 @@ def test_get_real_instances(self): # from a manual list objects = Model2A.objects.get_real_instances(list(qs)) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2A, Model2B, Model2C, Model2D], transform=lambda o: o.__class__, @@ -361,7 +352,7 @@ def test_get_real_instances(self): # from empty list objects = Model2A.objects.get_real_instances([]) - self.assertQuerysetEqual(objects, [], transform=lambda o: o.__class__) + self.assertQuerySetEqual(objects, [], transform=lambda o: o.__class__) def test_queryset_missing_derived(self): a = Model2A.objects.create(field1="A1") @@ -375,8 +366,8 @@ def test_queryset_missing_derived(self): qs_base = Model2A.objects.order_by("field1").non_polymorphic() qs_polymorphic = Model2A.objects.order_by("field1").all() - self.assertEqual(list(qs_base), [a, b_base, c_base]) - self.assertEqual(list(qs_polymorphic), [a, c]) + assert list(qs_base) == [a, b_base, c_base] + assert list(qs_polymorphic) == [a, c] def test_queryset_missing_contenttype(self): stale_ct = ContentType.objects.create(app_label="tests", model="nonexisting") @@ -390,15 +381,15 @@ def test_queryset_missing_contenttype(self): qs_base = Model2A.objects.order_by("field1").non_polymorphic() qs_polymorphic = Model2A.objects.order_by("field1").all() - self.assertEqual(list(qs_base), [a1, a2, c_base]) - self.assertEqual(list(qs_polymorphic), [a1, a2, c]) + assert list(qs_base) == [a1, a2, c_base] + assert list(qs_polymorphic) == [a1, a2, c] def test_translate_polymorphic_q_object(self): self.create_model2abcd() q = Model2A.translate_polymorphic_Q_object(Q(instance_of=Model2C)) objects = Model2A.objects.filter(q) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2C, Model2D], transform=lambda o: o.__class__, ordered=False ) @@ -407,27 +398,24 @@ def test_create_instanceof_q(self): expected = sorted( ContentType.objects.get_for_model(m).pk for m in [Model2B, Model2C, Model2D] ) - self.assertEqual(dict(q.children), dict(polymorphic_ctype__in=expected)) + assert dict(q.children) == dict(polymorphic_ctype__in=expected) def test_base_manager(self): def base_manager(model): return (type(model._base_manager), model._base_manager.model) - self.assertEqual(base_manager(PlainA), (models.Manager, PlainA)) - self.assertEqual(base_manager(PlainB), (models.Manager, PlainB)) - self.assertEqual(base_manager(PlainC), (models.Manager, PlainC)) + assert base_manager(PlainA) == (models.Manager, PlainA) + assert base_manager(PlainB) == (models.Manager, PlainB) + assert base_manager(PlainC) == (models.Manager, PlainC) - self.assertEqual(base_manager(Model2A), (PolymorphicManager, Model2A)) - self.assertEqual(base_manager(Model2B), (PolymorphicManager, Model2B)) - self.assertEqual(base_manager(Model2C), (PolymorphicManager, Model2C)) + assert base_manager(Model2A) == (PolymorphicManager, Model2A) + assert base_manager(Model2B) == (PolymorphicManager, Model2B) + assert base_manager(Model2C) == (PolymorphicManager, Model2C) - self.assertEqual( - base_manager(One2OneRelatingModel), - (PolymorphicManager, One2OneRelatingModel), - ) - self.assertEqual( - base_manager(One2OneRelatingModelDerived), - (PolymorphicManager, One2OneRelatingModelDerived), + assert base_manager(One2OneRelatingModel) == (PolymorphicManager, One2OneRelatingModel) + assert base_manager(One2OneRelatingModelDerived) == ( + PolymorphicManager, + One2OneRelatingModelDerived, ) def test_instance_default_manager(self): @@ -445,22 +433,22 @@ def default_manager(instance): model_2b = Model2B(field2="C1") model_2c = Model2C(field3="C1") - self.assertEqual(default_manager(plain_a), (models.Manager, PlainA)) - self.assertEqual(default_manager(plain_b), (models.Manager, PlainB)) - self.assertEqual(default_manager(plain_c), (models.Manager, PlainC)) + assert default_manager(plain_a) == (models.Manager, PlainA) + assert default_manager(plain_b) == (models.Manager, PlainB) + assert default_manager(plain_c) == (models.Manager, PlainC) - self.assertEqual(default_manager(model_2a), (PolymorphicManager, Model2A)) - self.assertEqual(default_manager(model_2b), (PolymorphicManager, Model2B)) - self.assertEqual(default_manager(model_2c), (PolymorphicManager, Model2C)) + assert default_manager(model_2a) == (PolymorphicManager, Model2A) + assert default_manager(model_2b) == (PolymorphicManager, Model2B) + assert default_manager(model_2c) == (PolymorphicManager, Model2C) def test_foreignkey_field(self): self.create_model2abcd() object2a = Model2A.objects.get(field1="C1") - self.assertEqual(object2a.model2b.__class__, Model2B) + assert object2a.model2b.__class__ == Model2B object2b = Model2B.objects.get(field1="C1") - self.assertEqual(object2b.model2c.__class__, Model2C) + assert object2b.model2c.__class__ == Model2C def test_onetoone_field(self): self.create_model2abcd() @@ -470,57 +458,54 @@ def test_onetoone_field(self): b = One2OneRelatingModelDerived.objects.create(one2one=a, field1="f1", field2="f2") # FIXME: this result is basically wrong, probably due to Django cacheing (we used base_objects), but should not be a problem - self.assertEqual(b.one2one.__class__, Model2A) - self.assertEqual(b.one2one_id, b.one2one.id) + assert b.one2one.__class__ == Model2A + assert b.one2one_id == b.one2one.id c = One2OneRelatingModelDerived.objects.get(field1="f1") - self.assertEqual(c.one2one.__class__, Model2C) - self.assertEqual(a.one2onerelatingmodel.__class__, One2OneRelatingModelDerived) + assert c.one2one.__class__ == Model2C + assert a.one2onerelatingmodel.__class__ == One2OneRelatingModelDerived def test_manytomany_field(self): # Model 1 o = ModelShow1.objects.create(field1="abc") o.m2m.add(o) o.save() - self.assertEqual( - repr(ModelShow1.objects.all()), - "[ ]", + assert ( + repr(ModelShow1.objects.all()) + == "[ ]" ) # Model 2 o = ModelShow2.objects.create(field1="abc") o.m2m.add(o) o.save() - self.assertEqual( - repr(ModelShow2.objects.all()), - '[ ]', - ) + assert repr(ModelShow2.objects.all()) == '[ ]' # Model 3 o = ModelShow3.objects.create(field1="abc") o.m2m.add(o) o.save() - self.assertEqual( - repr(ModelShow3.objects.all()), - '[ ]', + assert ( + repr(ModelShow3.objects.all()) + == '[ ]' ) - self.assertEqual( - repr(ModelShow1.objects.all().annotate(Count("m2m"))), - "[ ]", + assert ( + repr(ModelShow1.objects.all().annotate(Count("m2m"))) + == "[ ]" ) - self.assertEqual( - repr(ModelShow2.objects.all().annotate(Count("m2m"))), - '[ ]', + assert ( + repr(ModelShow2.objects.all().annotate(Count("m2m"))) + == '[ ]' ) - self.assertEqual( - repr(ModelShow3.objects.all().annotate(Count("m2m"))), - '[ ]', + assert ( + repr(ModelShow3.objects.all().annotate(Count("m2m"))) + == '[ ]' ) # no pretty printing ModelShow1_plain.objects.create(field1="abc") ModelShow2_plain.objects.create(field1="abc", field2="def") - self.assertQuerysetEqual( + self.assertQuerySetEqual( ModelShow1_plain.objects.all(), [ModelShow1_plain, ModelShow2_plain], transform=lambda o: o.__class__, @@ -531,7 +516,7 @@ def test_extra_method(self): a, b, c, d = self.create_model2abcd() objects = Model2A.objects.extra(where=[f"id IN ({b.id}, {c.id})"]) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2B, Model2C], transform=lambda o: o.__class__, ordered=False ) @@ -540,7 +525,7 @@ def test_extra_method(self): where=["field1 = 'A1' OR field1 = 'B1'"], order_by=["-id"], ) - self.assertQuerysetEqual(objects, [Model2B, Model2A], transform=lambda o: o.__class__) + self.assertQuerySetEqual(objects, [Model2B, Model2A], transform=lambda o: o.__class__) ModelExtraA.objects.create(field1="A1") ModelExtraB.objects.create(field1="B1", field2="B2") @@ -553,25 +538,25 @@ def test_extra_method(self): select={"topic": "tests_modelextraexternal.topic"}, where=["tests_modelextraa.id = tests_modelextraexternal.id"], ) - self.assertEqual( - repr(objects[0]), - '', + assert ( + repr(objects[0]) + == '' ) - self.assertEqual( - repr(objects[1]), - '', + assert ( + repr(objects[1]) + == '' ) - self.assertEqual( - repr(objects[2]), - '', + assert ( + repr(objects[2]) + == '' ) - self.assertEqual(len(objects), 3) + assert len(objects) == 3 def test_instance_of_filter(self): self.create_model2abcd() objects = Model2A.objects.instance_of(Model2B) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2B, Model2C, Model2D], transform=lambda o: o.__class__, @@ -579,7 +564,7 @@ def test_instance_of_filter(self): ) objects = Model2A.objects.filter(instance_of=Model2B) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2B, Model2C, Model2D], transform=lambda o: o.__class__, @@ -587,7 +572,7 @@ def test_instance_of_filter(self): ) objects = Model2A.objects.filter(Q(instance_of=Model2B)) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2B, Model2C, Model2D], transform=lambda o: o.__class__, @@ -595,7 +580,7 @@ def test_instance_of_filter(self): ) objects = Model2A.objects.not_instance_of(Model2B) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2A], transform=lambda o: o.__class__, ordered=False ) @@ -603,7 +588,7 @@ def test_polymorphic___filter(self): self.create_model2abcd() objects = Model2A.objects.filter(Q(Model2B___field2="B2") | Q(Model2C___field3="C3")) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2B, Model2C], transform=lambda o: o.__class__, ordered=False ) @@ -614,7 +599,7 @@ def test_polymorphic_applabel___filter(self): objects = Model2A.objects.filter( Q(tests__Model2B___field2="B2") | Q(tests__Model2C___field3="C3") ) - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2B, Model2C], transform=lambda o: o.__class__, ordered=False ) @@ -625,7 +610,7 @@ def test_query_filter_exclude_is_immutable(self): # when Model2A.objects.filter(q_to_reuse).all() # then - self.assertEqual(q_to_reuse.children, untouched_q_object.children) + assert q_to_reuse.children == untouched_q_object.children # given q_to_reuse = Q(Model2B___field2="something") @@ -633,7 +618,7 @@ def test_query_filter_exclude_is_immutable(self): # when Model2B.objects.filter(q_to_reuse).all() # then - self.assertEqual(q_to_reuse.children, untouched_q_object.children) + assert q_to_reuse.children == untouched_q_object.children def test_polymorphic___filter_field(self): p = ModelUnderRelParent.objects.create(_private=True, field1="AA") @@ -641,7 +626,7 @@ def test_polymorphic___filter_field(self): # The "___" filter should also parse to "parent" -> "_private" as fallback. objects = ModelUnderRelChild.objects.filter(parent___private=True) - self.assertEqual(len(objects), 1) + assert len(objects) == 1 def test_polymorphic___filter_reverse_field(self): p = ModelUnderRelParent.objects.create(_private=True, field1="BB") @@ -649,18 +634,18 @@ def test_polymorphic___filter_reverse_field(self): # Also test for reverse relations objects = ModelUnderRelParent.objects.filter(children___private2=True) - self.assertEqual(len(objects), 1) + assert len(objects) == 1 def test_delete(self): a, b, c, d = self.create_model2abcd() oa = Model2A.objects.get(id=b.id) - self.assertEqual(oa.__class__, Model2B) - self.assertEqual(Model2A.objects.count(), 4) + assert oa.__class__ == Model2B + assert Model2A.objects.count() == 4 oa.delete() objects = Model2A.objects.all() - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [Model2A, Model2C, Model2D], transform=lambda o: o.__class__, @@ -673,9 +658,9 @@ def test_combine_querysets(self): qs = Base.objects.instance_of(ModelX) | Base.objects.instance_of(ModelY) qs = qs.order_by("field_b") - self.assertEqual(repr(qs[0]), "") - self.assertEqual(repr(qs[1]), "") - self.assertEqual(len(qs), 2) + assert repr(qs[0]) == "" + assert repr(qs[1]) == "" + assert len(qs) == 2 def test_multiple_inheritance(self): # multiple inheritance, subclassing third party models (mix PolymorphicModel with models.Model) @@ -684,14 +669,13 @@ def test_multiple_inheritance(self): Enhance_Inherit.objects.create(field_b="b-inherit", field_p="p", field_i="i") qs = Enhance_Base.objects.all() - self.assertEqual(len(qs), 2) - self.assertEqual( - repr(qs[0]), - '', + assert len(qs) == 2 + assert ( + repr(qs[0]) == '' ) - self.assertEqual( - repr(qs[1]), - '', + assert ( + repr(qs[1]) + == '' ) def test_relation_base(self): @@ -704,58 +688,58 @@ def test_relation_base(self): oa.m2m.add(ob) objects = RelationBase.objects.all() - self.assertEqual( - repr(objects[0]), - '', + assert ( + repr(objects[0]) + == '' ) - self.assertEqual( - repr(objects[1]), - '', + assert ( + repr(objects[1]) + == '' ) - self.assertEqual( - repr(objects[2]), - '', + assert ( + repr(objects[2]) + == '' ) - self.assertEqual( - repr(objects[3]), - '', + assert ( + repr(objects[3]) + == '' ) - self.assertEqual(len(objects), 4) + assert len(objects) == 4 oa = RelationBase.objects.get(id=2) - self.assertEqual( - repr(oa.fk), - '', + assert ( + repr(oa.fk) + == '' ) objects = oa.relationbase_set.all() - self.assertEqual( - repr(objects[0]), - '', + assert ( + repr(objects[0]) + == '' ) - self.assertEqual( - repr(objects[1]), - '', + assert ( + repr(objects[1]) + == '' ) - self.assertEqual(len(objects), 2) + assert len(objects) == 2 ob = RelationBase.objects.get(id=3) - self.assertEqual( - repr(ob.fk), - '', + assert ( + repr(ob.fk) + == '' ) oa = RelationA.objects.get() objects = oa.m2m.all() - self.assertEqual( - repr(objects[0]), - '', + assert ( + repr(objects[0]) + == '' ) - self.assertEqual( - repr(objects[1]), - '', + assert ( + repr(objects[1]) + == '' ) - self.assertEqual(len(objects), 2) + assert len(objects) == 2 def test_user_defined_manager(self): self.create_model2abcd() @@ -764,14 +748,14 @@ def test_user_defined_manager(self): # MyManager should reverse the sorting of field1 objects = ModelWithMyManager.objects.all() - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [(ModelWithMyManager, "D1b", "D4b"), (ModelWithMyManager, "D1a", "D4a")], transform=lambda o: (o.__class__, o.field1, o.field4), ) - self.assertIs(type(ModelWithMyManager.objects), MyManager) - self.assertIs(type(ModelWithMyManager._default_manager), MyManager) + assert type(ModelWithMyManager.objects) is MyManager + assert type(ModelWithMyManager._default_manager) is MyManager def test_user_defined_manager_as_secondary(self): self.create_model2abcd() @@ -780,7 +764,7 @@ def test_user_defined_manager_as_secondary(self): # MyManager should reverse the sorting of field1 objects = ModelWithMyManagerNoDefault.my_objects.all() - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [ (ModelWithMyManagerNoDefault, "D1b", "D4b"), @@ -789,18 +773,18 @@ def test_user_defined_manager_as_secondary(self): transform=lambda o: (o.__class__, o.field1, o.field4), ) - self.assertIs(type(ModelWithMyManagerNoDefault.my_objects), MyManager) - self.assertIs(type(ModelWithMyManagerNoDefault.objects), PolymorphicManager) - self.assertIs(type(ModelWithMyManagerNoDefault._default_manager), PolymorphicManager) + assert type(ModelWithMyManagerNoDefault.my_objects) is MyManager + assert type(ModelWithMyManagerNoDefault.objects) is PolymorphicManager + assert type(ModelWithMyManagerNoDefault._default_manager) is PolymorphicManager def test_user_objects_manager_as_secondary(self): self.create_model2abcd() ModelWithMyManagerDefault.objects.create(field1="D1a", field4="D4a") ModelWithMyManagerDefault.objects.create(field1="D1b", field4="D4b") - self.assertIs(type(ModelWithMyManagerDefault.my_objects), MyManager) - self.assertIs(type(ModelWithMyManagerDefault.objects), PolymorphicManager) - self.assertIs(type(ModelWithMyManagerDefault._default_manager), MyManager) + assert type(ModelWithMyManagerDefault.my_objects) is MyManager + assert type(ModelWithMyManagerDefault.objects) is PolymorphicManager + assert type(ModelWithMyManagerDefault._default_manager) is MyManager def test_user_defined_queryset_as_manager(self): self.create_model2abcd() @@ -808,48 +792,47 @@ def test_user_defined_queryset_as_manager(self): ModelWithMyManager2.objects.create(field1="D1b", field4="D4b") objects = ModelWithMyManager2.objects.all() - self.assertQuerysetEqual( + self.assertQuerySetEqual( objects, [(ModelWithMyManager2, "D1a", "D4a"), (ModelWithMyManager2, "D1b", "D4b")], transform=lambda o: (o.__class__, o.field1, o.field4), ordered=False, ) - self.assertEqual( - type(ModelWithMyManager2.objects).__name__, - "PolymorphicManagerFromMyManagerQuerySet", + assert ( + type(ModelWithMyManager2.objects).__name__ == "PolymorphicManagerFromMyManagerQuerySet" ) - self.assertEqual( - type(ModelWithMyManager2._default_manager).__name__, - "PolymorphicManagerFromMyManagerQuerySet", + assert ( + type(ModelWithMyManager2._default_manager).__name__ + == "PolymorphicManagerFromMyManagerQuerySet" ) def test_manager_inheritance(self): # by choice of MRO, should be MyManager from MROBase1. - self.assertIs(type(MRODerived.objects), MyManager) + assert type(MRODerived.objects) is MyManager def test_queryset_assignment(self): # This is just a consistency check for now, testing standard Django behavior. parent = PlainParentModelWithManager.objects.create() child = PlainChildModelWithManager.objects.create(fk=parent) - self.assertIs(type(PlainParentModelWithManager._default_manager), models.Manager) - self.assertIs(type(PlainChildModelWithManager._default_manager), PlainMyManager) - self.assertIs(type(PlainChildModelWithManager.objects), PlainMyManager) - self.assertIs(type(PlainChildModelWithManager.objects.all()), PlainMyManagerQuerySet) + assert type(PlainParentModelWithManager._default_manager) is models.Manager + assert type(PlainChildModelWithManager._default_manager) is PlainMyManager + assert type(PlainChildModelWithManager.objects) is PlainMyManager + assert type(PlainChildModelWithManager.objects.all()) is PlainMyManagerQuerySet # A related set is created using the model's _default_manager, so does gain extra methods. - self.assertIs(type(parent.childmodel_set.my_queryset_foo()), PlainMyManagerQuerySet) + assert type(parent.childmodel_set.my_queryset_foo()) is PlainMyManagerQuerySet # For polymorphic models, the same should happen. parent = ParentModelWithManager.objects.create() child = ChildModelWithManager.objects.create(fk=parent) - self.assertIs(type(ParentModelWithManager._default_manager), PolymorphicManager) - self.assertIs(type(ChildModelWithManager._default_manager), MyManager) - self.assertIs(type(ChildModelWithManager.objects), MyManager) - self.assertIs(type(ChildModelWithManager.objects.my_queryset_foo()), MyManagerQuerySet) + assert type(ParentModelWithManager._default_manager) is PolymorphicManager + assert type(ChildModelWithManager._default_manager) is MyManager + assert type(ChildModelWithManager.objects) is MyManager + assert type(ChildModelWithManager.objects.my_queryset_foo()) is MyManagerQuerySet # A related set is created using the model's _default_manager, so does gain extra methods. - self.assertIs(type(parent.childmodel_set.my_queryset_foo()), MyManagerQuerySet) + assert type(parent.childmodel_set.my_queryset_foo()) is MyManagerQuerySet def test_proxy_models(self): # prepare some data @@ -863,7 +846,7 @@ def test_proxy_models(self): with self.assertNumQueries(1): items = list(ProxyBase.objects.all()) - self.assertIsInstance(items[0], ProxyChild) + assert isinstance(items[0], ProxyChild) def test_queryset_on_proxy_model_does_not_return_superclasses(self): ProxyBase.objects.create(some_data="Base1") @@ -872,8 +855,8 @@ def test_queryset_on_proxy_model_does_not_return_superclasses(self): ProxyChild.objects.create(some_data="Child2") ProxyChild.objects.create(some_data="Child3") - self.assertEqual(5, ProxyBase.objects.count()) - self.assertEqual(3, ProxyChild.objects.count()) + assert ProxyBase.objects.count() == 5 + assert ProxyChild.objects.count() == 3 def test_proxy_get_real_instance_class(self): """ @@ -885,21 +868,21 @@ def test_proxy_get_real_instance_class(self): nonproxychild = NonProxyChild.objects.create(name=name) pb = ProxyBase.objects.get(id=1) - self.assertEqual(pb.get_real_instance_class(), NonProxyChild) - self.assertEqual(pb.get_real_instance(), nonproxychild) - self.assertEqual(pb.name, name) + assert pb.get_real_instance_class() == NonProxyChild + assert pb.get_real_instance() == nonproxychild + assert pb.name == name pbm = NonProxyChild.objects.get(id=1) - self.assertEqual(pbm.get_real_instance_class(), NonProxyChild) - self.assertEqual(pbm.get_real_instance(), nonproxychild) - self.assertEqual(pbm.name, name) + assert pbm.get_real_instance_class() == NonProxyChild + assert pbm.get_real_instance() == nonproxychild + assert pbm.name == name def test_content_types_for_proxy_models(self): """Checks if ContentType is capable of returning proxy models.""" from django.contrib.contenttypes.models import ContentType ct = ContentType.objects.get_for_model(ProxyChild, for_concrete_model=False) - self.assertEqual(ProxyChild, ct.model_class()) + assert ProxyChild == ct.model_class() def test_proxy_model_inheritance(self): """ @@ -907,10 +890,10 @@ def test_proxy_model_inheritance(self): """ # The managers should point to the proper objects. # otherwise, the whole excersise is pointless. - self.assertEqual(ProxiedBase.objects.model, ProxiedBase) - self.assertEqual(ProxyModelBase.objects.model, ProxyModelBase) - self.assertEqual(ProxyModelA.objects.model, ProxyModelA) - self.assertEqual(ProxyModelB.objects.model, ProxyModelB) + assert ProxiedBase.objects.model == ProxiedBase + assert ProxyModelBase.objects.model == ProxyModelBase + assert ProxyModelA.objects.model == ProxyModelA + assert ProxyModelB.objects.model == ProxyModelB # Create objects object1_pk = ProxyModelA.objects.create(name="object1").pk @@ -919,52 +902,48 @@ def test_proxy_model_inheritance(self): # Getting single objects object1 = ProxyModelBase.objects.get(name="object1") object2 = ProxyModelBase.objects.get(name="object2") - self.assertEqual( - repr(object1), - '' % object1_pk, + assert repr(object1) == ( + '' % object1_pk ) - self.assertEqual( - repr(object2), + assert repr(object2) == ( '' - % object2_pk, + % object2_pk ) - self.assertIsInstance(object1, ProxyModelA) - self.assertIsInstance(object2, ProxyModelB) + assert isinstance(object1, ProxyModelA) + assert isinstance(object2, ProxyModelB) # Same for lists objects = list(ProxyModelBase.objects.all().order_by("name")) - self.assertEqual( - repr(objects[0]), - '' % object1_pk, + assert repr(objects[0]) == ( + '' % object1_pk ) - self.assertEqual( - repr(objects[1]), + assert repr(objects[1]) == ( '' - % object2_pk, + % object2_pk ) - self.assertIsInstance(objects[0], ProxyModelA) - self.assertIsInstance(objects[1], ProxyModelB) + assert isinstance(objects[0], ProxyModelA) + assert isinstance(objects[1], ProxyModelB) def test_custom_pk(self): CustomPkBase.objects.create(b="b") CustomPkInherit.objects.create(b="b", i="i") qs = CustomPkBase.objects.all() - self.assertEqual(len(qs), 2) - self.assertEqual(repr(qs[0]), '') - self.assertEqual( - repr(qs[1]), - '', + assert len(qs) == 2 + assert repr(qs[0]) == '' + assert ( + repr(qs[1]) + == '' ) def test_fix_getattribute(self): # fixed issue in PolymorphicModel.__getattribute__: field name same as model name o = ModelFieldNameTest.objects.create(modelfieldnametest="1") - self.assertEqual(repr(o), "") + assert repr(o) == "" # if subclass defined __init__ and accessed class members, # __getattribute__ had a problem: "...has no attribute 'sub_and_superclass_dict'" o = InitTestModelSubclass.objects.create() - self.assertEqual(o.bar, "XYZ") + assert o.bar == "XYZ" def test_parent_link_and_related_name(self): t = TestParentLinkAndRelatedName(field1="TestParentLinkAndRelatedName") @@ -972,15 +951,15 @@ def test_parent_link_and_related_name(self): p = ModelShow1_plain.objects.get(field1="TestParentLinkAndRelatedName") # check that p is equal to the - self.assertIsInstance(p, TestParentLinkAndRelatedName) - self.assertEqual(p, t) + assert isinstance(p, TestParentLinkAndRelatedName) + assert p == t # check that the accessors to parent and sublass work correctly and return the right object p = ModelShow1_plain.objects.non_polymorphic().get(field1="TestParentLinkAndRelatedName") # p should be Plain1 and t TestParentLinkAndRelatedName, so not equal - self.assertNotEqual(p, t) - self.assertEqual(p, t.superclass) - self.assertEqual(p.related_name_subclass, t) + assert p != t + assert p == t.superclass + assert p.related_name_subclass == t # test that we can delete the object t.delete() @@ -994,14 +973,14 @@ def test_polymorphic__aggregate(self): # aggregate using **kwargs result = Model2A.objects.aggregate(cnt=Count("Model2B___field2")) - self.assertEqual(result, {"cnt": 2}) + assert result == {"cnt": 2} # aggregate using **args - self.assertRaisesMessage( + with pytest.raises( AssertionError, - "PolymorphicModel: annotate()/aggregate(): ___ model lookup supported for keyword arguments only", - lambda: Model2A.objects.aggregate(Count("Model2B___field2")), - ) + match="model lookup supported for keyword arguments only", + ): + Model2A.objects.aggregate(Count("Model2B___field2")) def test_polymorphic__complex_aggregate(self): """test (complex expression on) aggregate (should work for annotate either)""" @@ -1015,7 +994,7 @@ def test_polymorphic__complex_aggregate(self): cnt_a1=Count(Case(When(field1="A1", then=1))), cnt_b2=Count(Case(When(Model2B___field2="B2", then=1))), ) - self.assertEqual(result, {"cnt_b2": 2, "cnt_a1": 3}) + assert result == {"cnt_b2": 2, "cnt_a1": 3} # aggregate using **args # we have to set the defaul alias or django won't except a complex expression @@ -1025,9 +1004,9 @@ def ComplexAgg(expression): complexagg.default_alias = "complexagg" return complexagg - with self.assertRaisesMessage( + with pytest.raises( AssertionError, - "PolymorphicModel: annotate()/aggregate(): ___ model lookup supported for keyword arguments only", + match="model lookup supported for keyword arguments only", ): Model2A.objects.aggregate(ComplexAgg("Model2B___field2")) @@ -1049,39 +1028,39 @@ def test_polymorphic__filtered_relation(self): "blogentry", condition=Q(blogentry__text__contains="joined") ), ).aggregate(Count("text_joined")) - self.assertEqual(result, {"text_joined__count": 3}) + assert result == {"text_joined__count": 3} result = BlogA.objects.annotate( text_joined=FilteredRelation( "blogentry", condition=Q(blogentry__text__contains="joined") ), ).aggregate(count=Count("text_joined")) - self.assertEqual(result, {"count": 3}) + assert result == {"count": 3} result = BlogBase.objects.annotate( info_joined=FilteredRelation("bloga", condition=Q(BlogA___info__contains="joined")), ).aggregate(Count("info_joined")) - self.assertEqual(result, {"info_joined__count": 2}) + assert result == {"info_joined__count": 2} result = BlogBase.objects.annotate( info_joined=FilteredRelation("bloga", condition=Q(BlogA___info__contains="joined")), ).aggregate(count=Count("info_joined")) - self.assertEqual(result, {"count": 2}) + assert result == {"count": 2} # We should get a BlogA and a BlogB result = BlogBase.objects.annotate( info_joined=FilteredRelation("bloga", condition=Q(BlogA___info__contains="joined")), ).filter(info_joined__isnull=True) - self.assertEqual(result.count(), 2) - self.assertIsInstance(result.first(), BlogA) - self.assertIsInstance(result.last(), BlogB) + assert result.count() == 2 + assert isinstance(result.first(), BlogA) + assert isinstance(result.last(), BlogB) def test_polymorphic__expressions(self): from django.db.models.functions import Concat # no exception raised result = Model2B.objects.annotate(val=Concat("field1", "field2")) - self.assertEqual(list(result), []) + assert list(result) == [] def test_null_polymorphic_id(self): """Test that a proper error message is displayed when the database lacks the ``polymorphic_ctype_id``""" @@ -1090,7 +1069,7 @@ def test_null_polymorphic_id(self): Model2B.objects.create(field1="A1", field2="B2") Model2A.objects.all().update(polymorphic_ctype_id=None) - with self.assertRaises(PolymorphicTypeUndefined): + with pytest.raises(PolymorphicTypeUndefined): list(Model2A.objects.all()) def test_invalid_polymorphic_id(self): @@ -1101,7 +1080,7 @@ def test_invalid_polymorphic_id(self): invalid = ContentType.objects.get_for_model(PlainA).pk Model2A.objects.all().update(polymorphic_ctype_id=invalid) - with self.assertRaises(PolymorphicTypeInvalid): + with pytest.raises(PolymorphicTypeInvalid): list(Model2A.objects.all()) def test_bulk_create_abstract_inheritance(self): @@ -1111,10 +1090,10 @@ def test_bulk_create_abstract_inheritance(self): ArtProject(topic="Sculpture with Tim", artist="T. Turner"), ] ) - self.assertEqual( - sorted(ArtProject.objects.values_list("topic", "artist")), - [("Painting with Tim", "T. Turner"), ("Sculpture with Tim", "T. Turner")], - ) + assert sorted(ArtProject.objects.values_list("topic", "artist")) == [ + ("Painting with Tim", "T. Turner"), + ("Sculpture with Tim", "T. Turner"), + ] def test_bulk_create_proxy_inheritance(self): RedheadDuck.objects.bulk_create( @@ -1131,28 +1110,25 @@ def test_bulk_create_proxy_inheritance(self): Duck(name="duck2"), ] ) - self.assertEqual( - sorted(RedheadDuck.objects.values_list("name", flat=True)), - ["redheadduck1", "redheadduck2"], - ) - self.assertEqual( - sorted(RubberDuck.objects.values_list("name", flat=True)), - ["rubberduck1", "rubberduck2"], - ) - self.assertEqual( - sorted(Duck.objects.values_list("name", flat=True)), - [ - "duck1", - "duck2", - "redheadduck1", - "redheadduck2", - "rubberduck1", - "rubberduck2", - ], - ) + assert sorted(RedheadDuck.objects.values_list("name", flat=True)) == [ + "redheadduck1", + "redheadduck2", + ] + assert sorted(RubberDuck.objects.values_list("name", flat=True)) == [ + "rubberduck1", + "rubberduck2", + ] + assert sorted(Duck.objects.values_list("name", flat=True)) == [ + "duck1", + "duck2", + "redheadduck1", + "redheadduck2", + "rubberduck1", + "rubberduck2", + ] def test_bulk_create_unsupported_multi_table_inheritance(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): MultiTableDerived.objects.bulk_create( [MultiTableDerived(field1="field1", field2="field2")] ) @@ -1165,10 +1141,10 @@ def test_bulk_create_ignore_conflicts(self): ], ignore_conflicts=True, ) - self.assertEqual(ArtProject.objects.count(), 2) + assert ArtProject.objects.count() == 2 def test_bulk_create_no_ignore_conflicts(self): - with self.assertRaises(IntegrityError): + with pytest.raises(IntegrityError): ArtProject.objects.bulk_create( [ ArtProject(topic="Painting with Tim", artist="T. Turner"), @@ -1176,7 +1152,7 @@ def test_bulk_create_no_ignore_conflicts(self): ], ignore_conflicts=False, ) - self.assertEqual(ArtProject.objects.count(), 1) + assert ArtProject.objects.count() == 1 def test_can_query_using_subclass_selector_on_abstract_model(self): obj = SubclassSelectorAbstractConcreteModel.objects.create(concrete_field="abc") @@ -1185,7 +1161,7 @@ def test_can_query_using_subclass_selector_on_abstract_model(self): SubclassSelectorAbstractConcreteModel___concrete_field="abc" ).get() - self.assertEqual(obj.pk, queried_obj.pk) + assert obj.pk == queried_obj.pk def test_can_query_using_subclass_selector_on_proxy_model(self): obj = SubclassSelectorProxyConcreteModel.objects.create(concrete_field="abc") @@ -1194,7 +1170,7 @@ def test_can_query_using_subclass_selector_on_proxy_model(self): SubclassSelectorProxyConcreteModel___concrete_field="abc" ).get() - self.assertEqual(obj.pk, queried_obj.pk) + assert obj.pk == queried_obj.pk def test_prefetch_related_behaves_normally_with_polymorphic_model(self): b1 = RelatingModel.objects.create() @@ -1203,7 +1179,7 @@ def test_prefetch_related_behaves_normally_with_polymorphic_model(self): b2.many2many.add(a) # add same to second relating model qs = RelatingModel.objects.prefetch_related("many2many") for obj in qs: - self.assertEqual(len(obj.many2many.all()), 1) + assert len(obj.many2many.all()) == 1 def test_prefetch_related_with_missing(self): b1 = RelatingModel.objects.create() @@ -1219,13 +1195,13 @@ def test_prefetch_related_with_missing(self): qs = RelatingModel.objects.order_by("pk").prefetch_related("many2many") objects = list(qs) - self.assertEqual(len(objects[0].many2many.all()), 1) + assert len(objects[0].many2many.all()) == 1 # derived object was not fetched - self.assertEqual(len(objects[1].many2many.all()), 0) + assert len(objects[1].many2many.all()) == 0 # base object does exist - self.assertEqual(len(objects[1].many2many.non_polymorphic()), 1) + assert len(objects[1].many2many.non_polymorphic()) == 1 def test_refresh_from_db_fields(self): """Test whether refresh_from_db(fields=..) works as it performs .only() queries""" diff --git a/polymorphic/tests/test_regression.py b/polymorphic/tests/test_regression.py index e416f4fb..0f48971e 100644 --- a/polymorphic/tests/test_regression.py +++ b/polymorphic/tests/test_regression.py @@ -18,16 +18,16 @@ def test_for_query_result_incomplete_with_inheritance(self): bottom.save() expected_queryset = [top, middle, bottom] - self.assertQuerysetEqual( + self.assertQuerySetEqual( Top.objects.order_by("pk"), [repr(r) for r in expected_queryset], **transform_arg ) expected_queryset = [middle, bottom] - self.assertQuerysetEqual( + self.assertQuerySetEqual( Middle.objects.order_by("pk"), [repr(r) for r in expected_queryset], **transform_arg ) expected_queryset = [bottom] - self.assertQuerysetEqual( + self.assertQuerySetEqual( Bottom.objects.order_by("pk"), [repr(r) for r in expected_queryset], **transform_arg ) diff --git a/polymorphic/tests/test_utils.py b/polymorphic/tests/test_utils.py index 03d17631..0934a507 100644 --- a/polymorphic/tests/test_utils.py +++ b/polymorphic/tests/test_utils.py @@ -1,3 +1,4 @@ +import pytest from django.test import TransactionTestCase from polymorphic.models import PolymorphicModel, PolymorphicTypeUndefined @@ -14,10 +15,13 @@ class UtilsTests(TransactionTestCase): def test_sort_by_subclass(self): - self.assertEqual( - sort_by_subclass(Model2D, Model2B, Model2D, Model2A, Model2C), - [Model2A, Model2B, Model2C, Model2D, Model2D], - ) + assert sort_by_subclass(Model2D, Model2B, Model2D, Model2A, Model2C) == [ + Model2A, + Model2B, + Model2C, + Model2D, + Model2D, + ] def test_reset_polymorphic_ctype(self): """ @@ -29,12 +33,12 @@ def test_reset_polymorphic_ctype(self): Model2B.objects.create(field1="A1", field2="B2") Model2A.objects.all().update(polymorphic_ctype_id=None) - with self.assertRaises(PolymorphicTypeUndefined): + with pytest.raises(PolymorphicTypeUndefined): list(Model2A.objects.all()) reset_polymorphic_ctype(Model2D, Model2B, Model2D, Model2A, Model2C) - self.assertQuerysetEqual( + self.assertQuerySetEqual( Model2A.objects.order_by("pk"), [Model2A, Model2D, Model2B, Model2B], transform=lambda o: o.__class__, @@ -45,16 +49,16 @@ def test_get_base_polymorphic_model(self): Test that finding the base polymorphic model works. """ # Finds the base from every level (including lowest) - self.assertIs(get_base_polymorphic_model(Model2D), Model2A) - self.assertIs(get_base_polymorphic_model(Model2C), Model2A) - self.assertIs(get_base_polymorphic_model(Model2B), Model2A) - self.assertIs(get_base_polymorphic_model(Model2A), Model2A) + assert get_base_polymorphic_model(Model2D) is Model2A + assert get_base_polymorphic_model(Model2C) is Model2A + assert get_base_polymorphic_model(Model2B) is Model2A + assert get_base_polymorphic_model(Model2A) is Model2A # Properly handles multiple inheritance - self.assertIs(get_base_polymorphic_model(Enhance_Inherit), Enhance_Base) + assert get_base_polymorphic_model(Enhance_Inherit) is Enhance_Base # Ignores PolymorphicModel itself. - self.assertIs(get_base_polymorphic_model(PolymorphicModel), None) + assert get_base_polymorphic_model(PolymorphicModel) is None def test_get_base_polymorphic_model_skip_abstract(self): """ @@ -71,8 +75,8 @@ class B(A): class C(B): pass - self.assertIs(get_base_polymorphic_model(A), None) - self.assertIs(get_base_polymorphic_model(B), B) - self.assertIs(get_base_polymorphic_model(C), B) + assert get_base_polymorphic_model(A) is None + assert get_base_polymorphic_model(B) is B + assert get_base_polymorphic_model(C) is B - self.assertIs(get_base_polymorphic_model(C, allow_abstract=True), A) + assert get_base_polymorphic_model(C, allow_abstract=True) is A diff --git a/polymorphic_test_settings.py b/polymorphic_test_settings.py new file mode 100644 index 00000000..6d2d3407 --- /dev/null +++ b/polymorphic_test_settings.py @@ -0,0 +1,56 @@ +import dj_database_url + +DEBUG = False +DATABASES = { + "default": dj_database_url.config( + env="PRIMARY_DATABASE", + default="sqlite://:memory:", + ), + "secondary": dj_database_url.config( + env="SECONDARY_DATABASE", + default="sqlite://:memory:", + ), +} +DEFAULT_AUTO_FIELD = "django.db.models.AutoField" +INSTALLED_APPS = ( + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.messages", + "django.contrib.sessions", + "django.contrib.sites", + "django.contrib.admin", + "polymorphic", + "polymorphic.tests", +) +MIDDLEWARE = ( + "django.middleware.common.CommonMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", +) +SITE_ID = 3 +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": (), + "OPTIONS": { + "loaders": ( + "django.template.loaders.filesystem.Loader", + "django.template.loaders.app_directories.Loader", + ), + "context_processors": ( + "django.template.context_processors.debug", + "django.template.context_processors.i18n", + "django.template.context_processors.media", + "django.template.context_processors.request", + "django.template.context_processors.static", + "django.contrib.messages.context_processors.messages", + "django.contrib.auth.context_processors.auth", + ), + }, + } +] +POLYMORPHIC_TEST_SWAPPABLE = "polymorphic.swappedmodel" +ROOT_URLCONF = None +SECRET_KEY = "supersecret" diff --git a/pyproject.toml b/pyproject.toml index 60b79e90..4706713b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,3 +33,6 @@ select = [ "F841", "I", ] + +[tool.pytest.ini_options] +DJANGO_SETTINGS_MODULE = "polymorphic_test_settings" diff --git a/runtests.py b/runtests.py deleted file mode 100755 index 9dbdb70e..00000000 --- a/runtests.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python -Wd -import sys -import warnings -from os.path import abspath, dirname - -import dj_database_url -import django -from django.conf import settings -from django.core.management import execute_from_command_line - -# python -Wd, or run via coverage: -warnings.simplefilter("always", DeprecationWarning) - -# Give feedback on used versions -sys.stderr.write(f"Using Python version {sys.version[:5]} from {sys.executable}\n") -sys.stderr.write( - "Using Django version {} from {}\n".format( - django.get_version(), dirname(abspath(django.__file__)) - ) -) - -if not settings.configured: - settings.configure( - DEBUG=False, - DATABASES={ - "default": dj_database_url.config(env="PRIMARY_DATABASE", default="sqlite://:memory:"), - "secondary": dj_database_url.config( - env="SECONDARY_DATABASE", default="sqlite://:memory:" - ), - }, - DEFAULT_AUTO_FIELD="django.db.models.AutoField", - TEST_RUNNER="django.test.runner.DiscoverRunner", - INSTALLED_APPS=( - "django.contrib.auth", - "django.contrib.contenttypes", - "django.contrib.messages", - "django.contrib.sessions", - "django.contrib.sites", - "django.contrib.admin", - "polymorphic", - "polymorphic.tests", - ), - MIDDLEWARE=( - "django.middleware.common.CommonMiddleware", - "django.contrib.sessions.middleware.SessionMiddleware", - "django.middleware.csrf.CsrfViewMiddleware", - "django.contrib.auth.middleware.AuthenticationMiddleware", - "django.contrib.messages.middleware.MessageMiddleware", - ), - SITE_ID=3, - TEMPLATES=[ - { - "BACKEND": "django.template.backends.django.DjangoTemplates", - "DIRS": (), - "OPTIONS": { - "loaders": ( - "django.template.loaders.filesystem.Loader", - "django.template.loaders.app_directories.Loader", - ), - "context_processors": ( - "django.template.context_processors.debug", - "django.template.context_processors.i18n", - "django.template.context_processors.media", - "django.template.context_processors.request", - "django.template.context_processors.static", - "django.contrib.messages.context_processors.messages", - "django.contrib.auth.context_processors.auth", - ), - }, - } - ], - POLYMORPHIC_TEST_SWAPPABLE="polymorphic.swappedmodel", - ROOT_URLCONF=None, - SECRET_KEY="supersecret", - ) - - -DEFAULT_TEST_APPS = ["polymorphic"] - - -def runtests(): - other_args = list(filter(lambda arg: arg.startswith("-"), sys.argv[1:])) - test_apps = ( - list(filter(lambda arg: not arg.startswith("-"), sys.argv[1:])) or DEFAULT_TEST_APPS - ) - argv = sys.argv[:1] + ["test", "--traceback"] + other_args + test_apps - execute_from_command_line(argv) - - -if __name__ == "__main__": - runtests() diff --git a/tox.ini b/tox.ini index c129822d..eedc0c46 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,9 @@ setenv = postgres: DEFAULT_DATABASE = postgres:///default postgres: SECONDARY_DATABASE = postgres:///secondary deps = - coverage + pytest + pytest-cov + pytest-django dj-database-url django32: Django ~= 3.2 django40: Django ~= 4.0 @@ -22,7 +24,7 @@ deps = djangomain: https://github.com/django/django/archive/main.tar.gz postgres: psycopg2 commands = - coverage run --source polymorphic runtests.py + py.test --cov --cov-report=term-missing --cov-report=xml . [testenv:docs] deps =