diff --git a/cohort/migrations/0017_cohortresult_parent_cohort_and_more.py b/cohort/migrations/0017_cohortresult_parent_cohort_and_more.py new file mode 100644 index 00000000..ca70e82c --- /dev/null +++ b/cohort/migrations/0017_cohortresult_parent_cohort_and_more.py @@ -0,0 +1,24 @@ +# Generated by Django 5.0.10 on 2024-12-30 12:57 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('cohort', '0016_fhirfilter_auto_generated_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='cohortresult', + name='parent_cohort', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='sample_cohorts', to='cohort.cohortresult'), + ), + migrations.AddField( + model_name='cohortresult', + name='sampling_ratio', + field=models.FloatField(blank=True, null=True), + ), + ] diff --git a/cohort/models/cohort_result.py b/cohort/models/cohort_result.py index 10409072..c3b52d45 100644 --- a/cohort/models/cohort_result.py +++ b/cohort/models/cohort_result.py @@ -26,6 +26,8 @@ class CohortResult(CohortBaseModel, JobModel): create_task_id = models.TextField(blank=True) type = models.CharField(max_length=20, choices=COHORT_TYPES, default=MY_COHORTS_TYPE) is_subset = models.BooleanField(default=False) + parent_cohort = models.ForeignKey("CohortResult", related_name="sample_cohorts", on_delete=models.SET_NULL, null=True) + sampling_ratio = models.FloatField(blank=True, null=True) @property def result_size(self) -> int: diff --git a/cohort/serializers.py b/cohort/serializers.py index 9e83f5fd..28bb8fc8 100644 --- a/cohort/serializers.py +++ b/cohort/serializers.py @@ -83,8 +83,8 @@ class Meta: class CohortResultCreateSerializer(serializers.ModelSerializer): name = serializers.CharField(required=True) description = serializers.CharField(allow_blank=True, allow_null=True) - global_estimate = serializers.BooleanField(default=False) - request = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=Request.objects.all()) + owner = UserPrimaryKeyRelatedField(queryset=User.objects.all(), required=False) + request = PrimaryKeyRelatedFieldWithOwner(required=False, queryset=Request.objects.all()) request_query_snapshot = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=RequestQuerySnapshot.objects.all()) dated_measure = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=DatedMeasure.objects.all()) @@ -92,12 +92,40 @@ class Meta: model = CohortResult fields = ["name", "description", - "global_estimate", + "owner", "request", "request_query_snapshot", "dated_measure"] +class SampledCohortResultCreateSerializer(serializers.ModelSerializer): + name = serializers.CharField(required=True) + description = serializers.CharField(allow_blank=True, allow_null=True) + owner = UserPrimaryKeyRelatedField(queryset=User.objects.all(), required=False) + parent_cohort = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=CohortResult.objects.filter(parent_cohort__isnull=True)) + sampling_ratio = serializers.FloatField(required=True) + + class Meta: + model = CohortResult + fields = ["name", + "description", + "owner", + "parent_cohort", + "sampling_ratio"] + + def validate_sampling_ratio(self, value): + if not 0 < value < 1: + raise serializers.ValidationError("Sampling ratio must be between 0 and 1") + return value + + def create(self, validated_data): + parent_cohort = validated_data.get("parent_cohort") + validated_data.update({"request_query_snapshot": parent_cohort.request_query_snapshot, + "dated_measure": parent_cohort.dated_measure + }) + return super().create(validated_data) + + class CohortResultPatchSerializer(serializers.ModelSerializer): name = serializers.CharField(required=False) description = serializers.CharField(required=False) diff --git a/cohort/services/cohort_result.py b/cohort/services/cohort_result.py index 6ec09892..2bf82a96 100644 --- a/cohort/services/cohort_result.py +++ b/cohort/services/cohort_result.py @@ -68,12 +68,12 @@ def handle_cohort_creation(self, cohort: CohortResult, request) -> None: if request.data.pop("global_estimate", False): dm_service.handle_global_count(cohort, request) try: - create_cohort.s(cohort_id=cohort.pk, - json_query=cohort.request_query_snapshot.serialized_query, - auth_headers=get_authorization_header(request), - cohort_creator_cls=self.operator_cls) \ - .apply_async() - + create_cohort.s(cohort_id=cohort.pk, + json_query=cohort.request_query_snapshot.serialized_query, + auth_headers=get_authorization_header(request), + cohort_creator_cls=self.operator_cls, + sampling_ratio=cohort.sampling_ratio) \ + .apply_async() except Exception as e: cohort.delete() raise ServerError("Could not launch cohort creation") from e diff --git a/cohort/tests/tests_view_cohort_result.py b/cohort/tests/tests_view_cohort_result.py index f33746b1..92667e1c 100644 --- a/cohort/tests/tests_view_cohort_result.py +++ b/cohort/tests/tests_view_cohort_result.py @@ -188,32 +188,6 @@ def __init__(self, mock_create_task_called: bool, **kwargs): class CohortsCreateTests(CohortsTests): - @mock.patch('cohort.services.cohort_result.get_authorization_header') - @mock.patch('cohort.services.cohort_result.create_cohort.apply_async') - @mock.patch('cohort.services.dated_measure.count_cohort.apply_async') - def check_create_case_with_mock(self, case: CohortCreateCase, mock_count_task: MagicMock, mock_create_task: MagicMock, - mock_header: MagicMock, other_view: any, view_kwargs: dict): - mock_header.return_value = None - mock_create_task.return_value = None - mock_count_task.return_value = None - - with self.captureOnCommitCallbacks(execute=True): - super(CohortsCreateTests, self).check_create_case(case, other_view, **(view_kwargs or {})) - - if case.success: - inst = self.model_objects.filter(**case.retrieve_filter.args)\ - .exclude(**case.retrieve_filter.exclude).first() - self.assertIsNotNone(inst.dated_measure) - - if case.data.get('global_estimate'): - self.assertIsNotNone(inst.dated_measure_global) - mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called() - - mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called() - mock_header.assert_called() if case.mock_create_task_called else mock_header.assert_not_called() - - def check_create_case(self, case: CohortCreateCase, other_view: Any = None, **view_kwargs): - return self.check_create_case_with_mock(case, other_view=other_view or None, view_kwargs=view_kwargs) def setUp(self): super(CohortsCreateTests, self).setUp() @@ -231,11 +205,17 @@ def setUp(self): favorite=True, request_query_snapshot=self.user1_req1_snap1.pk, dated_measure=self.user1_req1_snap1_dm.pk, - # group_id - # dated_measure_global - # create_task_id - # type ) + parent_cohort = CohortResult.objects.create(name="Parent cohort for sampling", + description="Parent cohort for sampling", + owner=self.user1, + request_query_snapshot=self.user1_req1_snap1, + dated_measure=self.user1_req1_snap1_dm) + + self.basic_data_for_sampled_cohort = dict(name="Sampled Cohort", + description="Sampled Cohort", + parent_cohort=parent_cohort.uuid, + sampling_ratio=0.3) self.basic_case = CohortCreateCase( data=self.basic_data, status=status.HTTP_201_CREATED, @@ -251,20 +231,42 @@ def setUp(self): status=status.HTTP_400_BAD_REQUEST, ) + @mock.patch('cohort.services.cohort_result.get_authorization_header') + @mock.patch('cohort.services.cohort_result.create_cohort.apply_async') + @mock.patch('cohort.services.dated_measure.count_cohort.apply_async') + def check_create_case_with_mock(self, case: CohortCreateCase, mock_count_task: MagicMock, mock_create_task: MagicMock, + mock_header: MagicMock, other_view: any, view_kwargs: dict): + mock_header.return_value = None + mock_create_task.return_value = None + mock_count_task.return_value = None + + with self.captureOnCommitCallbacks(execute=True): + super(CohortsCreateTests, self).check_create_case(case, other_view, **(view_kwargs or {})) + + if case.success: + inst = self.model_objects.filter(**case.retrieve_filter.args)\ + .exclude(**case.retrieve_filter.exclude).first() + self.assertIsNotNone(inst.dated_measure) + + if case.data.get('global_estimate'): + self.assertIsNotNone(inst.dated_measure_global) + mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called() + + mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called() + mock_header.assert_called() if case.mock_create_task_called else mock_header.assert_not_called() + + def check_create_case(self, case: CohortCreateCase, other_view: Any = None, **view_kwargs): + return self.check_create_case_with_mock(case, other_view=other_view or None, view_kwargs=view_kwargs) + def test_create(self): - # As a user, I can create a DatedMeasure with only RQS, - # it will launch a task self.check_create_case(self.basic_case) def test_create_with_global(self): - # As a user, I can create a DatedMeasure with only RQS, - # it will launch a task self.check_create_case(self.basic_case.clone( data={**self.basic_data, 'global_estimate': True} )) def test_create_with_unread_fields(self): - # As a user, I can create a dm self.check_create_case(self.basic_case.clone( data={**self.basic_data, 'create_task_id': random_str(5), @@ -277,24 +279,13 @@ def test_create_with_unread_fields(self): )) def test_error_create_missing_field(self): - # As a user, I cannot create a dm if some field is missing - cases = (self.basic_err_case.clone( - data={**self.basic_data, k: None}, + case = self.basic_err_case.clone( + data={**self.basic_data, "request_query_snapshot": None}, success=False, - status=status.HTTP_400_BAD_REQUEST, - ) for k in ['request_query_snapshot']) - [self.check_create_case(case) for case in cases] + status=status.HTTP_400_BAD_REQUEST) + self.check_create_case(case) - def test_error_create_with_other_owner(self): - # As a user, I cannot create a cohort providing another user as owner - self.check_create_case(self.basic_err_case.clone( - data={**self.basic_data, 'owner': self.user2.pk}, - status=status.HTTP_400_BAD_REQUEST, - success=False, - )) - - def test_error_create_on_rqs_not_owned(self): - # As a user, I cannot create a dm on a Rqs I don't own + def test_error_create_with_rqs_not_owned(self): self.check_create_case(self.basic_err_case.clone( data={**self.basic_data, 'request_query_snapshot': self.user2_req1_snap1.pk}, @@ -302,6 +293,52 @@ def test_error_create_on_rqs_not_owned(self): status=status.HTTP_400_BAD_REQUEST, )) + def test_successfully_create_sampled_cohort(self): + self.check_create_case( + self.basic_case.clone(data=self.basic_data_for_sampled_cohort, + retrieve_filter=CohortCaseRetrieveFilter(name=self.basic_data_for_sampled_cohort.get("name")) + )) + + def test_error_create_sampled_cohort_with_ratio_0(self): + data = {**self.basic_data_for_sampled_cohort, + "name": "Sampled Cohort with ratio 0.0", + "sampling_ratio": 0.0 + } + self.check_create_case( + self.basic_err_case.clone(data=data, + retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name")) + )) + + def test_error_create_sampled_cohort_with_ratio_1(self): + data = {**self.basic_data_for_sampled_cohort, + "name": "Sampled Cohort with ratio 1.0", + "sampling_ratio": 1.0 + } + self.check_create_case( + self.basic_err_case.clone(data=data, + retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name")) + )) + + def test_error_create_sampled_cohort_with_ratio_gt_1(self): + data = {**self.basic_data_for_sampled_cohort, + "name": "Sampled Cohort with ratio gt 1", + "sampling_ratio": 2.5 + } + self.check_create_case( + self.basic_err_case.clone(data=data, + retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name")) + )) + + def test_error_create_sampled_cohort_with_ratio_lt_0(self): + data = {**self.basic_data_for_sampled_cohort, + "name": "Sampled Cohort with ratio lt 0", + "sampling_ratio": -1.5 + } + self.check_create_case( + self.basic_err_case.clone(data=data, + retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name")) + )) + class CohortsDeleteTests(CohortsTests): @@ -383,7 +420,6 @@ def test_update_cohort_as_owner(self, mock_patch_handler, mock_post_update, mock data_to_update = dict(name="new_name", description="new_desc", request_job_status=JobStatus.failed, - request_job_fail_msg="test_fail_msg", # read_only create_task_id="test_task_id", request_job_id="test_job_id", diff --git a/cohort/views/cohort_result.py b/cohort/views/cohort_result.py index 3d7a0c17..0fdef7a3 100644 --- a/cohort/views/cohort_result.py +++ b/cohort/views/cohort_result.py @@ -1,7 +1,7 @@ from django.db import transaction from django.db.models import Q, F from django_filters import rest_framework as filters, OrderingFilter -from drf_spectacular.utils import extend_schema +from drf_spectacular.utils import extend_schema, PolymorphicProxySerializer from rest_framework import status from rest_framework.decorators import action from rest_framework.permissions import AllowAny @@ -13,7 +13,7 @@ from cohort.services.cohort_result import cohort_service from cohort.models import CohortResult from cohort.serializers import CohortResultSerializer, CohortResultSerializerFullDatedMeasure, CohortResultCreateSerializer, \ - CohortResultPatchSerializer, CohortRightsSerializer + CohortResultPatchSerializer, CohortRightsSerializer, SampledCohortResultCreateSerializer from cohort.services.cohort_rights import cohort_rights_service from cohort.views.shared import UserObjectsRestrictedViewSet from exports.services.export import export_service @@ -52,7 +52,8 @@ class Meta: 'favorite', 'group_id', 'request_id', - 'status') + 'status', + 'parent_cohort') class CohortResultViewSet(NestedViewSetMixin, UserObjectsRestrictedViewSet): @@ -81,19 +82,27 @@ def get_queryset(self): return super().get_queryset().filter(is_subset=False) def get_serializer_class(self): - if self.request.method in ["POST", "PUT", "PATCH"] \ - and "dated_measure" in self.request.data \ - and isinstance(self.request.data["dated_measure"], dict) \ - or self.request.method == "GET": + if self.request.method == "GET": return CohortResultSerializerFullDatedMeasure - return self.serializer_class + elif self.request.method == "PATCH": + return CohortResultPatchSerializer + elif self.request.method == "POST": + if CohortResult.sampling_ratio.field.name in self.request.data: + return SampledCohortResultCreateSerializer + return CohortResultCreateSerializer + else: + return self.serializer_class @cache_response() def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) - @extend_schema(request=CohortResultCreateSerializer, - responses={status.HTTP_201_CREATED: CohortResultSerializer}) + @extend_schema( + request=PolymorphicProxySerializer( + component_name="CreateCohortResult", + resource_type_field_name=None, + serializers=[CohortResultCreateSerializer, SampledCohortResultCreateSerializer]), + responses={status.HTTP_201_CREATED: CohortResultSerializer}) @transaction.atomic def create(self, request, *args, **kwargs): response = super().create(request, *args, **kwargs) diff --git a/cohort_job_server/cohort_creator.py b/cohort_job_server/cohort_creator.py index 18f89900..84c8c20f 100644 --- a/cohort_job_server/cohort_creator.py +++ b/cohort_job_server/cohort_creator.py @@ -12,15 +12,21 @@ class CohortCreator(BaseCohortOperator): - def launch_cohort_creation(self, cohort_id: Optional[str], json_query: str, auth_headers: dict, callback_path: Optional[str] = None, + def launch_cohort_creation(self, + cohort_id: Optional[str], + json_query: str, + auth_headers: dict, + callback_path: Optional[str] = None, existing_cohort_id: Optional[int] = None, - owner_username: Optional[str] = None) -> None: + owner_username: Optional[str] = None, + sampling_ratio: Optional[float] = None) -> None: self.sjs_requester.launch_request(CohortCreate(instance_id=cohort_id, json_query=json_query, auth_headers=auth_headers, callback_path=callback_path, owner_username=owner_username, - existing_cohort_id=existing_cohort_id + existing_cohort_id=existing_cohort_id, + sampling_ratio=sampling_ratio )) @staticmethod diff --git a/cohort_job_server/sjs_api/cohort_requests/base_cohort_request.py b/cohort_job_server/sjs_api/cohort_requests/base_cohort_request.py index efe9168a..c60b7848 100644 --- a/cohort_job_server/sjs_api/cohort_requests/base_cohort_request.py +++ b/cohort_job_server/sjs_api/cohort_requests/base_cohort_request.py @@ -25,7 +25,8 @@ def __init__(self, mode: Mode, auth_headers: dict, callback_path: str = None, existing_cohort_id: int = None, - owner_username: str = None): + owner_username: str = None, + sampling_ratio: Optional[float] = None): self.mode = mode self.instance_id = instance_id self.json_query = json_query @@ -34,6 +35,7 @@ def __init__(self, mode: Mode, self.callback_path = callback_path self.owner_username = owner_username self.existing_cohort_id = existing_cohort_id + self.sampling_ratio = sampling_ratio @staticmethod def is_cohort_request_pseudo_read(username: str, source_population: List[int]) -> bool: @@ -55,12 +57,13 @@ def create_sjs_request(self, cohort_query: CohortQuery) -> str: callback_path = self.callback_path or ( self.mode == Mode.COUNT_WITH_DETAILS and f"/cohort/feasibility-studies/{cohort_query.instance_id}/" or None) - spark_job_request = SparkJobObject(cohort_definition_name="Created from Django", + spark_job_request = SparkJobObject(cohort_definition_name="Created from C360 backend", cohort_definition_syntax=cohort_query, mode=self.mode, owner_entity_id=self.owner_username, callbackPath=callback_path, - existingCohortId=self.existing_cohort_id + existingCohortId=self.existing_cohort_id, + samplingRatio=self.sampling_ratio ) return format_spark_job_request_for_sjs(spark_job_request) diff --git a/cohort_job_server/sjs_api/schemas.py b/cohort_job_server/sjs_api/schemas.py index f9adcd84..6276446b 100644 --- a/cohort_job_server/sjs_api/schemas.py +++ b/cohort_job_server/sjs_api/schemas.py @@ -51,7 +51,6 @@ class TemporalConstraint(BaseModel): dates_are_not_null: list = Field(default=None, alias="dateIsNotNullList") filtered_criteria_id: list = Field(default=None, alias="filteredCriteriaIdList") - class SourcePopulation(BaseModel): care_site_cohort_list: list[int] = Field(default_factory=list, alias="caresiteCohortList") @@ -101,6 +100,7 @@ class SparkJobObject: owner_entity_id: str callbackPath: Optional[str] = Field(None, alias='callbackPath') existingCohortId: Optional[int] = Field(None, alias='existingCohortId') + samplingRatio: Optional[float] = Field(None, alias='samplingRatio') class FhirParameter(BaseModel):