From 2b31c042b8e030bbc5dce8ccd42cda6bee3a512d Mon Sep 17 00:00:00 2001
From: Hicham <hicham.taroq-ext@aphp.fr>
Date: Mon, 30 Dec 2024 17:39:18 +0100
Subject: [PATCH] feat(cohort): cohorts sampling

---
 ...017_cohortresult_parent_cohort_and_more.py |  24 +++
 cohort/models/cohort_result.py                |   2 +
 cohort/serializers.py                         |  34 ++++-
 cohort/services/cohort_result.py              |  12 +-
 cohort/tests/tests_view_cohort_result.py      | 140 +++++++++++-------
 cohort/views/cohort_result.py                 |  29 ++--
 cohort_job_server/cohort_creator.py           |  12 +-
 .../cohort_requests/base_cohort_request.py    |   9 +-
 cohort_job_server/sjs_api/schemas.py          |   2 +-
 9 files changed, 186 insertions(+), 78 deletions(-)
 create mode 100644 cohort/migrations/0017_cohortresult_parent_cohort_and_more.py

diff --git a/cohort/migrations/0017_cohortresult_parent_cohort_and_more.py b/cohort/migrations/0017_cohortresult_parent_cohort_and_more.py
new file mode 100644
index 00000000..ca70e82c
--- /dev/null
+++ b/cohort/migrations/0017_cohortresult_parent_cohort_and_more.py
@@ -0,0 +1,24 @@
+# Generated by Django 5.0.10 on 2024-12-30 12:57
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('cohort', '0016_fhirfilter_auto_generated_and_more'),
+    ]
+
+    operations = [
+        migrations.AddField(
+            model_name='cohortresult',
+            name='parent_cohort',
+            field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='sample_cohorts', to='cohort.cohortresult'),
+        ),
+        migrations.AddField(
+            model_name='cohortresult',
+            name='sampling_ratio',
+            field=models.FloatField(blank=True, null=True),
+        ),
+    ]
diff --git a/cohort/models/cohort_result.py b/cohort/models/cohort_result.py
index 10409072..c3b52d45 100644
--- a/cohort/models/cohort_result.py
+++ b/cohort/models/cohort_result.py
@@ -26,6 +26,8 @@ class CohortResult(CohortBaseModel, JobModel):
     create_task_id = models.TextField(blank=True)
     type = models.CharField(max_length=20, choices=COHORT_TYPES, default=MY_COHORTS_TYPE)
     is_subset = models.BooleanField(default=False)
+    parent_cohort = models.ForeignKey("CohortResult", related_name="sample_cohorts", on_delete=models.SET_NULL, null=True)
+    sampling_ratio = models.FloatField(blank=True, null=True)
 
     @property
     def result_size(self) -> int:
diff --git a/cohort/serializers.py b/cohort/serializers.py
index 9e83f5fd..28bb8fc8 100644
--- a/cohort/serializers.py
+++ b/cohort/serializers.py
@@ -83,8 +83,8 @@ class Meta:
 class CohortResultCreateSerializer(serializers.ModelSerializer):
     name = serializers.CharField(required=True)
     description = serializers.CharField(allow_blank=True, allow_null=True)
-    global_estimate = serializers.BooleanField(default=False)
-    request = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=Request.objects.all())
+    owner = UserPrimaryKeyRelatedField(queryset=User.objects.all(), required=False)
+    request = PrimaryKeyRelatedFieldWithOwner(required=False, queryset=Request.objects.all())
     request_query_snapshot = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=RequestQuerySnapshot.objects.all())
     dated_measure = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=DatedMeasure.objects.all())
 
@@ -92,12 +92,40 @@ class Meta:
         model = CohortResult
         fields = ["name",
                   "description",
-                  "global_estimate",
+                  "owner",
                   "request",
                   "request_query_snapshot",
                   "dated_measure"]
 
 
+class SampledCohortResultCreateSerializer(serializers.ModelSerializer):
+    name = serializers.CharField(required=True)
+    description = serializers.CharField(allow_blank=True, allow_null=True)
+    owner = UserPrimaryKeyRelatedField(queryset=User.objects.all(), required=False)
+    parent_cohort = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=CohortResult.objects.filter(parent_cohort__isnull=True))
+    sampling_ratio = serializers.FloatField(required=True)
+
+    class Meta:
+        model = CohortResult
+        fields = ["name",
+                  "description",
+                  "owner",
+                  "parent_cohort",
+                  "sampling_ratio"]
+
+    def validate_sampling_ratio(self, value):
+        if not 0 < value < 1:
+            raise serializers.ValidationError("Sampling ratio must be between 0 and 1")
+        return value
+
+    def create(self, validated_data):
+        parent_cohort = validated_data.get("parent_cohort")
+        validated_data.update({"request_query_snapshot": parent_cohort.request_query_snapshot,
+                               "dated_measure": parent_cohort.dated_measure
+                               })
+        return super().create(validated_data)
+
+
 class CohortResultPatchSerializer(serializers.ModelSerializer):
     name = serializers.CharField(required=False)
     description = serializers.CharField(required=False)
