From d0512c3a6563a3c26a776543c3cfb8c983a151a9 Mon Sep 17 00:00:00 2001 From: R1kaB3rN <100738684+R1kaB3rN@users.noreply.github.com> Date: Fri, 24 Jan 2025 22:40:06 -0800 Subject: [PATCH] refactor: improve task submission and handling on delta update (#355) * umu_bspatch: raise exception on verify failure * umu_bspatch: refactor task submission * umu_proton: refactor task handling for delta updates --- umu/umu_bspatch.py | 41 ++++++++++++++++++++--------------------- umu/umu_proton.py | 32 +++++++++++++++++--------------- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/umu/umu_bspatch.py b/umu/umu_bspatch.py index 6f062217f..68d7c8f41 100644 --- a/umu/umu_bspatch.py +++ b/umu/umu_bspatch.py @@ -115,7 +115,12 @@ def __init__( # noqa: D107 self._arc_manifest: list[ManifestEntry] = self._arc_contents["manifest"] self._compat_tool = compat_tool self._thread_pool = thread_pool - self._futures: list[Future] = [] + # Collection where each task creates a new file within an existing compatibility tool + self._add: list[Future] = [] + # Collection where each task updates an existing file + self._update: list[Future] = [] + # Collection where each task verifies an existing file + self._verify: list[Future] = [] def add_binaries(self) -> None: """Add binaries within a compatibility tool. @@ -131,7 +136,7 @@ def add_binaries(self) -> None: build_file: Path = self._compat_tool.joinpath(item["name"]) if item["type"] == FileType.File.value: # Decompress the zstd data and write the file - self._futures.append( + self._add.append( self._thread_pool.submit(self._write_proton_file, build_file, item) ) continue @@ -160,7 +165,7 @@ def update_binaries(self) -> None: build_file: Path = self._compat_tool.joinpath(item["name"]) if item["type"] == FileType.File.value: # For files, apply a binary patch - self._futures.append( + self._update.append( self._thread_pool.submit(self._patch_proton_file, build_file, item) ) continue @@ -209,17 +214,15 @@ def delete_binaries(self) -> None: def verify_integrity(self) -> None: """Verify the expected mode, size, file and digest of the compatibility tool.""" for item in self._arc_manifest: - self._futures.append( + self._verify.append( self._thread_pool.submit(self._check_binaries, self._compat_tool, item) ) - def result(self) -> list[Future]: - """Return the currently submitted tasks.""" - return self._futures + def result(self) -> tuple[list[Future], list[Future], list[Future]]: + """Return all the currently submitted tasks.""" + return (self._verify, self._add, self._update) - def _check_binaries( - self, proton: Path, item: ManifestEntry - ) -> ManifestEntry | None: + def _check_binaries(self, proton: Path, item: ManifestEntry) -> ManifestEntry: rpath: Path = proton.joinpath(item["name"]) try: @@ -227,15 +230,11 @@ def _check_binaries( stats: os.stat_result = os.fstat(fp.fileno()) xxhash: int = 0 if item["size"] != stats.st_size: - log.error( - "Expected size %s, received %s", item["size"], stats.st_size - ) - return None + err: str = f"Expected size {item['size']}, received {stats.st_size}" + raise ValueError(err) if item["mode"] != stats.st_mode: - log.error( - "Expected mode %s, received %s", item["mode"], stats.st_mode - ) - return None + err: str = f"Expected mode {item['mode']}, received {stats.st_mode}" + raise ValueError(err) if stats.st_size > MMAP_MIN: with mmap(fp.fileno(), length=0, access=ACCESS_READ) as mm: # Ignore. Passing an mmap is valid here @@ -245,11 +244,11 @@ def _check_binaries( else: xxhash = xxh3_64_intdigest(fp.read()) if item["xxhash"] != xxhash: - log.error("Expected xxhash %s, received %s", item["xxhash"], xxhash) - return None + err: str = f"Expected xxhash {item['xxhash']}, received {xxhash}" + raise ValueError(err) except FileNotFoundError: log.debug("Aborting partial update, file not found: %s", rpath) - return None + raise return item diff --git a/umu/umu_proton.py b/umu/umu_proton.py index 84a46aebe..fa16da15d 100644 --- a/umu/umu_proton.py +++ b/umu/umu_proton.py @@ -1,10 +1,12 @@ import os import time -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ALL_COMPLETED, FIRST_EXCEPTION, ThreadPoolExecutor +from concurrent.futures import wait as futures_wait from enum import Enum from hashlib import sha512 from http import HTTPStatus from importlib.util import find_spec +from itertools import chain from pathlib import Path from re import split as resplit from shutil import move @@ -575,24 +577,21 @@ def _get_delta( # Apply the patch for content in cbor["contents"]: src: str = content["source"] - if src.startswith((ProtonVersion.GE.value, ProtonVersion.UMU.value)): patchers.append(_apply_delta(proton, content, thread_pool)) continue - subdir: Path | None = next(umu_compat.joinpath(version).rglob(src), None) if not subdir: log.error("Could not find subdirectory '%s', skipping", subdir) continue - patchers.append(_apply_delta(subdir, content, thread_pool)) renames.append((subdir, subdir.parent / content["target"])) # Wait for results and rename versioned subdirectories start: float = time.time_ns() for patcher in filter(None, patchers): - for future in filter(None, patcher.result()): - future.result() + _, *futures = patcher.result() + futures_wait(list(chain.from_iterable(futures)), return_when=ALL_COMPLETED) for rename in renames: orig, new = rename @@ -614,25 +613,28 @@ def _apply_delta( thread_pool: ThreadPoolExecutor, ) -> CustomPatcher | None: patcher: CustomPatcher = CustomPatcher(content, path, thread_pool) - is_updated: bool = False # Verify the identity of the build. At this point the patch file is authenticated. # Note, this will skip the update if the user had tinkered with their build. We do # this so we can ensure the result of each binary patch isn't garbage patcher.verify_integrity() - for item in patcher.result(): - if item.result() is None: - is_updated = True - break + # Handle tasks that failed metadata validation. On success, skip waiting for results + futures, *_ = patcher.result() + done, not_done = futures_wait(futures, return_when=FIRST_EXCEPTION) + for future in done: + try: + future.result() + except (FileNotFoundError, ValueError) as e: + log.exception(e) + for future in not_done: + future.cancel() + return None - if is_updated: - log.debug("%s (latest) validation failed, skipping", os.environ["PROTONPATH"]) - return None + futures_wait(not_done, return_when=ALL_COMPLETED) # Patch the current build, upgrading proton to the latest log.info("%s is OK, applying partial update...", os.environ["PROTONPATH"]) - patcher.update_binaries() patcher.add_binaries() patcher.delete_binaries()