From edf753b6133414f4749ba9367dcabb94fa394030 Mon Sep 17 00:00:00 2001 From: David Lougheed Date: Wed, 12 Jul 2023 12:59:50 -0400 Subject: [PATCH] refact!: use pydantic models for db objects/requests/some responses --- bento_wes/backends/_wes_backend.py | 102 +++++++----- bento_wes/backends/cromwell_local.py | 5 +- bento_wes/constants.py | 4 +- bento_wes/db.py | 115 ++++++++----- bento_wes/models.py | 22 ++- bento_wes/runner.py | 4 +- bento_wes/runs.py | 233 ++++++++++----------------- bento_wes/workflows.py | 25 +-- tests/test_runs.py | 26 +-- 9 files changed, 262 insertions(+), 274 deletions(-) diff --git a/bento_wes/backends/_wes_backend.py b/bento_wes/backends/_wes_backend.py index cfcc94e8..e588b26b 100644 --- a/bento_wes/backends/_wes_backend.py +++ b/bento_wes/backends/_wes_backend.py @@ -15,13 +15,14 @@ from typing import Optional, Tuple, Union from bento_wes import states +from bento_wes.constants import SERVICE_ARTIFACT, RUN_PARAM_FROM_CONFIG from bento_wes.db import get_db, finish_run, update_run_state_and_commit +from bento_wes.models import Run, RunWithDetails, BentoWorkflowMetadata from bento_wes.states import STATE_EXECUTOR_ERROR, STATE_SYSTEM_ERROR from bento_wes.utils import iso_now from bento_wes.workflows import WorkflowType, WorkflowManager from .backend_types import Command, ProcessResult -from ..constants import SERVICE_ARTIFACT __all__ = ["WESBackend"] @@ -109,7 +110,7 @@ def _get_supported_types(self) -> Tuple[WorkflowType]: pass @abstractmethod - def _get_params_file(self, run: dict) -> str: + def _get_params_file(self, run: Run) -> str: """ Returns the name of the params file to use for the workflow run. :param run: The run description @@ -126,27 +127,26 @@ def _serialize_params(self, workflow_params: ParamDict) -> str: """ pass - def workflow_path(self, run: dict) -> str: + def workflow_path(self, run: RunWithDetails) -> str: """ Gets the local filesystem path to the workflow file specified by a run's workflow URI. """ - return self._workflow_manager.workflow_path(run["request"]["workflow_url"], - WorkflowType(run["request"]["workflow_type"])) + return self._workflow_manager.workflow_path(run.request.workflow_url, WorkflowType(run.request.workflow_type)) - def run_dir(self, run: dict) -> str: + def run_dir(self, run: Run) -> str: """ Returns a path to the work directory for executing a run. """ - return os.path.join(self.tmp_dir, run["run_id"]) + return os.path.join(self.tmp_dir, run.run_id) - def _params_path(self, run: dict) -> str: + def _params_path(self, run: Run) -> str: """ Returns a path to the workflow parameters file for a run. """ return os.path.join(self.run_dir(run), self._get_params_file(run)) @abstractmethod - def _check_workflow(self, run: dict) -> Optional[Tuple[str, str]]: + def _check_workflow(self, run: Run) -> Optional[Tuple[str, str]]: """ Checks that a workflow can be executed by the backend via the workflow's URI. :param run: The run, including a request with the workflow URI @@ -154,7 +154,7 @@ def _check_workflow(self, run: dict) -> Optional[Tuple[str, str]]: """ pass - def _check_workflow_wdl(self, run: dict) -> Optional[Tuple[str, str]]: + def _check_workflow_wdl(self, run: RunWithDetails) -> Optional[Tuple[str, str]]: """ Checks that a particular WDL workflow is valid. :param run: The run whose workflow is being checked @@ -207,14 +207,14 @@ def _check_workflow_wdl(self, run: dict) -> Optional[Tuple[str, str]]: STATE_EXECUTOR_ERROR ) - def _check_workflow_and_type(self, run: dict) -> Optional[Tuple[str, str]]: + def _check_workflow_and_type(self, run: RunWithDetails) -> Optional[Tuple[str, str]]: """ Checks a workflow file's validity. :param run: The run specifying the workflow in question :return: None if the workflow is valid; a tuple of an error message and an error state otherwise """ - workflow_type: WorkflowType = WorkflowType(run["request"]["workflow_type"]) + workflow_type: WorkflowType = WorkflowType(run.request.workflow_type) if workflow_type not in self._get_supported_types(): raise NotImplementedError(f"The specified WES backend cannot execute workflows of type {workflow_type}") @@ -264,7 +264,7 @@ def _update_run_state_and_commit(self, run_id: Union[uuid.UUID, str], state: str """ update_run_state_and_commit(self.db, self.db.cursor(), run_id, state, event_bus=self.event_bus) - def _finish_run_and_clean_up(self, run: dict, state: str) -> None: + def _finish_run_and_clean_up(self, run: Run, state: str) -> None: """ Performs standard run-finishing operations (updating state, setting end time, etc.) as well as deleting the run folder if it exists. @@ -288,7 +288,12 @@ def _finish_run_and_clean_up(self, run: dict, state: str) -> None: if not self.debug: shutil.rmtree(self.run_dir(run), ignore_errors=True) - def _initialize_run_and_get_command(self, run: dict, celery_id, access_token: str) -> tuple[Command, dict] | None: + def _initialize_run_and_get_command( + self, + run: RunWithDetails, + celery_id: int, + access_token: str, + ) -> tuple[Command, dict] | None: """ Performs "initialization" operations on the run, including setting states, downloading and validating the workflow file, and generating and logging the workflow-running command. @@ -298,7 +303,7 @@ def _initialize_run_and_get_command(self, run: dict, celery_id, access_token: st :return: The command to execute, if no errors occurred; None otherwise """ - self._update_run_state_and_commit(run["run_id"], states.STATE_INITIALIZING) + self._update_run_state_and_commit(run.run_id, states.STATE_INITIALIZING) run_dir = self.run_dir(run) @@ -310,10 +315,9 @@ def _initialize_run_and_get_command(self, run: dict, celery_id, access_token: st c = self.db.cursor() - workflow_id = run["request"]["tags"].get("workflow_id", run["request"]["workflow_url"]) - + workflow_id = run.request.tags.workflow_id workflow_params: ParamDict = { - **run["request"]["workflow_params"], + **run.request.workflow_params, f"{workflow_id}.{PARAM_SECRET_PREFIX}access_token": access_token, # In export/analysis mode, as we rely on services located in different containers @@ -332,6 +336,15 @@ def _initialize_run_and_get_command(self, run: dict, celery_id, access_token: st # TODO: more special parameters: service URLs, system__run_dir... } + # Some workflow parameters depend on the WES application configuration + # and need to be added from there. + # The reserved keyword `FROM_CONFIG` is used to detect those inputs. + # All parameters in config are upper case. e.g. drs_url --> DRS_URL + for i in run.request.tags.workflow_metadata.inputs: + if i.value != RUN_PARAM_FROM_CONFIG: + continue + workflow_params[f"{workflow_id}.{i.id}"] = current_app.config.get(i.id, "") + # -- Validate the workflow -------------------------------------------- error = self._check_workflow_and_type(run) if error is not None: @@ -354,9 +367,7 @@ def _initialize_run_and_get_command(self, run: dict, celery_id, access_token: st pf.write(self._serialize_params(workflow_params)) # -- Create the runner command based on inputs ------------------------ - cmd = self._get_command(self.workflow_path(run), - self._params_path(run), - self.run_dir(run)) + cmd = self._get_command(self.workflow_path(run), self._params_path(run), self.run_dir(run)) # -- Update run log with command and Celery ID ------------------------ c.execute( @@ -366,34 +377,41 @@ def _initialize_run_and_get_command(self, run: dict, celery_id, access_token: st return cmd, workflow_params - def _build_workflow_outputs(self, run_dir: str, workflow_id: str, workflow_params: dict, c_workflow_metadata: dict): + def _build_workflow_outputs( + self, + run_dir: str, + workflow_id: str, + workflow_params: dict, + workflow_metadata: BentoWorkflowMetadata, + ): self.logger.info(f"Building workflow outputs for workflow ID {workflow_id}") - output_params = w.make_output_params(workflow_id, workflow_params, c_workflow_metadata["inputs"]) + output_params = w.make_output_params(workflow_id, workflow_params, [dict(i) for i in workflow_metadata.inputs]) workflow_outputs = {} - for output in c_workflow_metadata["outputs"]: - fo = w.formatted_output(output, output_params) + for output in workflow_metadata.outputs: + o_id = output.id + fo = w.formatted_output(dict(output), output_params) # Skip optional outputs resulting from optional inputs if fo is None: continue # Rewrite file outputs to include full path to temporary location - if output["type"] == w.WORKFLOW_TYPE_FILE: - workflow_outputs[output["id"]] = os.path.abspath(os.path.join(run_dir, "output", fo)) + if output.type == w.WORKFLOW_TYPE_FILE: + workflow_outputs[o_id] = os.path.abspath(os.path.join(run_dir, "output", fo)) - elif output["type"] == w.WORKFLOW_TYPE_FILE_ARRAY: - workflow_outputs[output["id"]] = [os.path.abspath(os.path.join(run_dir, wo)) for wo in fo] + elif output.type == w.WORKFLOW_TYPE_FILE_ARRAY: + workflow_outputs[o_id] = [os.path.abspath(os.path.join(run_dir, wo)) for wo in fo] self.logger.info( - f"Setting workflow output {output['id']} to [{', '.join(workflow_outputs[output['id']])}]") + f"Setting workflow output {o_id} to [{', '.join(workflow_outputs[o_id])}]") else: - workflow_outputs[output["id"]] = fo - self.logger.info(f"Setting workflow output {output['id']} to {workflow_outputs[output['id']]}") + workflow_outputs[o_id] = fo + self.logger.info(f"Setting workflow output {o_id} to {workflow_outputs[o_id]}") return workflow_outputs - def _perform_run(self, run: dict, cmd: Command, params_with_extras: ParamDict) -> Optional[ProcessResult]: + def _perform_run(self, run: RunWithDetails, cmd: Command, params_with_extras: ParamDict) -> Optional[ProcessResult]: """ Performs a run based on a provided command and returns stdout, stderr, exit code, and whether the process timed out while running. @@ -410,8 +428,8 @@ def _perform_run(self, run: dict, cmd: Command, params_with_extras: ParamDict) - # -- Start process running the generated command ---------------------- runner_process = subprocess.Popen( cmd, cwd=self.tmp_dir, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") - c.execute("UPDATE runs SET run_log__start_time = ? WHERE id = ?", (iso_now(), run["id"])) - self._update_run_state_and_commit(run["run_id"], states.STATE_RUNNING) + c.execute("UPDATE runs SET run_log__start_time = ? WHERE id = ?", (iso_now(), run.id)) + self._update_run_state_and_commit(run.run_id, states.STATE_RUNNING) # -- Wait for and capture output -------------------------------------- @@ -441,16 +459,16 @@ def _perform_run(self, run: dict, cmd: Command, params_with_extras: ParamDict) - # -- Get various Bento-specific data from tags ------------------------ - tags = run["request"]["tags"] + tags = run.request.tags - workflow_metadata = tags.get("workflow_metadata", {}) - project_id: str = tags["project_id"] - dataset_id: str | None = tags.get("dataset_id") + workflow_metadata = tags.workflow_metadata + project_id: str = tags.project_id + dataset_id: str | None = tags.dataset_id # -- Update run log with stdout/stderr, exit code --------------------- # - Explicitly don't commit here; sync with state update c.execute("UPDATE runs SET run_log__stdout = ?, run_log__stderr = ?, run_log__exit_code = ? WHERE id = ?", - (stdout, stderr, exit_code, run["id"])) + (stdout, stderr, exit_code, run.run_id)) if timed_out: # TODO: Report error somehow @@ -468,7 +486,7 @@ def _perform_run(self, run: dict, cmd: Command, params_with_extras: ParamDict) - run_dir = self.run_dir(run) workflow_name = self.get_workflow_name(self.workflow_path(run)) - workflow_params: dict = run["request"]["workflow_params"] + workflow_params: dict = run.request.workflow_params workflow_outputs = self._build_workflow_outputs(run_dir, workflow_name, workflow_params, workflow_metadata) @@ -494,7 +512,7 @@ def _perform_run(self, run: dict, cmd: Command, params_with_extras: ParamDict) - return ProcessResult((stdout, stderr, exit_code, timed_out)) - def perform_run(self, run: dict, celery_id, access_token: str) -> Optional[ProcessResult]: + def perform_run(self, run: RunWithDetails, celery_id, access_token: str) -> Optional[ProcessResult]: """ Executes a run from start to finish (initialization, startup, and completion / cleanup.) :param run: The run to execute diff --git a/bento_wes/backends/cromwell_local.py b/bento_wes/backends/cromwell_local.py index cfdcb2d8..63ce2941 100644 --- a/bento_wes/backends/cromwell_local.py +++ b/bento_wes/backends/cromwell_local.py @@ -5,6 +5,7 @@ from bento_wes.backends import WESBackend from bento_wes.backends.backend_types import Command +from bento_wes.models import Run, RunWithDetails from bento_wes.workflows import WorkflowType, WES_WORKFLOW_TYPE_WDL @@ -24,7 +25,7 @@ def _get_supported_types(self) -> Tuple[WorkflowType]: """ return WES_WORKFLOW_TYPE_WDL, - def _get_params_file(self, run: dict) -> str: + def _get_params_file(self, run: Run) -> str: """ Returns the name of the params file to use for the workflow run. :param run: The run description; unused here @@ -40,7 +41,7 @@ def _serialize_params(self, workflow_params: dict) -> str: """ return json.dumps(workflow_params) - def _check_workflow(self, run: dict) -> Optional[Tuple[str, str]]: + def _check_workflow(self, run: RunWithDetails) -> Optional[Tuple[str, str]]: return self._check_workflow_wdl(run) def get_workflow_name(self, workflow_path: str) -> Optional[str]: diff --git a/bento_wes/constants.py b/bento_wes/constants.py index 5aec1232..981c8a20 100644 --- a/bento_wes/constants.py +++ b/bento_wes/constants.py @@ -1,6 +1,8 @@ import bento_wes import os +from typing import Literal + __all__ = [ "BENTO_SERVICE_KIND", @@ -21,4 +23,4 @@ SERVICE_ID = os.environ.get("SERVICE_ID", ":".join(SERVICE_TYPE.values())) SERVICE_NAME = "Bento WES" -RUN_PARAM_FROM_CONFIG = "FROM_CONFIG" +RUN_PARAM_FROM_CONFIG: Literal["FROM_CONFIG"] = "FROM_CONFIG" diff --git a/bento_wes/db.py b/bento_wes/db.py index 5246ee1e..264cd281 100644 --- a/bento_wes/db.py +++ b/bento_wes/db.py @@ -12,6 +12,7 @@ from . import states from .constants import SERVICE_ARTIFACT from .events import get_flask_event_bus +from .models import RunLog, RunRequest, Run, RunWithDetailsAndOutput from .types import RunStream from .utils import iso_now @@ -23,11 +24,13 @@ "finish_run", "update_stuck_runs", "update_db", - "run_request_dict", - "run_log_dict", + "run_request_from_row", + "run_log_from_row", "task_log_dict", "get_task_logs", - "get_run_details", + "run_with_details_and_output_from_row", + "get_run", + "get_run_with_details", "update_run_state_and_commit", ] @@ -64,7 +67,7 @@ def finish_run( db: sqlite3.Connection, c: sqlite3.Cursor, event_bus: EventBus, - run: dict, + run: Run, state: str, logger: logging.Logger | None = None, ) -> None: @@ -80,7 +83,7 @@ def finish_run( :return: """ - run_id = run["run_id"] + run_id = run.run_id end_time = iso_now() # Explicitly don't commit here to sync with state update @@ -127,13 +130,14 @@ def update_stuck_runs(db: sqlite3.Connection): c.execute("SELECT id FROM runs WHERE state = ? OR state = ?", (states.STATE_INITIALIZING, states.STATE_RUNNING)) stuck_run_ids: list[sqlite3.Row] = c.fetchall() - for run, err in (get_run_details(c, r["id"]) for r in stuck_run_ids): - if err: - logger.error(f"Encountered error while updating stuck runs: {err}") + for r in stuck_run_ids: + run = get_run_with_details(c, r["id"], stream_content=True) + if run is None: + logger.error(f"Missing run: {r['id']}") continue logger.info( - f"Found stuck run: {run['run_id']} at state {run['state']}. Setting state to {states.STATE_SYSTEM_ERROR}") + f"Found stuck run: {run.run_id} at state {run.state}. Setting state to {states.STATE_SYSTEM_ERROR}") finish_run(db, c, event_bus, run, states.STATE_SYSTEM_ERROR) db.commit() @@ -153,15 +157,15 @@ def update_db(): # TODO: Migrations if needed -def run_request_dict(run: sqlite3.Row) -> dict: - return { - "workflow_params": json.loads(run["request__workflow_params"]), - "workflow_type": run["request__workflow_type"], - "workflow_type_version": run["request__workflow_type_version"], - "workflow_engine_parameters": json.loads(run["request__workflow_engine_parameters"]), # TODO - "workflow_url": run["request__workflow_url"], - "tags": json.loads(run["request__tags"]) - } +def run_request_from_row(run: sqlite3.Row) -> RunRequest: + return RunRequest( + workflow_params=run["request__workflow_params"], + workflow_type=run["request__workflow_type"], + workflow_type_version=run["request__workflow_type_version"], + workflow_engine_parameters=run["request__workflow_engine_parameters"], + workflow_url=run["request__workflow_url"], + tags=run["request__tags"], + ) def _strip_first_slash(string: str) -> str: @@ -172,17 +176,17 @@ def _stream_url(run_id: uuid.UUID | str, stream: RunStream) -> str: return urljoin(current_app.config["SERVICE_BASE_URL"], f"runs/{str(run_id)}/{stream}") -def run_log_dict(run: sqlite3.Row) -> dict: +def run_log_from_row(run: sqlite3.Row, stream_content: bool) -> RunLog: run_id = run["id"] - return { - "name": run["run_log__name"], - "cmd": run["run_log__cmd"], - "start_time": run["run_log__start_time"], - "end_time": run["run_log__end_time"], - "stdout": _stream_url(run_id, "stdout"), - "stderr": _stream_url(run_id, "stderr"), - "exit_code": run["run_log__exit_code"] - } + return RunLog( + name=run["run_log__name"], + cmd=run["run_log__cmd"], + start_time=run["run_log__start_time"] or None, + end_time=run["run_log__end_time"] or None, + stdout=run["run_log__stdout"] if stream_content else _stream_url(run_id, "stdout"), + stderr=run["run_log__stderr"] if stream_content else _stream_url(run_id, "stderr"), + exit_code=run["run_log__exit_code"], + ) def task_log_dict(task_log: sqlite3.Row) -> dict: @@ -202,24 +206,43 @@ def get_task_logs(c: sqlite3.Cursor, run_id: uuid.UUID | str) -> list: return [task_log_dict(task_log) for task_log in c.fetchall()] -def get_run_details(c: sqlite3.Cursor, run_id: uuid.UUID | str) -> tuple[None, str] | tuple[dict, None]: - # Runs, run requests, and run logs are created at the same time, so if any of them is missing return None. +def run_from_row(run: sqlite3.Row) -> Run: + return Run(run_id=run["id"], state=run["state"]) - c.execute("SELECT * FROM runs WHERE id = ?", (str(run_id),)) - run = c.fetchone() - if run is None: - return None, "Missing entry in table 'runs'" - c.execute("SELECT * FROM task_logs WHERE run_id = ?", (str(run_id),)) +def run_with_details_and_output_from_row( + c: sqlite3.Cursor, + run: sqlite3.Row, + stream_content: bool, +) -> RunWithDetailsAndOutput: + return RunWithDetailsAndOutput( + run_id=run["id"], + state=run["state"], + request=run_request_from_row(run), + run_log=run_log_from_row(run, stream_content), + task_logs=get_task_logs(c, run["id"]), + outputs=json.loads(run["outputs"]), + ) + + +def _get_run_row(c: sqlite3.Cursor, run_id: uuid.UUID | str) -> sqlite3.Row | None: + return c.execute("SELECT * FROM runs WHERE id = ?", (str(run_id),)).fetchone() - return { - "run_id": run["id"], - "request": run_request_dict(run), - "state": run["state"], - "run_log": run_log_dict(run), - "task_logs": get_task_logs(c, run["id"]), - "outputs": json.loads(run["outputs"]) - }, None + +def get_run(c: sqlite3.Cursor, run_id: uuid.UUID | str) -> Run | None: + if run := _get_run_row(c, run_id): + return run_from_row(run) + return None + + +def get_run_with_details( + c: sqlite3.Cursor, + run_id: uuid.UUID | str, + stream_content: bool, +) -> RunWithDetailsAndOutput | None: + if run := _get_run_row(c, run_id): + return run_with_details_and_output_from_row(c, run, stream_content) + return None def update_run_state_and_commit( @@ -236,4 +259,8 @@ def update_run_state_and_commit( c.execute("UPDATE runs SET state = ? WHERE id = ?", (state, str(run_id))) db.commit() if event_bus and publish_event: - event_bus.publish_service_event(SERVICE_ARTIFACT, EVENT_WES_RUN_UPDATED, get_run_details(c, run_id)[0]) + event_bus.publish_service_event( + SERVICE_ARTIFACT, + EVENT_WES_RUN_UPDATED, + get_run_with_details(c, run_id, stream_content=False).model_dump(), + ) diff --git a/bento_wes/models.py b/bento_wes/models.py index 6fff94a7..2f22b368 100644 --- a/bento_wes/models.py +++ b/bento_wes/models.py @@ -1,5 +1,5 @@ from datetime import datetime -from pydantic import BaseModel, ConfigDict, AnyUrl +from pydantic import BaseModel, ConfigDict, AnyUrl, Json from typing import Literal __all__ = [ @@ -20,7 +20,15 @@ class BentoWorkflowInput(BaseModel): type: Literal["string", "string[]", "number", "number[]", "enum", "enum[]", "file", "file[]"] required: bool = False, extensions: list[str] | None = None - value: str | None = None + + +class BentoWorkflowInputWithFileExtensions(BentoWorkflowInput): + type: Literal["file", "file[]"] + extensions: list[str] | None = None + + +class BentoWorkflowInputWithValue(BentoWorkflowInput): + value: Literal["FROM_CONFIG"] class BentoWorkflowOutput(BaseModel): @@ -33,10 +41,10 @@ class BentoWorkflowOutput(BaseModel): class BentoWorkflowMetadata(BaseModel): name: str description: str - action: Literal["BentoWorkflowMetadata", "analysis", "export"] + action: Literal["ingestion", "analysis", "export"] data_type: str | None = None file: str - inputs: list[BentoWorkflowInput] + inputs: list[BentoWorkflowInputWithValue | BentoWorkflowInputWithFileExtensions | BentoWorkflowInput] outputs: list[BentoWorkflowOutput] @@ -51,12 +59,12 @@ class BentoRunRequestTags(BaseModel): class RunRequest(BaseModel): - workflow_params: dict[str, str | int | float | bool] + workflow_params: Json[dict[str, str | int | float | bool]] workflow_type: Literal["WDL"] workflow_type_version: Literal["1.0"] - workflow_engine_parameters: dict[str, str] + workflow_engine_parameters: Json[dict[str, str]] workflow_url: AnyUrl - tags: BentoRunRequestTags + tags: Json[BentoRunRequestTags] class RunLog(BaseModel): diff --git a/bento_wes/runner.py b/bento_wes/runner.py index 6ce0b5bc..1c8e57fc 100644 --- a/bento_wes/runner.py +++ b/bento_wes/runner.py @@ -8,7 +8,7 @@ from .backends import WESBackend from .backends.cromwell_local import CromwellLocalBackend from .celery import celery -from .db import get_db, get_run_details, finish_run +from .db import get_db, get_run_with_details, finish_run from .events import get_new_event_bus from .workflows import parse_workflow_host_allow_list @@ -25,7 +25,7 @@ def run_workflow(self, run_id: uuid.UUID): # Checks ------------------------------------------------------------------ # Check that the run and its associated objects exist - run, err = get_run_details(c, run_id) + run, err = get_run_with_details(c, run_id) if run is None: logger.error(f"Cannot find run {run_id} ({err})") return diff --git a/bento_wes/runs.py b/bento_wes/runs.py index 3d842b12..d3aeb81a 100644 --- a/bento_wes/runs.py +++ b/bento_wes/runs.py @@ -1,7 +1,7 @@ import json import os import sqlite3 - +import pydantic import requests import shutil import traceback @@ -20,9 +20,16 @@ from . import states from .authz import authz_middleware, PERMISSION_INGEST_DATA, PERMISSION_VIEW_RUNS from .celery import celery -from .constants import RUN_PARAM_FROM_CONFIG +from .db import ( + get_db, + run_with_details_and_output_from_row, + get_run, + get_run_with_details, + update_run_state_and_commit, +) from .events import get_flask_event_bus from .logger import logger +from .models import RunRequest, Run, RunWithDetails from .runner import run_workflow from .types import RunStream from .workflows import ( @@ -33,99 +40,48 @@ parse_workflow_host_allow_list, ) -from .db import get_db, run_request_dict, run_log_dict, get_task_logs, get_run_details, update_run_state_and_commit - bp_runs = Blueprint("runs", __name__) -def _get_project_and_dataset_id_from_tags(tags: dict) -> tuple[str, str | None]: - project_id = tags["project_id"] - dataset_id = tags.get("dataset_id", None) - return project_id, dataset_id - - -def _get_project_and_dataset_id_from_run_request(run_request: dict) -> tuple[str, str | None]: - return _get_project_and_dataset_id_from_tags(run_request["tags"]) - - -def _check_runs_permission(runs_project_datasets: list[tuple[str, str | None]], permission: str) -> tuple[bool, ...]: +def _check_runs_permission(run_requests: list[RunRequest], permission: str) -> tuple[bool, ...]: if not current_app.config["AUTHZ_ENABLED"]: - return tuple([True] * len(runs_project_datasets)) # Assume we have permission for everything if authz disabled + return tuple([True] * len(run_requests)) # Assume we have permission for everything if authz disabled return authz_middleware.authz_post(request, "/policy/evaluate", body={ "requested_resource": [ { - "project": project_id, - **({"dataset": dataset_id} if dataset_id else {}), + "project": run_request.tags.project_id, + **({"dataset": run_request.tags.dataset_id} if run_request.tags.dataset_id else {}), } - for project_id, dataset_id in runs_project_datasets + for run_request in run_requests ], "required_permissions": [permission], }).json()["result"] -def _check_single_run_permission_and_mark(project_and_dataset: tuple[str, str | None], permission: str) -> bool: - p_res = _check_runs_permission([project_and_dataset], permission) +def _check_single_run_permission_and_mark(run_req: RunRequest, permission: str) -> bool: + p_res = _check_runs_permission([run_req], permission) # By calling this, the developer indicates that they will have handled permissions adequately: authz_middleware.mark_authz_done(request) return p_res and p_res[0] def _create_run(db: sqlite3.Connection, c: sqlite3.Cursor) -> Response: - assert "workflow_params" in request.form - assert "workflow_type" in request.form - assert "workflow_type_version" in request.form - assert "workflow_engine_parameters" in request.form - assert "workflow_url" in request.form - assert "tags" in request.form - - workflow_params = json.loads(request.form["workflow_params"]) - workflow_type = request.form["workflow_type"].upper().strip() - workflow_type_version = request.form["workflow_type_version"].strip() - workflow_engine_parameters = json.loads(request.form["workflow_engine_parameters"]) # TODO: Unused - workflow_url = request.form["workflow_url"].lower() # TODO: This can refer to an attachment - workflow_attachment_list = request.files.getlist("workflow_attachment") # TODO: Use this fully - tags = json.loads(request.form["tags"]) - - # TODO: Move Bento-specific stuff out somehow? - - # Bento-specific required tags - assert "workflow_id" in tags - assert "workflow_metadata" in tags - workflow_metadata = tags["workflow_metadata"] - assert "action" in workflow_metadata - - workflow_id = tags.get("workflow_id", workflow_url) + run_req = RunRequest(**request.form) + + # TODO: Use this fully + # - files inside the workflow + # - workflow_url can refer to an attachment + workflow_attachment_list = request.files.getlist("workflow_attachment") # Check ingest permissions before continuing - if not _check_single_run_permission_and_mark( - _get_project_and_dataset_id_from_tags(tags), PERMISSION_INGEST_DATA): + if not _check_single_run_permission_and_mark(run_req, PERMISSION_INGEST_DATA): return flask_forbidden_error("Forbidden") # We have permission - so continue --------- - # Don't accept anything (ex. CWL) other than WDL - assert workflow_type == "WDL" - assert workflow_type_version == "1.0" - - assert isinstance(workflow_params, dict) - assert isinstance(workflow_engine_parameters, dict) - assert isinstance(tags, dict) - - # Some workflow parameters depend on the WES application configuration - # and need to be added from there. - # The reserved keyword `FROM_CONFIG` is used to detect those inputs. - # All parameters in config are upper case. e.g. drs_url --> DRS_URL - for i in workflow_metadata["inputs"]: - if i.get("value") != RUN_PARAM_FROM_CONFIG: - continue - param_name = i["id"] - workflow_params[f"{workflow_id}.{param_name}"] = current_app.config.get(param_name.upper(), "") - - # TODO: Use JSON schemas for workflow params / engine parameters / tags - # Get list of allowed workflow hosts from configuration for any checks inside the runner # If it's blank, assume that means "any host is allowed" and pass None to the runner workflow_host_allow_list = parse_workflow_host_allow_list(current_app.config["WORKFLOW_HOST_ALLOW_LIST"]) @@ -151,11 +107,12 @@ def _create_run(db: sqlite3.Connection, c: sqlite3.Cursor) -> Response: auth_header_dict = {"Authorization": auth_header} if auth_header else {} try: - wm.download_or_copy_workflow(workflow_url, WorkflowType(workflow_type), auth_headers=auth_header_dict) + wm.download_or_copy_workflow( + run_req.workflow_url, WorkflowType(run_req.workflow_type), auth_headers=auth_header_dict) except UnsupportedWorkflowType: - return flask_bad_request_error(f"Unsupported workflow type: {workflow_type}") + return flask_bad_request_error(f"Unsupported workflow type: {run_req.workflow_type}") except (WorkflowDownloadError, requests.exceptions.ConnectionError) as e: - return flask_bad_request_error(f"Could not access workflow file: {workflow_url} (Python error: {e})") + return flask_bad_request_error(f"Could not access workflow file: {run_req.workflow_url} (Python error: {e})") # --- @@ -201,14 +158,14 @@ def _create_run(db: sqlite3.Connection, c: sqlite3.Cursor) -> Response: states.STATE_UNKNOWN, json.dumps({}), - json.dumps(workflow_params), - workflow_type, - workflow_type_version, - json.dumps(workflow_engine_parameters), - workflow_url, - json.dumps(tags), + json.dumps(run_req.workflow_params), + run_req.workflow_type, + run_req.workflow_type_version, + json.dumps(run_req.workflow_engine_parameters), + str(run_req.workflow_url), + run_req.tags.model_dump_json(), - workflow_id, + run_req.tags.workflow_id, )) db.commit() @@ -230,44 +187,28 @@ def run_list(): if request.method == "POST": try: return _create_run(db, c) + except pydantic.ValidationError: # TODO: Better error messages + authz_middleware.mark_authz_done(request) + logger.error(f"Encountered validation error: {traceback.format_exc()}") + return flask_bad_request_error("Validation error: bad run request format") except ValueError: authz_middleware.mark_authz_done(request) return flask_bad_request_error("Value error") - except AssertionError: # TODO: Better error messages - authz_middleware.mark_authz_done(request) - logger.error(f"Encountered assertion error: {traceback.format_exc()}") - return flask_bad_request_error("Assertion error: bad run request format") # GET # Bento Extension: Include run details with /runs request with_details = request.args.get("with_details", "false").lower() == "true" res_list = [] - perms_list: list[tuple[str, str | None]] = [] - - c.execute("SELECT * FROM runs") - - for r in c.fetchall(): - run = { - "run_id": r["id"], - "state": r["state"], - } - - run_req = run_request_dict(r) - - project_id, dataset_id = _get_project_and_dataset_id_from_run_request(run_req) - perms_list.append((project_id, dataset_id)) - - if with_details: - run["details"] = { - "run_id": r["id"], - "state": r["state"], - "request": run_req, - "run_log": run_log_dict(r), - "task_logs": get_task_logs(c, r["id"]) - } - - res_list.append(run) + perms_list: list[RunRequest] = [] + + for r in c.execute("SELECT * FROM runs").fetchall(): + run = run_with_details_and_output_from_row(c, r, stream_content=False) + perms_list.append(run.request) + res_list.append({ + **run.model_dump(mode="json", include={"run_id", "state"}), + **({"details": run.model_dump(mode="json", exclude={"outputs"})} if with_details else {}), + }) p_res = _check_runs_permission(perms_list, PERMISSION_VIEW_RUNS) res_list = [v for v, p in zip(res_list, p_res) if p] @@ -280,37 +221,32 @@ def run_list(): @bp_runs.route("/runs/", methods=["GET"]) def run_detail(run_id: uuid.UUID): authz_enabled = current_app.config["AUTHZ_ENABLED"] - run_details, err = get_run_details(get_db().cursor(), run_id) + run_details = get_run_with_details(get_db().cursor(), run_id, stream_content=False) if run_details is None: if authz_enabled: return flask_forbidden_error("Forbidden") else: - return flask_not_found_error(f"Run {run_id} not found ({err})") + return flask_not_found_error(f"Run {run_id} not found") - if not _check_single_run_permission_and_mark( - _get_project_and_dataset_id_from_run_request(run_details["request"]), PERMISSION_VIEW_RUNS): + if not _check_single_run_permission_and_mark(run_details.request, PERMISSION_VIEW_RUNS): return flask_forbidden_error("Forbidden") - if run_details is None and not authz_enabled: - return flask_not_found_error(f"Run {run_id} not found ({err})") - - return jsonify(run_details) + return jsonify(run_details.model_dump(mode="json")) def get_stream(c: sqlite3.Cursor, stream: RunStream, run_id: uuid.UUID): - c.execute("SELECT * FROM runs WHERE id = ?", (str(run_id),)) - run = c.fetchone() + run = get_run_with_details(c, run_id, stream_content=True) return (current_app.response_class( headers={ # If we've finished, we allow long-term (24h) caching of the stdout/stderr responses. # Otherwise, no caching allowed! "Cache-Control": ( - "private, max-age=86400" if run["state"] in states.TERMINATED_STATES + "private, max-age=86400" if run.state in states.TERMINATED_STATES else "no-cache, no-store, must-revalidate, max-age=0" ), }, - response=run[f"run_log__{stream}"], + response=run.run_log.stdout if stream == "stdout" else run.run_log.stderr, mimetype="text/plain", status=200, ) if run is not None else flask_not_found_error(f"Stream {stream} not found for run {run_id}")) @@ -322,18 +258,17 @@ def check_run_authz_then_return_response( cb: Callable[[], Response | dict], permission: str = PERMISSION_VIEW_RUNS, ): - run_details, rd_err = get_run_details(c, run_id) + run = get_run_with_details(c, run_id, stream_content=False) - if rd_err: + if run is None: if current_app.config["AUTHZ_ENABLED"]: # Without the required permissions, don't even leak if this run exists - just return forbidden authz_middleware.mark_authz_done(request) return flask_forbidden_error("Forbidden") else: - return flask_not_found_error(rd_err) + return flask_not_found_error(f"Run {run_id} not found") - if not _check_single_run_permission_and_mark( - _get_project_and_dataset_id_from_run_request(run_details["request"]), permission): + if not _check_single_run_permission_and_mark(run.request, permission): return flask_forbidden_error("Forbidden") return cb() @@ -351,6 +286,13 @@ def run_stderr(run_id: uuid.UUID): return check_run_authz_then_return_response(c, run_id, lambda: get_stream(c, "stderr", run_id)) +RUN_CANCEL_BAD_REQUEST_STATES = ( + ((states.STATE_CANCELING, states.STATE_CANCELED), "Run already canceled"), + (states.FAILURE_STATES, "Run already terminated with error"), + (states.SUCCESS_STATES, "Run already completed"), +) + + @bp_runs.route("/runs//cancel", methods=["POST"]) def run_cancel(run_id: uuid.UUID): # TODO: Check if already completed @@ -359,42 +301,38 @@ def run_cancel(run_id: uuid.UUID): db = get_db() c = db.cursor() - def perform_run_cancel(): - c.execute("SELECT * FROM runs WHERE id = ?", (str(run_id),)) - run = c.fetchone() + run_id_str = str(run_id) - if run is None: - return flask_not_found_error(f"Run {run_id} not found") - - if run["state"] in (states.STATE_CANCELING, states.STATE_CANCELED): - return flask_bad_request_error("Run already canceled") + def perform_run_cancel() -> Response: + run = get_run_with_details(c, run_id_str, stream_content=False) - if run["state"] in states.FAILURE_STATES: - return flask_bad_request_error("Run already terminated with error") + if run is None: + return flask_not_found_error(f"Run {run_id_str} not found") - if run["state"] in states.SUCCESS_STATES: - return flask_bad_request_error("Run already completed") + for bad_req_states, bad_req_err in RUN_CANCEL_BAD_REQUEST_STATES: + if run.state in bad_req_states: + return flask_bad_request_error(bad_req_err) - celery_id = run["run_log__celery_id"] + celery_id = run.run_log.celery_id if celery_id is None: # Never made it into the queue, so "cancel" it - return flask_internal_server_error(f"No Celery ID present for run {run_id}") + return flask_internal_server_error(f"No Celery ID present for run {run_id_str}") event_bus = get_flask_event_bus() # TODO: terminate=True might be iffy - update_run_state_and_commit(db, c, run["id"], states.STATE_CANCELING, event_bus=event_bus) + update_run_state_and_commit(db, c, run_id_str, states.STATE_CANCELING, event_bus=event_bus) celery.control.revoke(celery_id, terminate=True) # Remove from queue if there, terminate if running # TODO: wait for revocation / failure and update status... # TODO: Generalize clean-up code / fetch from back-end - run_dir = os.path.join(current_app.config["SERVICE_TEMP"], run["run_id"]) + run_dir = os.path.join(current_app.config["SERVICE_TEMP"], run_id_str) if not current_app.config["BENTO_DEBUG"]: shutil.rmtree(run_dir, ignore_errors=True) - update_run_state_and_commit(db, c, run["id"], states.STATE_CANCELED, event_bus=event_bus) + update_run_state_and_commit(db, c, run_id_str, states.STATE_CANCELED, event_bus=event_bus) return current_app.response_class(status=204) # TODO: Better response @@ -405,16 +343,9 @@ def perform_run_cancel(): def run_status(run_id: uuid.UUID): c = get_db().cursor() - def run_status_response(): - c.execute("SELECT * FROM runs WHERE id = ?", (str(run_id),)) - run = c.fetchone() - - if run is None: - return flask_not_found_error(f"Run {run_id} not found") - - return jsonify({ - "run_id": run["id"], - "state": run["state"] - }) + def run_status_response() -> Response: + if run := get_run(c, run_id): + return jsonify(run.model_dump()) + return flask_not_found_error(f"Run {run_id} not found") return check_run_authz_then_return_response(c, run_id, run_status_response) diff --git a/bento_wes/workflows.py b/bento_wes/workflows.py index fbb0cb23..6d5887ae 100644 --- a/bento_wes/workflows.py +++ b/bento_wes/workflows.py @@ -5,6 +5,7 @@ import requests from base64 import urlsafe_b64encode +from pydantic import AnyUrl from typing import NewType from urllib.parse import urlparse @@ -92,19 +93,19 @@ def _error(self, message: str): if self.logger: self.logger.error(message) - def workflow_path(self, workflow_uri: str, workflow_type: WorkflowType) -> str: + def workflow_path(self, workflow_uri: AnyUrl, workflow_type: WorkflowType) -> str: """ Generates a unique filesystem path name for a specified workflow URI. """ if workflow_type not in WES_SUPPORTED_WORKFLOW_TYPES: raise UnsupportedWorkflowType(f"Unsupported workflow type: {workflow_type}") - workflow_name = str(urlsafe_b64encode(bytes(workflow_uri, encoding="utf-8")), encoding="utf-8") + workflow_name = str(urlsafe_b64encode(bytes(str(workflow_uri), encoding="utf-8")), encoding="utf-8") return os.path.join(self.tmp_dir, f"workflow_{workflow_name}.{WORKFLOW_EXTENSIONS[workflow_type]}") def download_or_copy_workflow( self, - workflow_uri: str, + workflow_uri: AnyUrl, workflow_type: WorkflowType, auth_headers: dict, ) -> str | None: @@ -116,23 +117,23 @@ def download_or_copy_workflow( :param auth_headers: Authorization headers to pass while requesting the workflow file. """ - parsed_wf_uri = urlparse(workflow_uri) # TODO: Handle errors, handle references to attachments + # TODO: Handle references to attachments workflow_path = self.workflow_path(workflow_uri, workflow_type) - if parsed_wf_uri.scheme not in ALLOWED_WORKFLOW_REQUEST_SCHEMES: # file:// + if workflow_uri.scheme not in ALLOWED_WORKFLOW_REQUEST_SCHEMES: # file:// # TODO: Other else cases # TODO: Handle exceptions - shutil.copyfile(parsed_wf_uri.path, workflow_path) + shutil.copyfile(workflow_uri.path, workflow_path) return if self.workflow_host_allow_list is not None: # We need to check that the workflow in question is from an # allowed set of workflow hosts - if parsed_wf_uri.scheme != "file" and parsed_wf_uri.netloc not in self.workflow_host_allow_list: + if workflow_uri.scheme != "file" and workflow_uri.netloc not in self.workflow_host_allow_list: # Dis-allowed workflow URL self._error( - f"Dis-allowed workflow host: {parsed_wf_uri.netloc} (allow list: {self.workflow_host_allow_list})") + f"Dis-allowed workflow host: {workflow_uri.netloc} (allow list: {self.workflow_host_allow_list})") return states.STATE_EXECUTOR_ERROR self._info(f"Fetching workflow file from {workflow_uri}") @@ -145,15 +146,15 @@ def download_or_copy_workflow( parsed_bento_url = urlparse(self.bento_url) use_auth_headers = all(( self.bento_url, - parsed_bento_url.scheme == parsed_wf_uri.scheme, - parsed_bento_url.netloc == parsed_wf_uri.netloc, - parsed_wf_uri.path.startswith(parsed_bento_url.path), + parsed_bento_url.scheme == workflow_uri.scheme, + parsed_bento_url.netloc == workflow_uri.netloc, + workflow_uri.path.startswith(parsed_bento_url.path), )) # TODO: Better auth? May only be allowed to access specific workflows try: wr = requests.get( - workflow_uri, + str(workflow_uri), headers={ "Host": urlparse(self.service_base_url or "").netloc or "", **(auth_headers if use_auth_headers else {}), diff --git a/tests/test_runs.py b/tests/test_runs.py index 564b2bcb..c7d2a8ee 100644 --- a/tests/test_runs.py +++ b/tests/test_runs.py @@ -18,6 +18,13 @@ def _add_workflow_response(r): content_type="text/plain") +def _create_valid_run(client): + rv = client.post("/runs", data=EXAMPLE_RUN_BODY) + data = rv.get_json() + assert rv.status_code == 200 # 200 is WES spec, even though 201 would be better (?) + return data + + def test_runs_endpoint(client, mocked_responses): _add_workflow_response(mocked_responses) @@ -26,9 +33,7 @@ def test_runs_endpoint(client, mocked_responses): data = rv.get_json() assert json.dumps(data) == json.dumps([]) - rv = client.post("/runs", data=EXAMPLE_RUN_BODY) - assert rv.status_code == 200 # 200 is WES spec, even though 201 would be better (?) - cr_data = rv.get_json() + cr_data = _create_valid_run(client) assert "run_id" in cr_data rv = client.get("/runs") @@ -71,15 +76,13 @@ def test_run_create_errors(client): assert rv.status_code == 400 error = rv.get_json() assert len(error["errors"]) == 1 - assert error["errors"][0]["message"].startswith("Assertion error") + assert error["errors"][0]["message"].startswith("Validation error") def test_run_detail_endpoint(client, mocked_responses): _add_workflow_response(mocked_responses) - rv = client.post("/runs", data=EXAMPLE_RUN_BODY) - assert rv.status_code == 200 - cr_data = rv.get_json() + cr_data = _create_valid_run(client) rv = client.get(f"/runs/{uuid.uuid4()}") assert rv.status_code == 404 @@ -110,8 +113,7 @@ def test_run_detail_endpoint(client, mocked_responses): def test_run_status_endpoint(client, mocked_responses): _add_workflow_response(mocked_responses) - rv = client.post("/runs", data=EXAMPLE_RUN_BODY) - cr_data = rv.get_json() + cr_data = _create_valid_run(client) rv = client.get(f"/runs/{uuid.uuid4()}/status") assert rv.status_code == 404 @@ -124,8 +126,7 @@ def test_run_status_endpoint(client, mocked_responses): def test_run_streams(client, mocked_responses): _add_workflow_response(mocked_responses) - rv = client.post("/runs", data=EXAMPLE_RUN_BODY) - cr_data = rv.get_json() + cr_data = _create_valid_run(client) rv = client.get(f"/runs/{uuid.uuid4()}/stdout") assert rv.status_code == 404 @@ -145,8 +146,7 @@ def test_run_streams(client, mocked_responses): def test_run_cancel_endpoint(client, mocked_responses): _add_workflow_response(mocked_responses) - rv = client.post("/runs", data=EXAMPLE_RUN_BODY) - cr_data = rv.get_json() + cr_data = _create_valid_run(client) rv = client.post(f"/runs/{uuid.uuid4()}/cancel") assert rv.status_code == 404