Skip to content

Commit

Permalink
feat(cohort): cohorts sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Hicham committed Dec 30, 2024
1 parent bce7730 commit 2b31c04
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 78 deletions.
24 changes: 24 additions & 0 deletions cohort/migrations/0017_cohortresult_parent_cohort_and_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 5.0.10 on 2024-12-30 12:57

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('cohort', '0016_fhirfilter_auto_generated_and_more'),
]

operations = [
migrations.AddField(
model_name='cohortresult',
name='parent_cohort',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='sample_cohorts', to='cohort.cohortresult'),
),
migrations.AddField(
model_name='cohortresult',
name='sampling_ratio',
field=models.FloatField(blank=True, null=True),
),
]
2 changes: 2 additions & 0 deletions cohort/models/cohort_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class CohortResult(CohortBaseModel, JobModel):
create_task_id = models.TextField(blank=True)
type = models.CharField(max_length=20, choices=COHORT_TYPES, default=MY_COHORTS_TYPE)
is_subset = models.BooleanField(default=False)
parent_cohort = models.ForeignKey("CohortResult", related_name="sample_cohorts", on_delete=models.SET_NULL, null=True)
sampling_ratio = models.FloatField(blank=True, null=True)

@property
def result_size(self) -> int:
Expand Down
34 changes: 31 additions & 3 deletions cohort/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,49 @@ class Meta:
class CohortResultCreateSerializer(serializers.ModelSerializer):
name = serializers.CharField(required=True)
description = serializers.CharField(allow_blank=True, allow_null=True)
global_estimate = serializers.BooleanField(default=False)
request = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=Request.objects.all())
owner = UserPrimaryKeyRelatedField(queryset=User.objects.all(), required=False)
request = PrimaryKeyRelatedFieldWithOwner(required=False, queryset=Request.objects.all())
request_query_snapshot = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=RequestQuerySnapshot.objects.all())
dated_measure = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=DatedMeasure.objects.all())

class Meta:
model = CohortResult
fields = ["name",
"description",
"global_estimate",
"owner",
"request",
"request_query_snapshot",
"dated_measure"]


class SampledCohortResultCreateSerializer(serializers.ModelSerializer):
name = serializers.CharField(required=True)
description = serializers.CharField(allow_blank=True, allow_null=True)
owner = UserPrimaryKeyRelatedField(queryset=User.objects.all(), required=False)
parent_cohort = PrimaryKeyRelatedFieldWithOwner(required=True, queryset=CohortResult.objects.filter(parent_cohort__isnull=True))
sampling_ratio = serializers.FloatField(required=True)

class Meta:
model = CohortResult
fields = ["name",
"description",
"owner",
"parent_cohort",
"sampling_ratio"]

def validate_sampling_ratio(self, value):
if not 0 < value < 1:
raise serializers.ValidationError("Sampling ratio must be between 0 and 1")
return value

def create(self, validated_data):
parent_cohort = validated_data.get("parent_cohort")
validated_data.update({"request_query_snapshot": parent_cohort.request_query_snapshot,
"dated_measure": parent_cohort.dated_measure
})
return super().create(validated_data)


class CohortResultPatchSerializer(serializers.ModelSerializer):
name = serializers.CharField(required=False)
description = serializers.CharField(required=False)
Expand Down
12 changes: 6 additions & 6 deletions cohort/services/cohort_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ def handle_cohort_creation(self, cohort: CohortResult, request) -> None:
if request.data.pop("global_estimate", False):
dm_service.handle_global_count(cohort, request)
try:
create_cohort.s(cohort_id=cohort.pk,
json_query=cohort.request_query_snapshot.serialized_query,
auth_headers=get_authorization_header(request),
cohort_creator_cls=self.operator_cls) \
.apply_async()

create_cohort.s(cohort_id=cohort.pk,
json_query=cohort.request_query_snapshot.serialized_query,
auth_headers=get_authorization_header(request),
cohort_creator_cls=self.operator_cls,
sampling_ratio=cohort.sampling_ratio) \
.apply_async()
except Exception as e:
cohort.delete()
raise ServerError("Could not launch cohort creation") from e
Expand Down
140 changes: 88 additions & 52 deletions cohort/tests/tests_view_cohort_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,32 +188,6 @@ def __init__(self, mock_create_task_called: bool, **kwargs):


class CohortsCreateTests(CohortsTests):
@mock.patch('cohort.services.cohort_result.get_authorization_header')
@mock.patch('cohort.services.cohort_result.create_cohort.apply_async')
@mock.patch('cohort.services.dated_measure.count_cohort.apply_async')
def check_create_case_with_mock(self, case: CohortCreateCase, mock_count_task: MagicMock, mock_create_task: MagicMock,
mock_header: MagicMock, other_view: any, view_kwargs: dict):
mock_header.return_value = None
mock_create_task.return_value = None
mock_count_task.return_value = None

with self.captureOnCommitCallbacks(execute=True):
super(CohortsCreateTests, self).check_create_case(case, other_view, **(view_kwargs or {}))

