Skip to content

Commit

Permalink
lint & hint
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Jul 6, 2023
1 parent bb772b1 commit 6733751
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 22 deletions.
15 changes: 7 additions & 8 deletions bento_wes/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from bento_lib.events.notifications import format_notification
from bento_lib.events.types import EVENT_CREATE_NOTIFICATION, EVENT_WES_RUN_UPDATED
from flask import current_app, g
from typing import Optional, Tuple, Union
from urllib.parse import urljoin

from . import states
Expand Down Expand Up @@ -66,7 +65,7 @@ def finish_run(
event_bus: EventBus,
run: dict,
state: str,
logger: Optional[logging.Logger] = None,
logger: logging.Logger | None = None,
) -> None:
"""
Updates a run's state, sets the run log's end time, and publishes an event corresponding with a run failure
Expand Down Expand Up @@ -169,14 +168,14 @@ def _strip_first_slash(string: str):
return string[1:] if len(string) > 0 and string[0] == "/" else string


def _stream_url(run_id: Union[uuid.UUID, str], stream: str):
def _stream_url(run_id: uuid.UUID | str, stream: str):
return urljoin(
urljoin(current_app.config["CHORD_URL"], _strip_first_slash(current_app.config["SERVICE_URL_BASE_PATH"]) + "/"),
f"runs/{str(run_id)}/{stream}"
)


def run_log_dict(run_id: Union[uuid.UUID, str], run_log: sqlite3.Row) -> dict:
def run_log_dict(run_id: uuid.UUID | str, run_log: sqlite3.Row) -> dict:
return {
"id": run_log["id"], # TODO: This is non-WES-compliant
"name": run_log["name"],
Expand All @@ -201,12 +200,12 @@ def task_log_dict(task_log: sqlite3.Row) -> dict:
}


def get_task_logs(c: sqlite3.Cursor, run_id: Union[uuid.UUID, str]) -> list:
def get_task_logs(c: sqlite3.Cursor, run_id: uuid.UUID | str) -> list:
c.execute("SELECT * FROM task_logs WHERE run_id = ?", (str(run_id),))
return [task_log_dict(task_log) for task_log in c.fetchall()]


def get_run_details(c: sqlite3.Cursor, run_id: Union[uuid.UUID, str]) -> tuple[None, str] | tuple[dict, None]:
def get_run_details(c: sqlite3.Cursor, run_id: uuid.UUID | str) -> tuple[None, str] | tuple[dict, None]:
# Runs, run requests, and run logs are created at the same time, so if any of them is missing return None.

c.execute("SELECT * FROM runs WHERE id = ?", (str(run_id),))
Expand Down Expand Up @@ -240,9 +239,9 @@ def update_run_state_and_commit(
db: sqlite3.Connection,
c: sqlite3.Cursor,
event_bus: EventBus,
run_id: Union[uuid.UUID, str],
run_id: uuid.UUID | str,
state: str,
logger: Optional[logging.Logger] = None,
logger: logging.Logger | None = None,
):
if logger:
logger.info(f"Updating run state of {run_id} to {state}")
Expand Down
2 changes: 1 addition & 1 deletion bento_wes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
logger = get_task_logger(__name__)


def build_workflow_outputs(run_dir, workflow_id, workflow_params: dict, c_workflow_metadata: dict):
def build_workflow_outputs(run_dir, workflow_id: str, workflow_params: dict, c_workflow_metadata: dict):
logger.info(f"Building workflow outputs for workflow ID {workflow_id} "
f"(WRITE_OUTPUT_TO_DRS={current_app.config['WRITE_OUTPUT_TO_DRS']})")
output_params = w.make_output_params(workflow_id, workflow_params, c_workflow_metadata["inputs"])
Expand Down
10 changes: 5 additions & 5 deletions bento_wes/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _check_single_run_permission_and_mark(project_and_dataset: tuple[str | None,
return p_res and p_res[0]


def _create_run(db, c):
def _create_run(db: sqlite3.Connection, c: sqlite3.Cursor) -> Response:
try:
assert "workflow_params" in request.form
assert "workflow_type" in request.form
Expand Down Expand Up @@ -299,7 +299,7 @@ def run_list():


@bp_runs.route("/runs/<uuid:run_id>", methods=["GET"])
def run_detail(run_id):
def run_detail(run_id: uuid.UUID):
run_details, err = get_run_details(get_db().cursor(), run_id)

if not _check_single_run_permission_and_mark(
Expand Down Expand Up @@ -354,13 +354,13 @@ def run_stdout(run_id: uuid.UUID):


@bp_runs.route("/runs/<uuid:run_id>/stderr", methods=["GET"])
def run_stderr(run_id):
def run_stderr(run_id: uuid.UUID):
c = get_db().cursor()
return check_run_authz_then_return_response(c, run_id, lambda: get_stream(c, "stderr", run_id))


@bp_runs.route("/runs/<uuid:run_id>/cancel", methods=["POST"])
def run_cancel(run_id):
def run_cancel(run_id: uuid.UUID):
# TODO: Check if already completed
# TODO: Check if run log exists
# TODO: from celery.task.control import revoke; revoke(celery_id, terminate=True)
Expand Down Expand Up @@ -414,7 +414,7 @@ def perform_run_cancel():


@bp_runs.route("/runs/<uuid:run_id>/status", methods=["GET"])
def run_status(run_id):
def run_status(run_id: uuid.UUID):
# TODO: check permissions based on project/dataset

c = get_db().cursor()
Expand Down
2 changes: 1 addition & 1 deletion bento_wes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
__all__ = ["iso_now"]


def iso_now():
def iso_now() -> str:
return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") # ISO date format
23 changes: 16 additions & 7 deletions bento_wes/workflows.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging

import bento_lib.workflows as w
import os
import shutil
import requests

from base64 import urlsafe_b64encode
from typing import Dict, NewType, Optional, Set
from typing import NewType
from urllib.parse import urlparse

from bento_wes import states
Expand All @@ -28,7 +30,7 @@
# Currently, only WDL is supported
WES_SUPPORTED_WORKFLOW_TYPES = frozenset({WES_WORKFLOW_TYPE_WDL})

WORKFLOW_EXTENSIONS: Dict[WorkflowType, str] = {
WORKFLOW_EXTENSIONS: dict[WorkflowType, str] = {
WES_WORKFLOW_TYPE_WDL: "wdl",
WES_WORKFLOW_TYPE_CWL: "cwl",
}
Expand All @@ -40,7 +42,7 @@


# TODO: Types for params/metadata
def count_bento_workflow_file_outputs(workflow_id, workflow_params: dict, workflow_metadata: dict) -> int:
def count_bento_workflow_file_outputs(workflow_id: str, workflow_params: dict, workflow_metadata: dict) -> int:
"""
Given a workflow run's parameters and workflow metadata, returns the number
of files being output for the purposes of generating one-time ingest tokens
Expand All @@ -66,7 +68,7 @@ def count_bento_workflow_file_outputs(workflow_id, workflow_params: dict, workfl
return n_file_outputs


def parse_workflow_host_allow_list(allow_list: Optional[str]) -> Optional[Set[str]]:
def parse_workflow_host_allow_list(allow_list: str | None) -> set[str] | None:
"""
Get set of allowed workflow hosts from a configuration string for any
checks while downloading workflows. If it's blank, assume that means
Expand All @@ -87,8 +89,15 @@ class WorkflowDownloadError(Exception):


class WorkflowManager:
def __init__(self, tmp_dir: str, chord_url: Optional[str] = None, logger: Optional = None,
workflow_host_allow_list: Optional[set] = None, validate_ssl: bool = True, debug: bool = False):
def __init__(
self,
tmp_dir: str,
chord_url: str | None = None,
logger: logging.Logger | None = None,
workflow_host_allow_list: str | None = None,
validate_ssl: bool = True,
debug: bool = False,
):
self.tmp_dir = tmp_dir
self.chord_url = chord_url
self.logger = logger
Expand Down Expand Up @@ -121,7 +130,7 @@ def workflow_path(self, workflow_uri: str, workflow_type: WorkflowType) -> str:
return os.path.join(self.tmp_dir, f"workflow_{workflow_name}.{WORKFLOW_EXTENSIONS[workflow_type]}")

def download_or_copy_workflow(self, workflow_uri: str, workflow_type: WorkflowType, auth_headers: dict) \
-> Optional[str]:
-> str | None:
"""
Given a URI, downloads the specified workflow via its URI, or copies it over if it's on the local
file system. # TODO: Local file system = security issue?
Expand Down

0 comments on commit 6733751

Please sign in to comment.