From 6733751cfa804e3bb88424a81cd4c7699891d7a5 Mon Sep 17 00:00:00 2001 From: David Lougheed Date: Thu, 6 Jul 2023 13:37:31 -0400 Subject: [PATCH] lint & hint --- bento_wes/db.py | 15 +++++++-------- bento_wes/runner.py | 2 +- bento_wes/runs.py | 10 +++++----- bento_wes/utils.py | 2 +- bento_wes/workflows.py | 23 ++++++++++++++++------- 5 files changed, 30 insertions(+), 22 deletions(-) diff --git a/bento_wes/db.py b/bento_wes/db.py index 73b391d2..a28799f9 100644 --- a/bento_wes/db.py +++ b/bento_wes/db.py @@ -7,7 +7,6 @@ from bento_lib.events.notifications import format_notification from bento_lib.events.types import EVENT_CREATE_NOTIFICATION, EVENT_WES_RUN_UPDATED from flask import current_app, g -from typing import Optional, Tuple, Union from urllib.parse import urljoin from . import states @@ -66,7 +65,7 @@ def finish_run( event_bus: EventBus, run: dict, state: str, - logger: Optional[logging.Logger] = None, + logger: logging.Logger | None = None, ) -> None: """ Updates a run's state, sets the run log's end time, and publishes an event corresponding with a run failure @@ -169,14 +168,14 @@ def _strip_first_slash(string: str): return string[1:] if len(string) > 0 and string[0] == "/" else string -def _stream_url(run_id: Union[uuid.UUID, str], stream: str): +def _stream_url(run_id: uuid.UUID | str, stream: str): return urljoin( urljoin(current_app.config["CHORD_URL"], _strip_first_slash(current_app.config["SERVICE_URL_BASE_PATH"]) + "/"), f"runs/{str(run_id)}/{stream}" ) -def run_log_dict(run_id: Union[uuid.UUID, str], run_log: sqlite3.Row) -> dict: +def run_log_dict(run_id: uuid.UUID | str, run_log: sqlite3.Row) -> dict: return { "id": run_log["id"], # TODO: This is non-WES-compliant "name": run_log["name"], @@ -201,12 +200,12 @@ def task_log_dict(task_log: sqlite3.Row) -> dict: } -def get_task_logs(c: sqlite3.Cursor, run_id: Union[uuid.UUID, str]) -> list: +def get_task_logs(c: sqlite3.Cursor, run_id: uuid.UUID | str) -> list: c.execute("SELECT * FROM task_logs WHERE run_id = ?", (str(run_id),)) return [task_log_dict(task_log) for task_log in c.fetchall()] -def get_run_details(c: sqlite3.Cursor, run_id: Union[uuid.UUID, str]) -> tuple[None, str] | tuple[dict, None]: +def get_run_details(c: sqlite3.Cursor, run_id: uuid.UUID | str) -> tuple[None, str] | tuple[dict, None]: # Runs, run requests, and run logs are created at the same time, so if any of them is missing return None. c.execute("SELECT * FROM runs WHERE id = ?", (str(run_id),)) @@ -240,9 +239,9 @@ def update_run_state_and_commit( db: sqlite3.Connection, c: sqlite3.Cursor, event_bus: EventBus, - run_id: Union[uuid.UUID, str], + run_id: uuid.UUID | str, state: str, - logger: Optional[logging.Logger] = None, + logger: logging.Logger | None = None, ): if logger: logger.info(f"Updating run state of {run_id} to {state}") diff --git a/bento_wes/runner.py b/bento_wes/runner.py index eb03fc68..6667decf 100644 --- a/bento_wes/runner.py +++ b/bento_wes/runner.py @@ -20,7 +20,7 @@ logger = get_task_logger(__name__) -def build_workflow_outputs(run_dir, workflow_id, workflow_params: dict, c_workflow_metadata: dict): +def build_workflow_outputs(run_dir, workflow_id: str, workflow_params: dict, c_workflow_metadata: dict): logger.info(f"Building workflow outputs for workflow ID {workflow_id} " f"(WRITE_OUTPUT_TO_DRS={current_app.config['WRITE_OUTPUT_TO_DRS']})") output_params = w.make_output_params(workflow_id, workflow_params, c_workflow_metadata["inputs"]) diff --git a/bento_wes/runs.py b/bento_wes/runs.py index 6f08c01f..f6c9e0a4 100644 --- a/bento_wes/runs.py +++ b/bento_wes/runs.py @@ -68,7 +68,7 @@ def _check_single_run_permission_and_mark(project_and_dataset: tuple[str | None, return p_res and p_res[0] -def _create_run(db, c): +def _create_run(db: sqlite3.Connection, c: sqlite3.Cursor) -> Response: try: assert "workflow_params" in request.form assert "workflow_type" in request.form @@ -299,7 +299,7 @@ def run_list(): @bp_runs.route("/runs/", methods=["GET"]) -def run_detail(run_id): +def run_detail(run_id: uuid.UUID): run_details, err = get_run_details(get_db().cursor(), run_id) if not _check_single_run_permission_and_mark( @@ -354,13 +354,13 @@ def run_stdout(run_id: uuid.UUID): @bp_runs.route("/runs//stderr", methods=["GET"]) -def run_stderr(run_id): +def run_stderr(run_id: uuid.UUID): c = get_db().cursor() return check_run_authz_then_return_response(c, run_id, lambda: get_stream(c, "stderr", run_id)) @bp_runs.route("/runs//cancel", methods=["POST"]) -def run_cancel(run_id): +def run_cancel(run_id: uuid.UUID): # TODO: Check if already completed # TODO: Check if run log exists # TODO: from celery.task.control import revoke; revoke(celery_id, terminate=True) @@ -414,7 +414,7 @@ def perform_run_cancel(): @bp_runs.route("/runs//status", methods=["GET"]) -def run_status(run_id): +def run_status(run_id: uuid.UUID): # TODO: check permissions based on project/dataset c = get_db().cursor() diff --git a/bento_wes/utils.py b/bento_wes/utils.py index d5a46d20..ee0c07db 100644 --- a/bento_wes/utils.py +++ b/bento_wes/utils.py @@ -4,5 +4,5 @@ __all__ = ["iso_now"] -def iso_now(): +def iso_now() -> str: return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") # ISO date format diff --git a/bento_wes/workflows.py b/bento_wes/workflows.py index 7a53adc9..5531662e 100644 --- a/bento_wes/workflows.py +++ b/bento_wes/workflows.py @@ -1,10 +1,12 @@ +import logging + import bento_lib.workflows as w import os import shutil import requests from base64 import urlsafe_b64encode -from typing import Dict, NewType, Optional, Set +from typing import NewType from urllib.parse import urlparse from bento_wes import states @@ -28,7 +30,7 @@ # Currently, only WDL is supported WES_SUPPORTED_WORKFLOW_TYPES = frozenset({WES_WORKFLOW_TYPE_WDL}) -WORKFLOW_EXTENSIONS: Dict[WorkflowType, str] = { +WORKFLOW_EXTENSIONS: dict[WorkflowType, str] = { WES_WORKFLOW_TYPE_WDL: "wdl", WES_WORKFLOW_TYPE_CWL: "cwl", } @@ -40,7 +42,7 @@ # TODO: Types for params/metadata -def count_bento_workflow_file_outputs(workflow_id, workflow_params: dict, workflow_metadata: dict) -> int: +def count_bento_workflow_file_outputs(workflow_id: str, workflow_params: dict, workflow_metadata: dict) -> int: """ Given a workflow run's parameters and workflow metadata, returns the number of files being output for the purposes of generating one-time ingest tokens @@ -66,7 +68,7 @@ def count_bento_workflow_file_outputs(workflow_id, workflow_params: dict, workfl return n_file_outputs -def parse_workflow_host_allow_list(allow_list: Optional[str]) -> Optional[Set[str]]: +def parse_workflow_host_allow_list(allow_list: str | None) -> set[str] | None: """ Get set of allowed workflow hosts from a configuration string for any checks while downloading workflows. If it's blank, assume that means @@ -87,8 +89,15 @@ class WorkflowDownloadError(Exception): class WorkflowManager: - def __init__(self, tmp_dir: str, chord_url: Optional[str] = None, logger: Optional = None, - workflow_host_allow_list: Optional[set] = None, validate_ssl: bool = True, debug: bool = False): + def __init__( + self, + tmp_dir: str, + chord_url: str | None = None, + logger: logging.Logger | None = None, + workflow_host_allow_list: str | None = None, + validate_ssl: bool = True, + debug: bool = False, + ): self.tmp_dir = tmp_dir self.chord_url = chord_url self.logger = logger @@ -121,7 +130,7 @@ def workflow_path(self, workflow_uri: str, workflow_type: WorkflowType) -> str: return os.path.join(self.tmp_dir, f"workflow_{workflow_name}.{WORKFLOW_EXTENSIONS[workflow_type]}") def download_or_copy_workflow(self, workflow_uri: str, workflow_type: WorkflowType, auth_headers: dict) \ - -> Optional[str]: + -> str | None: """ Given a URI, downloads the specified workflow via its URI, or copies it over if it's on the local file system. # TODO: Local file system = security issue?