Skip to content

Commit

Permalink
Merge pull request #188 from opensafely-core/evansd/add-mypy
Browse files Browse the repository at this point in the history
Add mypy type checking
  • Loading branch information
evansd authored Mar 22, 2024
2 parents c422602 + 4b8dd74 commit 07b5d3b
Show file tree
Hide file tree
Showing 19 changed files with 327 additions and 107 deletions.
144 changes: 101 additions & 43 deletions airlock/business_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path

# we use PurePosixPath as a convenient url path representation
from pathlib import PurePosixPath as UrlPath
from typing import Optional, Protocol
from pathlib import Path, PurePosixPath
from typing import TYPE_CHECKING, Protocol, Self, cast

from django.conf import settings
from django.shortcuts import reverse
from django.urls import reverse
from django.utils.functional import SimpleLazyObject
from django.utils.module_loading import import_string

Expand All @@ -20,6 +17,17 @@
from airlock.users import User


# We use PurePosixPath as a convenient URL path representation. In theory we could use
# `NewType` here to indicate that we want this to be treated as a distinct type without
# actually creating one. But doing so results in a number of spurious type errors for
# reasons I don't fully understand (possibly because PurePosixPath isn't itself type
# annotated?).
if TYPE_CHECKING: # pragma: no cover

class UrlPath(PurePosixPath): ...
else:
UrlPath = PurePosixPath

ROOT_PATH = UrlPath() # empty path


Expand Down Expand Up @@ -75,6 +83,9 @@ def get_contents_url(
def is_supporting_file(self, relpath: UrlPath):
"""Is this path a supporting file?"""

def abspath(self, path: UrlPath) -> Path:
"""Get the absolute path of the container object with path"""


@dataclass(order=True)
class Workspace:
Expand All @@ -85,7 +96,7 @@ class Workspace:
"""

name: str
metadata: dict = field(default_factory=dict)
metadata: dict[str, str] = field(default_factory=dict)

# can be set to mark the currently selected path in this workspace
selected_path: UrlPath = ROOT_PATH
Expand Down Expand Up @@ -172,7 +183,7 @@ class RequestFile:
filetype: RequestFileType = RequestFileType.OUTPUT

@classmethod
def from_dict(cls, attrs):
def from_dict(cls, attrs) -> Self:
return cls(
**{k: v for k, v in attrs.items() if k != "reviews"},
reviews=[FileReview.from_dict(value) for value in attrs.get("reviews", ())],
Expand All @@ -186,7 +197,7 @@ class FileGroup:
"""

name: str
files: dict[RequestFile]
files: dict[UrlPath, RequestFile]

@property
def output_files(self):
Expand All @@ -199,7 +210,7 @@ def supporting_files(self):
]

