Skip to content

Commit

Permalink
Feat(wfr) api add "DEPRECATED" state (#770)
Browse files Browse the repository at this point in the history
* add deprecated state for state api

* add deprecated in stats api

* add api for rerun/deprecated vlidation

* add allowed_dataset_choice in rerun valid api

* bring the emit events back
  • Loading branch information
raylrui authored Dec 12, 2024
1 parent a139e7e commit 8a6af8f
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_current_state(self, obj) -> dict:
class WorkflowRunListParamSerializer(OptionalFieldsMixin, WorkflowRunBaseSerializer):
class Meta:
model = WorkflowRun
fields = "__all__"
fields = ["orcabus_id", "workflow", "analysis_run", "workflow_run_name", "portal_run_id", "execution_id", "comment",]

class WorkflowRunSerializer(WorkflowRunBaseSerializer):
from .workflow import WorkflowMinSerializer
Expand Down Expand Up @@ -59,6 +59,7 @@ class WorkflowRunCountByStatusSerializer(serializers.Serializer):
failed = serializers.IntegerField()
resolved = serializers.IntegerField()
ongoing = serializers.IntegerField()
deprecated = serializers.IntegerField()

def update(self, instance, validated_data):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

class AllowedRerunWorkflow(StrEnum):
RNASUM = "rnasum"


class AllowedRerunWorkflowSerializer(serializers.Serializer):
is_valid = serializers.BooleanField()
allowed_dataset_choice = serializers.ListField(child=serializers.CharField())
valid_workflows = serializers.ListField(child=serializers.CharField())

class BaseRerunInputSerializer(serializers.Serializer):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

router = OptionalSlashDefaultRouter()

router.register("workflowrun/stats", WorkflowRunStatsViewSet, basename="workflowrun_list_all") # put it before workflowrun, as it will match the workflowrun/list_all/ url
router.register(r"workflowrun/stats", WorkflowRunStatsViewSet, basename="workflowrun_list_all") # put it before workflowrun, as it will match the workflowrun/list_all/ url
router.register(r"analysis", AnalysisViewSet, basename="analysis")
router.register(r"analysisrun", AnalysisRunViewSet, basename="analysisrun")
router.register(r"analysiscontext", AnalysisContextViewSet, basename="analysiscontext")
Expand All @@ -29,7 +29,6 @@
router.register(r"workflowrun", WorkflowRunActionViewSet, basename="workflowrun-action")
router.register(r"payload", PayloadViewSet, basename="payload")

# may no longer need this as it's currently included in the detail response for an individual WorkflowRun record
router.register(
"workflowrun/(?P<orcabus_id>[^/.]+)/state",
StateViewSet,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, PolymorphicProxySerializer
from rest_framework.decorators import action
from rest_framework import mixins, status
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
Expand All @@ -14,52 +16,70 @@ class StateViewSet(mixins.CreateModelMixin, mixins.UpdateModelMixin, mixins.List
orcabus_id_prefix = State.orcabus_id_prefix
http_method_names = ['get', 'post', 'patch']
pagination_class = None

"""
valid_states_map for state creation, update
refer:
"Resolved" -- https://github.com/umccr/orcabus/issues/593
"Deprecated" -- https://github.com/umccr/orcabus/issues/695
"""
valid_states_map = {
'RESOLVED': ['FAILED'],
'DEPRECATED': ['SUCCEEDED']
}

def get_queryset(self):
return State.objects.filter(workflow_run=self.kwargs["orcabus_id"])

@extend_schema(responses=OpenApiTypes.OBJECT, description="Valid states map for new state creation, update")
@action(detail=False, methods=['get'], url_name='valid_states_map', url_path='valid_states_map')
def get_valid_states_map(self, request, **kwargs):
return Response(self.valid_states_map)

def create(self, request, *args, **kwargs):
"""
Create a customed new state for a workflow run.
Currently we support "Resolved", "Deprecated"
"""
wfr_orcabus_id = self.kwargs.get("orcabus_id")
workflow_run = WorkflowRun.objects.get(orcabus_id=wfr_orcabus_id)

# Check if the workflow run has a "Failed" or "Aborted" state
latest_state = workflow_run.get_latest_state()
if latest_state.status not in ["FAILED"]:
return Response({"detail": "Can only create 'Resolved' state for workflow runs with 'Failed' states."},
if not latest_state:
return Response({"detail": "No state found for workflow run '{}'".format(wfr_orcabus_id)},
status=status.HTTP_400_BAD_REQUEST)

# Check if the new state is "Resolved"
if request.data.get('status', '').upper() != "RESOLVED":
return Response({"detail": "Can only create 'Resolved' state."},
latest_status = latest_state.status
request_status = request.data.get('status', '').upper()

# check if the state status is valid
if not self.check_state_status(latest_status, request_status):
return Response({"detail": "Invalid state request. Can't add state '{}' to '{}'".format(request_status, latest_status)},
status=status.HTTP_400_BAD_REQUEST)

# comment is required when status is "Resolved"
# comment is required when request change state
if not request.data.get('comment'):
return Response({"detail": "Comment is required when status is 'Resolved'."},
return Response({"detail": "Comment is required when request status is '{}'".format(request_status)},
status=status.HTTP_400_BAD_REQUEST)


# Prepare data for serializer
data = request.data.copy()
data['timestamp'] = timezone.now()
data['workflow_run'] = wfr_orcabus_id

data['status'] = request_status

serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
serializer.save()
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)

def perform_create(self, serializer):
serializer.save(workflow_run_id=self.kwargs["orcabus_id"], status="RESOLVED")

def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()

# Check if the state being updated is "Resolved"
if instance.status != "RESOLVED":
return Response({"detail": "Can only update 'Resolved' state records."},
if instance.status not in self.valid_states_map:
return Response({"detail": "Invalid state status."},
status=status.HTTP_400_BAD_REQUEST)

# Check if only the comment field is being updated
Expand All @@ -69,7 +89,7 @@ def update(self, request, *args, **kwargs):

serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
serializer.save()

if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
Expand All @@ -78,5 +98,14 @@ def update(self, request, *args, **kwargs):

return Response(serializer.data)

def perform_update(self, serializer):
serializer.save(status="RESOLVED")

def check_state_status(self, current_status, request_status):
"""
check if the state status is valid:
valid_states_map[request_state] == current_state.status
"""
if request_status not in self.valid_states_map:
return False
if current_status not in self.valid_states_map[request_status]:
return False
return True
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from workflow_manager.models.utils import create_portal_run_id
from workflow_manager.serializers.library import LibrarySerializer
from workflow_manager.serializers.payload import PayloadSerializer
from workflow_manager.serializers.workflow_run_action import AllowedRerunWorkflow, RERUN_INPUT_SERIALIZERS
from workflow_manager.serializers.workflow_run_action import AllowedRerunWorkflow, RERUN_INPUT_SERIALIZERS, AllowedRerunWorkflowSerializer
from workflow_manager.models import (
WorkflowRun,
State,
Expand All @@ -27,6 +27,25 @@ class WorkflowRunActionViewSet(ViewSet):
queryset = WorkflowRun.objects.prefetch_related('states').all()
orcabus_id_prefix = WorkflowRun.orcabus_id_prefix

@extend_schema(responses=AllowedRerunWorkflowSerializer, description="Allowed rerun workflows")
@action(detail=True, methods=['get'], url_name='validate_rerun_workflows', url_path='validate_rerun_workflows')
def validate_rerun_workflows(self, request, *args, **kwargs):
wfl_run = get_object_or_404(self.queryset, pk=kwargs.get('pk'))
is_valid = wfl_run.workflow.workflow_name in AllowedRerunWorkflow

# Get allowed dataset choice for the workflow
wfl_name = wfl_run.workflow.workflow_name
allowed_dataset_choice = []
if wfl_name == AllowedRerunWorkflow.RNASUM.value:
allowed_dataset_choice = RERUN_INPUT_SERIALIZERS[wfl_name].allowed_dataset_choice

reponse = {
'is_valid': is_valid,
'allowed_dataset_choice': allowed_dataset_choice,
'valid_workflows': AllowedRerunWorkflow,
}
return Response(reponse, status=status.HTTP_200_OK)

@extend_schema(
request=PolymorphicProxySerializer(
component_name='WorkflowRunRerun',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def exclude_params(params):
~Q(states__status="FAILED") &
~Q(states__status="ABORTED") &
~Q(states__status="SUCCEEDED") &
~Q(states__status="RESOLVED")
~Q(states__status="RESOLVED") &
~Q(states__status="DEPRECATED")
)

if status:
Expand Down Expand Up @@ -131,6 +132,11 @@ def count_by_status(self, request):
states__status="RESOLVED"
).count()

deprecated_count = annotate_queryset.filter(
states__timestamp=F('latest_state_time'),
states__status="DEPRECATED"
).count()

ongoing_count = base_queryset.filter(
~Q(states__status="FAILED") &
~Q(states__status="ABORTED") &
Expand All @@ -143,6 +149,7 @@ def count_by_status(self, request):
'aborted': aborted_count,
'failed': failed_count,
'resolved': resolved_count,
'deprecated': deprecated_count,
'ongoing': ongoing_count
}, status=200)

Expand Down

0 comments on commit 8a6af8f

Please sign in to comment.