diff --git a/cohort/services/cohort_result.py b/cohort/services/cohort_result.py
index 6ec09892..2bf82a96 100644
--- a/cohort/services/cohort_result.py
+++ b/cohort/services/cohort_result.py
@@ -68,12 +68,12 @@ def handle_cohort_creation(self, cohort: CohortResult, request) -> None:
         if request.data.pop("global_estimate", False):
             dm_service.handle_global_count(cohort, request)
         try:
-            create_cohort.s(cohort_id=cohort.pk,
-                            json_query=cohort.request_query_snapshot.serialized_query,
-                            auth_headers=get_authorization_header(request),
-                            cohort_creator_cls=self.operator_cls) \
-                .apply_async()
-
+           create_cohort.s(cohort_id=cohort.pk,
+                           json_query=cohort.request_query_snapshot.serialized_query,
+                           auth_headers=get_authorization_header(request),
+                           cohort_creator_cls=self.operator_cls,
+                           sampling_ratio=cohort.sampling_ratio) \
+                        .apply_async()
         except Exception as e:
             cohort.delete()
             raise ServerError("Could not launch cohort creation") from e
diff --git a/cohort/tests/tests_view_cohort_result.py b/cohort/tests/tests_view_cohort_result.py
index f33746b1..92667e1c 100644
--- a/cohort/tests/tests_view_cohort_result.py
+++ b/cohort/tests/tests_view_cohort_result.py
@@ -188,32 +188,6 @@ def __init__(self, mock_create_task_called: bool, **kwargs):
 
 
 class CohortsCreateTests(CohortsTests):
-    @mock.patch('cohort.services.cohort_result.get_authorization_header')
-    @mock.patch('cohort.services.cohort_result.create_cohort.apply_async')
-    @mock.patch('cohort.services.dated_measure.count_cohort.apply_async')
-    def check_create_case_with_mock(self, case: CohortCreateCase, mock_count_task: MagicMock, mock_create_task: MagicMock,
-                                    mock_header: MagicMock, other_view: any, view_kwargs: dict):
-        mock_header.return_value = None
-        mock_create_task.return_value = None
-        mock_count_task.return_value = None
-
-        with self.captureOnCommitCallbacks(execute=True):
-            super(CohortsCreateTests, self).check_create_case(case, other_view, **(view_kwargs or {}))
-
-        if case.success:
-            inst = self.model_objects.filter(**case.retrieve_filter.args)\
-                                     .exclude(**case.retrieve_filter.exclude).first()
-            self.assertIsNotNone(inst.dated_measure)
-
-            if case.data.get('global_estimate'):
-                self.assertIsNotNone(inst.dated_measure_global)
-                mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called()
-
-        mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called()
-        mock_header.assert_called() if case.mock_create_task_called else mock_header.assert_not_called()
-
-    def check_create_case(self, case: CohortCreateCase, other_view: Any = None, **view_kwargs):
-        return self.check_create_case_with_mock(case, other_view=other_view or None, view_kwargs=view_kwargs)
 
     def setUp(self):
         super(CohortsCreateTests, self).setUp()
@@ -231,11 +205,17 @@ def setUp(self):
             favorite=True,
             request_query_snapshot=self.user1_req1_snap1.pk,
             dated_measure=self.user1_req1_snap1_dm.pk,
-            # group_id
-            # dated_measure_global
-            # create_task_id
-            # type
         )
+        parent_cohort = CohortResult.objects.create(name="Parent cohort for sampling",
+                                                    description="Parent cohort for sampling",
+                                                    owner=self.user1,
+                                                    request_query_snapshot=self.user1_req1_snap1,
+                                                    dated_measure=self.user1_req1_snap1_dm)
+
+        self.basic_data_for_sampled_cohort = dict(name="Sampled Cohort",
+                                                  description="Sampled Cohort",
+                                                  parent_cohort=parent_cohort.uuid,
+                                                  sampling_ratio=0.3)
         self.basic_case = CohortCreateCase(
             data=self.basic_data,
             status=status.HTTP_201_CREATED,
@@ -251,20 +231,42 @@ def setUp(self):
             status=status.HTTP_400_BAD_REQUEST,
         )
 
