From 0e7e5afbb36356e1e0fc5ab0edc479aaa5f51c2f Mon Sep 17 00:00:00 2001 From: Jack Cushman Date: Thu, 5 Dec 2024 14:52:08 -0500 Subject: [PATCH] Add collection_tasks and id to signed-metadata.json * Also add --collect-errors flag to control error handling * Also add --timeout flag to set timeouts for URLs --- README.md | 3 ++ src/nabit/bin/cli.py | 26 +++++++++--- src/nabit/lib/archive.py | 37 +++++++++++++--- src/nabit/lib/backends/base.py | 26 ++++++++++-- src/nabit/lib/backends/path.py | 14 +++++- src/nabit/lib/backends/url.py | 18 ++++++-- tests/backends/test_path_backend.py | 21 ++++++++- tests/backends/test_url_backend.py | 66 +++++++++++++++++++++++++---- tests/conftest.py | 9 ++-- tests/test_archive.py | 2 +- tests/test_cli.py | 62 ++++++++++++++++++++++----- tests/test_validation.py | 4 +- tests/utils.py | 12 ++++++ 13 files changed, 251 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 2ca0188..083e396 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,9 @@ Options: -t, --timestamp | : Timestamp using either a TSA keyword or a cert chain path and URL (can be repeated) + --timeout FLOAT Timeout for collection tasks (default: 5.0) + --collect-errors [fail|ignore] How to handle collection task errors + (default: fail) --help Show this message and exit. ``` diff --git a/src/nabit/bin/cli.py b/src/nabit/bin/cli.py index aafdc52..78201b8 100644 --- a/src/nabit/bin/cli.py +++ b/src/nabit/bin/cli.py @@ -6,7 +6,7 @@ from .utils import assert_file_exists, assert_url, cli_validate, CaptureCommand from ..lib.archive import package from ..lib.sign import KNOWN_TSAS -from ..lib.backends.base import CollectionTask +from ..lib.backends.base import CollectionTask, CollectionError from ..lib.backends.path import PathCollectionTask @click.group() @@ -39,6 +39,10 @@ def main(): help='Timestamp using either a TSA keyword or a cert chain path and URL (can be repeated)', metavar=' | :', ) +@click.option('--timeout', type=float, default=5.0, + help='Timeout for collection tasks (default: 5.0)') +@click.option('--collect-errors', type=click.Choice(['fail', 'ignore']), default='fail', + help='How to handle collection task errors (default: fail)') @click.pass_context def archive( ctx, @@ -53,7 +57,9 @@ def archive( unsigned_metadata_path, signed_metadata_json, unsigned_metadata_json, - signature_args + signature_args, + collect_errors, + timeout, ): """ Archive files and URLs into a BagIt package. @@ -128,7 +134,10 @@ def archive( processed_collect = [] for task in collect: try: - processed_collect.append(CollectionTask.from_dict(task)) + task = CollectionTask.from_dict(task) + if hasattr(task, 'timeout'): + task.timeout = timeout + processed_collect.append(task) except Exception as e: raise click.BadParameter(f'Invalid task definition for --collect: {task} resulted in {e}') @@ -173,16 +182,19 @@ def archive( click.echo(f"Creating package at {bag_path} ...") - package( - output_path=bag_path, + try: + package( + output_path=bag_path, collect=processed_collect, bag_info=bag_info, signatures=signatures, signed_metadata=metadata['signed'], unsigned_metadata=metadata['unsigned'], amend=amend, - use_hard_links=hard_link, - ) + collect_errors=collect_errors, + ) + except CollectionError as e: + raise click.BadParameter(f'Collection task failed: {e}') cli_validate(bag_path) diff --git a/src/nabit/lib/archive.py b/src/nabit/lib/archive.py index 243cb7f..79dacec 100644 --- a/src/nabit/lib/archive.py +++ b/src/nabit/lib/archive.py @@ -4,12 +4,13 @@ import os import hashlib import json - +import uuid from .utils import noop from .backends.url import validate_warc_headers from .sign import validate_signatures, KNOWN_TSAS, add_signatures from .. import __version__ -from .backends.base import CollectionTask +from .backends.base import CollectionTask, CollectionError +from typing import Literal def validate_bag_format(bag_path: Path, error, warn, success) -> None: @@ -27,6 +28,13 @@ def validate_bag_format(bag_path: Path, error, warn, success) -> None: def validate_data_files(bag_path: Path, error = None, warn = noop, success = noop) -> None: """Validate only expected files are present in data/.""" + + # make sure there are files in files_path + files_path = bag_path / "data/files" + if not files_path.exists() or not any(files_path.iterdir()): + warn("No files in data/files") + + # make sure only expected files are present expected_files = set(['files', 'headers.warc', 'signed-metadata.json']) actual_files = set(f.name for f in bag_path.glob('data/*')) unexpected_files = actual_files - expected_files @@ -61,7 +69,7 @@ def package( signatures: list[dict] | None = None, signed_metadata: dict | None = None, unsigned_metadata: dict | None = None, - use_hard_links: bool = False, + collect_errors: Literal['fail', 'ignore'] = 'fail', ) -> None: """ Create a BagIt package. @@ -79,15 +87,32 @@ def package( data_path = output_path / 'data' files_path = data_path / 'files' files_path.mkdir(exist_ok=True, parents=True) + signed_metadata_path = data_path / "signed-metadata.json" + + # set or extend signed metadata + if signed_metadata is None: + if signed_metadata_path.exists(): + signed_metadata = json.loads(signed_metadata_path.read_text()) + else: + signed_metadata = {} + + if not signed_metadata.get('id'): + signed_metadata['id'] = str(uuid.uuid4()) + # run collection tasks and record results if collect: + results = [] for task in collect: - task.collect(files_path) + result = task.collect(files_path) + if collect_errors == 'fail' and not result['response']['success']: + raise CollectionError(f"Collection task failed: {result}") + results.append(result) + signed_metadata.setdefault('collection_tasks', []).extend(results) # Add metadata files - if signed_metadata is not None: + if signed_metadata: (data_path / "signed-metadata.json").write_text(json.dumps(signed_metadata, indent=2)) - if unsigned_metadata is not None: + if unsigned_metadata: (output_path / "unsigned-metadata.json").write_text(json.dumps(unsigned_metadata, indent=2)) ## add bag files diff --git a/src/nabit/lib/backends/base.py b/src/nabit/lib/backends/base.py index 2b8e47e..5081c93 100644 --- a/src/nabit/lib/backends/base.py +++ b/src/nabit/lib/backends/base.py @@ -1,5 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, asdict from functools import lru_cache +from pathlib import Path @lru_cache def get_backends() -> dict[str, type['CollectionTask']]: @@ -8,13 +9,32 @@ def get_backends() -> dict[str, type['CollectionTask']]: from .path import PathCollectionTask return { - 'url': UrlCollectionTask, - 'path': PathCollectionTask, + UrlCollectionTask.backend: UrlCollectionTask, + PathCollectionTask.backend: PathCollectionTask, } +class CollectionError(Exception): + """Base class for collection errors""" + @dataclass class CollectionTask: @classmethod def from_dict(cls, data: dict) -> 'CollectionTask': backend = data.pop('backend') return get_backends()[backend](**data) + + def collect(self, files_dir: Path) -> dict: + """Call the backend-specific _collect method and return the result, handling any errors.""" + try: + result = self._collect(files_dir) + result['success'] = True + except Exception as e: + result = {'success': False, 'error': str(e)} + return { + 'request': self.request_dict(), + 'response': result, + } + + def _collect(self, files_dir: Path) -> dict: + """Collect the data to the given directory.""" + raise NotImplementedError diff --git a/src/nabit/lib/backends/path.py b/src/nabit/lib/backends/path.py index 5170a0b..4123fff 100644 --- a/src/nabit/lib/backends/path.py +++ b/src/nabit/lib/backends/path.py @@ -8,6 +8,8 @@ @dataclass class PathCollectionTask(CollectionTask): """Collect files or directories from the local filesystem.""" + backend = 'path' + path: Path output: Path | None = None hard_links: bool = False @@ -22,7 +24,7 @@ def __post_init__(self): if self.output is not None: self.output = Path(self.output) # Also coerce output if provided - def collect(self, files_dir: Path) -> None: + def _collect(self, files_dir: Path) -> Path: """Copy paths to a destination directory, optionally using hard links.""" path = self.path dest_path = get_unique_path(files_dir / path.name) @@ -43,3 +45,13 @@ def collect(self, files_dir: Path) -> None: copy_function=copy_function, ignore=shutil.ignore_patterns(*self.ignore_patterns) ) + return {'path': str(dest_path.relative_to(files_dir))} + + def request_dict(self) -> dict: + """Return a dictionary representation of the request.""" + return { + 'path': str(self.path), + 'output': str(self.output) if self.output else None, + 'hard_links': self.hard_links, + 'ignore_patterns': self.ignore_patterns, + } diff --git a/src/nabit/lib/backends/url.py b/src/nabit/lib/backends/url.py index 6db728e..df8dfc4 100644 --- a/src/nabit/lib/backends/url.py +++ b/src/nabit/lib/backends/url.py @@ -35,6 +35,8 @@ @dataclass class UrlCollectionTask(CollectionTask): """Collect URLs and request/response metadata.""" + backend = 'url' + url: str output: Path | None = None @@ -44,7 +46,7 @@ def __post_init__(self): """Validate the URL by attempting to prepare a request.""" requests.Request('GET', self.url).prepare() - def collect(self, files_dir: Path) -> None: + def _collect(self, files_dir: Path) -> None: """ Capture URL to a WARC file using our custom FileWriter. Appends to the WARC file if it already exists. @@ -55,6 +57,15 @@ def collect(self, files_dir: Path) -> None: with capture_http(warc_writer): warc_writer.custom_out_path = self.output requests.get(self.url, timeout=self.timeout) + return {'path': str(warc_writer.result_path)} + + def request_dict(self) -> dict: + """Return a dictionary representation of the request.""" + return { + 'url': self.url, + 'output': str(self.output) if self.output else None, + 'timeout': self.timeout, + } class FileWriter(WARCWriter): @@ -63,6 +74,7 @@ class FileWriter(WARCWriter): """ revisit_status_codes = set(['200', '203']) custom_out_path = None # override output path + result_path = None def __init__(self, filebuf, warc_path: Path, *args, **kwargs): super(WARCWriter, self).__init__(*args, **kwargs) @@ -97,6 +109,7 @@ def _write_warc_record(self, out, record): out_path = f'{stem}{extension}' out_path = get_unique_path(self.files_path / out_path) relative_path = out_path.relative_to(self.warc_path.parent) + self.result_path = out_path.relative_to(self.files_path) # add our custom WARC-Profile header headers.add_header('WARC-Profile', f'file-content; filename="{relative_path}"') @@ -136,9 +149,6 @@ def validate_warc_headers(headers_path: Path, error, warn, success) -> None: data_path = headers_path.parent files_path = data_path / "files" - # make sure there are files in files_path - if not files_path.exists() or not any(files_path.iterdir()): - error("No files in data/files") if not headers_path.exists(): warn("No headers.warc found; archive lacks request and response metadata") else: diff --git a/tests/backends/test_path_backend.py b/tests/backends/test_path_backend.py index c1e0cc8..835212c 100644 --- a/tests/backends/test_path_backend.py +++ b/tests/backends/test_path_backend.py @@ -1,5 +1,6 @@ from nabit.lib.backends.path import PathCollectionTask - +from inline_snapshot import snapshot +from ..utils import filter_str def test_ds_store_ignored(tmp_path): """Test that files in ignore_patterns are ignored when copying directories""" @@ -14,7 +15,23 @@ def test_ds_store_ignored(tmp_path): dest_dir.mkdir() # Test copying - PathCollectionTask(path=str(source_dir)).collect(dest_dir) + response = PathCollectionTask(path=str(source_dir)).collect(dest_dir) + assert filter_str(response, path=tmp_path) == snapshot("""\ +{ + "request": { + "path": "/test_dir", + "output": null, + "hard_links": false, + "ignore_patterns": [ + ".DS_Store" + ] + }, + "response": { + "path": "test_dir", + "success": true + } +}\ +""") # Verify results assert not (dest_dir / "test_dir/.DS_Store").exists() diff --git a/tests/backends/test_url_backend.py b/tests/backends/test_url_backend.py index 45b9daf..842cfca 100644 --- a/tests/backends/test_url_backend.py +++ b/tests/backends/test_url_backend.py @@ -3,7 +3,8 @@ from nabit.lib.backends.url import UrlCollectionTask import requests from time import sleep - +from inline_snapshot import snapshot +from ..utils import filter_str @pytest.fixture def capture_dir(tmp_path): @@ -20,9 +21,21 @@ def capture_dir(tmp_path): def test_capture_with_content(capture_dir, server): """Test capturing a 200 response with body content""" - UrlCollectionTask(url=server.url_for("/test.txt")).collect(capture_dir["files_dir"]) + response = UrlCollectionTask(url=server.url_for("/test.txt")).collect(capture_dir["files_dir"]) + assert filter_str(response, port=server.port) == snapshot("""\ +{ + "request": { + "url": "http://localhost:/test.txt", + "output": null, + "timeout": 5.0 + }, + "response": { + "path": "test.txt", + "success": true + } +}\ +""") - # Check headers.warc with open(capture_dir["headers_path"], 'rb') as fh: records = list(ArchiveIterator(fh)) assert len(records) == 2 # request and response @@ -42,7 +55,20 @@ def test_capture_empty_response(capture_dir, server): # Add empty response to server server.expect_request("/empty").respond_with_data("") - UrlCollectionTask(url=server.url_for("/empty")).collect(capture_dir["headers_path"]) + response = UrlCollectionTask(url=server.url_for("/empty")).collect(capture_dir["headers_path"]) + assert filter_str(response, port=server.port) == snapshot("""\ +{ + "request": { + "url": "http://localhost:/empty", + "output": null, + "timeout": 5.0 + }, + "response": { + "path": "empty.txt", + "success": true + } +}\ +""") # Check headers.warc - should be a response record, not revisit with open(capture_dir["headers_path"], 'rb') as fh: @@ -66,7 +92,20 @@ def test_capture_redirect(capture_dir, server): headers={"Location": target_url} ) - UrlCollectionTask(url=redirect_url).collect(capture_dir["headers_path"]) + response = UrlCollectionTask(url=redirect_url).collect(capture_dir["headers_path"]) + assert filter_str(response, port=server.port) == snapshot("""\ +{ + "request": { + "url": "http://localhost:/redirect", + "output": null, + "timeout": 5.0 + }, + "response": { + "path": "test.txt", + "success": true + } +}\ +""") # Check headers.warc with open(capture_dir["headers_path"], 'rb') as fh: @@ -92,6 +131,17 @@ def test_capture_redirect(capture_dir, server): def test_capture_timeout(capture_dir, server): """Test that requests timeout after the specified duration""" server.expect_request("/slow").respond_with_handler(lambda req: sleep(.2)) - task = UrlCollectionTask(url=server.url_for("/slow"), timeout=0.1) - with pytest.raises(requests.exceptions.Timeout): - task.collect(capture_dir["files_dir"]) \ No newline at end of file + response = UrlCollectionTask(url=server.url_for("/slow"), timeout=0.1).collect(capture_dir["files_dir"]) + assert filter_str(response, port=server.port) == snapshot("""\ +{ + "request": { + "url": "http://localhost:/slow", + "output": null, + "timeout": 0.1 + }, + "response": { + "success": false, + "error": "HTTPConnectionPool(host='localhost', port=): Read timed out. (read timeout=0.1)" + } +}\ +""") diff --git a/tests/conftest.py b/tests/conftest.py index 500c5a1..d1d51e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ from click.testing import CliRunner +import json import pytest from pytest_httpserver import HTTPServer @@ -31,8 +32,8 @@ def test_bag(tmp_path, test_files): PathCollectionTask(path=str(test_files["payload"][0])), PathCollectionTask(path=str(test_files["payload"][1])) ], - signed_metadata=test_files["signed_metadata"].read_text(), - unsigned_metadata=test_files["unsigned_metadata"].read_text(), + signed_metadata=json.loads(test_files["signed_metadata"].read_text()), + unsigned_metadata=json.loads(test_files["unsigned_metadata"].read_text()), bag_info={"Source-Organization": "Test Org"} ) return bag_path @@ -69,8 +70,8 @@ def signed_bag(tmp_path, test_files, root_ca): PathCollectionTask(path=str(test_files["payload"][0])), PathCollectionTask(path=str(test_files["payload"][1])) ], - signed_metadata=test_files["signed_metadata"].read_text(), - unsigned_metadata=test_files["unsigned_metadata"].read_text(), + signed_metadata=json.loads(test_files["signed_metadata"].read_text()), + unsigned_metadata=json.loads(test_files["unsigned_metadata"].read_text()), bag_info={"Source-Organization": "Test Org"}, signatures=[ { diff --git a/tests/test_archive.py b/tests/test_archive.py index f74e158..b8adadc 100644 --- a/tests/test_archive.py +++ b/tests/test_archive.py @@ -5,5 +5,5 @@ def test_validate_raises(tmp_path): # make sure that vanilla validate_package raises an error # unless there's an error callback that does something else - with pytest.raises(ValueError, match='No files in data/files'): + with pytest.raises(ValueError, match='bagit.txt does not exist'): validate_package(tmp_path) diff --git a/tests/test_cli.py b/tests/test_cli.py index 51a6fe6..4f08114 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,8 +5,7 @@ import re import pytest -from tests.utils import validate_passing -from .utils import validate_passing, validate_failing +from .utils import validate_passing, validate_failing, filter_str ### helpers @@ -111,8 +110,8 @@ def test_metadata(runner, tmp_path, test_files, metadata_format): """) # check metadata files - assert json.loads((bag_path / 'unsigned-metadata.json').read_text()) == {'metadata': 'unsigned'} - assert json.loads((bag_path / 'data/signed-metadata.json').read_text()) == {'metadata': 'signed'} + assert json.loads((bag_path / 'unsigned-metadata.json').read_text())['metadata'] == 'unsigned' + assert json.loads((bag_path / 'data/signed-metadata.json').read_text())['metadata'] == 'signed' # check bag-info.txt metadata bag_info = (bag_path / 'bag-info.txt').read_text() @@ -345,6 +344,13 @@ def test_collect_json(runner, tmp_path, server): assert (bag_path / 'data/files/data.html').read_text() == 'root content' assert (bag_path / 'data/files/custom.html').read_text() == 'another content' +def test_empty_package(runner, tmp_path): + """Test creating a package with no content""" + run(runner, [ + 'archive', + str(tmp_path), + ]) + ## validation errors def test_invalid_metadata_file_extension(runner, tmp_path): @@ -431,13 +437,6 @@ def test_invalid_url(runner, tmp_path): '-u', 'not_a_url', ], exit_code=2, output='Invalid URL') -def test_empty_package(runner, tmp_path): - """Test creating a package with no content""" - run(runner, [ - 'archive', - str(tmp_path), - ], exit_code=1, output='No files in data/files') - def test_invalid_collect_json(runner, tmp_path): """Test error handling for invalid --collect JSON""" # Test invalid JSON syntax @@ -460,3 +459,44 @@ def test_invalid_collect_json(runner, tmp_path): str(tmp_path / 'bag'), '--collect', '[{"backend": "invalid"}]' ], exit_code=2, output='Invalid task definition for --collect') + +def test_collect_errors_fail(runner, tmp_path): + """Test --collect-errors=fail with a non-resolving URL""" + bag_path = tmp_path / 'bag' + non_resolving_url = 'http://nonexistent.local' + + result = run(runner, [ + 'archive', + str(bag_path), + '--collect-errors', 'fail', + '-u', non_resolving_url, + ], exit_code=2, output='Max retries exceeded') + +def test_collect_errors_ignore(runner, tmp_path): + """Test --collect-errors=ignore with a non-resolving URL""" + bag_path = tmp_path / 'bag' + non_resolving_url = 'http://nonexistent.local' + + result = run(runner, [ + 'archive', + str(bag_path), + '--collect-errors', 'ignore', + '-u', non_resolving_url, + ], output='Package created') + + collection_tasks = json.loads((bag_path / 'data/signed-metadata.json').read_text())['collection_tasks'] + assert filter_str(collection_tasks) == snapshot("""\ +[ + { + "request": { + "url": "http://nonexistent.local", + "output": null, + "timeout": 5.0 + }, + "response": { + "success": false, + "error": "HTTPConnectionPool(host='nonexistent.local', port=80): Max retries exceeded with url: / (Caused by NameResolutionError(\\">: Failed to resolve 'nonexistent.local' ([Errno 8] nodename nor servname provided, or not known)\\"))" + } + } +]\ +""") diff --git a/tests/test_validation.py b/tests/test_validation.py index 07d954c..8ef31c1 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -53,9 +53,9 @@ def test_extra_payload(test_bag): def test_missing_warc_file(warc_bag): (warc_bag / "data/files/data.html").unlink() assert validate_failing(warc_bag) == snapshot("""\ -ERROR: No files in data/files SUCCESS: headers.warc found ERROR: headers.warc specifies files that do not exist in data/files. Example: files/data.html +WARNING: No files in data/files ERROR: bag format is invalid: Bag validation failed: data/files/data.html exists in manifest but was not found on filesystem WARNING: Cannot verify the validity of empty directories: /data/files WARNING: No signatures found @@ -100,8 +100,8 @@ def test_extra_data(test_bag): def test_missing_data(test_bag): shutil.rmtree(test_bag / "data") assert validate_failing(test_bag) == snapshot("""\ -ERROR: No files in data/files WARNING: No headers.warc found; archive lacks request and response metadata +WARNING: No files in data/files ERROR: bag format is invalid: Expected data directory /data does not exist WARNING: No signatures found WARNING: No timestamps found\ diff --git a/tests/utils.py b/tests/utils.py index 651a31c..0750bcf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,4 @@ +import json from pathlib import Path import re from nabit.lib.archive import validate_package @@ -35,3 +36,14 @@ def validate_passing(bag_path: Path): def replace_hashes(text: str) -> str: """Replace all hashes with a placeholder""" return re.sub(r'\b[0-9a-f]{64}\b', '', text) + +def filter_str(obj, **kwargs): + """ + Turn obj into a string, replacing any kwarg values with their keys. + Helpful for consistent comparisons in assertions. + """ + out = json.dumps(obj, indent=2, default=str) + out = re.sub(r'object at 0x[0-9a-f]+', 'object at ', out) + for key, value in kwargs.items(): + out = out.replace(str(value), f"<{key}>") + return out