Skip to content

Commit

Permalink
Prevent duplicative merge attempts.
Browse files Browse the repository at this point in the history
  • Loading branch information
rebeccacremona committed Jan 26, 2024
1 parent 64981a3 commit 07cfb27
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
17 changes: 12 additions & 5 deletions web/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from django.core.exceptions import ValidationError
from django.core.paginator import Paginator, Page
from django.core.validators import MaxLengthValidator, validate_unicode_slug
from django.db import ProgrammingError, connection, models, transaction
from django.db import DatabaseError, ProgrammingError, connection, models, transaction
from django.db.models import Count, F, JSONField, Q, QuerySet

from django.template.defaultfilters import truncatechars
Expand Down Expand Up @@ -2964,10 +2964,17 @@ def merge_draft(self):
See main/test/test_drafts.py
"""
# set up variables
draft = self
if not self.is_draft:
raise ValueError("Only draft casebooks may be merged")
parent = self.draft_of
try:
# Technique from https://github.com/harvard-lil/capstone/blob/0f7fb80f26e753e36e0c7a6a199b8fdccdd318be/capstone/capapi/serializers.py#L121
#
# Fetch casebooks here inside a transaction, using select_for_update
# to lock the rows so we don't collide with any simultaneous requests
draft = Casebook.objects.select_for_update(nowait=True).get(pk=self.pk)
if not draft.is_draft:
raise ValueError("Only draft casebooks may be merged")
parent = Casebook.objects.select_for_update(nowait=True).get(draft=draft.id)
except DatabaseError:
raise ValueError("This casebook's draft is already being merged.")

# swap all attributes

Expand Down
27 changes: 27 additions & 0 deletions web/main/test/test_drafts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from concurrent.futures import ThreadPoolExecutor
from main.models import User, Section, Resource
from django.db import connections

from test.test_helpers import dump_casebook_outline

import pytest


def test_merge_drafts(reset_sequences, full_casebook, assert_num_queries, legal_document_factory):

Expand Down Expand Up @@ -59,3 +63,26 @@ def test_merge_drafts(reset_sequences, full_casebook, assert_num_queries, legal_

# Clones of the original casebook have proper attribution
assert elena in second_casebook.attributed_authors


@pytest.mark.django_db(transaction=True)
def test_duplicative_merge_prevented(full_casebook_with_draft):
""" Fetch two jobs at the same time in threads and make sure same job isn't returned to both. """
draft = full_casebook_with_draft.draft

def attempt_merge(i):
try:
draft.merge_draft()
return True
except Exception as e:
return e
finally:
for connection in connections.all():
connection.close()

with ThreadPoolExecutor(max_workers=2) as e:
results = e.map(attempt_merge, range(2))

first, second = list(results)
assert (first is True and "already being merged" in str(second)) or \
(second is True and "already being merged" in str(first))

0 comments on commit 07cfb27

Please sign in to comment.