Skip to content

Commit

Permalink
Improve documentation; return copy/write status; add tests for io.
Browse files Browse the repository at this point in the history
  • Loading branch information
felixfontein committed Sep 8, 2024
1 parent 814c1e9 commit 57f8eb6
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 7 deletions.
38 changes: 31 additions & 7 deletions src/antsibull_fileutils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def copy_file(
check_content: bool = True,
file_check_content: int = 0,
chunksize: int,
) -> None:
) -> bool:
"""
Copy content from one file to another.
Expand All @@ -33,6 +33,7 @@ async def copy_file(
:kwarg check_content: If ``True`` (default) and ``file_check_content > 0`` and the
destination file exists, first check whether source and destination are potentially equal
before actually copying,
:return: ``True`` if the file was actually copied.
"""
if check_content and file_check_content > 0:
# Check whether the destination file exists and has the same content as the source file,
Expand All @@ -48,12 +49,12 @@ async def copy_file(
async with aiofiles.open(dest_path, "rb") as f_in:
existing_content = await f_in.read()
if content_to_copy == existing_content:
return
return False
# Since we already read the contents of the file to copy, simply write it to
# the destination instead of reading it again
async with aiofiles.open(dest_path, "wb") as f_out:
await f_out.write(content_to_copy)
return
return True
except FileNotFoundError:
# Destination (or source) file does not exist
pass
Expand All @@ -62,13 +63,29 @@ async def copy_file(
async with aiofiles.open(dest_path, "wb") as f_out:
while chunk := await f_in.read(chunksize):
await f_out.write(chunk)
return True


async def write_file(
filename: StrOrBytesPath, content: str, *, file_check_content: int = 0
) -> None:
content_bytes = content.encode("utf-8")
filename: StrOrBytesPath,
content: str,
*,
file_check_content: int = 0,
encoding: str = "utf-8",
) -> bool:
"""
Write encoded content to file.
:arg filename: The filename to write to.
:arg content: The content to write to the file.
:kwarg file_check_content: If > 0 and the file exists and its size in bytes does not exceed this
value, will read the file and compare it to the encoded content before overwriting.
:return: ``True`` if the file was actually written.
"""

content_bytes = content.encode(encoding)

print(file_check_content, len(content_bytes))
if file_check_content > 0 and len(content_bytes) <= file_check_content:
# Check whether the destination file exists and has the same content as the one we want to
# write, in which case we won't overwrite the file
Expand All @@ -79,16 +96,23 @@ async def write_file(
async with aiofiles.open(filename, "rb") as f:
existing_content = await f.read()
if existing_content == content_bytes:
return
return False
except FileNotFoundError:
# Destination file does not exist
pass

async with aiofiles.open(filename, "wb") as f:
await f.write(content_bytes)
return True


async def read_file(filename: StrOrBytesPath, *, encoding: str = "utf-8") -> str:
"""
Read the file and decode its contents with the given encoding.
:arg filename: The filename to read from.
:kwarg encoding: The encoding to use.
"""
async with aiofiles.open(filename, "r", encoding=encoding) as f:
content = await f.read()

Expand Down
161 changes: 161 additions & 0 deletions tests/units/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-FileCopyrightText: Ansible Project

from __future__ import annotations

from io import BytesIO

import pytest

from antsibull_fileutils.io import copy_file, read_file, write_file


@pytest.mark.asyncio
async def test_copy_file(tmp_path):
content = "foo\x00bar\x00baz\x00çäø☺".encode("utf-8")
content_len = len(content)

alt_content = "boo\x00bar\x00baz\x00çäø☺".encode("utf-8")
alt_content_len = "foo\x00bar".encode("utf-8")

src_path = tmp_path / "file-src"
src_path.write_bytes(content)

dst_path = tmp_path / "file-dst"
assert await copy_file(src_path, dst_path, chunksize=4) is True
assert dst_path.read_bytes() == content

assert (
await copy_file(
src_path, dst_path, chunksize=4, file_check_content=content_len - 1
)
is True
)
assert dst_path.read_bytes() == content

assert (
await copy_file(src_path, dst_path, chunksize=4, file_check_content=content_len)
is False
)
assert dst_path.read_bytes() == content

src_path.write_bytes(alt_content)

assert (
await copy_file(src_path, dst_path, chunksize=4, file_check_content=content_len)
is True
)
assert dst_path.read_bytes() == alt_content

src_path.write_bytes(alt_content_len)

assert (
await copy_file(src_path, dst_path, chunksize=4, file_check_content=content_len)
is True
)
assert dst_path.read_bytes() == alt_content_len

dst_path = tmp_path / "file-dst-2"
assert (
await copy_file(src_path, dst_path, chunksize=4, file_check_content=content_len)
is True
)
assert dst_path.read_bytes() == alt_content_len


@pytest.mark.asyncio
async def test_read_file(tmp_path):
content = "foo\x00bar\x00baz\x00çäø☺"
filename = tmp_path / "file"
filename.write_text(content, encoding="utf-8")

assert await read_file(filename, encoding="utf-8") == content
assert await read_file(str(filename), encoding="utf-8") == content
assert await read_file(str(filename).encode("utf-8"), encoding="utf-8") == content
assert await read_file(str(filename).encode("utf-8"), encoding="utf-8") == content

filename.write_text(content, encoding="utf-16")

assert await read_file(filename, encoding="utf-16") == content


@pytest.mark.asyncio
async def test_write_file(tmp_path):
content = "foo\x00bar\x00baz\x00çäø☺"
encoded_content_len = len(content.encode("utf-8"))

alt_content = "boo\x00bar\x00baz\x00çäø☺"
alt_content_len = "foo\x00bar"

filename = tmp_path / "file"
alt_filename = tmp_path / "file2"

assert await write_file(filename, content, encoding="utf-8") is True
assert filename.read_text(encoding="utf-8") == content

assert await write_file(filename, content, encoding="utf-8") is True
assert filename.read_text(encoding="utf-8") == content

assert (
await write_file(
filename,
content,
encoding="utf-8",
file_check_content=encoded_content_len - 1,
)
is True
)
assert filename.read_text(encoding="utf-8") == content

assert (
await write_file(
filename, content, encoding="utf-8", file_check_content=encoded_content_len
)
is False
)
assert filename.read_text(encoding="utf-8") == content

assert (
await write_file(
filename,
content,
encoding="utf-8",
file_check_content=encoded_content_len + 1,
)
is False
)
assert filename.read_text(encoding="utf-8") == content

assert (
await write_file(
filename,
alt_content,
encoding="utf-8",
file_check_content=encoded_content_len,
)
is True
)
assert filename.read_text(encoding="utf-8") == alt_content

assert (
await write_file(
alt_filename,
alt_content,
encoding="utf-8",
file_check_content=encoded_content_len,
)
is True
)
assert alt_filename.read_text(encoding="utf-8") == alt_content

assert (
await write_file(
alt_filename,
alt_content_len,
encoding="utf-8",
file_check_content=encoded_content_len,
)
is True
)
assert alt_filename.read_text(encoding="utf-8") == alt_content_len

0 comments on commit 57f8eb6

Please sign in to comment.