From 18b74732daab446d0912876a73b798755cf0107b Mon Sep 17 00:00:00 2001 From: Acly Date: Mon, 1 Jan 2024 12:14:48 +0100 Subject: [PATCH] Make sure not to overwrite existing images when exporting results --- ai_diffusion/model.py | 1 + ai_diffusion/util.py | 12 ++++++++++++ tests/test_util.py | 18 +++++++++++++++++- 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 2bdd8c6e1..f7beeb629 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -500,6 +500,7 @@ def _save_job_result(model: Model, job: Job | None, index: int): prompt = util.sanitize_prompt(job.params.prompt) path = Path(model.document.filename) path = path.parent / f"{path.stem}-generated-{timestamp}-{index}-{prompt}.webp" + path = util.find_unused_path(path) base_image = model._get_current_image(Bounds(0, 0, *model.document.extent)) result_image = job.results[index] base_image.draw_image(result_image, job.params.bounds.offset) diff --git a/ai_diffusion/util.py b/ai_diffusion/util.py index ea838be1a..e8e9cd798 100644 --- a/ai_diffusion/util.py +++ b/ai_diffusion/util.py @@ -64,6 +64,18 @@ def sanitize_prompt(prompt: str): return "".join(c for c in prompt if c.isalnum() or c in " _-") +def find_unused_path(path: Path): + """Finds an unused path by appending a number to the filename""" + if not path.exists(): + return path + stem = path.stem + ext = path.suffix + i = 1 + while (new_path := path.with_name(f"{stem}-{i}{ext}")).exists(): + i += 1 + return new_path + + def get_path_dict(paths: Sequence[str | Path]) -> dict: """Builds a tree like structure out of a list of paths""" diff --git a/tests/test_util.py b/tests/test_util.py index 037689906..c98d1653a 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,7 +1,7 @@ import pytest from tempfile import TemporaryDirectory from pathlib import Path -from ai_diffusion.util import batched, get_path_dict, ZipFile +from ai_diffusion.util import batched, sanitize_prompt, find_unused_path, get_path_dict, ZipFile def test_batched(): @@ -11,6 +11,22 @@ def test_batched(): assert list(batched(iterable, n)) == expected_output +def test_sanitize_prompt(): + assert sanitize_prompt("") == "no prompt" + assert sanitize_prompt("a" * 50) == "a" * 40 + assert sanitize_prompt("bla\nblub\n (neg) [pos]") == "blablubx24 neg pos" + + +def test_unused_path(): + with TemporaryDirectory() as dir: + file = Path(dir) / "test.txt" + assert find_unused_path(file) == file + file.touch() + assert find_unused_path(file) == Path(dir) / "test-1.txt" + (Path(dir) / "test-1.txt").touch() + assert find_unused_path(file) == Path(dir) / "test-2.txt" + + def test_path_dict(): paths = [ "f1.txt",