if case.success:
inst = self.model_objects.filter(**case.retrieve_filter.args)\
.exclude(**case.retrieve_filter.exclude).first()
self.assertIsNotNone(inst.dated_measure)

if case.data.get('global_estimate'):
self.assertIsNotNone(inst.dated_measure_global)
mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called()

mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called()
mock_header.assert_called() if case.mock_create_task_called else mock_header.assert_not_called()

def check_create_case(self, case: CohortCreateCase, other_view: Any = None, **view_kwargs):
return self.check_create_case_with_mock(case, other_view=other_view or None, view_kwargs=view_kwargs)

def setUp(self):
super(CohortsCreateTests, self).setUp()
Expand All @@ -231,11 +205,17 @@ def setUp(self):
favorite=True,
request_query_snapshot=self.user1_req1_snap1.pk,
dated_measure=self.user1_req1_snap1_dm.pk,
# group_id
# dated_measure_global
# create_task_id
# type
)
parent_cohort = CohortResult.objects.create(name="Parent cohort for sampling",
description="Parent cohort for sampling",
owner=self.user1,
request_query_snapshot=self.user1_req1_snap1,
dated_measure=self.user1_req1_snap1_dm)

self.basic_data_for_sampled_cohort = dict(name="Sampled Cohort",
description="Sampled Cohort",
parent_cohort=parent_cohort.uuid,
sampling_ratio=0.3)
self.basic_case = CohortCreateCase(
data=self.basic_data,
status=status.HTTP_201_CREATED,
Expand All @@ -251,20 +231,42 @@ def setUp(self):
status=status.HTTP_400_BAD_REQUEST,
)

@mock.patch('cohort.services.cohort_result.get_authorization_header')
@mock.patch('cohort.services.cohort_result.create_cohort.apply_async')
@mock.patch('cohort.services.dated_measure.count_cohort.apply_async')
def check_create_case_with_mock(self, case: CohortCreateCase, mock_count_task: MagicMock, mock_create_task: MagicMock,
mock_header: MagicMock, other_view: any, view_kwargs: dict):
mock_header.return_value = None
mock_create_task.return_value = None
mock_count_task.return_value = None

with self.captureOnCommitCallbacks(execute=True):
super(CohortsCreateTests, self).check_create_case(case, other_view, **(view_kwargs or {}))

if case.success:
inst = self.model_objects.filter(**case.retrieve_filter.args)\
.exclude(**case.retrieve_filter.exclude).first()
self.assertIsNotNone(inst.dated_measure)

if case.data.get('global_estimate'):
self.assertIsNotNone(inst.dated_measure_global)
mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called()

mock_create_task.assert_called() if case.mock_create_task_called else mock_create_task.assert_not_called()
mock_header.assert_called() if case.mock_create_task_called else mock_header.assert_not_called()

def check_create_case(self, case: CohortCreateCase, other_view: Any = None, **view_kwargs):
return self.check_create_case_with_mock(case, other_view=other_view or None, view_kwargs=view_kwargs)

def test_create(self):
# As a user, I can create a DatedMeasure with only RQS,
# it will launch a task
self.check_create_case(self.basic_case)

def test_create_with_global(self):
# As a user, I can create a DatedMeasure with only RQS,
# it will launch a task
self.check_create_case(self.basic_case.clone(
data={**self.basic_data, 'global_estimate': True}
))

