Skip to content

Commit

Permalink
refactor: improve task submission and handling on delta update (#355)
Browse files Browse the repository at this point in the history
* umu_bspatch: raise exception on verify failure

* umu_bspatch: refactor task submission

* umu_proton: refactor task handling for delta updates
  • Loading branch information
R1kaB3rN authored Jan 25, 2025
1 parent 2439791 commit d0512c3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
41 changes: 20 additions & 21 deletions umu/umu_bspatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ def __init__( # noqa: D107
self._arc_manifest: list[ManifestEntry] = self._arc_contents["manifest"]
self._compat_tool = compat_tool
self._thread_pool = thread_pool
self._futures: list[Future] = []
# Collection where each task creates a new file within an existing compatibility tool
self._add: list[Future] = []
# Collection where each task updates an existing file
self._update: list[Future] = []
# Collection where each task verifies an existing file
self._verify: list[Future] = []

def add_binaries(self) -> None:
"""Add binaries within a compatibility tool.
Expand All @@ -131,7 +136,7 @@ def add_binaries(self) -> None:
build_file: Path = self._compat_tool.joinpath(item["name"])
if item["type"] == FileType.File.value:
# Decompress the zstd data and write the file
self._futures.append(
self._add.append(
self._thread_pool.submit(self._write_proton_file, build_file, item)
)
continue
Expand Down Expand Up @@ -160,7 +165,7 @@ def update_binaries(self) -> None:
build_file: Path = self._compat_tool.joinpath(item["name"])
if item["type"] == FileType.File.value:
# For files, apply a binary patch
self._futures.append(
self._update.append(
self._thread_pool.submit(self._patch_proton_file, build_file, item)
)
continue
Expand Down Expand Up @@ -209,33 +214,27 @@ def delete_binaries(self) -> None:
def verify_integrity(self) -> None:
"""Verify the expected mode, size, file and digest of the compatibility tool."""
for item in self._arc_manifest:
self._futures.append(
self._verify.append(
self._thread_pool.submit(self._check_binaries, self._compat_tool, item)
)

def result(self) -> list[Future]:
"""Return the currently submitted tasks."""
return self._futures
def result(self) -> tuple[list[Future], list[Future], list[Future]]:
"""Return all the currently submitted tasks."""
return (self._verify, self._add, self._update)

def _check_binaries(
self, proton: Path, item: ManifestEntry
) -> ManifestEntry | None:
def _check_binaries(self, proton: Path, item: ManifestEntry) -> ManifestEntry:
rpath: Path = proton.joinpath(item["name"])

try:
with rpath.open("rb") as fp:
stats: os.stat_result = os.fstat(fp.fileno())
xxhash: int = 0
if item["size"] != stats.st_size:
log.error(
"Expected size %s, received %s", item["size"], stats.st_size
)
return None
err: str = f"Expected size {item['size']}, received {stats.st_size}"
raise ValueError(err)
if item["mode"] != stats.st_mode:
log.error(
"Expected mode %s, received %s", item["mode"], stats.st_mode
)
return None
err: str = f"Expected mode {item['mode']}, received {stats.st_mode}"
raise ValueError(err)
if stats.st_size > MMAP_MIN:
with mmap(fp.fileno(), length=0, access=ACCESS_READ) as mm:
# Ignore. Passing an mmap is valid here
Expand All @@ -245,11 +244,11 @@ def _check_binaries(
else:
xxhash = xxh3_64_intdigest(fp.read())
if item["xxhash"] != xxhash:
log.error("Expected xxhash %s, received %s", item["xxhash"], xxhash)
return None
err: str = f"Expected xxhash {item['xxhash']}, received {xxhash}"
raise ValueError(err)
except FileNotFoundError:
log.debug("Aborting partial update, file not found: %s", rpath)
return None
raise

return item

Expand Down
32 changes: 17 additions & 15 deletions umu/umu_proton.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ALL_COMPLETED, FIRST_EXCEPTION, ThreadPoolExecutor
from concurrent.futures import wait as futures_wait
from enum import Enum
from hashlib import sha512
from http import HTTPStatus
from importlib.util import find_spec
from itertools import chain
from pathlib import Path
from re import split as resplit
from shutil import move
Expand Down Expand Up @@ -575,24 +577,21 @@ def _get_delta(
# Apply the patch
for content in cbor["contents"]:
src: str = content["source"]

if src.startswith((ProtonVersion.GE.value, ProtonVersion.UMU.value)):
patchers.append(_apply_delta(proton, content, thread_pool))
continue

subdir: Path | None = next(umu_compat.joinpath(version).rglob(src), None)
if not subdir:
log.error("Could not find subdirectory '%s', skipping", subdir)
continue

patchers.append(_apply_delta(subdir, content, thread_pool))
renames.append((subdir, subdir.parent / content["target"]))

# Wait for results and rename versioned subdirectories
start: float = time.time_ns()
for patcher in filter(None, patchers):
for future in filter(None, patcher.result()):
future.result()
_, *futures = patcher.result()
futures_wait(list(chain.from_iterable(futures)), return_when=ALL_COMPLETED)

for rename in renames:
orig, new = rename
Expand All @@ -614,25 +613,28 @@ def _apply_delta(
thread_pool: ThreadPoolExecutor,
) -> CustomPatcher | None:
patcher: CustomPatcher = CustomPatcher(content, path, thread_pool)
is_updated: bool = False

# Verify the identity of the build. At this point the patch file is authenticated.
# Note, this will skip the update if the user had tinkered with their build. We do
# this so we can ensure the result of each binary patch isn't garbage
patcher.verify_integrity()

for item in patcher.result():
if item.result() is None:
is_updated = True
break
# Handle tasks that failed metadata validation. On success, skip waiting for results
futures, *_ = patcher.result()
done, not_done = futures_wait(futures, return_when=FIRST_EXCEPTION)
for future in done:
try:
future.result()
except (FileNotFoundError, ValueError) as e:
log.exception(e)
for future in not_done:
future.cancel()
return None

if is_updated:
log.debug("%s (latest) validation failed, skipping", os.environ["PROTONPATH"])
return None
futures_wait(not_done, return_when=ALL_COMPLETED)

# Patch the current build, upgrading proton to the latest
log.info("%s is OK, applying partial update...", os.environ["PROTONPATH"])

patcher.update_binaries()
patcher.add_binaries()
patcher.delete_binaries()
Expand Down

0 comments on commit d0512c3

Please sign in to comment.