Skip to content

Commit

Permalink
Add more utils from antsibull-core.
Browse files Browse the repository at this point in the history
  • Loading branch information
felixfontein committed Sep 8, 2024
1 parent 429b3ec commit 814c1e9
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ classifiers = [
]
requires-python = ">=3.9"
dependencies = [
"aiofiles",
"PyYAML",
]

Expand Down Expand Up @@ -64,12 +65,15 @@ formatters = [
"isort",
]
test = [
"asynctest",
"pytest",
"pytest-asyncio >= 0.20",
"pytest-cov",
"pytest-error-for-skips",
]
typing = [
"mypy",
"types-aiofiles",
"types-PyYAML",
"typing-extensions",
]
Expand Down
86 changes: 86 additions & 0 deletions src/antsibull_fileutils/hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Author: Toshio Kuratomi <[email protected]>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or
# https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-FileCopyrightText: 2020, Ansible Project
"""Functions to help with hashing."""

from __future__ import annotations

import dataclasses
import hashlib
import typing as t
from collections.abc import Mapping

import aiofiles

if t.TYPE_CHECKING:
from _typeshed import StrOrBytesPath


@dataclasses.dataclass(frozen=True)
class _AlgorithmData:
name: str
algorithm: str
kwargs: dict[str, t.Any]


_PREFERRED_HASHES: tuple[_AlgorithmData, ...] = (
# https://pypi.org/help/#verify-hashes, https://github.com/pypi/warehouse/issues/9628
_AlgorithmData(name="sha256", algorithm="sha256", kwargs={}),
_AlgorithmData(name="blake2b_256", algorithm="blake2b", kwargs={"digest_size": 32}),
)


async def verify_hash(
filename: StrOrBytesPath,
hash_digest: str,
*,
algorithm: str = "sha256",
algorithm_kwargs: dict[str, t.Any] | None = None,
chunksize: int,
) -> bool:
"""
Verify whether a file has a given sha256sum.
:arg filename: The file to verify the sha256sum of.
:arg hash_digest: The hash that is expected.
:kwarg algorithm: The hash algorithm to use. This must be present in hashlib on this
system. The default is 'sha256'
:kwarg algorithm_kwargs: Parameters to provide to the hash algorithm's constructor.
:returns: True if the hash matches, otherwise False.
"""
hasher = getattr(hashlib, algorithm)(**(algorithm_kwargs or {}))
async with aiofiles.open(filename, "rb") as f:
while chunk := await f.read(chunksize):
hasher.update(chunk)
if hasher.hexdigest() != hash_digest:
return False

return True


async def verify_a_hash(
filename: StrOrBytesPath,
hash_digests: Mapping[str, str],
*,
chunksize: int,
) -> bool:
"""
Verify whether a file has a given hash, given a set of digests with different algorithms.
Will only test trustworthy hashes and return ``False`` if none matches.
:arg filename: The file to verify the hash of.
:arg hash_digest: A mapping of hash types to digests.
:returns: True if the hash matches, otherwise False.
"""
for algorithm_data in _PREFERRED_HASHES:
if algorithm_data.name in hash_digests:
return await verify_hash(
filename,
hash_digests[algorithm_data.name],
algorithm=algorithm_data.algorithm,
algorithm_kwargs=algorithm_data.kwargs,
chunksize=chunksize,
)
return False
95 changes: 95 additions & 0 deletions src/antsibull_fileutils/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Author: Toshio Kuratomi <[email protected]>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or
# https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-FileCopyrightText: 2021, Ansible Project
"""I/O helper functions."""

from __future__ import annotations

import os
import os.path
import typing as t

import aiofiles

if t.TYPE_CHECKING:
from _typeshed import StrOrBytesPath


