diff --git a/AUTHORS.rst b/AUTHORS.rst index 4c38bbbf..2013ea0c 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -144,6 +144,7 @@ Authors - `DanialErfanian `_ - `Sridhar Marella `_ - `Mattia Fantoni `_ +- `Trent Holliday `_ Background ========== diff --git a/CHANGES.rst b/CHANGES.rst index 9c44563c..7dc3aaeb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,10 +3,10 @@ Changes Unreleased ---------- - - Made ``skip_history_when_saving`` work when creating an object - not just when updating an object (gh-1262) - Improved performance of the ``latest_of_each()`` history manager method (gh-1360) +- Added support for m2m fields that specify through tables as strings (gh-1390) 3.7.0 (2024-05-29) ------------------ diff --git a/simple_history/models.py b/simple_history/models.py index 3ffe42a5..ba319cea 100644 --- a/simple_history/models.py +++ b/simple_history/models.py @@ -15,7 +15,7 @@ from django.db import models from django.db.models import ManyToManyField from django.db.models.fields.proxy import OrderWrt -from django.db.models.fields.related import ForeignKey +from django.db.models.fields.related import ForeignKey, lazy_related_operation from django.db.models.fields.related_descriptors import ( ForwardManyToOneDescriptor, ReverseManyToOneDescriptor, @@ -84,6 +84,8 @@ class HistoricalRecords: DEFAULT_MODEL_NAME_PREFIX = "Historical" thread = context = LocalContext() # retain thread for backwards compatibility + # Key is the m2m field and value is a tuple where first entry is the historical m2m + # model and second is the through model m2m_models = {} def __init__( @@ -222,13 +224,6 @@ def finalize(self, sender, **kwargs): m2m_fields = self.get_m2m_fields_from_model(sender) - for field in m2m_fields: - m2m_changed.connect( - partial(self.m2m_changed, attr=field.name), - sender=field.remote_field.through, - weak=False, - ) - descriptor = HistoryDescriptor( history_model, manager=self.history_manager, @@ -238,15 +233,29 @@ def finalize(self, sender, **kwargs): sender._meta.simple_history_manager_attribute = self.manager_name for field in m2m_fields: - m2m_model = self.create_history_m2m_model( - history_model, field.remote_field.through - ) - self.m2m_models[field] = m2m_model - setattr(module, m2m_model.__name__, m2m_model) + def resolve_through_model(history_model, through_model): + m2m_changed.connect( + partial(self.m2m_changed, attr=field.name), + sender=through_model, + weak=False, + ) + m2m_model = self.create_history_m2m_model(history_model, through_model) + # Save the created history model and the resolved through model together + # for reference later + self.m2m_models[field] = (m2m_model, through_model) + + setattr(module, m2m_model.__name__, m2m_model) - m2m_descriptor = HistoryDescriptor(m2m_model) - setattr(history_model, field.name, m2m_descriptor) + m2m_descriptor = HistoryDescriptor(m2m_model) + setattr(history_model, field.name, m2m_descriptor) + + # Lazily generate the historical m2m models for the fields when all of the + # associated models have been fully loaded. This handles resolving through + # models referenced as strings. This is how django m2m fields handle this. + lazy_related_operation( + resolve_through_model, history_model, field.remote_field.through + ) def get_history_model_name(self, model): if not self.custom_model_name: @@ -685,9 +694,7 @@ def m2m_changed(self, instance, action, attr, pk_set, reverse, **_): def create_historical_record_m2ms(self, history_instance, instance): for field in history_instance._history_m2m_fields: - m2m_history_model = self.m2m_models[field] - original_instance = history_instance.instance - through_model = getattr(original_instance, field.name).through + m2m_history_model, through_model = self.m2m_models[field] through_model_field_names = [f.name for f in through_model._meta.fields] through_model_fk_field_names = [ f.name for f in through_model._meta.fields if isinstance(f, ForeignKey) diff --git a/simple_history/tests/models.py b/simple_history/tests/models.py index f35b5cf6..ec4914ce 100644 --- a/simple_history/tests/models.py +++ b/simple_history/tests/models.py @@ -237,6 +237,18 @@ class PollWithSelfManyToMany(models.Model): history = HistoricalRecords(m2m_fields=[relations]) +class PollWithManyToManyThroughString(models.Model): + books = models.ManyToManyField("Book", through="PollBookThroughTable") + history = HistoricalRecords(m2m_fields=[books]) + + +class PollBookThroughTable(models.Model): + book = models.ForeignKey("Book", on_delete=models.CASCADE) + poll = models.ForeignKey( + "PollWithManyToManyThroughString", on_delete=models.CASCADE + ) + + class CustomAttrNameForeignKey(models.ForeignKey): def __init__(self, *args, **kwargs): self.attr_name = kwargs.pop("attr_name", None) diff --git a/simple_history/tests/tests/test_models.py b/simple_history/tests/tests/test_models.py index d70ca3ce..c25266bd 100644 --- a/simple_history/tests/tests/test_models.py +++ b/simple_history/tests/tests/test_models.py @@ -96,6 +96,7 @@ PollWithHistoricalIPAddress, PollWithManyToMany, PollWithManyToManyCustomHistoryID, + PollWithManyToManyThroughString, PollWithManyToManyWithIPAddress, PollWithNonEditableField, PollWithQuerySetCustomizations, @@ -2515,6 +2516,60 @@ def test_diff_against(self): self.assertEqual(delta, expected_delta) +class ManyToManyThroughStringTest(TestCase): + def setUp(self): + self.model = PollWithManyToManyThroughString + self.history_model = self.model.history.model + self.poll = self.model.objects.create() + self.book = Book.objects.create(isbn="1234") + + def assertDatetimesEqual(self, time1, time2): + self.assertAlmostEqual(time1, time2, delta=timedelta(seconds=2)) + + def assertRecordValues(self, record, klass, values_dict): + for key, value in values_dict.items(): + self.assertEqual(getattr(record, key), value) + self.assertEqual(record.history_object.__class__, klass) + for key, value in values_dict.items(): + if key not in ["history_type", "history_change_reason"]: + self.assertEqual(getattr(record.history_object, key), value) + + def test_create(self): + # There should be 1 history record for our poll, the create from setUp + self.assertEqual(self.poll.history.all().count(), 1) + + # The created history row should be normal and correct + (record,) = self.poll.history.all() + self.assertRecordValues( + record, + self.model, + { + "id": self.poll.id, + "history_type": "+", + }, + ) + self.assertDatetimesEqual(record.history_date, datetime.now()) + + historical_poll = self.poll.history.all()[0] + + # There should be no books associated with the current poll yet + self.assertEqual(historical_poll.books.count(), 0) + + # Add a many-to-many child + self.poll.books.add(self.book) + + # A new history row has been created by adding the M2M + self.assertEqual(self.poll.history.all().count(), 2) + + # The new row has a place attached to it + m2m_record = self.poll.history.all()[0] + self.assertEqual(m2m_record.books.count(), 1) + + # And the historical place is the correct one + historical_book = m2m_record.books.first() + self.assertEqual(historical_book.book, self.book) + + @override_settings(**database_router_override_settings) class MultiDBExplicitHistoryUserIDTest(TestCase): databases = {"default", "other"}