+    @mock.patch('cohort.services.cohort_result.get_authorization_header')
+    @mock.patch('cohort.services.cohort_result.create_cohort.apply_async')
+    @mock.patch('cohort.services.dated_measure.count_cohort.apply_async')
+    def check_create_case_with_mock(self, case: CohortCreateCase, mock_count_task: MagicMock, mock_create_task: MagicMock,
+                                    mock_header: MagicMock, other_view: any, view_kwargs: dict):
+        mock_header.return_value = None
+        mock_create_task.return_value = None
+        mock_count_task.return_value = None
+
+        with self.captureOnCommitCallbacks(execute=True):
+            super(CohortsCreateTests, self).check_create_case(case, other_view, **(view_kwargs or {}))
+
+        if case.success:
+            inst = self.model_objects.filter(**case.retrieve_filter.args)\
+                                     .exclude(**case.retrieve_filter.exclude).first()
+            self.assertIsNotNone(inst.dated_measure)
+
+            if case.data.get('global_estimate'):
+                self.assertIsNotNone(inst.dated_measure_global)
+                mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called()
+
+        mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called()
+        mock_header.assert_called() if case.mock_create_task_called else mock_header.assert_not_called()
+
+    def check_create_case(self, case: CohortCreateCase, other_view: Any = None, **view_kwargs):
+        return self.check_create_case_with_mock(case, other_view=other_view or None, view_kwargs=view_kwargs)
+
     def test_create(self):
-        # As a user, I can create a DatedMeasure with only RQS,
-        # it will launch a task
         self.check_create_case(self.basic_case)
 
     def test_create_with_global(self):
-        # As a user, I can create a DatedMeasure with only RQS,
-        # it will launch a task
         self.check_create_case(self.basic_case.clone(
             data={**self.basic_data, 'global_estimate': True}
         ))
 
     def test_create_with_unread_fields(self):
