diff --git a/src/antsibull_fileutils/io.py b/src/antsibull_fileutils/io.py index c456e7c..941e359 100644 --- a/src/antsibull_fileutils/io.py +++ b/src/antsibull_fileutils/io.py @@ -24,7 +24,7 @@ async def copy_file( check_content: bool = True, file_check_content: int = 0, chunksize: int, -) -> None: +) -> bool: """ Copy content from one file to another. @@ -33,6 +33,7 @@ async def copy_file( :kwarg check_content: If ``True`` (default) and ``file_check_content > 0`` and the destination file exists, first check whether source and destination are potentially equal before actually copying, + :return: ``True`` if the file was actually copied. """ if check_content and file_check_content > 0: # Check whether the destination file exists and has the same content as the source file, @@ -48,12 +49,12 @@ async def copy_file( async with aiofiles.open(dest_path, "rb") as f_in: existing_content = await f_in.read() if content_to_copy == existing_content: - return + return False # Since we already read the contents of the file to copy, simply write it to # the destination instead of reading it again async with aiofiles.open(dest_path, "wb") as f_out: await f_out.write(content_to_copy) - return + return True except FileNotFoundError: # Destination (or source) file does not exist pass @@ -62,13 +63,29 @@ async def copy_file( async with aiofiles.open(dest_path, "wb") as f_out: while chunk := await f_in.read(chunksize): await f_out.write(chunk) + return True async def write_file( - filename: StrOrBytesPath, content: str, *, file_check_content: int = 0 -) -> None: - content_bytes = content.encode("utf-8") + filename: StrOrBytesPath, + content: str, + *, + file_check_content: int = 0, + encoding: str = "utf-8", +) -> bool: + """ + Write encoded content to file. + + :arg filename: The filename to write to. + :arg content: The content to write to the file. + :kwarg file_check_content: If > 0 and the file exists and its size in bytes does not exceed this + value, will read the file and compare it to the encoded content before overwriting. + :return: ``True`` if the file was actually written. + """ + content_bytes = content.encode(encoding) + + print(file_check_content, len(content_bytes)) if file_check_content > 0 and len(content_bytes) <= file_check_content: # Check whether the destination file exists and has the same content as the one we want to # write, in which case we won't overwrite the file @@ -79,16 +96,23 @@ async def write_file( async with aiofiles.open(filename, "rb") as f: existing_content = await f.read() if existing_content == content_bytes: - return + return False except FileNotFoundError: # Destination file does not exist pass async with aiofiles.open(filename, "wb") as f: await f.write(content_bytes) + return True async def read_file(filename: StrOrBytesPath, *, encoding: str = "utf-8") -> str: + """ + Read the file and decode its contents with the given encoding. + + :arg filename: The filename to read from. + :kwarg encoding: The encoding to use. + """ async with aiofiles.open(filename, "r", encoding=encoding) as f: content = await f.read() diff --git a/tests/units/test_io.py b/tests/units/test_io.py new file mode 100644 index 0000000..badbf7f --- /dev/null +++ b/tests/units/test_io.py @@ -0,0 +1,161 @@ +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later +# SPDX-FileCopyrightText: Ansible Project + +from __future__ import annotations + +from io import BytesIO + +import pytest + +from antsibull_fileutils.io import copy_file, read_file, write_file + + +@pytest.mark.asyncio +async def test_copy_file(tmp_path): + content = "foo\x00bar\x00baz\x00çäø☺".encode("utf-8") + content_len = len(content) + + alt_content = "boo\x00bar\x00baz\x00çäø☺".encode("utf-8") + alt_content_len = "foo\x00bar".encode("utf-8") + + src_path = tmp_path / "file-src" + src_path.write_bytes(content) + + dst_path = tmp_path / "file-dst" + assert await copy_file(src_path, dst_path, chunksize=4) is True + assert dst_path.read_bytes() == content + + assert ( + await copy_file( + src_path, dst_path, chunksize=4, file_check_content=content_len - 1 + ) + is True + ) + assert dst_path.read_bytes() == content + + assert ( + await copy_file(src_path, dst_path, chunksize=4, file_check_content=content_len) + is False + ) + assert dst_path.read_bytes() == content + + src_path.write_bytes(alt_content) + + assert ( + await copy_file(src_path, dst_path, chunksize=4, file_check_content=content_len) + is True + ) + assert dst_path.read_bytes() == alt_content + + src_path.write_bytes(alt_content_len) + + assert ( + await copy_file(src_path, dst_path, chunksize=4, file_check_content=content_len) + is True + ) + assert dst_path.read_bytes() == alt_content_len + + dst_path = tmp_path / "file-dst-2" + assert ( + await copy_file(src_path, dst_path, chunksize=4, file_check_content=content_len) + is True + ) + assert dst_path.read_bytes() == alt_content_len + + +@pytest.mark.asyncio +async def test_read_file(tmp_path): + content = "foo\x00bar\x00baz\x00çäø☺" + filename = tmp_path / "file" + filename.write_text(content, encoding="utf-8") + + assert await read_file(filename, encoding="utf-8") == content + assert await read_file(str(filename), encoding="utf-8") == content + assert await read_file(str(filename).encode("utf-8"), encoding="utf-8") == content + assert await read_file(str(filename).encode("utf-8"), encoding="utf-8") == content + + filename.write_text(content, encoding="utf-16") + + assert await read_file(filename, encoding="utf-16") == content + + +@pytest.mark.asyncio +async def test_write_file(tmp_path): + content = "foo\x00bar\x00baz\x00çäø☺" + encoded_content_len = len(content.encode("utf-8")) + + alt_content = "boo\x00bar\x00baz\x00çäø☺" + alt_content_len = "foo\x00bar" + + filename = tmp_path / "file" + alt_filename = tmp_path / "file2" + + assert await write_file(filename, content, encoding="utf-8") is True + assert filename.read_text(encoding="utf-8") == content + + assert await write_file(filename, content, encoding="utf-8") is True + assert filename.read_text(encoding="utf-8") == content + + assert ( + await write_file( + filename, + content, + encoding="utf-8", + file_check_content=encoded_content_len - 1, + ) + is True + ) + assert filename.read_text(encoding="utf-8") == content + + assert ( + await write_file( + filename, content, encoding="utf-8", file_check_content=encoded_content_len + ) + is False + ) + assert filename.read_text(encoding="utf-8") == content + + assert ( + await write_file( + filename, + content, + encoding="utf-8", + file_check_content=encoded_content_len + 1, + ) + is False + ) + assert filename.read_text(encoding="utf-8") == content + + assert ( + await write_file( + filename, + alt_content, + encoding="utf-8", + file_check_content=encoded_content_len, + ) + is True + ) + assert filename.read_text(encoding="utf-8") == alt_content + + assert ( + await write_file( + alt_filename, + alt_content, + encoding="utf-8", + file_check_content=encoded_content_len, + ) + is True + ) + assert alt_filename.read_text(encoding="utf-8") == alt_content + + assert ( + await write_file( + alt_filename, + alt_content_len, + encoding="utf-8", + file_check_content=encoded_content_len, + ) + is True + ) + assert alt_filename.read_text(encoding="utf-8") == alt_content_len