diff --git a/pyproject.toml b/pyproject.toml index 8991622..fb70913 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ ] requires-python = ">=3.9" dependencies = [ + "aiofiles", "PyYAML", ] @@ -64,12 +65,15 @@ formatters = [ "isort", ] test = [ + "asynctest", "pytest", + "pytest-asyncio >= 0.20", "pytest-cov", "pytest-error-for-skips", ] typing = [ "mypy", + "types-aiofiles", "types-PyYAML", "typing-extensions", ] diff --git a/src/antsibull_fileutils/hashing.py b/src/antsibull_fileutils/hashing.py new file mode 100644 index 0000000..18eb070 --- /dev/null +++ b/src/antsibull_fileutils/hashing.py @@ -0,0 +1,86 @@ +# Author: Toshio Kuratomi +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or +# https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later +# SPDX-FileCopyrightText: 2020, Ansible Project +"""Functions to help with hashing.""" + +from __future__ import annotations + +import dataclasses +import hashlib +import typing as t +from collections.abc import Mapping + +import aiofiles + +if t.TYPE_CHECKING: + from _typeshed import StrOrBytesPath + + +@dataclasses.dataclass(frozen=True) +class _AlgorithmData: + name: str + algorithm: str + kwargs: dict[str, t.Any] + + +_PREFERRED_HASHES: tuple[_AlgorithmData, ...] = ( + # https://pypi.org/help/#verify-hashes, https://github.com/pypi/warehouse/issues/9628 + _AlgorithmData(name="sha256", algorithm="sha256", kwargs={}), + _AlgorithmData(name="blake2b_256", algorithm="blake2b", kwargs={"digest_size": 32}), +) + + +async def verify_hash( + filename: StrOrBytesPath, + hash_digest: str, + *, + algorithm: str = "sha256", + algorithm_kwargs: dict[str, t.Any] | None = None, + chunksize: int, +) -> bool: + """ + Verify whether a file has a given sha256sum. + + :arg filename: The file to verify the sha256sum of. + :arg hash_digest: The hash that is expected. + :kwarg algorithm: The hash algorithm to use. This must be present in hashlib on this + system. The default is 'sha256' + :kwarg algorithm_kwargs: Parameters to provide to the hash algorithm's constructor. + :returns: True if the hash matches, otherwise False. + """ + hasher = getattr(hashlib, algorithm)(**(algorithm_kwargs or {})) + async with aiofiles.open(filename, "rb") as f: + while chunk := await f.read(chunksize): + hasher.update(chunk) + if hasher.hexdigest() != hash_digest: + return False + + return True + + +async def verify_a_hash( + filename: StrOrBytesPath, + hash_digests: Mapping[str, str], + *, + chunksize: int, +) -> bool: + """ + Verify whether a file has a given hash, given a set of digests with different algorithms. + Will only test trustworthy hashes and return ``False`` if none matches. + + :arg filename: The file to verify the hash of. + :arg hash_digest: A mapping of hash types to digests. + :returns: True if the hash matches, otherwise False. + """ + for algorithm_data in _PREFERRED_HASHES: + if algorithm_data.name in hash_digests: + return await verify_hash( + filename, + hash_digests[algorithm_data.name], + algorithm=algorithm_data.algorithm, + algorithm_kwargs=algorithm_data.kwargs, + chunksize=chunksize, + ) + return False diff --git a/src/antsibull_fileutils/io.py b/src/antsibull_fileutils/io.py new file mode 100644 index 0000000..c456e7c --- /dev/null +++ b/src/antsibull_fileutils/io.py @@ -0,0 +1,95 @@ +# Author: Toshio Kuratomi +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or +# https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later +# SPDX-FileCopyrightText: 2021, Ansible Project +"""I/O helper functions.""" + +from __future__ import annotations + +import os +import os.path +import typing as t + +import aiofiles + +if t.TYPE_CHECKING: + from _typeshed import StrOrBytesPath + + +async def copy_file( + source_path: StrOrBytesPath, + dest_path: StrOrBytesPath, + *, + check_content: bool = True, + file_check_content: int = 0, + chunksize: int, +) -> None: + """ + Copy content from one file to another. + + :arg source_path: Source path. Must be a file. + :arg dest_path: Destination path. + :kwarg check_content: If ``True`` (default) and ``file_check_content > 0`` and the + destination file exists, first check whether source and destination are potentially equal + before actually copying, + """ + if check_content and file_check_content > 0: + # Check whether the destination file exists and has the same content as the source file, + # in which case we won't overwrite the destination file + try: + stat_d = os.stat(dest_path) + if stat_d.st_size <= file_check_content: + stat_s = os.stat(source_path) + if stat_d.st_size == stat_s.st_size: + # Read both files and compare + async with aiofiles.open(source_path, "rb") as f_in: + content_to_copy = await f_in.read() + async with aiofiles.open(dest_path, "rb") as f_in: + existing_content = await f_in.read() + if content_to_copy == existing_content: + return + # Since we already read the contents of the file to copy, simply write it to + # the destination instead of reading it again + async with aiofiles.open(dest_path, "wb") as f_out: + await f_out.write(content_to_copy) + return + except FileNotFoundError: + # Destination (or source) file does not exist + pass + + async with aiofiles.open(source_path, "rb") as f_in: + async with aiofiles.open(dest_path, "wb") as f_out: + while chunk := await f_in.read(chunksize): + await f_out.write(chunk) + + +async def write_file( + filename: StrOrBytesPath, content: str, *, file_check_content: int = 0 +) -> None: + content_bytes = content.encode("utf-8") + + if file_check_content > 0 and len(content_bytes) <= file_check_content: + # Check whether the destination file exists and has the same content as the one we want to + # write, in which case we won't overwrite the file + try: + stat = os.stat(filename) + if stat.st_size == len(content_bytes): + # Read file and compare + async with aiofiles.open(filename, "rb") as f: + existing_content = await f.read() + if existing_content == content_bytes: + return + except FileNotFoundError: + # Destination file does not exist + pass + + async with aiofiles.open(filename, "wb") as f: + await f.write(content_bytes) + + +async def read_file(filename: StrOrBytesPath, *, encoding: str = "utf-8") -> str: + async with aiofiles.open(filename, "r", encoding=encoding) as f: + content = await f.read() + + return content diff --git a/tests/units/test_hashing.py b/tests/units/test_hashing.py new file mode 100644 index 0000000..6c68287 --- /dev/null +++ b/tests/units/test_hashing.py @@ -0,0 +1,107 @@ +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later +# SPDX-FileCopyrightText: Ansible Project + +from __future__ import annotations + +import pytest + +from antsibull_fileutils.hashing import verify_a_hash, verify_hash + +HASH_TESTS = [ + ( + b"", + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + "sha256", + {}, + True, + ), + ( + b"", + "01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b", + "sha256", + {}, + False, + ), +] + + +@pytest.mark.parametrize( + "content, hash, algorithm, algorithm_kwargs, expected_match", HASH_TESTS +) +@pytest.mark.asyncio +async def test_verify_hash( + content: bytes, + hash: bytes, + algorithm: str, + algorithm_kwargs: dict | None, + expected_match: bool, + tmp_path, +): + filename = tmp_path / "file" + with open(filename, "wb") as f: + f.write(content) + result = await verify_hash( + filename, + hash, + algorithm=algorithm, + algorithm_kwargs=algorithm_kwargs, + chunksize=65536, + ) + assert result is expected_match + + +HASH_DICT_TESTS = [ + ( + b"foo", + { + "sha256": "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae", + }, + True, + ), + ( + b"bar", + { + "sha256": "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae", + }, + False, + ), + ( + b"foo", + { + "blake2b_256": "b8fe9f7f6255a6fa08f668ab632a8d081ad87983c77cd274e48ce450f0b349fd", + }, + True, + ), + ( + b"bar", + { + "blake2b_256": "b8fe9f7f6255a6fa08f668ab632a8d081ad87983c77cd274e48ce450f0b349fd", + }, + False, + ), + ( + b"", + {}, + False, + ), + ( + b"", + { + "foo": "bar", + }, + False, + ), +] + + +@pytest.mark.parametrize("content, hash_digests, expected_match", HASH_DICT_TESTS) +@pytest.mark.asyncio +async def test_verify_a_hash( + content: bytes, hash_digests: dict[str, str], expected_match: bool, tmp_path +): + filename = tmp_path / "file" + with open(filename, "wb") as f: + f.write(content) + result = await verify_a_hash(filename, hash_digests, chunksize=65536) + assert result is expected_match