Skip to content

Commit

Permalink
Rework revision (#797)
Browse files Browse the repository at this point in the history
* rework revisions

* fix tests
  • Loading branch information
roman-right authored Dec 24, 2023
1 parent 5f03c6f commit 75ba9e3
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 47 deletions.
24 changes: 8 additions & 16 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,12 @@
previous_saved_state_needed,
save_state_after,
saved_state_needed,
swap_revision_after,
)
from beanie.odm.utils.typing import extract_id_class

if IS_PYDANTIC_V2:
from pydantic import model_validator


DocType = TypeVar("DocType", bound="Document")
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)

Expand Down Expand Up @@ -165,7 +163,6 @@ class Config:

# State
revision_id: Optional[UUID] = Field(default=None, exclude=True)
_previous_revision_id: Optional[UUID] = PrivateAttr(default=None)
_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_previous_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)

Expand All @@ -181,11 +178,6 @@ class Config:
# Database
_database_major_version: ClassVar[int] = 4

def _swap_revision(self):
if self.get_settings().use_revision:
self._previous_revision_id = self.revision_id
self.revision_id = uuid4()

def __init__(self, *args, **kwargs):
super(Document, self).__init__(*args, **kwargs)
self.get_motor_collection()
Expand Down Expand Up @@ -263,7 +255,6 @@ async def get(
)

