Skip to content

Commit

Permalink
add create_workflow_task to handler and scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Sep 12, 2024
1 parent 5c351ac commit 4aa3046
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
14 changes: 13 additions & 1 deletion jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
create_output_directory,
create_output_filename,
)
from jupyter_scheduler.workflows import CreateWorkflow, DescribeWorkflow
from jupyter_scheduler.workflows import CreateWorkflow, DescribeWorkflow, UpdateWorkflow


class BaseScheduler(LoggingConfigurable):
Expand Down Expand Up @@ -124,6 +124,10 @@ def get_workflow(self, workflow_id: str) -> DescribeWorkflow:
"""Returns workflow record for a single workflow."""
raise NotImplementedError("must be implemented by subclass")

def create_workflow_task(self, workflow_id: str, model: CreateJob) -> str:
"""Adds a task to a workflow."""
raise NotImplementedError("must be implemented by subclass")

def update_job(self, job_id: str, model: UpdateJob):
"""Updates job metadata in the persistence store,
for example name, status etc. In case of status
Expand Down Expand Up @@ -565,6 +569,14 @@ def get_workflow(self, workflow_id: str) -> DescribeWorkflow:
model = DescribeWorkflow.from_orm(workflow_record)
return model

def create_workflow_task(self, workflow_id: str, model: CreateJob) -> str:
job_id = self.scheduler.create_job(model, run=False)
workflow: DescribeWorkflow = self.scheduler.get_workflow(workflow_id)
updated_tasks = (workflow.tasks or [])[:]
updated_tasks.append(job_id)
self.scheduler.update_workflow(workflow_id, UpdateWorkflow(depends_on=updated_tasks))
return job_id

def update_job(self, job_id: str, model: UpdateJob):
with self.db_session() as session:
session.query(Job).filter(Job.job_id == job_id).update(model.dict(exclude_none=True))
Expand Down
15 changes: 13 additions & 2 deletions jupyter_scheduler/workflows.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import List
from typing import List, Optional

from jupyter_server.utils import ensure_async
from tornado.web import HTTPError, authenticated
Expand Down Expand Up @@ -71,7 +71,11 @@ async def post(self, workflow_id: str):
"Error during workflow job creation. workflow_id in the URL and payload don't match.",
)
try:
job_id = await ensure_async(self.scheduler.create_job(CreateJob(**payload), run=False))
job_id = await ensure_async(
self.scheduler.create_workflow_task(
workflow_id=workflow_id, model=CreateJob(**payload)
)
)
except ValidationError as e:
self.log.exception(e)
raise HTTPError(500, str(e)) from e
Expand Down Expand Up @@ -175,3 +179,10 @@ class UpdateWorkflow(BaseModel):

class Config:
orm_mode = True


class UpdateWorkflow(BaseModel):
status: Optional[Status] = None
name: Optional[str] = None
compute_type: Optional[str] = None
depends_on: Optional[str] = None

0 comments on commit 4aa3046

Please sign in to comment.