Skip to content

Commit

Permalink
Merge pull request #141 from invoke-ai/fix-masks-in-ui
Browse files Browse the repository at this point in the history
Update ImageCaptionJsonlDataset.save_jsonl() to include the mask column.
  • Loading branch information
RyanJDick authored Jun 5, 2024
2 parents 734a442 + c01449b commit 70d8e1c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ def __init__(
def save_jsonl(self):
data = []
for example in self.examples:
data.append({self._image_column: example.image_path, self._caption_column: example.caption})
data.append(
{
self._image_column: example.image_path,
self._caption_column: example.caption,
MASK_COLUMN_DEFAULT: example.mask_path,
}
)
save_jsonl(data, self._jsonl_path)

def _get_image_path(self, idx: int) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import shutil
from pathlib import Path

import PIL.Image

from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset
from invoke_training._shared.utils.jsonl import load_jsonl

from ..dataset_fixtures import image_caption_jsonl # noqa: F401

Expand Down Expand Up @@ -52,3 +56,21 @@ def test_image_caption_jsonl_dataset_get_image_dimensions(image_caption_jsonl):
image_dims = dataset.get_image_dimensions()

assert len(image_dims) == len(dataset)


def test_image_caption_jsonl_dataset_save_jsonl(image_caption_jsonl, tmp_path: Path): # noqa: F811
# Create a copy of the image_caption_jsonl file to avoid modifying the original file.
image_caption_jsonl_copy = tmp_path / "test.jsonl"
shutil.copy(image_caption_jsonl, image_caption_jsonl_copy)

# Load the dataset from the copied jsonl file.
dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl))

# Save the dataset to a new jsonl file.
dataset.save_jsonl()

# Verify that the roundtrip was successful.
assert image_caption_jsonl != image_caption_jsonl_copy
original_jsonl = load_jsonl(image_caption_jsonl)
roundtrip_jsonl = load_jsonl(image_caption_jsonl_copy)
assert original_jsonl == roundtrip_jsonl

0 comments on commit 70d8e1c

Please sign in to comment.