Skip to content

Commit

Permalink
Enable mypy type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshilliard committed Sep 9, 2024
1 parent ffee8e7 commit bb7f43b
Show file tree
Hide file tree
Showing 14 changed files with 63 additions and 36 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ repos:
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies:
[types-pyyaml==6.0.12.20240808, types-aiofiles==24.1.0.20240626]
- repo: local
hooks:
- id: pytest
Expand Down
9 changes: 7 additions & 2 deletions goosebit/api/v1/devices/device/routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from fastapi import APIRouter, Depends, Security
from fastapi import APIRouter, Depends, HTTPException, Security
from fastapi.requests import Request

from goosebit.api.v1.devices.device.responses import DeviceLogResponse, DeviceResponse
Expand All @@ -15,7 +15,10 @@
dependencies=[Security(validate_user_permissions, scopes=["home.read"])],
)
async def device_get(_: Request, updater: UpdateManager = Depends(get_update_manager)) -> DeviceResponse:
return await DeviceResponse.convert(await updater.get_device())
device = await updater.get_device()
if device is None:
raise HTTPException(404)
return await DeviceResponse.convert(device)


@router.get(
Expand All @@ -24,4 +27,6 @@ async def device_get(_: Request, updater: UpdateManager = Depends(get_update_man
)
async def device_logs(_: Request, updater: UpdateManager = Depends(get_update_manager)) -> DeviceLogResponse:
device = await updater.get_device()
if device is None:
raise HTTPException(404)
return DeviceLogResponse(log=device.last_log)
6 changes: 4 additions & 2 deletions goosebit/api/v1/software/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ async def post_update(_: Request, file: UploadFile | None = File(None), url: str
raise HTTPException(409, "Software with same URL already exists and is referenced by rollout")

software = await create_software_update(url, None)
else:
elif file is not None:
# local file
file_path = config.artifacts_dir.joinpath(file.filename)

async with aiofiles.tempfile.NamedTemporaryFile("w+b") as f:
await f.write(await file.read())
software = await create_software_update(file_path.absolute().as_uri(), Path(f.name))
software = await create_software_update(file_path.absolute().as_uri(), Path(str(f.name)))
else:
raise HTTPException(422)

return {"id": software.id}
14 changes: 7 additions & 7 deletions goosebit/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def create_token(username: str) -> str:
return jwt.encode(header={"alg": "HS256"}, claims={"username": username}, key=config.secret_key)


def get_user_from_token(token: str) -> User | None:
def get_user_from_token(token: str | None) -> User | None:
if token is None:
return
return None
try:
token_data = jwt.decode(token, config.secret_key)
username = token_data.claims["username"]
return USERS.get(username)
except (BadSignatureError, LookupError, ValueError):
pass
return None


def login_user(username: str, password: str) -> str:
Expand All @@ -58,17 +58,17 @@ def login_user(username: str, password: str) -> str:


def get_current_user(
session_token: Annotated[str, Depends(session_auth)] = None,
oauth2_token: Annotated[str, Depends(oauth2_auth)] = None,
) -> User:
session_token: Annotated[str | None, Depends(session_auth)] = None,
oauth2_token: Annotated[str | None, Depends(oauth2_auth)] = None,
) -> User | None:
session_user = get_user_from_token(session_token)
oauth2_user = get_user_from_token(oauth2_token)
user = session_user or oauth2_user
return user


# using | Request because oauth2_auth.__call__ expects is
async def get_user_from_request(connection: HTTPConnection | Request) -> User:
async def get_user_from_request(connection: HTTPConnection | Request) -> User | None:
token = await session_auth(connection) or await oauth2_auth(connection)
return get_user_from_token(token)

Expand Down
4 changes: 2 additions & 2 deletions goosebit/schema/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class DeviceSchema(BaseModel):
hw_revision: str
feed: str
progress: int | None
last_state: Annotated[UpdateStateSchema, BeforeValidator(UpdateStateSchema.convert)]
update_mode: Annotated[UpdateModeSchema, BeforeValidator(UpdateModeSchema.convert)]
last_state: Annotated[UpdateStateSchema, BeforeValidator(UpdateStateSchema.convert)] # type: ignore[valid-type]
update_mode: Annotated[UpdateModeSchema, BeforeValidator(UpdateModeSchema.convert)] # type: ignore[valid-type]
force_update: bool
last_ip: str | None
last_seen: int | None
Expand Down
2 changes: 1 addition & 1 deletion goosebit/ui/nav.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ class Navigation:
def __init__(self):
self.items = []

def route(self, text: str, permissions: str = None):
def route(self, text: str, permissions: str | None = None):
def decorator(func):
self.items.append({"function": func.__name__, "text": text, "permissions": permissions})
return func
Expand Down
8 changes: 4 additions & 4 deletions goosebit/ui/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def ui_root(request: Request):
"/home",
dependencies=[Security(validate_user_permissions, scopes=["home.read"])],
)
@nav.route("Home", permissions=["home.read"])
@nav.route("Home", permissions="home.read")
async def home_ui(request: Request):
return templates.TemplateResponse(request, "index.html.jinja", context={"title": "Home"})

Expand All @@ -33,7 +33,7 @@ async def home_ui(request: Request):
"/devices",
dependencies=[Security(validate_user_permissions, scopes=["device.read"])],
)
@nav.route("Devices", permissions=["device.read"])
@nav.route("Devices", permissions="device.read")
async def devices_ui(request: Request):
return templates.TemplateResponse(request, "devices.html.jinja", context={"title": "Devices"})

