From 4aa3046df9805e065fac5cb8c87bec3e5075c59a Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Thu, 12 Sep 2024 13:42:13 -0700 Subject: [PATCH] add create_workflow_task to handler and scheduler --- jupyter_scheduler/scheduler.py | 14 +++++++++++++- jupyter_scheduler/workflows.py | 15 +++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index d9c00875..5aece2be 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -46,7 +46,7 @@ create_output_directory, create_output_filename, ) -from jupyter_scheduler.workflows import CreateWorkflow, DescribeWorkflow +from jupyter_scheduler.workflows import CreateWorkflow, DescribeWorkflow, UpdateWorkflow class BaseScheduler(LoggingConfigurable): @@ -124,6 +124,10 @@ def get_workflow(self, workflow_id: str) -> DescribeWorkflow: """Returns workflow record for a single workflow.""" raise NotImplementedError("must be implemented by subclass") + def create_workflow_task(self, workflow_id: str, model: CreateJob) -> str: + """Adds a task to a workflow.""" + raise NotImplementedError("must be implemented by subclass") + def update_job(self, job_id: str, model: UpdateJob): """Updates job metadata in the persistence store, for example name, status etc. In case of status @@ -565,6 +569,14 @@ def get_workflow(self, workflow_id: str) -> DescribeWorkflow: model = DescribeWorkflow.from_orm(workflow_record) return model + def create_workflow_task(self, workflow_id: str, model: CreateJob) -> str: + job_id = self.scheduler.create_job(model, run=False) + workflow: DescribeWorkflow = self.scheduler.get_workflow(workflow_id) + updated_tasks = (workflow.tasks or [])[:] + updated_tasks.append(job_id) + self.scheduler.update_workflow(workflow_id, UpdateWorkflow(depends_on=updated_tasks)) + return job_id + def update_job(self, job_id: str, model: UpdateJob): with self.db_session() as session: session.query(Job).filter(Job.job_id == job_id).update(model.dict(exclude_none=True)) diff --git a/jupyter_scheduler/workflows.py b/jupyter_scheduler/workflows.py index fd4b6375..3b7707c3 100644 --- a/jupyter_scheduler/workflows.py +++ b/jupyter_scheduler/workflows.py @@ -1,5 +1,5 @@ import json -from typing import List +from typing import List, Optional from jupyter_server.utils import ensure_async from tornado.web import HTTPError, authenticated @@ -71,7 +71,11 @@ async def post(self, workflow_id: str): "Error during workflow job creation. workflow_id in the URL and payload don't match.", ) try: - job_id = await ensure_async(self.scheduler.create_job(CreateJob(**payload), run=False)) + job_id = await ensure_async( + self.scheduler.create_workflow_task( + workflow_id=workflow_id, model=CreateJob(**payload) + ) + ) except ValidationError as e: self.log.exception(e) raise HTTPError(500, str(e)) from e @@ -175,3 +179,10 @@ class UpdateWorkflow(BaseModel): class Config: orm_mode = True + + +class UpdateWorkflow(BaseModel): + status: Optional[Status] = None + name: Optional[str] = None + compute_type: Optional[str] = None + depends_on: Optional[str] = None