async def copy_file(
source_path: StrOrBytesPath,
dest_path: StrOrBytesPath,
*,
check_content: bool = True,
file_check_content: int = 0,
chunksize: int,
) -> None:
"""
Copy content from one file to another.
:arg source_path: Source path. Must be a file.
:arg dest_path: Destination path.
:kwarg check_content: If ``True`` (default) and ``file_check_content > 0`` and the
destination file exists, first check whether source and destination are potentially equal
before actually copying,
"""
if check_content and file_check_content > 0:
# Check whether the destination file exists and has the same content as the source file,
# in which case we won't overwrite the destination file
try:
stat_d = os.stat(dest_path)
if stat_d.st_size <= file_check_content:
stat_s = os.stat(source_path)
if stat_d.st_size == stat_s.st_size:
# Read both files and compare
async with aiofiles.open(source_path, "rb") as f_in:
content_to_copy = await f_in.read()
async with aiofiles.open(dest_path, "rb") as f_in:
existing_content = await f_in.read()
if content_to_copy == existing_content:
return
# Since we already read the contents of the file to copy, simply write it to
# the destination instead of reading it again
async with aiofiles.open(dest_path, "wb") as f_out:
await f_out.write(content_to_copy)
return
except FileNotFoundError:
# Destination (or source) file does not exist
pass

async with aiofiles.open(source_path, "rb") as f_in:
async with aiofiles.open(dest_path, "wb") as f_out:
while chunk := await f_in.read(chunksize):
await f_out.write(chunk)


async def write_file(
filename: StrOrBytesPath, content: str, *, file_check_content: int = 0
) -> None:
content_bytes = content.encode("utf-8")

if file_check_content > 0 and len(content_bytes) <= file_check_content:
# Check whether the destination file exists and has the same content as the one we want to
# write, in which case we won't overwrite the file
try:
stat = os.stat(filename)
if stat.st_size == len(content_bytes):
# Read file and compare
async with aiofiles.open(filename, "rb") as f:
existing_content = await f.read()
if existing_content == content_bytes:
return
except FileNotFoundError:
# Destination file does not exist
pass

async with aiofiles.open(filename, "wb") as f:
await f.write(content_bytes)


async def read_file(filename: StrOrBytesPath, *, encoding: str = "utf-8") -> str:
async with aiofiles.open(filename, "r", encoding=encoding) as f:
content = await f.read()

return content
107 changes: 107 additions & 0 deletions tests/units/test_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-FileCopyrightText: Ansible Project

from __future__ import annotations

import pytest

from antsibull_fileutils.hashing import verify_a_hash, verify_hash

HASH_TESTS = [
(
b"",
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
"sha256",
{},
True,
),
(
b"",
"01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b",
"sha256",
{},
False,
),
]


@pytest.mark.parametrize(
"content, hash, algorithm, algorithm_kwargs, expected_match", HASH_TESTS
)
@pytest.mark.asyncio
async def test_verify_hash(
content: bytes,
hash: bytes,
algorithm: str,
algorithm_kwargs: dict | None,
expected_match: bool,
tmp_path,
):
filename = tmp_path / "file"
with open(filename, "wb") as f:
f.write(content)
result = await verify_hash(
filename,
hash,
algorithm=algorithm,
algorithm_kwargs=algorithm_kwargs,
chunksize=65536,
)
assert result is expected_match


HASH_DICT_TESTS = [
(
b"foo",
{
"sha256": "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae",
},
True,
),
(
b"bar",
{
"sha256": "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae",
},
False,
),
(
b"foo",
{
"blake2b_256": "b8fe9f7f6255a6fa08f668ab632a8d081ad87983c77cd274e48ce450f0b349fd",
},
True,
),
(
b"bar",
{
"blake2b_256": "b8fe9f7f6255a6fa08f668ab632a8d081ad87983c77cd274e48ce450f0b349fd",
},
False,
),
(
b"",
{},
False,
),
(
b"",
{
"foo": "bar",
},
False,
),
]


@pytest.mark.parametrize("content, hash_digests, expected_match", HASH_DICT_TESTS)
@pytest.mark.asyncio
async def test_verify_a_hash(
content: bytes, hash_digests: dict[str, str], expected_match: bool, tmp_path
):
filename = tmp_path / "file"
with open(filename, "wb") as f:
f.write(content)
result = await verify_a_hash(filename, hash_digests, chunksize=65536)
assert result is expected_match

0 comments on commit 814c1e9

Please sign in to comment.