def test_create_with_unread_fields(self):
# As a user, I can create a dm
self.check_create_case(self.basic_case.clone(
data={**self.basic_data,
'create_task_id': random_str(5),
Expand All @@ -277,31 +279,66 @@ def test_create_with_unread_fields(self):
))

def test_error_create_missing_field(self):
# As a user, I cannot create a dm if some field is missing
cases = (self.basic_err_case.clone(
data={**self.basic_data, k: None},
case = self.basic_err_case.clone(
data={**self.basic_data, "request_query_snapshot": None},
success=False,
status=status.HTTP_400_BAD_REQUEST,
) for k in ['request_query_snapshot'])
[self.check_create_case(case) for case in cases]
status=status.HTTP_400_BAD_REQUEST)
self.check_create_case(case)

def test_error_create_with_other_owner(self):
# As a user, I cannot create a cohort providing another user as owner
self.check_create_case(self.basic_err_case.clone(
data={**self.basic_data, 'owner': self.user2.pk},
status=status.HTTP_400_BAD_REQUEST,
success=False,
))

def test_error_create_on_rqs_not_owned(self):
# As a user, I cannot create a dm on a Rqs I don't own
def test_error_create_with_rqs_not_owned(self):
self.check_create_case(self.basic_err_case.clone(
data={**self.basic_data,
'request_query_snapshot': self.user2_req1_snap1.pk},
success=False,
status=status.HTTP_400_BAD_REQUEST,
))

def test_successfully_create_sampled_cohort(self):
self.check_create_case(
self.basic_case.clone(data=self.basic_data_for_sampled_cohort,
retrieve_filter=CohortCaseRetrieveFilter(name=self.basic_data_for_sampled_cohort.get("name"))
))

def test_error_create_sampled_cohort_with_ratio_0(self):
data = {**self.basic_data_for_sampled_cohort,
"name": "Sampled Cohort with ratio 0.0",
"sampling_ratio": 0.0
}
self.check_create_case(
self.basic_err_case.clone(data=data,
retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name"))
))

def test_error_create_sampled_cohort_with_ratio_1(self):
data = {**self.basic_data_for_sampled_cohort,
"name": "Sampled Cohort with ratio 1.0",
"sampling_ratio": 1.0
}
self.check_create_case(
self.basic_err_case.clone(data=data,
retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name"))
))

def test_error_create_sampled_cohort_with_ratio_gt_1(self):
data = {**self.basic_data_for_sampled_cohort,
"name": "Sampled Cohort with ratio gt 1",
"sampling_ratio": 2.5
}
self.check_create_case(
self.basic_err_case.clone(data=data,
retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name"))
))

def test_error_create_sampled_cohort_with_ratio_lt_0(self):
data = {**self.basic_data_for_sampled_cohort,
"name": "Sampled Cohort with ratio lt 0",
"sampling_ratio": -1.5
}
self.check_create_case(
self.basic_err_case.clone(data=data,
retrieve_filter=CohortCaseRetrieveFilter(name=data.get("name"))
))


class CohortsDeleteTests(CohortsTests):

Expand Down Expand Up @@ -383,7 +420,6 @@ def test_update_cohort_as_owner(self, mock_patch_handler, mock_post_update, mock
data_to_update = dict(name="new_name",
description="new_desc",
request_job_status=JobStatus.failed,
request_job_fail_msg="test_fail_msg",
# read_only
create_task_id="test_task_id",
request_job_id="test_job_id",
Expand Down
29 changes: 19 additions & 10 deletions cohort/views/cohort_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django.db import transaction
from django.db.models import Q, F
from django_filters import rest_framework as filters, OrderingFilter
from drf_spectacular.utils import extend_schema
from drf_spectacular.utils import extend_schema, PolymorphicProxySerializer
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.permissions import AllowAny
Expand All @@ -13,7 +13,7 @@
from cohort.services.cohort_result import cohort_service
from cohort.models import CohortResult
from cohort.serializers import CohortResultSerializer, CohortResultSerializerFullDatedMeasure, CohortResultCreateSerializer, \
CohortResultPatchSerializer, CohortRightsSerializer
CohortResultPatchSerializer, CohortRightsSerializer, SampledCohortResultCreateSerializer
from cohort.services.cohort_rights import cohort_rights_service
from cohort.views.shared import UserObjectsRestrictedViewSet
from exports.services.export import export_service
Expand Down Expand Up @@ -52,7 +52,8 @@ class Meta:
'favorite',
'group_id',
'request_id',
'status')
'status',
'parent_cohort')


class CohortResultViewSet(NestedViewSetMixin, UserObjectsRestrictedViewSet):
Expand Down Expand Up @@ -81,19 +82,27 @@ def get_queryset(self):
return super().get_queryset().filter(is_subset=False)

def get_serializer_class(self):
if self.request.method in ["POST", "PUT", "PATCH"] \
and "dated_measure" in self.request.data \
and isinstance(self.request.data["dated_measure"], dict) \
or self.request.method == "GET":
if self.request.method == "GET":
return CohortResultSerializerFullDatedMeasure
return self.serializer_class
elif self.request.method == "PATCH":
return CohortResultPatchSerializer
elif self.request.method == "POST":
if CohortResult.sampling_ratio.field.name in self.request.data:
return SampledCohortResultCreateSerializer
return CohortResultCreateSerializer
else:
return self.serializer_class

@cache_response()
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

@extend_schema(request=CohortResultCreateSerializer,
responses={status.HTTP_201_CREATED: CohortResultSerializer})
@extend_schema(
request=PolymorphicProxySerializer(
component_name="CreateCohortResult",
resource_type_field_name=None,
serializers=[CohortResultCreateSerializer, SampledCohortResultCreateSerializer]),
responses={status.HTTP_201_CREATED: CohortResultSerializer})
@transaction.atomic
def create(self, request, *args, **kwargs):
response = super().create(request, *args, **kwargs)
Expand Down
12 changes: 9 additions & 3 deletions cohort_job_server/cohort_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@

class CohortCreator(BaseCohortOperator):

def launch_cohort_creation(self, cohort_id: Optional[str], json_query: str, auth_headers: dict, callback_path: Optional[str] = None,
def launch_cohort_creation(self,
cohort_id: Optional[str],
json_query: str,
auth_headers: dict,
callback_path: Optional[str] = None,
existing_cohort_id: Optional[int] = None,
owner_username: Optional[str] = None) -> None:
owner_username: Optional[str] = None,
sampling_ratio: Optional[float] = None) -> None:
self.sjs_requester.launch_request(CohortCreate(instance_id=cohort_id,
json_query=json_query,
auth_headers=auth_headers,
callback_path=callback_path,
owner_username=owner_username,
existing_cohort_id=existing_cohort_id
existing_cohort_id=existing_cohort_id,
sampling_ratio=sampling_ratio
))

@staticmethod
Expand Down
Loading

0 comments on commit 2b31c04

Please sign in to comment.