Skip to content

Commit

Permalink
Make sure not to overwrite existing images when exporting results
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Jan 1, 2024
1 parent 466886c commit 18b7473
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
1 change: 1 addition & 0 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def _save_job_result(model: Model, job: Job | None, index: int):
prompt = util.sanitize_prompt(job.params.prompt)
path = Path(model.document.filename)
path = path.parent / f"{path.stem}-generated-{timestamp}-{index}-{prompt}.webp"
path = util.find_unused_path(path)
base_image = model._get_current_image(Bounds(0, 0, *model.document.extent))
result_image = job.results[index]
base_image.draw_image(result_image, job.params.bounds.offset)
Expand Down
12 changes: 12 additions & 0 deletions ai_diffusion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def sanitize_prompt(prompt: str):
return "".join(c for c in prompt if c.isalnum() or c in " _-")


def find_unused_path(path: Path):
"""Finds an unused path by appending a number to the filename"""
if not path.exists():
return path
stem = path.stem
ext = path.suffix
i = 1
while (new_path := path.with_name(f"{stem}-{i}{ext}")).exists():
i += 1
return new_path


def get_path_dict(paths: Sequence[str | Path]) -> dict:
"""Builds a tree like structure out of a list of paths"""

Expand Down
18 changes: 17 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from tempfile import TemporaryDirectory
from pathlib import Path
from ai_diffusion.util import batched, get_path_dict, ZipFile
from ai_diffusion.util import batched, sanitize_prompt, find_unused_path, get_path_dict, ZipFile


def test_batched():
Expand All @@ -11,6 +11,22 @@ def test_batched():
assert list(batched(iterable, n)) == expected_output


def test_sanitize_prompt():
assert sanitize_prompt("") == "no prompt"
assert sanitize_prompt("a" * 50) == "a" * 40
assert sanitize_prompt("bla\nblub\n<x:2:4> (neg) [pos]") == "blablubx24 neg pos"


def test_unused_path():
with TemporaryDirectory() as dir:
file = Path(dir) / "test.txt"
assert find_unused_path(file) == file
file.touch()
assert find_unused_path(file) == Path(dir) / "test-1.txt"
(Path(dir) / "test-1.txt").touch()
assert find_unused_path(file) == Path(dir) / "test-2.txt"


def test_path_dict():
paths = [
"f1.txt",
Expand Down

0 comments on commit 18b7473

Please sign in to comment.