From 648b159726c440d947f0a9db32738ffbdb3fbcf7 Mon Sep 17 00:00:00 2001 From: Dale Cannon Date: Tue, 19 Sep 2023 12:36:42 +0100 Subject: [PATCH] Prevent AdditionalCodeDescriptionCreate MultipleObjectError --- additional_codes/tests/test_views.py | 37 ++++++++++++++++++++++++++++ additional_codes/views.py | 4 +-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/additional_codes/tests/test_views.py b/additional_codes/tests/test_views.py index a92818196..2906be63b 100644 --- a/additional_codes/tests/test_views.py +++ b/additional_codes/tests/test_views.py @@ -4,9 +4,13 @@ import pytest from dateutil.relativedelta import relativedelta from django.core.exceptions import ValidationError +from django.urls import reverse from additional_codes.models import AdditionalCode +from additional_codes.models import AdditionalCodeDescription from additional_codes.views import AdditionalCodeList +from common.models import Transaction +from common.models.utils import override_current_transaction from common.tests import factories from common.tests.util import assert_model_view_renders from common.tests.util import assert_read_only_model_view_returns_list @@ -186,3 +190,36 @@ def test_additional_code_type_api_list_view(valid_user_client): expected_results, valid_user_client, ) + + +def test_additional_code_description_create(valid_user_client): + """Tests that `AdditionalCodeDescriptionCreate` view returns 200 and creates + a description for the current version of an additional code.""" + additional_code = factories.AdditionalCodeFactory.create() + new_version = additional_code.new_version( + workbasket=additional_code.transaction.workbasket, + ) + assert not AdditionalCodeDescription.objects.exists() + + url = reverse( + "additional_code-ui-description-create", + kwargs={"sid": new_version.sid}, + ) + data = { + "description": "new test description", + "described_additionalcode": new_version.pk, + "validity_start_0": 1, + "validity_start_1": 1, + "validity_start_2": 2023, + } + + with override_current_transaction(Transaction.objects.last()): + get_response = valid_user_client.get(url) + assert get_response.status_code == 200 + + post_response = valid_user_client.post(url, data) + assert post_response.status_code == 302 + + assert AdditionalCodeDescription.objects.filter( + described_additionalcode__sid=new_version.sid, + ).exists() diff --git a/additional_codes/views.py b/additional_codes/views.py index c1df992f9..3219817eb 100644 --- a/additional_codes/views.py +++ b/additional_codes/views.py @@ -170,7 +170,7 @@ class AdditionalCodeCreateDescriptionMixin: def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) - context["described_object"] = AdditionalCode.objects.get( + context["described_object"] = AdditionalCode.objects.current().get( sid=(self.kwargs.get("sid")), ) return context @@ -193,7 +193,7 @@ class AdditionalCodeDescriptionCreate( def get_initial(self): initial = super().get_initial() - initial["described_additionalcode"] = AdditionalCode.objects.get( + initial["described_additionalcode"] = AdditionalCode.objects.current().get( sid=(self.kwargs.get("sid")), ) return initial