-        # As a user, I can create a dm
         self.check_create_case(self.basic_case.clone(
             data={**self.basic_data,
                   'create_task_id': random_str(5),
@@ -277,24 +279,13 @@ def test_create_with_unread_fields(self):
         ))
 
     def test_error_create_missing_field(self):
-        # As a user, I cannot create a dm if some field is missing
-        cases = (self.basic_err_case.clone(
-            data={**self.basic_data, k: None},
+        case = self.basic_err_case.clone(
+            data={**self.basic_data, "request_query_snapshot": None},
             success=False,
-            status=status.HTTP_400_BAD_REQUEST,
-        ) for k in ['request_query_snapshot'])
-        [self.check_create_case(case) for case in cases]
+            status=status.HTTP_400_BAD_REQUEST)
+        self.check_create_case(case)
 
-    def test_error_create_with_other_owner(self):
-        # As a user, I cannot create a cohort providing another user as owner
-        self.check_create_case(self.basic_err_case.clone(
-            data={**self.basic_data, 'owner': self.user2.pk},
-            status=status.HTTP_400_BAD_REQUEST,
-            success=False,
-        ))
-
-    def test_error_create_on_rqs_not_owned(self):
-        # As a user, I cannot create a dm on a Rqs I don't own
+    def test_error_create_with_rqs_not_owned(self):
         self.check_create_case(self.basic_err_case.clone(
             data={**self.basic_data,
                   'request_query_snapshot': self.user2_req1_snap1.pk},
@@ -302,6 +293,52 @@ def test_error_create_on_rqs_not_owned(self):
             status=status.HTTP_400_BAD_REQUEST,
         ))
 
+    def test_successfully_create_sampled_cohort(self):
+        self.check_create_case(
+            self.basic_case.clone(data=self.basic_data_for_sampled_cohort,
+                                  retrieve_filter=CohortCaseRetrieveFilter(name=self.basic_data_for_sampled_cohort.get("name"))
+                                  ))
+
+    def test_error_create_sampled_cohort_with_ratio_0(self):
+        data = {**self.basic_data_for_sampled_cohort,
+                "name": "Sampled Cohort with ratio 0.0",
+                "sampling_ratio": 0.0
+                }
+        self.check_create_case(
+            self.basic_err_case.clone(data=data,
+                                      retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name"))
+                                      ))
+
+    def test_error_create_sampled_cohort_with_ratio_1(self):
+        data = {**self.basic_data_for_sampled_cohort,
+                "name": "Sampled Cohort with ratio 1.0",
+                "sampling_ratio": 1.0
+                }
+        self.check_create_case(
+            self.basic_err_case.clone(data=data,
+                                      retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name"))
+                                      ))
+
+    def test_error_create_sampled_cohort_with_ratio_gt_1(self):
+        data = {**self.basic_data_for_sampled_cohort,
+                "name": "Sampled Cohort with ratio gt 1",
+                "sampling_ratio": 2.5
+                }
+        self.check_create_case(
+            self.basic_err_case.clone(data=data,
+                                      retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name"))
+                                      ))
+
+    def test_error_create_sampled_cohort_with_ratio_lt_0(self):
+        data = {**self.basic_data_for_sampled_cohort,
+                "name": "Sampled Cohort with ratio lt 0",
+                "sampling_ratio": -1.5
+                }
+        self.check_create_case(
+            self.basic_err_case.clone(data=data,
+                                      retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name"))
+                                      ))
+
 
 class CohortsDeleteTests(CohortsTests):
 
@@ -383,7 +420,6 @@ def test_update_cohort_as_owner(self, mock_patch_handler, mock_post_update, mock
         data_to_update = dict(name="new_name",
                               description="new_desc",
                               request_job_status=JobStatus.failed,
-                              request_job_fail_msg="test_fail_msg",
                               # read_only
                               create_task_id="test_task_id",
                               request_job_id="test_job_id",
diff --git a/cohort/views/cohort_result.py b/cohort/views/cohort_result.py
index 3d7a0c17..0fdef7a3 100644
--- a/cohort/views/cohort_result.py
+++ b/cohort/views/cohort_result.py
@@ -1,7 +1,7 @@
 from django.db import transaction
 from django.db.models import Q, F
 from django_filters import rest_framework as filters, OrderingFilter
-from drf_spectacular.utils import extend_schema
+from drf_spectacular.utils import extend_schema, PolymorphicProxySerializer
 from rest_framework import status
 from rest_framework.decorators import action
 from rest_framework.permissions import AllowAny
@@ -13,7 +13,7 @@
 from cohort.services.cohort_result import cohort_service
 from cohort.models import CohortResult
 from cohort.serializers import CohortResultSerializer, CohortResultSerializerFullDatedMeasure, CohortResultCreateSerializer, \
-    CohortResultPatchSerializer, CohortRightsSerializer
+    CohortResultPatchSerializer, CohortRightsSerializer, SampledCohortResultCreateSerializer
 from cohort.services.cohort_rights import cohort_rights_service
 from cohort.views.shared import UserObjectsRestrictedViewSet
 from exports.services.export import export_service
@@ -52,7 +52,8 @@ class Meta:
                   'favorite',
                   'group_id',
                   'request_id',
-                  'status')
+                  'status',
+                  'parent_cohort')
 
 
 class CohortResultViewSet(NestedViewSetMixin, UserObjectsRestrictedViewSet):
@@ -81,19 +82,27 @@ def get_queryset(self):
         return super().get_queryset().filter(is_subset=False)
 
     def get_serializer_class(self):
-        if self.request.method in ["POST", "PUT", "PATCH"] \
-                and "dated_measure" in self.request.data \
-                and isinstance(self.request.data["dated_measure"], dict) \
-                or self.request.method == "GET":
+        if self.request.method == "GET":
             return CohortResultSerializerFullDatedMeasure
-        return self.serializer_class
+        elif self.request.method == "PATCH":
+            return CohortResultPatchSerializer
+        elif self.request.method == "POST":
+            if CohortResult.sampling_ratio.field.name in self.request.data:
+                return SampledCohortResultCreateSerializer
+            return CohortResultCreateSerializer
+        else:
+            return self.serializer_class
 
     @cache_response()
     def list(self, request, *args, **kwargs):
         return super().list(request, *args, **kwargs)
 
-    @extend_schema(request=CohortResultCreateSerializer,
-                   responses={status.HTTP_201_CREATED: CohortResultSerializer})
+    @extend_schema(
+        request=PolymorphicProxySerializer(
+            component_name="CreateCohortResult",
+            resource_type_field_name=None,
+            serializers=[CohortResultCreateSerializer, SampledCohortResultCreateSerializer]),
+        responses={status.HTTP_201_CREATED: CohortResultSerializer})
     @transaction.atomic
     def create(self, request, *args, **kwargs):
         response = super().create(request, *args, **kwargs)
diff --git a/cohort_job_server/cohort_creator.py b/cohort_job_server/cohort_creator.py
index 18f89900..84c8c20f 100644
--- a/cohort_job_server/cohort_creator.py
+++ b/cohort_job_server/cohort_creator.py
@@ -12,15 +12,21 @@
 
 class CohortCreator(BaseCohortOperator):
 
-    def launch_cohort_creation(self, cohort_id: Optional[str], json_query: str, auth_headers: dict, callback_path: Optional[str] = None,
+    def launch_cohort_creation(self,
+                               cohort_id: Optional[str],
+                               json_query: str,
+                               auth_headers: dict,
+                               callback_path: Optional[str] = None,
                                existing_cohort_id: Optional[int] = None,
-                               owner_username: Optional[str] = None) -> None:
+                               owner_username: Optional[str] = None,
+                               sampling_ratio: Optional[float] = None) -> None:
         self.sjs_requester.launch_request(CohortCreate(instance_id=cohort_id,
                                                        json_query=json_query,
                                                        auth_headers=auth_headers,
                                                        callback_path=callback_path,
                                                        owner_username=owner_username,
-                                                       existing_cohort_id=existing_cohort_id
+                                                       existing_cohort_id=existing_cohort_id,
+                                                       sampling_ratio=sampling_ratio
                                                        ))
 
     @staticmethod
diff --git a/cohort_job_server/sjs_api/cohort_requests/base_cohort_request.py b/cohort_job_server/sjs_api/cohort_requests/base_cohort_request.py
index efe9168a..c60b7848 100644
--- a/cohort_job_server/sjs_api/cohort_requests/base_cohort_request.py
+++ b/cohort_job_server/sjs_api/cohort_requests/base_cohort_request.py
@@ -25,7 +25,8 @@ def __init__(self, mode: Mode,
                  auth_headers: dict,
                  callback_path: str = None,
                  existing_cohort_id: int = None,
-                 owner_username: str = None):
+                 owner_username: str = None,
+                 sampling_ratio: Optional[float] = None):
         self.mode = mode
         self.instance_id = instance_id
         self.json_query = json_query