@classmethod
def from_dict(cls, attrs):
def from_dict(cls, attrs) -> Self:
return cls(
**{k: v for k, v in attrs.items() if k != "files"},
files={
Expand All @@ -224,13 +235,13 @@ class ReleaseRequest:
author: str
created_at: datetime
status: RequestStatus = RequestStatus.PENDING
filegroups: dict[FileGroup] = field(default_factory=dict)
filegroups: dict[str, FileGroup] = field(default_factory=dict)

# can be set to mark the currently selected path in this release request
selected_path: UrlPath = ROOT_PATH

@classmethod
def from_dict(cls, attrs):
def from_dict(cls, attrs) -> Self:
return cls(
**{k: v for k, v in attrs.items() if k != "filegroups"},
filegroups=cls._filegroups_from_dict(attrs.get("filegroups", {})),
Expand Down Expand Up @@ -346,12 +357,48 @@ def store_file(release_request: ReleaseRequest, abspath: Path) -> str:
return digest


class DataAccessLayerProtocol:
class DataAccessLayerProtocol(Protocol):
"""
Placeholder for a structural type class we can use to define what a data access
layer should look like, once we've settled what that is.
Structural type class for the Data Access Layer
Implementations aren't obliged to subclass this as long as they implement the
specified methods, though it may be clearer to do so.
"""

def get_release_request(self, request_id: str):
raise NotImplementedError()

def create_release_request(self, **kwargs):
raise NotImplementedError()

def get_active_requests_for_workspace_by_user(self, workspace: str, username: str):
raise NotImplementedError()

def get_requests_authored_by_user(self, username: str):
raise NotImplementedError()

def get_outstanding_requests_for_review(self):
raise NotImplementedError()

def set_status(self, request_id: str, status: RequestStatus):
raise NotImplementedError()

def add_file_to_request(
self,
request_id,
relpath: UrlPath,
file_id: str,
group_name: str,
filetype: RequestFileType,
):
raise NotImplementedError()

def approve_file(self, request_id: str, relpath: UrlPath, username: str):
raise NotImplementedError()

def reject_file(self, request_id: str, relpath: UrlPath, username: str):
raise NotImplementedError()


class BusinessLogicLayer:
"""
Expand Down Expand Up @@ -437,38 +484,47 @@ def get_release_request(self, request_id: str, user: User) -> ReleaseRequest:
return release_request

def get_current_request(
self, workspace_name: str, user: User, create: bool = False
) -> ReleaseRequest:
"""Get the current request for the a workspace/user.
If create is True, create one.
"""
self, workspace_name: str, user: User
) -> ReleaseRequest | None:
"""Get the current request for a workspace/user."""
active_requests = self._dal.get_active_requests_for_workspace_by_user(
workspace=workspace_name,
username=user.username,
)

n = len(active_requests)
if n > 1:
raise Exception(
f"Multiple active release requests for user {user.username} in workspace {workspace_name}"
)
if n == 0:
return None
elif n == 1:
return ReleaseRequest.from_dict(active_requests[0])
elif create:
# To create a request, you must have explicit workspace permissions.
# Output checkers can view all workspaces, but are not allowed to
# create requests for all workspaces.
if workspace_name not in user.workspaces:
raise BusinessLogicLayer.RequestPermissionDenied(workspace_name)

new_request = self._dal.create_release_request(
workspace=workspace_name,
author=user.username,
)
return ReleaseRequest.from_dict(new_request)
else:
return None
raise Exception(
f"Multiple active release requests for user {user.username} in "
f"workspace {workspace_name}"
)

def get_or_create_current_request(
self, workspace_name: str, user: User
) -> ReleaseRequest:
"""
Get the current request for a workspace/user, or create a new one if there is
none.
"""
request = self.get_current_request(workspace_name, user)
if request is not None:
return request

# To create a request, you must have explicit workspace permissions. Output
# checkers can view all workspaces, but are not allowed to create requests for
# all workspaces.
if workspace_name not in user.workspaces:
raise BusinessLogicLayer.RequestPermissionDenied(workspace_name)

new_request = self._dal.create_release_request(
workspace=workspace_name,
author=user.username,
)
return ReleaseRequest.from_dict(new_request)

def get_requests_authored_by_user(self, user: User) -> list[ReleaseRequest]:
"""Get all current requests authored by user."""
Expand Down Expand Up @@ -582,7 +638,7 @@ def add_file_to_request(
release_request: ReleaseRequest,
relpath: UrlPath,
user: User,
group_name: Optional[str] = "default",
group_name: str = "default",
filetype: RequestFileType = RequestFileType.OUTPUT,
):
if user.username != release_request.author:
Expand Down Expand Up @@ -663,7 +719,7 @@ def approve_file(

self._verify_permission_to_review_file(release_request, relpath, user)

bll._dal.approve_file(release_request.id, relpath, user)
bll._dal.approve_file(release_request.id, relpath, user.username)

def reject_file(
self, release_request: ReleaseRequest, relpath: UrlPath, user: User
Expand All @@ -672,7 +728,7 @@ def reject_file(

self._verify_permission_to_review_file(release_request, relpath, user)

bll._dal.reject_file(release_request.id, relpath, user)
bll._dal.reject_file(release_request.id, relpath, user.username)


def _get_configured_bll():
Expand All @@ -681,5 +737,7 @@ def _get_configured_bll():


# We follow the Django pattern of using a lazy object which configures itself on first
# access so as to avoid reading `settings` during import
bll = SimpleLazyObject(_get_configured_bll)
# access so as to avoid reading `settings` during import. The `cast` here is a runtime
# no-op, but indicates to the type-checker that this should be treated as an instance of
# BusinessLogicLayer not SimpleLazyObject.
bll = cast(BusinessLogicLayer, SimpleLazyObject(_get_configured_bll))
30 changes: 17 additions & 13 deletions airlock/file_browser_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum

Expand Down Expand Up @@ -35,9 +37,9 @@ class PathNotFound(Exception):
container: AirlockContainer
relpath: UrlPath

type: PathType = None
children: list["PathItem"] = field(default_factory=list)
parent: "PathItem" = None
type: PathType | None = None
children: list[PathItem] = field(default_factory=list)
parent: PathItem | None = None

# is this the currently selected path?
selected: bool = False
Expand All @@ -49,7 +51,7 @@ class PathNotFound(Exception):

# what to display for this node when rendering the tree. Defaults to name,
# but this allow it to be overridden.
display_text: str = None
display_text: str | None = None

def __post_init__(self):
# ensure is UrlPath
Expand All @@ -76,18 +78,18 @@ def display(self):

def url(self):
suffix = "/" if self.is_directory() else ""
return self.container.get_url(f"{self.relpath}{suffix}")
return self.container.get_url(self.relpath) + suffix

def contents_url(self, download=False):
if self.type != PathType.FILE:
raise Exception(f"contents_url called on non-file path {self.relpath}")
return self.container.get_contents_url(f"{self.relpath}", download=download)
return self.container.get_contents_url(self.relpath, download=download)

def download_url(self):
return self.contents_url(download=True)

def siblings(self):
if not self.relpath.parents:
if self.parent is None:
return []
else:
return self.parent.children
Expand Down Expand Up @@ -132,7 +134,7 @@ def html_classes(self):
distinguish file/dirs, and maybe even file types, in the UI, in case we
need to.
"""
classes = [self.type.value.lower()]
classes = [self.type.value.lower()] if self.type else []

if self.type == PathType.FILE:
classes.append(self.file_type())
Expand Down Expand Up @@ -330,14 +332,16 @@ def get_path_tree(
):
"""Walk a flat list of paths and create a tree from them."""

def build_path_tree(path_parts, parent):
def build_path_tree(
path_parts: list[list[str]], parent: PathItem
) -> list[PathItem]:
# group multiple paths into groups by first part of path
grouped = dict()
for child, *descendants in path_parts:
grouped: dict[str, list[list[str]]] = dict()
for child, *descendant_parts in path_parts:
if child not in grouped:
grouped[child] = []
if descendants:
grouped[child].append(descendants)
if descendant_parts:
grouped[child].append(descendant_parts)

tree = []

Expand Down
10 changes: 5 additions & 5 deletions airlock/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def __init__(self, *args, **kwargs):
self.filegroup_names = release_request.filegroups.keys()
else:
self.filegroup_names = set()
group_choices = {
(name, name) for name in self.filegroup_names if name != "default"
}
group_choices = [("default", "default"), *sorted(group_choices)]
group_names = sorted(self.filegroup_names - {"default"})
group_choices = [(name, name) for name in ["default", *group_names]]
# Use type narrowing to persuade mpy this has a `choices` attr
assert isinstance(self.fields["filegroup"], forms.ChoiceField)
self.fields["filegroup"].choices = group_choices
self.fields["new_filegroup"]

def clean_new_filegroup(self):
new_filegroup = self.cleaned_data.get("new_filegroup").lower()
new_filegroup = self.cleaned_data.get("new_filegroup", "").lower()
if new_filegroup in [fg.lower() for fg in self.filegroup_names]:
self.add_error(
"new_filegroup",
Expand Down
7 changes: 4 additions & 3 deletions airlock/login_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from pathlib import Path

import requests
from django.conf import settings
Expand All @@ -13,7 +14,7 @@ class LoginError(Exception):

def get_user_data(user: str, token: str):
if settings.AIRLOCK_DEV_USERS_FILE and not settings.AIRLOCK_API_TOKEN:
return get_user_data_dev(user, token)
return get_user_data_dev(settings.AIRLOCK_DEV_USERS_FILE, user, token)
else:
return get_user_data_prod(user, token)

Expand All @@ -31,9 +32,9 @@ def get_user_data_prod(user: str, token: str):
return response.json()


def get_user_data_dev(user: str, token: str):
def get_user_data_dev(dev_users_file: Path, user: str, token: str):
try:
dev_users = json.loads(settings.AIRLOCK_DEV_USERS_FILE.read_text())
dev_users = json.loads(dev_users_file.read_text())
except FileNotFoundError as e: # pragma: no cover
e.add_note(
"You may want to run:\n\n just load-example-data\n\nto create one."
Expand Down
3 changes: 2 additions & 1 deletion airlock/middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from urllib.parse import urlencode

from django.shortcuts import redirect, reverse
from django.shortcuts import redirect
from django.urls import reverse

from airlock.users import User

Expand Down
Loading

0 comments on commit 07b5d3b

Please sign in to comment.