Skip to content

Commit

Permalink
#1266 resolving through tables defined as strings on m2m relations
Browse files Browse the repository at this point in the history
  • Loading branch information
trumpet2012 committed Sep 12, 2024
1 parent 7fdebc8 commit a7dd57f
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 19 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ Authors
- `DanialErfanian <https://github.com/DanialErfanian>`_
- `Sridhar Marella <https://github.com/sridhar562345>`_
- `Mattia Fantoni <https://github.com/MattFanto>`_
- `Trent Holliday <https://github.com/trumpet2012>`_

Background
==========
Expand Down
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Changes

Unreleased
----------

- Added support for m2m fields that specify through tables as strings
- Made ``skip_history_when_saving`` work when creating an object - not just when
updating an object (gh-1262)
- Improved performance of the ``latest_of_each()`` history manager method (gh-1360)
Expand Down
43 changes: 25 additions & 18 deletions simple_history/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from django.db import models
from django.db.models import ManyToManyField
from django.db.models.fields.proxy import OrderWrt
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.related import ForeignKey, lazy_related_operation
from django.db.models.fields.related_descriptors import (
ForwardManyToOneDescriptor,
ReverseManyToOneDescriptor,
Expand Down Expand Up @@ -84,6 +84,8 @@ class HistoricalRecords:
DEFAULT_MODEL_NAME_PREFIX = "Historical"

thread = context = LocalContext() # retain thread for backwards compatibility
# Key is the m2m field and value is a tuple where first entry is the historical m2m
# model and second is the through model
m2m_models = {}

def __init__(
Expand Down Expand Up @@ -222,13 +224,6 @@ def finalize(self, sender, **kwargs):

m2m_fields = self.get_m2m_fields_from_model(sender)

for field in m2m_fields:
m2m_changed.connect(
partial(self.m2m_changed, attr=field.name),
sender=field.remote_field.through,
weak=False,
)

descriptor = HistoryDescriptor(
history_model,
manager=self.history_manager,
Expand All @@ -238,15 +233,29 @@ def finalize(self, sender, **kwargs):
sender._meta.simple_history_manager_attribute = self.manager_name

for field in m2m_fields:
m2m_model = self.create_history_m2m_model(
history_model, field.remote_field.through
)
self.m2m_models[field] = m2m_model

setattr(module, m2m_model.__name__, m2m_model)
def resolve_through_model(history_model, through_model):
m2m_changed.connect(
partial(self.m2m_changed, attr=field.name),
sender=through_model,
weak=False,
)
m2m_model = self.create_history_m2m_model(history_model, through_model)
# Save the created history model and the resolved through model together
# for reference later
self.m2m_models[field] = (m2m_model, through_model)

setattr(module, m2m_model.__name__, m2m_model)

m2m_descriptor = HistoryDescriptor(m2m_model)
setattr(history_model, field.name, m2m_descriptor)
m2m_descriptor = HistoryDescriptor(m2m_model)
setattr(history_model, field.name, m2m_descriptor)

# Lazily generate the historical m2m models for the fields when all of the
# associated models have been fully loaded. This handles resolving through
# models referenced as strings. This is how django m2m fields handle this.
lazy_related_operation(
resolve_through_model, history_model, field.remote_field.through
)

def get_history_model_name(self, model):
if not self.custom_model_name:
Expand Down Expand Up @@ -685,9 +694,7 @@ def m2m_changed(self, instance, action, attr, pk_set, reverse, **_):

def create_historical_record_m2ms(self, history_instance, instance):
for field in history_instance._history_m2m_fields:
m2m_history_model = self.m2m_models[field]
original_instance = history_instance.instance
through_model = getattr(original_instance, field.name).through
m2m_history_model, through_model = self.m2m_models[field]
through_model_field_names = [f.name for f in through_model._meta.fields]
through_model_fk_field_names = [
f.name for f in through_model._meta.fields if isinstance(f, ForeignKey)
Expand Down
12 changes: 12 additions & 0 deletions simple_history/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,18 @@ class PollWithSelfManyToMany(models.Model):
history = HistoricalRecords(m2m_fields=[relations])


class PollWithManyToManyThroughString(models.Model):
books = models.ManyToManyField("Book", through="PollBookThroughTable")
history = HistoricalRecords(m2m_fields=[books])


class PollBookThroughTable(models.Model):
book = models.ForeignKey("Book", on_delete=models.CASCADE)
poll = models.ForeignKey(
"PollWithManyToManyThroughString", on_delete=models.CASCADE
)


class CustomAttrNameForeignKey(models.ForeignKey):
def __init__(self, *args, **kwargs):
self.attr_name = kwargs.pop("attr_name", None)
Expand Down
55 changes: 55 additions & 0 deletions simple_history/tests/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
PollWithHistoricalIPAddress,
PollWithManyToMany,
PollWithManyToManyCustomHistoryID,
PollWithManyToManyThroughString,
PollWithManyToManyWithIPAddress,
PollWithNonEditableField,
PollWithQuerySetCustomizations,
Expand Down Expand Up @@ -2515,6 +2516,60 @@ def test_diff_against(self):
self.assertEqual(delta, expected_delta)


class ManyToManyThroughStringTest(TestCase):
def setUp(self):
self.model = PollWithManyToManyThroughString
self.history_model = self.model.history.model
self.poll = self.model.objects.create()
self.book = Book.objects.create(isbn="1234")

def assertDatetimesEqual(self, time1, time2):
self.assertAlmostEqual(time1, time2, delta=timedelta(seconds=2))

def assertRecordValues(self, record, klass, values_dict):
for key, value in values_dict.items():
self.assertEqual(getattr(record, key), value)
self.assertEqual(record.history_object.__class__, klass)
for key, value in values_dict.items():
if key not in ["history_type", "history_change_reason"]:
self.assertEqual(getattr(record.history_object, key), value)

def test_create(self):
# There should be 1 history record for our poll, the create from setUp
self.assertEqual(self.poll.history.all().count(), 1)

# The created history row should be normal and correct
(record,) = self.poll.history.all()
self.assertRecordValues(
record,
self.model,
{
"id": self.poll.id,
"history_type": "+",
},
)
self.assertDatetimesEqual(record.history_date, datetime.now())

historical_poll = self.poll.history.all()[0]

# There should be no books associated with the current poll yet
self.assertEqual(historical_poll.books.count(), 0)

# Add a many-to-many child
self.poll.books.add(self.book)

# A new history row has been created by adding the M2M
self.assertEqual(self.poll.history.all().count(), 2)

# The new row has a place attached to it
m2m_record = self.poll.history.all()[0]
self.assertEqual(m2m_record.books.count(), 1)

# And the historical place is the correct one
historical_book = m2m_record.books.first()
self.assertEqual(historical_book.book, self.book)


@override_settings(**database_router_override_settings)
class MultiDBExplicitHistoryUserIDTest(TestCase):
databases = {"default", "other"}
Expand Down

0 comments on commit a7dd57f

Please sign in to comment.