@@ -34,6 +35,7 @@ def __init__(self, mode: Mode,
         self.callback_path = callback_path
         self.owner_username = owner_username
         self.existing_cohort_id = existing_cohort_id
+        self.sampling_ratio = sampling_ratio
 
     @staticmethod
     def is_cohort_request_pseudo_read(username: str, source_population: List[int]) -> bool:
@@ -55,12 +57,13 @@ def create_sjs_request(self, cohort_query: CohortQuery) -> str:
 
         callback_path = self.callback_path or (
                 self.mode == Mode.COUNT_WITH_DETAILS and f"/cohort/feasibility-studies/{cohort_query.instance_id}/" or None)
-        spark_job_request = SparkJobObject(cohort_definition_name="Created from Django",
+        spark_job_request = SparkJobObject(cohort_definition_name="Created from C360 backend",
                                            cohort_definition_syntax=cohort_query,
                                            mode=self.mode,
                                            owner_entity_id=self.owner_username,
                                            callbackPath=callback_path,
-                                           existingCohortId=self.existing_cohort_id
+                                           existingCohortId=self.existing_cohort_id,
+                                           samplingRatio=self.sampling_ratio
                                            )
         return format_spark_job_request_for_sjs(spark_job_request)
 
diff --git a/cohort_job_server/sjs_api/schemas.py b/cohort_job_server/sjs_api/schemas.py
index f9adcd84..6276446b 100644
--- a/cohort_job_server/sjs_api/schemas.py
+++ b/cohort_job_server/sjs_api/schemas.py
@@ -51,7 +51,6 @@ class TemporalConstraint(BaseModel):
     dates_are_not_null: list = Field(default=None, alias="dateIsNotNullList")
     filtered_criteria_id: list = Field(default=None, alias="filteredCriteriaIdList")
 
-
 class SourcePopulation(BaseModel):
     care_site_cohort_list: list[int] = Field(default_factory=list, alias="caresiteCohortList")
 
@@ -101,6 +100,7 @@ class SparkJobObject:
     owner_entity_id: str
     callbackPath: Optional[str] = Field(None, alias='callbackPath')
     existingCohortId: Optional[int] = Field(None, alias='existingCohortId')
+    samplingRatio: Optional[float] = Field(None, alias='samplingRatio')
 
 
 class FhirParameter(BaseModel):