From f6387161bc0e213299ad34b4fb7f9b2ac36e5a60 Mon Sep 17 00:00:00 2001 From: zawan-ila <87228907+zawan-ila@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:27:09 +0500 Subject: [PATCH] feat: add CRUD support for RestrictedCourseRun in CourseRun Api (#4331) --- course_discovery/apps/api/serializers.py | 6 +- .../apps/api/tests/test_serializers.py | 5 + .../v1/tests/test_views/test_course_runs.py | 141 +++++++++++++++++- .../api/v1/tests/test_views/test_courses.py | 31 +++- .../apps/api/v1/views/course_runs.py | 11 +- .../apps/course_metadata/models.py | 29 ++++ .../apps/course_metadata/tests/test_utils.py | 21 ++- .../apps/course_metadata/utils.py | 13 +- 8 files changed, 235 insertions(+), 22 deletions(-) diff --git a/course_discovery/apps/api/serializers.py b/course_discovery/apps/api/serializers.py index 0ce31f8819..b0cff4420e 100644 --- a/course_discovery/apps/api/serializers.py +++ b/course_discovery/apps/api/serializers.py @@ -922,6 +922,7 @@ class MinimalCourseRunSerializer(FlexFieldsSerializerMixin, TimestampModelSerial queryset=CourseRunType.objects.all()) term = serializers.CharField(required=False, write_only=True) variant_id = serializers.UUIDField(allow_null=True, required=False) + restriction_type = serializers.CharField(source='restricted_run.restriction_type', read_only=True) @classmethod def prefetch_queryset(cls, queryset=None): @@ -932,6 +933,7 @@ def prefetch_queryset(cls, queryset=None): return queryset.select_related('course', 'type').prefetch_related( '_official_version', 'course__partner', + 'restricted_run', Prefetch('seats', queryset=SeatSerializer.prefetch_queryset()), ) @@ -939,8 +941,8 @@ class Meta: model = CourseRun fields = ('key', 'uuid', 'title', 'external_key', 'image', 'short_description', 'marketing_url', 'seats', 'start', 'end', 'go_live_date', 'enrollment_start', 'enrollment_end', 'weeks_to_complete', - 'pacing_type', 'type', 'run_type', 'status', 'is_enrollable', 'is_marketable', 'term', 'availability', - 'variant_id') + 'pacing_type', 'type', 'restriction_type', 'run_type', 'status', 'is_enrollable', 'is_marketable', + 'term', 'availability', 'variant_id') def get_marketing_url(self, obj): include_archived = self.context.get('include_archived') diff --git a/course_discovery/apps/api/tests/test_serializers.py b/course_discovery/apps/api/tests/test_serializers.py index c83bd84558..29f553c19d 100644 --- a/course_discovery/apps/api/tests/test_serializers.py +++ b/course_discovery/apps/api/tests/test_serializers.py @@ -638,6 +638,11 @@ def get_expected_data(cls, course_run, request): 'is_marketable': course_run.is_marketable, 'availability': course_run.availability, 'variant_id': str(course_run.variant_id), + 'restriction_type': ( + course_run.restricted_run.restriction_type + if hasattr(course_run, 'restricted_run') + else None + ) } diff --git a/course_discovery/apps/api/v1/tests/test_views/test_course_runs.py b/course_discovery/apps/api/v1/tests/test_views/test_course_runs.py index 3215065c74..f997d5b9c7 100644 --- a/course_discovery/apps/api/v1/tests/test_views/test_course_runs.py +++ b/course_discovery/apps/api/v1/tests/test_views/test_course_runs.py @@ -21,14 +21,14 @@ from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, OAuth2Mixin, SerializationMixin from course_discovery.apps.core.tests.factories import UserFactory from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin -from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus -from course_discovery.apps.course_metadata.models import CourseRun, CourseRunType, Seat, SeatType +from course_discovery.apps.course_metadata.choices import CourseRunRestrictionType, CourseRunStatus, ProgramStatus +from course_discovery.apps.course_metadata.models import CourseRun, CourseRunType, RestrictedCourseRun, Seat, SeatType from course_discovery.apps.course_metadata.signals import ( connect_course_data_modified_timestamp_signal_handlers, disconnect_course_data_modified_timestamp_signal_handlers ) from course_discovery.apps.course_metadata.tests.factories import ( CourseEditorFactory, CourseFactory, CourseRunFactory, CourseRunTypeFactory, CourseTypeFactory, OrganizationFactory, - PersonFactory, ProgramFactory, SeatFactory, SourceFactory, SubjectFactory, TrackFactory + PersonFactory, ProgramFactory, RestrictedCourseRunFactory, SeatFactory, SourceFactory, SubjectFactory, TrackFactory ) from course_discovery.apps.course_metadata.toggles import ( IS_COURSE_RUN_VARIANT_ID_EDITABLE, IS_SUBDIRECTORY_SLUG_FORMAT_ENABLED @@ -182,6 +182,7 @@ def test_create_minimum(self): }, format='json') assert response.status_code == 201 new_course_run = CourseRun.everything.get(key=new_key) + assert RestrictedCourseRun.everything.count() == 0 self.assertDictEqual(response.data, self.serialize_course_run(new_course_run)) assert new_course_run.pacing_type == 'instructor_paced' # default we provide @@ -266,11 +267,16 @@ def test_create_sets_canonical_course_run(self, has_canonical_run): @responses.activate def test_create_sets_additional_fields(self): - """ Verify that instructors, languages, min & max effort, and weeks to complete are set on a rerun. """ - self.draft_course_run.staff.add(PersonFactory()) + """ + Verify that instructors, languages, min & max effort, and weeks to complete are set on a rerun. + Verify that the course run restriction is not copied to a rerun + """ self.draft_course_run.transcript_languages.add(self.draft_course_run.language) self.draft_course_run.save() - + RestrictedCourseRun.everything.create( + course_run=self.draft_course_run, + restriction_type=CourseRunRestrictionType.CustomB2BEnterprise.value + ) # Create rerun based on draft course course = self.draft_course_run.course new_key = f'course-v1:{course.key_for_reruns}+1T2000' @@ -295,6 +301,7 @@ def test_create_sets_additional_fields(self): assert list(new_course_run.staff.all()) == list(self.draft_course_run.staff.all()) assert new_course_run.language == self.draft_course_run.language assert list(new_course_run.transcript_languages.all()) == list(self.draft_course_run.transcript_languages.all()) + assert not hasattr(new_course_run, "restricted_run") @freeze_time("2022-01-14 12:00:01") @ddt.data(True, False, "bogus") @@ -320,6 +327,40 @@ def test_create_draft_ignored(self, draft): self.assertDictEqual(response.data, self.serialize_course_run(new_course_run)) assert new_course_run.draft + @ddt.data( + True, + False + ) + @freeze_time("2022-01-14 12:00:01") + @responses.activate + def test_create_restriction_type(self, is_restricted): + """ Verify the endpoint supports creating a course_run with a restriction_type. """ + course = self.draft_course_run.course + new_key = f'course-v1:{course.key_for_reruns}+1T2000' + self.mock_post_to_studio(new_key) + url = reverse('api:v1:course_run-list') + + post_data = { + 'course': course.key, + 'start': '2000-01-01T00:00:00Z', + 'end': '2001-01-01T00:00:00Z', + 'run_type': str(self.course_run_type.uuid), + } + + if is_restricted: + post_data['restriction_type'] = 'custom-b2c' + + response = self.client.post(url, post_data, format='json') + assert response.status_code == 201 + new_course_run = CourseRun.everything.get(key=new_key) + assert RestrictedCourseRun.everything.count() == ( + 1 if is_restricted else 0 + ) + if is_restricted: + assert new_course_run.restricted_run == RestrictedCourseRun.everything.get() + assert response.data['restriction_type'] == 'custom-b2c' + self.assertDictEqual(response.data, self.serialize_course_run(new_course_run)) + @freeze_time("2022-01-14 12:00:01") @responses.activate def test_create_using_type_with_price(self): @@ -526,6 +567,7 @@ def test_partial_update(self): 'max_effort': expected_max_effort, 'min_effort': expected_min_effort, 'variant_id': variant_id, + 'restriction_type': CourseRunRestrictionType.CustomB2BEnterprise.value } # Update this course_run with the new info @@ -539,6 +581,7 @@ def test_partial_update(self): assert self.draft_course_run.max_effort == expected_max_effort assert self.draft_course_run.min_effort == expected_min_effort assert self.draft_course_run.variant_id == prev_variant_id + assert self.draft_course_run.restricted_run == RestrictedCourseRun.everything.get() def test_partial_update_with_waffle_switch_variant_id_editable_enable(self): """ @@ -622,6 +665,10 @@ def test_partial_update_bad_permission(self): {'min_effort': 10000, 'max_effort': 10000}, 'Minimum effort and Maximum effort cannot be the same', ), + ( + {'restriction_type': 'foobar'}, + 'Not a valid choice for restriction_type' + ) ) @ddt.unpack def test_partial_update_common_errors(self, data, error): @@ -725,9 +772,10 @@ def test_patch_put_restrict_when_reviewing(self, status): 'start': self.draft_course_run.start, # required, so we need for a put 'end': self.draft_course_run.end, # required, so we need for a put 'run_type': str(self.draft_course_run.type.uuid), # required, so we need for a put + 'restriction_type': 'custom-b2c', }, format='json') assert response.status_code == 403 - + assert RestrictedCourseRun.everything.count() == 0 response = self.client.patch(url, {}, format='json') assert response.status_code == 403 @@ -902,6 +950,85 @@ def test_patch_published(self): assert draft_run.end == updated_end assert official_run.end == updated_end + @ddt.data( + (True, {'restriction_type': 'custom-b2b-enterprise'}, 'custom-b2b-enterprise'), + (True, {'restriction_type': ''}, None), + (True, {}, 'custom-b2c',), + (False, {'restriction_type': 'custom-b2c'}, 'custom-b2c'), + (False, {}, None), + ) + @ddt.unpack + @responses.activate + def test_patch_restriction_type(self, is_restricted, patch_data, changed_restriction_value): + """ + is_restriced: indicates if the run is restricted when the test starts. + If so, it is assigned the CustomB2C restriction + patch_data: data for the patch request + changed_restriction_value: expected restriction value after the patch request + + This test proceeds in 4 steps. First, we patch the course run and verify if + the restriction is created/updated successfully. Then we patch the course run to + the published state and verify that the restrictions are copied over to the official + versions. Thirdly, we update the restriction_type in the published state to verify changes + made in the published state. Finally, we remove the restrictions entirely. + """ + + self.mock_patch_to_studio(self.draft_course_run.key) + self.mock_ecommerce_publication() + + if is_restricted: + RestrictedCourseRunFactory( + course_run=self.draft_course_run, + restriction_type=CourseRunRestrictionType.CustomB2C.value, + draft=True + ) + url = reverse('api:v1:course_run-detail', kwargs={'key': self.draft_course_run.key}) + + response = self.client.patch(url, patch_data, format='json') + assert response.status_code == 200, f"Status {response.status_code}: {response.content}" + + self.draft_course_run.refresh_from_db() + assert hasattr(self.draft_course_run, 'restricted_run') == bool(changed_restriction_value) + if changed_restriction_value: + assert self.draft_course_run.restricted_run.restriction_type == changed_restriction_value + assert RestrictedCourseRun.everything.count() == 1 + else: + assert RestrictedCourseRun.everything.count() == 0 + + # Publish the course run and verify that official versions + # of RestrictedCourseRuns are created + self.draft_course_run.status = CourseRunStatus.InternalReview + self.draft_course_run.save() + assert CourseRun.objects.filter(key=self.draft_course_run.key, draft=False).count() == 0 + response = self.client.patch(url, {'status': 'reviewed'}, format='json') + assert response.status_code == 200, f"Status {response.status_code}: {response.content}" + + official_run = CourseRun.everything.get(key=self.draft_course_run.key, draft=False) + draft_run = official_run.draft_version + assert draft_run == self.draft_course_run + if not changed_restriction_value: + assert RestrictedCourseRun.everything.count() == 0 + else: + assert RestrictedCourseRun.everything.count() == 2 + assert official_run.restricted_run.draft_version == draft_run.restricted_run + assert official_run.restricted_run.restriction_type == draft_run.restricted_run.restriction_type + assert draft_run.restricted_run.restriction_type == changed_restriction_value + + # Make Changes while Published + response = self.client.patch(url, {'restriction_type': 'custom-b2b-enterprise', 'draft': False}, format='json') + assert response.status_code == 200, f"Status {response.status_code}: {response.content}" + official_run = CourseRun.everything.get(key=self.draft_course_run.key, draft=False) + draft_run = official_run.draft_version + assert RestrictedCourseRun.everything.count() == 2 + assert official_run.restricted_run.draft_version == draft_run.restricted_run + assert official_run.restricted_run.restriction_type == draft_run.restricted_run.restriction_type + assert draft_run.restricted_run.restriction_type == 'custom-b2b-enterprise' + + # Remove the restrictions + response = self.client.patch(url, {'restriction_type': '', 'draft': False}, format='json') + assert response.status_code == 200, f"Status {response.status_code}: {response.content}" + assert RestrictedCourseRun.everything.count() == 0 + def create_course_and_run_types(self, seat_type): tracks = [] entitlement_types = [] diff --git a/course_discovery/apps/api/v1/tests/test_views/test_courses.py b/course_discovery/apps/api/v1/tests/test_views/test_courses.py index cbcc314dc3..ae5a2018f6 100644 --- a/course_discovery/apps/api/v1/tests/test_views/test_courses.py +++ b/course_discovery/apps/api/v1/tests/test_views/test_courses.py @@ -25,8 +25,8 @@ from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus from course_discovery.apps.course_metadata.models import ( AbstractLocationRestrictionModel, AdditionalMetadata, CertificateInfo, Course, CourseEditor, CourseEntitlement, - CourseLocationRestriction, CourseRun, CourseRunType, CourseType, Fact, GeoLocation, ProductMeta, ProductValue, Seat, - Source + CourseLocationRestriction, CourseRun, CourseRunType, CourseType, Fact, GeoLocation, ProductMeta, ProductValue, + RestrictedCourseRun, Seat, Source ) from course_discovery.apps.course_metadata.signals import ( additional_metadata_facts_changed, connect_course_data_modified_timestamp_signal_handlers, @@ -769,12 +769,30 @@ def test_create_makes_editor(self): CourseEditor.objects.get(user=self.user, course=course) assert CourseEditor.objects.count() == 1 - def test_create_makes_course_and_course_run(self): + @ddt.data( + ({'restriction_type': 'custom-b2c'}, True), + ({}, False), + ) + @ddt.unpack + def test_create_makes_course_and_course_run(self, restriction_data, is_run_restricted): """ When creating a course and supplying a course_run, it should create both the course and course run as drafts. When mode = 'audit', an audit seat should also be created. + + The is_run_restricted param specifies if the created course run is restricted i.e has + an associated RestrictedCourseRun object """ - response = self.create_course_and_course_run() + + data = { + "course_run": { + "start": "2001-01-01T00:00:00Z", + "end": datetime.datetime.now() + datetime.timedelta(days=1), + "run_type": str(CourseRunType.objects.get(slug=CourseRunType.AUDIT).uuid), + } + } + data['course_run'] = {**data['course_run'], **restriction_data} + + response = self.create_course_and_course_run(data=data) assert response.status_code == 201 course = Course.everything.last() @@ -784,6 +802,11 @@ def test_create_makes_course_and_course_run(self): assert course_run.draft assert course_run.course == course + assert hasattr(course_run, 'restricted_run') == is_run_restricted + assert RestrictedCourseRun.everything.count() == ( + 1 if is_run_restricted else 0 + ) + # Creating with mode = 'audit' should also create an audit seat assert 1 == Seat.everything.count() seat = course_run.seats.first() diff --git a/course_discovery/apps/api/v1/views/course_runs.py b/course_discovery/apps/api/v1/views/course_runs.py index a7f9369694..cb8fe47bac 100644 --- a/course_discovery/apps/api/v1/views/course_runs.py +++ b/course_discovery/apps/api/v1/views/course_runs.py @@ -216,7 +216,7 @@ def create_run_helper(self, run_data, request=None): run_data.pop('draft', None) prices = run_data.pop('prices', {}) - + restriction_type = run_data.pop('restriction_type', None) # Grab any existing course run for this course (we'll use it when talking to studio to form basis of rerun) course_key = run_data.get('course', None) # required field if not course_key: @@ -235,6 +235,7 @@ def create_run_helper(self, run_data, request=None): course_run = serializer.save(draft=True) course_run.update_or_create_seats(course_run.type, prices) + course_run.update_or_create_restriction(restriction_type) # Set canonical course run if needed (done this way to match historical behavior - but shouldn't this be # updated *each* time we make a new run?) @@ -273,7 +274,7 @@ def create(self, request, *args, **kwargs): return response @writable_request_wrapper - def _update_course_run(self, course_run, draft, changed, serializer, request, prices, upgrade_deadline_override): + def _update_course_run(self, course_run, draft, changed, serializer, request, prices, upgrade_deadline_override, restriction_type=None): # pylint: disable=line-too-long save_kwargs = {} # If changes are made after review and before publish, revert status to unpublished. # Unless we're just switching the status @@ -293,7 +294,7 @@ def _update_course_run(self, course_run, draft, changed, serializer, request, pr if course_run in course_run.course.active_course_runs: course_run.update_or_create_seats(course_run.type, prices, upgrade_deadline_override,) - + course_run.update_or_create_restriction(restriction_type) self.push_to_studio(request, course_run, create=False) # Published course runs can be re-published directly or course runs that remain in the Reviewed @@ -333,6 +334,8 @@ def update(self, request, **kwargs): # Sending draft=False triggers the review process for unpublished courses draft = request.data.pop('draft', True) # Don't let draft parameter trickle down prices = request.data.pop('prices', {}) + restriction_type = request.data.pop('restriction_type', None) + upgrade_deadline_override = request.data.pop('upgrade_deadline_override', None) \ if self.request.user.is_staff else None @@ -367,7 +370,7 @@ def update(self, request, **kwargs): CourseRun.STATUS_CHANGE_EXEMPT_FIELDS ) response = self._update_course_run(course_run, draft, bool(changed_fields), - serializer, request, prices, upgrade_deadline_override,) + serializer, request, prices, upgrade_deadline_override, restriction_type) self.update_course_run_image_in_studio(course_run) diff --git a/course_discovery/apps/course_metadata/models.py b/course_discovery/apps/course_metadata/models.py index cdc31346a9..314f557074 100644 --- a/course_discovery/apps/course_metadata/models.py +++ b/course_discovery/apps/course_metadata/models.py @@ -2656,6 +2656,32 @@ def update_or_create_seats(self, run_type=None, prices=None, upgrade_deadline_ov self.seats.exclude(type__in=seat_types).delete() self.seats.set(seats) + def update_or_create_restriction(self, restriction_type): + """ + Updates or creates a CourseRunRestriction object for a draft course run + """ + + if restriction_type is None: + return + + if restriction_type and restriction_type not in CourseRunRestrictionType.values: + raise Exception('Not a valid choice for restriction_type') + + if not restriction_type and hasattr(self, "restricted_run"): + self.restricted_run.delete() + official_obj = self.official_version + if hasattr(official_obj, "restricted_run"): + official_obj.restricted_run.delete() + elif restriction_type: + RestrictedCourseRun.everything.update_or_create( + course_run=self, + draft=True, + defaults={ + 'restriction_type': restriction_type, + } + ) + self.refresh_from_db() + def update_or_create_official_version(self, notify_services=True): draft_version = CourseRun.everything.get(pk=self.pk) official_version = set_official_state(draft_version, CourseRun) @@ -2663,6 +2689,9 @@ def update_or_create_official_version(self, notify_services=True): for seat in self.seats.all(): set_official_state(seat, Seat, {'course_run': official_version}) + if hasattr(self, "restricted_run"): + set_official_state(self.restricted_run, RestrictedCourseRun, {'course_run': official_version}) + official_course = self.course._update_or_create_official_version(official_version) # pylint: disable=protected-access official_version.slug = self.slug official_version.course = official_course diff --git a/course_discovery/apps/course_metadata/tests/test_utils.py b/course_discovery/apps/course_metadata/tests/test_utils.py index 67cf955213..9898f953c5 100644 --- a/course_discovery/apps/course_metadata/tests/test_utils.py +++ b/course_discovery/apps/course_metadata/tests/test_utils.py @@ -28,13 +28,13 @@ EcommerceSiteAPIClientException, MarketingSiteAPIClientException ) from course_discovery.apps.course_metadata.models import ( - Course, CourseEditor, CourseRun, CourseType, CourseUrlSlug, Seat, SeatType, Track + Course, CourseEditor, CourseRun, CourseType, CourseUrlSlug, RestrictedCourseRun, Seat, SeatType, Track ) from course_discovery.apps.course_metadata.tests.constants import MOCK_PRODUCTS_DATA from course_discovery.apps.course_metadata.tests.factories import ( CourseEditorFactory, CourseEntitlementFactory, CourseFactory, CourseRunFactory, CourseTypeFactory, ModeFactory, - OrganizationFactory, OrganizationMappingFactory, PartnerFactory, ProgramFactory, SeatFactory, SeatTypeFactory, - SourceFactory, SubjectFactory + OrganizationFactory, OrganizationMappingFactory, PartnerFactory, ProgramFactory, RestrictedCourseRunFactory, + SeatFactory, SeatTypeFactory, SourceFactory, SubjectFactory ) from course_discovery.apps.course_metadata.tests.mixins import MarketingSiteAPIClientTestMixin from course_discovery.apps.course_metadata.toggles import ( @@ -557,12 +557,18 @@ def test_ensure_draft_world_not_draft_course_run_given(self): assert not_draft_course_run.draft_version == ensured_draft_course_run def test_ensure_draft_world_not_draft_course_given(self): + # pylint: disable=undefined-loop-variable course = CourseFactory() entitlement = CourseEntitlementFactory(course=course) course.entitlements.add(entitlement) course_runs = CourseRunFactory.create_batch(3, course=course) for run in course_runs: course.course_runs.add(run) + + RestrictedCourseRunFactory( + course_run=run, + restriction_type='custom-b2c' + ) course.canonical_course_run = course_runs[0] course.save() org = OrganizationFactory() @@ -616,6 +622,15 @@ def test_ensure_draft_world_not_draft_course_given(self): assert draft_entitlement.official_version == not_draft_entitlement assert not_draft_entitlement.draft_version == draft_entitlement + # Check restricted runs + run.refresh_from_db() + draft_run = run.draft_version + assert draft_run.restricted_run == run.restricted_run.draft_version + assert draft_run.restricted_run.restriction_type == 'custom-b2c' + assert run.restricted_run.restriction_type == 'custom-b2c' + assert RestrictedCourseRun.objects.count() == 1 + assert RestrictedCourseRun.everything.count() == 2 + def test_ensure_draft_world_creates_course_entitlement_from_seats(self): """ If the official course has no entitlement, an entitlement is created from the seat data from active runs. diff --git a/course_discovery/apps/course_metadata/utils.py b/course_discovery/apps/course_metadata/utils.py index 3c945db836..4c3549f961 100644 --- a/course_discovery/apps/course_metadata/utils.py +++ b/course_discovery/apps/course_metadata/utils.py @@ -85,7 +85,7 @@ def set_official_state(obj, model, attrs=None): the official version of that object with the attributes updated to attrs """ # pylint: disable=import-outside-toplevel - from course_discovery.apps.course_metadata.models import Course, CourseRun + from course_discovery.apps.course_metadata.models import Course, CourseRun, RestrictedCourseRun # This is so we don't create the marketing node with an incorrect slug. # We correct the slug after setting official state, but the AutoSlugField initially overwrites it. @@ -102,6 +102,11 @@ def set_official_state(obj, model, attrs=None): obj.draft_version = draft_version if isinstance(obj, Course): obj.canonical_course_run = official_obj.canonical_course_run if official_obj else None + # If an object has a one-to-one field, it is necessary to change that field + # before officiating(saving) it here, else we'll get an IntegrityError + # because of two different objects "being" 1-1 connected with the same object + if isinstance(obj, RestrictedCourseRun): + obj.course_run = attrs['course_run'] obj.save(**save_kwargs) official_obj = obj # Copy many-to-many fields manually (they are not copied by the pk trick above). @@ -264,7 +269,9 @@ def ensure_draft_world(obj): obj (Model object): The returned object will be the draft version on the input object. """ # pylint: disable=import-outside-toplevel - from course_discovery.apps.course_metadata.models import Course, CourseEntitlement, CourseRun, Seat + from course_discovery.apps.course_metadata.models import ( + Course, CourseEntitlement, CourseRun, RestrictedCourseRun, Seat + ) if obj.draft: return obj elif obj.draft_version: @@ -294,6 +301,8 @@ def ensure_draft_world(obj): for seat in original_run.seats.all(): set_draft_state(seat, Seat, {'course_run': draft_run}) + if hasattr(original_run, "restricted_run"): + set_draft_state(original_run.restricted_run, RestrictedCourseRun, {'course_run': draft_run}) if original_course.canonical_course_run and draft_run.uuid == original_course.canonical_course_run.uuid: draft_course.canonical_course_run = draft_run