Skip to content

Commit

Permalink
Merge pull request #419 from mesozoic/pickling
Browse files Browse the repository at this point in the history
Allow pickling/unpickling of ChangeTrackingLists
  • Loading branch information
mesozoic authored Feb 15, 2025
2 parents 40b4a37 + dff7582 commit eb1c815
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 6 deletions.
20 changes: 16 additions & 4 deletions pyairtable/orm/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,13 @@ def __get__(
return self
return self._get_list_value(instance)

def __set__(self, instance: "Model", value: Optional[List[T_ORM]]) -> None:
if isinstance(value, list) and not isinstance(value, self.list_class):
assert isinstance(self.list_class, type)
assert issubclass(self.list_class, ChangeTrackingList)
value = self.list_class(value, field=self, model=instance)
super().__set__(instance, value)

def _get_list_value(self, instance: "Model") -> T_ORM_List:
value = instance._fields.get(self.field_name)
# If Airtable returns no value, substitute an empty list.
Expand Down Expand Up @@ -712,10 +719,15 @@ class Meta: ...
# If the list contains record IDs, replace the contents with instances.
# Other code may already have references to this specific list, so
# we replace the existing list's values.
records[: self._max_retrieve] = [
new_records[cast(RecordId, value)] if isinstance(value, RecordId) else value
for value in records[: self._max_retrieve]
]
with records.disable_tracking():
records[: self._max_retrieve] = [
(
new_records[cast(RecordId, value)]
if isinstance(value, RecordId)
else value
)
for value in records[: self._max_retrieve]
]

def _get_list_value(self, instance: "Model") -> ChangeTrackingList[T_Linked]:
"""
Expand Down
9 changes: 7 additions & 2 deletions pyairtable/orm/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,13 @@ def disable_tracking(self) -> Iterator[Self]:
self._tracking_enabled = prev

def _on_change(self) -> None:
if self._tracking_enabled:
self._model._changed[self._field.field_name] = True
try:
if not self._tracking_enabled:
return
except AttributeError:
# This means we're being unpickled and won't call __init__.
return
self._model._changed[self._field.field_name] = True

@overload
def __setitem__(self, index: SupportsIndex, value: T, /) -> None: ...
Expand Down
2 changes: 2 additions & 0 deletions tests/test_orm_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ class Author(Model):
collection = [Book(), Book(), Book()]
author = Author()
author.books = collection
assert isinstance(author._fields["Books"], f.ChangeTrackingList)

assert author.books == collection

with pytest.raises(TypeError):
Expand Down
28 changes: 28 additions & 0 deletions tests/test_orm_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Fake(Model):
Meta = fake_meta()
attachments = F.AttachmentsField("Files")
readonly_attachments = F.AttachmentsField("Other Files", readonly=True)
others = F.LinkField["Fake"]("Others", F.LinkSelf)


@pytest.fixture
Expand Down Expand Up @@ -115,3 +116,30 @@ def test_attachment_upload__unsaved_value(mock_upload):
mock_upload.assert_called_once()
assert len(instance.attachments) == 1
assert instance.attachments[0]["url"] != unsaved_url


@pytest.mark.parametrize(
"op,retval,new_value",
[
(mock.call.append(4), None, [1, 2, 3, 4]),
(mock.call.insert(1, 4), None, [1, 4, 2, 3]),
(mock.call.remove(2), None, [1, 3]),
(mock.call.clear(), None, []),
(mock.call.extend([4, 5]), None, [1, 2, 3, 4, 5]),
(mock.call.pop(), 3, [1, 2]),
],
)
def test_change_tracking_list(op, retval, new_value):
"""
Test that ChangeTrackingList performs operations normally
and records (on the model instance) that the field changed.
"""
instance = Fake.from_record(fake_record())
ctl = F.ChangeTrackingList[int]([1, 2, 3], field=Fake.others, model=instance)
assert not instance._changed.get("Others")

fn = getattr(ctl, op._mock_parent._mock_name)
result = fn(*op.args, **op.kwargs)
assert result == retval
assert ctl == new_value
assert instance._changed["Others"] is True
45 changes: 45 additions & 0 deletions tests/test_orm_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from datetime import datetime, timezone
from functools import partial
from unittest import mock
Expand Down Expand Up @@ -463,3 +464,47 @@ def test_save_bool_deprecated():

with pytest.deprecated_call():
assert bool(SaveResult(fake_id(), created=True)) is True


def test_pickling():
"""
Test that a model instance can be pickled and unpickled.
"""
instance = FakeModel.from_record(fake_record(one="one", two="two"))
pickled = pickle.dumps(instance)
unpickled = pickle.loads(pickled)
assert isinstance(unpickled, FakeModel)
assert unpickled is not instance
assert unpickled.id == instance.id
assert unpickled.created_time == instance.created_time
assert unpickled._fields == instance._fields


class LinkedModel(Model):
Meta = fake_meta()
name = f.TextField("Name")
links = f.LinkField("Link", FakeModel)


def test_pickling_with_change_tracking_list():
"""
Test that a model with a ChangeTrackingList can be pickled and unpickled.
"""
fake_models = [FakeModel.from_record(fake_record()) for _ in range(5)]
instance = LinkedModel.from_record(fake_record())
instance.links = fake_models
instance._changed.clear() # Don't want to pickle that part.

# Now we need to be able to pickle and unpickle the model instance.
# We can't pickle/unpickle the list itself on its own, because it needs
# to retain references to the field and model.
pickled = pickle.dumps(instance)
unpickled = pickle.loads(pickled)
assert isinstance(unpickled, LinkedModel)
unpickled_link_ids = [link.id for link in unpickled.links]
assert unpickled_link_ids == [link.id for link in fake_models]

# Make sure change tracking still works.
assert "Link" not in unpickled._changed
unpickled.links.append(FakeModel.from_record(fake_record()))
assert unpickled._changed["Link"] is True
22 changes: 22 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,25 @@ def test_url_cannot_append_after_params():
v / "foo"
with pytest.raises(ValueError):
v // ["foo", "bar"]


@pytest.mark.parametrize(
"docstring,expected",
[
("", ""),
(
"This is a\ndocstring.",
"|enterprise_only|\n\nThis is a\ndocstring.",
),
(
"\t This is a\n\t docstring.",
"\t |enterprise_only|\n\n\t This is a\n\t docstring.",
),
],
)
def test_enterprise_docstring(docstring, expected):
@utils.enterprise_only
class Foo:
__doc__ = docstring

assert Foo.__doc__ == expected

0 comments on commit eb1c815

Please sign in to comment.