diff --git a/servicelayer/archive/file.py b/servicelayer/archive/file.py index d3a7213..0838a46 100644 --- a/servicelayer/archive/file.py +++ b/servicelayer/archive/file.py @@ -5,7 +5,7 @@ from normality import safe_filename from servicelayer.archive.archive import Archive -from servicelayer.archive.util import ensure_path, checksum, BUF_SIZE +from servicelayer.archive.util import ensure_path, checksum, sanitize_checksum, BUF_SIZE from servicelayer.archive.util import path_prefix, path_content_hash log = logging.getLogger(__name__) @@ -33,6 +33,8 @@ def archive_file(self, file_path, content_hash=None, mime_type=None): """Import the given file into the archive.""" if content_hash is None: content_hash = checksum(file_path) + else: + content_hash = sanitize_checksum(content_hash) if content_hash is None: return @@ -51,6 +53,7 @@ def archive_file(self, file_path, content_hash=None, mime_type=None): return content_hash def load_file(self, content_hash, file_name=None, temp_path=None): + content_hash = sanitize_checksum(content_hash) return self._locate_key(content_hash) def list_files(self, prefix=None): @@ -67,6 +70,7 @@ def list_files(self, prefix=None): yield path_content_hash(file_path) def delete_file(self, content_hash): + content_hash = sanitize_checksum(content_hash) prefix = path_prefix(content_hash) if prefix is None: return diff --git a/servicelayer/archive/gs.py b/servicelayer/archive/gs.py index 0a6eb62..92adce4 100644 --- a/servicelayer/archive/gs.py +++ b/servicelayer/archive/gs.py @@ -9,7 +9,7 @@ from google.resumable_media.common import DataCorruption, InvalidResponse from servicelayer.archive.virtual import VirtualArchive -from servicelayer.archive.util import checksum, ensure_path +from servicelayer.archive.util import checksum, sanitize_checksum, ensure_path from servicelayer.archive.util import path_prefix, ensure_posix_path from servicelayer.archive.util import path_content_hash, HASH_LENGTH from servicelayer.util import service_retries, backoff @@ -89,6 +89,8 @@ def archive_file(self, file_path, content_hash=None, mime_type=None): file_path = ensure_path(file_path) if content_hash is None: content_hash = checksum(file_path) + else: + content_hash = sanitize_checksum(content_hash) if content_hash is None: return @@ -111,6 +113,7 @@ def archive_file(self, file_path, content_hash=None, mime_type=None): def load_file(self, content_hash, file_name=None, temp_path=None): """Retrieve a file from Google storage and put it onto the local file system for further processing.""" + content_hash = sanitize_checksum(content_hash) for attempt in service_retries(): try: blob = self._locate_contenthash(content_hash) @@ -147,6 +150,7 @@ def delete_file(self, content_hash): """Check if a file with the given hash exists on S3.""" if content_hash is None or len(content_hash) < HASH_LENGTH: return + content_hash = sanitize_checksum(content_hash) prefix = path_prefix(content_hash) if prefix is None: return diff --git a/servicelayer/archive/s3.py b/servicelayer/archive/s3.py index 4b9c3b1..750b461 100644 --- a/servicelayer/archive/s3.py +++ b/servicelayer/archive/s3.py @@ -5,7 +5,7 @@ from servicelayer import settings from servicelayer.archive.virtual import VirtualArchive -from servicelayer.archive.util import checksum, ensure_path +from servicelayer.archive.util import checksum, sanitize_checksum, ensure_path from servicelayer.archive.util import path_prefix, path_content_hash log = logging.getLogger(__name__) @@ -86,6 +86,8 @@ def archive_file(self, file_path, content_hash=None, mime_type=None): file_path = ensure_path(file_path) if content_hash is None: content_hash = checksum(file_path) + else: + content_hash = sanitize_checksum(content_hash) # if content_hash is None: # return @@ -105,6 +107,7 @@ def archive_file(self, file_path, content_hash=None, mime_type=None): def load_file(self, content_hash, file_name=None, temp_path=None): """Retrieve a file from S3 storage and put it onto the local file system for further processing.""" + content_hash = sanitize_checksum(content_hash) key = self._locate_key(content_hash) if key is not None: path = self._local_path(content_hash, file_name, temp_path) @@ -114,6 +117,7 @@ def load_file(self, content_hash, file_name=None, temp_path=None): def delete_file(self, content_hash): if content_hash is None: return + content_hash = sanitize_checksum(content_hash) prefix = path_prefix(content_hash) if prefix is None: return diff --git a/servicelayer/archive/util.py b/servicelayer/archive/util.py index c33d8a2..29ad717 100644 --- a/servicelayer/archive/util.py +++ b/servicelayer/archive/util.py @@ -1,4 +1,5 @@ import os +import string from hashlib import sha1 from pathlib import Path @@ -32,6 +33,18 @@ def checksum(file_name): return str(digest.hexdigest()) +def sanitize_checksum(checksum): + """Normalize the checksum. Raises an error if the given checksum invalid.""" + if not checksum: + raise ValueError("Checksum is empty") + + for char in checksum: + if char not in string.hexdigits: + raise ValueError(f'Checksum contains invalid character "{char}"') + + return checksum + + def path_prefix(content_hash): """Get a prefix for a content hashed folder structure.""" if content_hash is None: diff --git a/servicelayer/archive/virtual.py b/servicelayer/archive/virtual.py index b5b6970..c4c9e94 100644 --- a/servicelayer/archive/virtual.py +++ b/servicelayer/archive/virtual.py @@ -5,7 +5,7 @@ from normality import safe_filename from servicelayer.archive.archive import Archive -from servicelayer.archive.util import ensure_path +from servicelayer.archive.util import ensure_path, sanitize_checksum log = logging.getLogger(__name__) @@ -33,6 +33,7 @@ def cleanup_file(self, content_hash, temp_path=None): """Delete the local cached version of the file.""" if content_hash is None: return + content_hash = sanitize_checksum(content_hash) path = self._get_local_prefix(content_hash, temp_path=temp_path) try: shutil.rmtree(path, ignore_errors=True) diff --git a/tests/archive/test_file.py b/tests/archive/test_file.py index c77e2f4..d059f80 100644 --- a/tests/archive/test_file.py +++ b/tests/archive/test_file.py @@ -1,3 +1,4 @@ +import pytest import shutil import tempfile from unittest import TestCase @@ -26,9 +27,11 @@ def test_basic_archive(self): assert out == out2, (out, out2) def test_basic_archive_with_checksum(self): - checksum_ = "banana" - out = self.archive.archive_file(self.file, checksum_) - assert checksum_ == out, (checksum_, out) + with pytest.raises(ValueError): + self.archive.archive_file(self.file, content_hash="banana") + + out = self.archive.archive_file(self.file, content_hash="01234567890abcdef") + assert out == "01234567890abcdef" def test_generate_url(self): out = self.archive.archive_file(self.file) @@ -39,6 +42,15 @@ def test_publish(self): assert not self.archive.can_publish def test_load_file(self): + # Invalid content hash + with pytest.raises(ValueError): + self.archive.load_file("banana") + + # Valid content hash, but file does not exist + path = self.archive.load_file("01234567890abcdef") + assert path is None + + # Valid content hash, file exists out = self.archive.archive_file(self.file) path = self.archive.load_file(out) assert path is not None, path @@ -64,10 +76,17 @@ def test_list_files(self): assert len(keys) == 0, keys def test_delete_file(self): + # Invalid content hash + with pytest.raises(ValueError): + self.archive.delete_file("banana") + + # File does not exist + assert self.archive.delete_file("01234567890abcdef") is None + + # Valid content hash, file exists out = self.archive.archive_file(self.file) path = self.archive.load_file(out) assert path is not None, path - self.archive.cleanup_file(out) - self.archive.delete_file(out) + assert self.archive.delete_file(out) is None path = self.archive.load_file(out) assert path is None, path diff --git a/tests/archive/test_s3.py b/tests/archive/test_s3.py index b7474ab..71bf624 100644 --- a/tests/archive/test_s3.py +++ b/tests/archive/test_s3.py @@ -1,3 +1,4 @@ +import pytest from unittest import TestCase from urllib.parse import urlparse, parse_qs @@ -26,9 +27,11 @@ def test_basic_archive(self): assert out == out2, (out, out2) def test_basic_archive_with_checksum(self): - checksum_ = "banana" - out = self.archive.archive_file(self.file, checksum_) - assert checksum_ == out, (checksum_, out) + with pytest.raises(ValueError): + self.archive.archive_file(self.file, content_hash="banana") + + out = self.archive.archive_file(self.file, content_hash="01234567890abcdef") + assert out == "01234567890abcdef" def test_generate_url(self): content_hash = self.archive.archive_file(self.file) @@ -60,12 +63,29 @@ def test_publish_file(self): assert "https://foo.s3.amazonaws.com/self.py" in url, url def test_load_file(self): + # Invalid content hash + with pytest.raises(ValueError): + self.archive.load_file("banana") + + # Valid content hash, but file does not exist + path = self.archive.load_file("01234567890abcdef") + assert path is None + + # Valid content hash, file exists out = self.archive.archive_file(self.file) path = self.archive.load_file(out) assert path is not None, path assert path.is_file(), path def test_cleanup_file(self): + # Invalid content hash + with pytest.raises(ValueError): + self.archive.cleanup_file("banana") + + # File does not exist + assert self.archive.cleanup_file("01234567890abcdef") is None + + # Valid content hash, file exists out = self.archive.archive_file(self.file) self.archive.cleanup_file(out) path = self.archive.load_file(out) @@ -86,6 +106,14 @@ def test_list_files(self): assert len(keys) == 0, keys def test_delete_file(self): + # Invalid content hash + with pytest.raises(ValueError): + self.archive.delete_file("banana") + + # File does not exist + assert self.archive.delete_file("01234567890abcdef") is None + + # Valid content hash, file exists out = self.archive.archive_file(self.file) path = self.archive.load_file(out) assert path is not None, path diff --git a/tests/archive/test_util.py b/tests/archive/test_util.py new file mode 100644 index 0000000..4d03f24 --- /dev/null +++ b/tests/archive/test_util.py @@ -0,0 +1,21 @@ +import pytest +from unittest import TestCase + +from servicelayer.archive.util import sanitize_checksum + + +class UtilTest(TestCase): + def test_sanitize_checksum(self): + assert sanitize_checksum("0123456789abcdef") == "0123456789abcdef" + + with pytest.raises(ValueError, match="Checksum is empty"): + sanitize_checksum(None) + + with pytest.raises(ValueError, match="Checksum is empty"): + sanitize_checksum("") + + with pytest.raises(ValueError, match='Checksum contains invalid character "n"'): + sanitize_checksum("banana") + + with pytest.raises(ValueError, match='Checksum contains invalid character "/"'): + sanitize_checksum("/")