Skip to content

Commit

Permalink
[iree.build] Make the fetch_http action more robust. (#19330)
Browse files Browse the repository at this point in the history
* Downloads to a staging file and then atomically renames into place,
avoiding potential for partial downloads.
* Reports completion percent as part of the console updates.
* Persists metadata for the source URL and will refetch if changed.
* Fixes an error handling test for the onnx mnist_builder that missed
the prior update.

More sophistication is possible but this brings it up to min-viable from
a usability perspective.

Signed-off-by: Stella Laurenzo <[email protected]>
  • Loading branch information
stellaraccident authored Nov 28, 2024
1 parent d182e57 commit 9789438
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 6 deletions.
52 changes: 52 additions & 0 deletions compiler/bindings/python/iree/build/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import concurrent.futures
import enum
import json
import math
import multiprocessing
import os
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(self, output_dir: Path, stderr: IO, reporter: ProgressReporter):
self.failed_deps: set["BuildDependency"] = set()
self.stderr = stderr
self.reporter = reporter
self.metadata_lock = threading.RLock()
BuildContext("", self)

def check_path_not_exists(self, path: str, for_entity):
Expand Down Expand Up @@ -160,6 +162,7 @@ def get_file(self, path: str) -> "BuildFile":
return existing

def write_status(self, message: str):
self.reporter.reset_display()
print(message, file=self.stderr)

def get_root(self, namespace: FileNamespace) -> Path:
Expand Down Expand Up @@ -294,6 +297,9 @@ def finish(self):
self.future.set_result(self)


BuildFileMetadata = dict[str, str | int | bool | float]


class BuildFile(BuildDependency):
"""Generated file in the build tree."""

Expand Down Expand Up @@ -322,6 +328,35 @@ def get_fs_path(self) -> Path:
path.parent.mkdir(parents=True, exist_ok=True)
return path

def access_metadata(
self,
mutation_callback: Callable[[BuildFileMetadata], bool] | None = None,
) -> BuildFileMetadata:
"""Accesses persistent metadata about the build file.
This is intended for the storage of small amounts of metadata relevant to the
build system for performing up-to-date checks and the like.
If a `mutation_callback=` is provided, then any modifications it makes will be
persisted prior to returning. Using a callback in this fashion holds a lock
and avoids data races. If the callback returns True, it is persisted.
"""
with self.executor.metadata_lock:
metadata = _load_metadata(self.executor)
path_metadata = metadata.get("paths")
if path_metadata is None:
path_metadata = {}
metadata["paths"] = path_metadata
file_key = f"{self.namespace}/{self.path}"
file_metadata = path_metadata.get(file_key)
if file_metadata is None:
file_metadata = {}
path_metadata[file_key] = file_metadata
if mutation_callback:
if mutation_callback(file_metadata):
_save_metadata(self.executor, metadata)
return file_metadata

def __repr__(self):
return f"BuildFile[{self.namespace}]({self.path})"

Expand Down Expand Up @@ -658,3 +693,20 @@ def invoke():

# Type aliases.
BuildFileLike = BuildFile | str

# Private utilities.
_METADATA_FILENAME = ".metadata.json"


def _load_metadata(executor: Executor) -> dict:
path = executor.output_dir / _METADATA_FILENAME
if not path.exists():
return {}
with open(path, "rb") as f:
return json.load(f)


def _save_metadata(executor: Executor, metadata: dict):
path = executor.output_dir / _METADATA_FILENAME
with open(path, "wt") as f:
json.dump(metadata, f, sort_keys=True, indent=2)
42 changes: 40 additions & 2 deletions compiler/bindings/python/iree/build/net_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import urllib.error
import urllib.request

from iree.build.executor import BuildAction, BuildContext, BuildFile
from iree.build.executor import BuildAction, BuildContext, BuildFile, BuildFileMetadata

__all__ = [
"fetch_http",
Expand All @@ -29,11 +29,49 @@ def __init__(self, url: str, output_file: BuildFile, **kwargs):
super().__init__(**kwargs)
self.url = url
self.output_file = output_file
self.original_desc = self.desc

def _invoke(self):
# Determine whether metadata indicates that fetch is needed.
path = self.output_file.get_fs_path()
needs_fetch = False
existing_metadata = self.output_file.access_metadata()
existing_url = existing_metadata.get("fetch_http.url")
if existing_url != self.url:
needs_fetch = True

# Always fetch if empty or absent.
if not path.exists() or path.stat().st_size == 0:
needs_fetch = True

# Bail if already obtained.
if not needs_fetch:
return

# Download to a staging file.
stage_path = path.with_name(f".{path.name}.download")
self.executor.write_status(f"Fetching URL: {self.url} -> {path}")

def reporthook(received_blocks: int, block_size: int, total_size: int):
received_size = received_blocks * block_size
if total_size == 0:
self.desc = f"{self.original_desc} ({received_size} bytes received)"
else:
complete_percent = round(100 * received_size / total_size)
self.desc = f"{self.original_desc} ({complete_percent}% complete)"

try:
urllib.request.urlretrieve(self.url, str(path))
urllib.request.urlretrieve(self.url, str(stage_path), reporthook=reporthook)
except urllib.error.HTTPError as e:
raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None
finally:
self.desc = self.original_desc

# Commit the download.
def commit(metadata: BuildFileMetadata) -> bool:
metadata["fetch_http.url"] = self.url
path.unlink(missing_ok=True)
stage_path.rename(path)
return True

self.output_file.access_metadata(commit)
7 changes: 7 additions & 0 deletions compiler/bindings/python/test/build_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,10 @@ iree_py_test(
SRCS
"basic_test.py"
)

iree_py_test(
NAME
net_test
SRCS
"net_test.py"
)
6 changes: 2 additions & 4 deletions compiler/bindings/python/test/build_api/mnist_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,7 @@ def testActionCLArg(self):
mod = load_build_module(THIS_DIR / "mnist_builder.py")
out_file = io.StringIO()
err_file = io.StringIO()
with self.assertRaisesRegex(
IOError,
re.escape("Failed to fetch URL 'https://github.com/iree-org/doesnotexist'"),
):
with self.assertRaises(SystemExit):
iree_build_main(
mod,
args=[
Expand All @@ -104,6 +101,7 @@ def testActionCLArg(self):
stdout=out_file,
stderr=err_file,
)
self.assertIn("ERROR:", err_file.getvalue())

def testBuildNonDefaultSubTarget(self):
mod = load_build_module(THIS_DIR / "mnist_builder.py")
Expand Down
100 changes: 100 additions & 0 deletions compiler/bindings/python/test/build_api/net_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import io
import os
from pathlib import Path
import tempfile
import unittest

from iree.build import *
from iree.build.executor import BuildContext
from iree.build.test_actions import ExecuteOutOfProcessThunkAction


TEST_URL = None
TEST_URL_1 = "https://huggingface.co/google-bert/bert-base-cased/resolve/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer.json"
TEST_URL_2 = "https://huggingface.co/google-bert/bert-base-cased/resolve/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer_config.json"


@entrypoint
def tokenizer_via_http():
return fetch_http(
name="tokenizer.json",
url=TEST_URL,
)


class BasicTest(unittest.TestCase):
def setUp(self):
self._temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
self._temp_dir.__enter__()
self.output_path = Path(self._temp_dir.name)

def tearDown(self) -> None:
self._temp_dir.__exit__(None, None, None)

def test_fetch_http(self):
# This just does a sanity check that rich console mode does not crash. Actual
# behavior can really only be completely verified visually.
out = None
err = None
global TEST_URL
path = self.output_path / "genfiles" / "tokenizer_via_http" / "tokenizer.json"

def run():
nonlocal out
nonlocal err
try:
out_io = io.StringIO()
err_io = io.StringIO()
iree_build_main(
args=[
"tokenizer_via_http",
"--output-dir",
str(self.output_path),
"--test-force-console",
],
stderr=err_io,
stdout=out_io,
)
finally:
out = out_io.getvalue()
err = err_io.getvalue()
print(f"::test_fetch_http err: {err!r}")
print(f"::test_fetch_http out: {out!r}")

def assertExists():
self.assertTrue(path.exists(), msg=f"Path {path} exists")

# First run should fetch.
TEST_URL = TEST_URL_1
run()
self.assertIn("Fetching URL: https://", err)
assertExists()

# Second run should not fetch.
TEST_URL = TEST_URL_1
run()
self.assertNotIn("Fetching URL: https://", err)
assertExists()

# Fetching a different URL should download again.
TEST_URL = TEST_URL_2
run()
self.assertIn("Fetching URL: https://", err)
assertExists()

# Removing the file should fetch again.
TEST_URL = TEST_URL_2
path.unlink()
run()
self.assertIn("Fetching URL: https://", err)
assertExists()


if __name__ == "__main__":
unittest.main()

0 comments on commit 9789438

Please sign in to comment.