Expand All @@ -42,7 +42,7 @@ async def devices_ui(request: Request):
"/software",
dependencies=[Security(validate_user_permissions, scopes=["software.read"])],
)
@nav.route("Software", permissions=["software.read"])
@nav.route("Software", permissions="software.read")
async def software_ui(request: Request):
return templates.TemplateResponse(request, "software.html.jinja", context={"title": "Software"})

Expand All @@ -51,7 +51,7 @@ async def software_ui(request: Request):
"/rollouts",
dependencies=[Security(validate_user_permissions, scopes=["rollout.read"])],
)
@nav.route("Rollouts", permissions=["rollout.read"])
@nav.route("Rollouts", permissions="rollout.read")
async def rollouts_ui(request: Request):
return templates.TemplateResponse(request, "rollouts.html.jinja", context={"title": "Rollouts"})

Expand Down
10 changes: 7 additions & 3 deletions goosebit/updater/controller/v1/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@

@router.get("/{dev_id}")
async def polling(request: Request, dev_id: str, updater: UpdateManager = Depends(get_update_manager)):
links = {}
links: dict[str, dict[str, str]] = {}

sleep = updater.poll_time
device = await updater.get_device()

if device is None:
raise HTTPException(404)

if device.last_state == UpdateStateEnum.UNKNOWN:
# device registration
sleep = config.poll_time_registration
Expand All @@ -49,7 +52,7 @@ async def polling(request: Request, dev_id: str, updater: UpdateManager = Depend
# provide update if available. Note: this is also required while in state "running", otherwise swupdate
# won't confirm a successful testing (might be a bug/problem in swupdate)
handling_type, software = await updater.get_update()
if handling_type != HandlingType.SKIP:
if handling_type != HandlingType.SKIP and software is not None:
links["deploymentBase"] = {
"href": str(
request.url_for(
Expand Down Expand Up @@ -152,7 +155,8 @@ async def deployment_feedback(

try:
log = data.status.details
await updater.update_log("\n".join(log))
if log is not None:
await updater.update_log("\n".join(log))
except AttributeError:
logging.warning(f"No details to update device update log, device={updater.dev_id}")

Expand Down
8 changes: 4 additions & 4 deletions goosebit/updater/controller/v1/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ class FeedbackStatusResultFinished(StrEnum):

class FeedbackStatusResultSchema(BaseModel):
finished: FeedbackStatusResultFinished
progress: FeedbackStatusProgressSchema = None
progress: FeedbackStatusProgressSchema | None = None


class FeedbackStatusSchema(BaseModel):
execution: FeedbackStatusExecutionState
result: FeedbackStatusResultSchema
code: int = None
details: list[str] = None
code: int | None = None
details: list[str] | None = None


class FeedbackSchema(BaseModel):
time: str = None
time: str | None = None
status: FeedbackStatusSchema
20 changes: 13 additions & 7 deletions goosebit/updater/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, dev_id: str):
self.dev_id = dev_id

async def get_device(self) -> Device | None:
return
return None

async def update_force_update(self, force_update: bool) -> None:
return
Expand Down Expand Up @@ -86,7 +86,8 @@ async def subscribe_log(self, callback: Callable):
subscribers = self.log_subscribers
subscribers.append(callback)
self.log_subscribers = subscribers
await callback(device.last_log)
if device is not None:
await callback(device.last_log)
try:
yield
except asyncio.CancelledError:
Expand Down Expand Up @@ -126,7 +127,7 @@ async def publish_log(self, log_data: str | None):
await cb(log_data)

@abstractmethod
async def get_update(self) -> tuple[HandlingType, Software]: ...
async def get_update(self) -> tuple[HandlingType, Software | None]: ...

@abstractmethod
async def update_log(self, log_data: str) -> None: ...
Expand All @@ -137,11 +138,16 @@ def __init__(self, dev_id: str):
super().__init__(dev_id)
self.poll_time = config.poll_time_updating

async def _get_software(self) -> Software:
return await Software.latest(await self.get_device())
async def _get_software(self) -> Software | None:
device = await self.get_device()
if device is None:
return None
return await Software.latest(device)

async def get_update(self) -> tuple[HandlingType, Software]:
async def get_update(self) -> tuple[HandlingType, Software | None]:
software = await self._get_software()
if software is None:
return HandlingType.SKIP, None
return HandlingType.FORCED, software

async def update_log(self, log_data: str) -> None:
Expand Down Expand Up @@ -276,7 +282,7 @@ async def _get_software(self) -> Software | None:
assert device.update_mode == UpdateModeEnum.PINNED
return None

async def get_update(self) -> tuple[HandlingType, Software]:
async def get_update(self) -> tuple[HandlingType, Software | None]:
device = await self.get_device()
software = await self._get_software()

Expand Down
4 changes: 2 additions & 2 deletions goosebit/updates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def create_software_update(uri: str, temp_file: Path | None) -> Software:
parsed_uri = urlparse(uri)

# parse swu header into update_info
if parsed_uri.scheme == "file":
if parsed_uri.scheme == "file" and temp_file is not None:
try:
update_info = await swdesc.parse_file(temp_file)
except Exception:
Expand All @@ -44,7 +44,7 @@ async def create_software_update(uri: str, temp_file: Path | None) -> Software:
raise HTTPException(409, "Software with same version and overlapping compatibility already exists")

# for local file: rename temp file to final name
if parsed_uri.scheme == "file":
if parsed_uri.scheme == "file" and temp_file is not None:
filename = Path(url2pathname(unquote(parsed_uri.path))).name
path = config.artifacts_dir.joinpath(update_info["hash"], filename)
path.parent.mkdir(parents=True, exist_ok=True)
Expand Down
4 changes: 2 additions & 2 deletions goosebit/updates/swdesc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def parse_descriptor(swdesc: libconf.AttrDict[Any, Any | None]):
swdesc_attrs = {}
try:
swdesc_attrs["version"] = semver.Version.parse(swdesc["software"]["version"])
compatibility = []
compatibility: list[dict[str, str]] = []
_append_compatibility("default", swdesc["software"], compatibility)

for key in swdesc["software"]:
Expand Down Expand Up @@ -69,7 +69,7 @@ async def parse_remote(url: str):
file = await c.get(url)
async with aiofiles.tempfile.NamedTemporaryFile("w+b") as f:
await f.write(file.content)
return await parse_file(Path(f.name))
return await parse_file(Path(str(f.name)))


def _sha1_hash_file(file_path: Path):
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ isort = "^5.13.2"
black = "^24.2.0"
pre-commit = "^3.6.2"
flake8 = "7.1.0"
mypy = "^1.11.2"
types-pyyaml = "^6.0.12.20240808"
types-aiofiles = "^24.1.0.20240626"

[tool.poetry.group.docs.dependencies]
mkdocs = "^1.6.0"
Expand Down
1 change: 1 addition & 0 deletions tests/updater/controller/v1/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ async def _poll(async_client, device_uuid, software: Software | None, expect_upd
if expect_update:
assert "deploymentBase" in data["_links"], "expected update, but none available"
deployment_base = data["_links"]["deploymentBase"]["href"]
assert software is not None
assert deployment_base == f"http://test/ddi/controller/v1/{device_uuid}/deploymentBase/{software.id}"
return deployment_base
else:
Expand Down

0 comments on commit bb7f43b

Please sign in to comment.