From f7a0ce7d05bce75bd168550313d611c17c42b00f Mon Sep 17 00:00:00 2001 From: James Hilliard Date: Sat, 7 Sep 2024 16:08:02 -0600 Subject: [PATCH] Use anyio based async file io operations --- goosebit/api/v1/software/routes.py | 13 +++++++------ goosebit/db/models.py | 4 ++-- goosebit/ui/bff/software/routes.py | 20 ++++++++++---------- goosebit/updates/__init__.py | 12 ++++++------ goosebit/updates/swdesc.py | 25 ++++++++++++++++++------- tests/api/v1/software/test_routes.py | 20 +++++++++++--------- tests/updates/test_swdesc.py | 6 +++--- 7 files changed, 57 insertions(+), 43 deletions(-) diff --git a/goosebit/api/v1/software/routes.py b/goosebit/api/v1/software/routes.py index b090ca01..10dd3462 100644 --- a/goosebit/api/v1/software/routes.py +++ b/goosebit/api/v1/software/routes.py @@ -1,6 +1,5 @@ -from pathlib import Path - import aiofiles +from anyio import Path from fastapi import APIRouter, File, Form, HTTPException, Security, UploadFile from fastapi.requests import Request @@ -42,8 +41,8 @@ async def software_delete(_: Request, delete_req: SoftwareDeleteRequest) -> Stat if software.local: path = software.path - if path.exists(): - path.unlink() + if await path.exists(): + await path.unlink() await software.delete() success = True @@ -68,11 +67,13 @@ async def post_update(_: Request, file: UploadFile | None = File(None), url: str software = await create_software_update(url, None) elif file is not None: # local file - file_path = config.artifacts_dir.joinpath(file.filename) + artifacts_dir = Path(config.artifacts_dir) + file_path = artifacts_dir.joinpath(file.filename) async with aiofiles.tempfile.NamedTemporaryFile("w+b") as f: await f.write(await file.read()) - software = await create_software_update(file_path.absolute().as_uri(), Path(str(f.name))) + absolute = await file_path.absolute() + software = await create_software_update(absolute.as_uri(), Path(f.name)) else: raise HTTPException(422) diff --git a/goosebit/db/models.py b/goosebit/db/models.py index 8ac25c4d..979f800a 100644 --- a/goosebit/db/models.py +++ b/goosebit/db/models.py @@ -1,10 +1,10 @@ from enum import IntEnum -from pathlib import Path from typing import Self from urllib.parse import unquote, urlparse from urllib.request import url2pathname import semver +from anyio import Path from tortoise import Model, fields from goosebit.api.telemetry.metrics import devices_count @@ -132,7 +132,7 @@ async def latest(cls, device: Device) -> Self | None: )[0] @property - def path(self): + def path(self) -> Path: return Path(url2pathname(unquote(urlparse(self.uri).path))) @property diff --git a/goosebit/ui/bff/software/routes.py b/goosebit/ui/bff/software/routes.py index c120cef2..95dc19bb 100644 --- a/goosebit/ui/bff/software/routes.py +++ b/goosebit/ui/bff/software/routes.py @@ -1,6 +1,6 @@ from __future__ import annotations -import aiofiles +from anyio import Path, open_file from fastapi import APIRouter, Form, HTTPException, Security, UploadFile from fastapi.requests import Request from tortoise.expressions import Q @@ -64,20 +64,20 @@ async def post_update( await create_software_update(url, None) else: # local file - file = config.artifacts_dir.joinpath(filename) - config.artifacts_dir.mkdir(parents=True, exist_ok=True) + artifacts_dir = Path(config.artifacts_dir) + file = artifacts_dir.joinpath(filename) + await artifacts_dir.mkdir(parents=True, exist_ok=True) temp_file = file.with_suffix(".tmp") if init: - temp_file.unlink(missing_ok=True) + await temp_file.unlink(missing_ok=True) - contents = await chunk.read() - - async with aiofiles.open(temp_file, mode="ab") as f: - await f.write(contents) + async with await open_file(temp_file, "ab") as f: + await f.write(await chunk.read()) if done: try: - await create_software_update(file.absolute().as_uri(), temp_file) + absolute = await file.absolute() + await create_software_update(absolute.as_uri(), temp_file) finally: - temp_file.unlink(missing_ok=True) + await temp_file.unlink(missing_ok=True) diff --git a/goosebit/updates/__init__.py b/goosebit/updates/__init__.py index 3512c21b..35c35dbf 100644 --- a/goosebit/updates/__init__.py +++ b/goosebit/updates/__init__.py @@ -1,10 +1,9 @@ from __future__ import annotations -import shutil -from pathlib import Path from urllib.parse import unquote, urlparse from urllib.request import url2pathname +from anyio import Path from fastapi import HTTPException from fastapi.requests import Request from tortoise.expressions import Q @@ -46,10 +45,11 @@ async def create_software_update(uri: str, temp_file: Path | None) -> Software: # for local file: rename temp file to final name if parsed_uri.scheme == "file" and temp_file is not None: filename = Path(url2pathname(unquote(parsed_uri.path))).name - path = config.artifacts_dir.joinpath(update_info["hash"], filename) - path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy(temp_file, path) - uri = path.absolute().as_uri() + path = Path(config.artifacts_dir).joinpath(update_info["hash"], filename) + await path.parent.mkdir(parents=True, exist_ok=True) + await temp_file.rename(path) + absolute = await path.absolute() + uri = absolute.as_uri() # create software software = await Software.create( diff --git a/goosebit/updates/swdesc.py b/goosebit/updates/swdesc.py index 4ee01193..c6e5685c 100644 --- a/goosebit/updates/swdesc.py +++ b/goosebit/updates/swdesc.py @@ -1,12 +1,12 @@ import hashlib import logging -from pathlib import Path from typing import Any import aiofiles import httpx import libconf import semver +from anyio import AsyncFile, Path, open_file logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ def parse_descriptor(swdesc: libconf.AttrDict[Any, Any | None]): async def parse_file(file: Path): - async with aiofiles.open(file, "r+b") as f: + async with await open_file(file, "r+b") as f: # get file size size = int((await f.read(110))[54:62], 16) filename = b"" @@ -59,8 +59,9 @@ async def parse_file(file: Path): swdesc = libconf.loads((await f.read(size)).decode("utf-8")) swdesc_attrs = parse_descriptor(swdesc) - swdesc_attrs["size"] = file.stat().st_size - swdesc_attrs["hash"] = _sha1_hash_file(file) + stat = await file.stat() + swdesc_attrs["size"] = stat.st_size + swdesc_attrs["hash"] = await _sha1_hash_file(f) return swdesc_attrs @@ -72,7 +73,17 @@ async def parse_remote(url: str): return await parse_file(Path(str(f.name))) -def _sha1_hash_file(file_path: Path): - with file_path.open("rb") as f: - sha1_hash = hashlib.file_digest(f, "sha1") +async def _sha1_hash_file(fileobj: AsyncFile): + last = await fileobj.tell() + await fileobj.seek(0) + sha1_hash = hashlib.sha1() + buf = bytearray(2**18) + view = memoryview(buf) + while True: + size = await fileobj.readinto(buf) + if size == 0: + break + sha1_hash.update(view[:size]) + + await fileobj.seek(last) return sha1_hash.hexdigest() diff --git a/tests/api/v1/software/test_routes.py b/tests/api/v1/software/test_routes.py index c6810b34..5eb686e8 100644 --- a/tests/api/v1/software/test_routes.py +++ b/tests/api/v1/software/test_routes.py @@ -1,14 +1,15 @@ -from pathlib import Path - import pytest +from anyio import Path, open_file @pytest.mark.asyncio async def test_create_software_local(async_client, test_data): - path = Path(__file__).resolve().parent / "software-header.swu" - with open(path, "rb") as file: - files = {"file": file} - response = await async_client.post(f"/api/v1/software", files=files) + resolved = await Path(__file__).resolve() + path = resolved.parent / "software-header.swu" + async with await open_file(path, "rb") as file: + files = {"file": await file.read()} + + response = await async_client.post(f"/api/v1/software", files=files) assert response.status_code == 200 software = response.json() @@ -17,9 +18,10 @@ async def test_create_software_local(async_client, test_data): @pytest.mark.asyncio async def test_create_software_remote(async_client, httpserver, test_data): - path = Path(__file__).resolve().parent / "software-header.swu" - with open(path, "rb") as file: - byte_array = file.read() + resolved = await Path(__file__).resolve() + path = resolved.parent / "software-header.swu" + async with await open_file(path, "rb") as file: + byte_array = await file.read() httpserver.expect_request("/software-header.swu").respond_with_data(byte_array) diff --git a/tests/updates/test_swdesc.py b/tests/updates/test_swdesc.py index 761561ed..3649cdb5 100644 --- a/tests/updates/test_swdesc.py +++ b/tests/updates/test_swdesc.py @@ -1,6 +1,5 @@ -from pathlib import Path - import pytest +from anyio import Path from libconf import AttrDict from goosebit.updates.swdesc import parse_descriptor, parse_file @@ -102,7 +101,8 @@ def test_parse_descriptor_several_boardname(): @pytest.mark.asyncio async def test_parse_software_header(): - swdesc_attrs = await parse_file(Path(__file__).resolve().parent / "software-header.swu") + resolved = await Path(__file__).resolve() + swdesc_attrs = await parse_file(resolved.parent / "software-header.swu") assert str(swdesc_attrs["version"]) == "8.8.1-11-g8c926e5+188370" assert swdesc_attrs["compatibility"] == [ {"hw_model": "smart-gateway-mt7688", "hw_revision": "0.5"},