@wrap_with_actions(EventTypes.INSERT)
@swap_revision_after
@save_state_after
@validate_self_before
async def insert(
Expand Down Expand Up @@ -402,7 +393,6 @@ async def insert_many(
)

@wrap_with_actions(EventTypes.REPLACE)
@swap_revision_after
@save_state_after
@validate_self_before
async def replace(
Expand Down Expand Up @@ -470,7 +460,8 @@ async def replace(
find_query: Dict[str, Any] = {"_id": self.id}

if use_revision_id and not ignore_revision:
find_query["revision_id"] = self._previous_revision_id
find_query["revision_id"] = self.revision_id
self.revision_id = uuid4()
try:
await self.find_one(find_query).replace_one(
self,
Expand Down Expand Up @@ -662,10 +653,11 @@ async def update(
find_query = {"_id": PydanticObjectId()}

if use_revision_id and not ignore_revision:
find_query["revision_id"] = self._previous_revision_id
find_query["revision_id"] = self.revision_id

if use_revision_id:
arguments.append(SetRevisionId(self.revision_id))
new_revision_id = uuid4()
arguments.append(SetRevisionId(new_revision_id))
try:
result = await self.find_one(find_query).update(
*arguments,
Expand Down Expand Up @@ -922,7 +914,7 @@ def _save_state(self) -> None:
self,
to_db=True,
keep_nulls=self.get_settings().keep_nulls,
exclude={"revision_id", "_previous_revision_id"},
exclude={"revision_id"},
)

def get_saved_state(self) -> Optional[Dict[str, Any]]:
Expand All @@ -946,7 +938,7 @@ def is_changed(self) -> bool:
self,
to_db=True,
keep_nulls=self.get_settings().keep_nulls,
exclude={"revision_id", "_previous_revision_id"},
exclude={"revision_id"},
):
return False
return True
Expand Down Expand Up @@ -1009,7 +1001,7 @@ def get_changes(self) -> Dict[str, Any]:
self,
to_db=True,
keep_nulls=self.get_settings().keep_nulls,
exclude={"revision_id", "_previous_revision_id"},
exclude={"revision_id"},
),
)

Expand Down
10 changes: 2 additions & 8 deletions beanie/odm/utils/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ def merge_models(left: BaseModel, right: BaseModel) -> None:
"""
from beanie.odm.fields import Link

if hasattr(left, "_previous_revision_id") and hasattr(
right, "_previous_revision_id"
):
left._previous_revision_id = right._previous_revision_id # type: ignore
for k, right_value in right.__iter__():
left_value = getattr(left, k)
if isinstance(right_value, BaseModel) and isinstance(
Expand All @@ -49,11 +45,9 @@ def merge_models(left: BaseModel, right: BaseModel) -> None:
left.__setattr__(k, right_value)


def save_state_swap_revision(item: BaseModel):
def save_state(item: BaseModel):
if hasattr(item, "_save_state"):
item._save_state() # type: ignore
if hasattr(item, "_swap_revision"):
item._swap_revision() # type: ignore


def parse_obj(
Expand Down Expand Up @@ -108,5 +102,5 @@ def parse_obj(
o._saved_state = {"_id": o.id}
return o
result = parse_model(model, data)
save_state_swap_revision(result)
save_state(result)
return result
10 changes: 0 additions & 10 deletions beanie/odm/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,3 @@ async def wrapper(self: "DocType", *args, **kwargs):
return result

return wrapper


def swap_revision_after(f: Callable):
@wraps(f)
async def wrapper(self: "DocType", *args, **kwargs):
result = await f(self, *args, **kwargs)
self._swap_revision()
return result

return wrapper
34 changes: 24 additions & 10 deletions tests/odm/documents/test_revision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from beanie import BulkWriter
from beanie.exceptions import RevisionIdWasChanged
from beanie.odm.operators.update.general import Inc
from tests.odm.models import DocumentWithRevisionTurnedOn
from tests.odm.models import (
DocumentWithRevisionTurnedOn,
LockWithRevision,
WindowWithRevision,
)


async def test_replace():
Expand All @@ -22,7 +26,7 @@ async def test_replace():
found_doc.num_1 += 1
await found_doc.replace()

doc._previous_revision_id = "wrong"
doc.revision_id = "wrong"
doc.num_1 = 4
with pytest.raises(RevisionIdWasChanged):
await doc.replace()
Expand All @@ -43,7 +47,7 @@ async def test_update():
found_doc = await DocumentWithRevisionTurnedOn.get(doc.id)
await found_doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1}))

doc._previous_revision_id = "wrong"
doc.revision_id = "wrong"
with pytest.raises(RevisionIdWasChanged):
await doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1}))

Expand All @@ -68,7 +72,7 @@ async def test_save_changes():
found_doc.num_1 += 1
await found_doc.save_changes()

doc._previous_revision_id = "wrong"
doc.revision_id = "wrong"
doc.num_1 = 4
with pytest.raises(RevisionIdWasChanged):
await doc.save_changes()
Expand All @@ -91,7 +95,7 @@ async def test_save():
found_doc.num_1 += 1
await found_doc.save()

doc._previous_revision_id = "wrong"
doc.revision_id = "wrong"
doc.num_1 = 4
with pytest.raises(RevisionIdWasChanged):
await doc.save()
Expand Down Expand Up @@ -122,7 +126,7 @@ async def test_update_bulk_writer():
async with BulkWriter() as bulk_writer:
await found_doc.save(bulk_writer=bulk_writer)

doc._previous_revision_id = "wrong"
doc.revision_id = "wrong"
doc.num_1 = 4
with pytest.raises(BulkWriteError):
async with BulkWriter() as bulk_writer:
Expand All @@ -144,11 +148,21 @@ async def test_save_changes_when_there_were_no_changes():
doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2)
await doc.insert()
revision = doc.revision_id
old_revision = doc._previous_revision_id

await doc.save_changes()
assert doc.revision_id == revision
assert doc._previous_revision_id == old_revision

doc = await DocumentWithRevisionTurnedOn.get(doc.id)
assert doc._previous_revision_id == old_revision
await DocumentWithRevisionTurnedOn.get(doc.id)
assert doc.revision_id == revision


async def test_revision_id_for_link():
lock = LockWithRevision(k=1)
await lock.insert()

lock_rev_id = lock.revision_id

window = WindowWithRevision(x=0, y=0, lock=lock)

await window.insert()
assert lock.revision_id == lock_rev_id
4 changes: 1 addition & 3 deletions tests/odm/test_state_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,7 @@ async def test_replace_save_previous(self, saved_doc_previous):
assert saved_doc_previous.get_saved_state()["num_1"] == 100
assert saved_doc_previous.get_previous_saved_state()["num_1"] == 1

async def test_exclude_revision_id_and_previous_revision_id(
self, saved_doc_previous
):
async def test_exclude_revision_id(self, saved_doc_previous):
saved_doc_previous.num_1 = 100
await saved_doc_previous.replace()

Expand Down

0 comments on commit 75ba9e3

Please sign in to comment.