-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Benjamin Lefaudeux
committed
Nov 12, 2024
1 parent
99d5af7
commit 76b0913
Showing
11 changed files
with
405 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,326 @@ | ||
from datago import datago | ||
import pytest | ||
import os | ||
import PIL | ||
|
||
|
||
def get_test_source(): | ||
return os.getenv("DATAROOM_TEST_SOURCE") | ||
|
||
|
||
def get_dataset(client_config): | ||
client = datago.GetClient(client_config) | ||
|
||
class Dataset: | ||
def __init__(self, client): | ||
self.client = client | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
return self.client.GetSample() | ||
|
||
return Dataset(client) | ||
|
||
|
||
def test_get_sample_db(): | ||
# Check that we can instantiate a client and get a sample, nothing more | ||
client_config = datago.GetDatagoConfig() | ||
client_config.SamplesBufferSize = 10 | ||
|
||
source_config = datago.GetSourceDBConfig() | ||
source_config.Sources = get_test_source() | ||
client_config.SourceConfig = source_config | ||
|
||
client = datago.GetClient(client_config) | ||
data = client.GetSample() | ||
assert data.ID != "" | ||
|
||
|
||
N_SAMPLES = 3 | ||
|
||
|
||
def test_sample_shapes(): | ||
client_config = datago.GetDatagoConfig() | ||
client_config.SamplesBufferSize = 10 | ||
|
||
source_config = datago.GetSourceDBConfig() | ||
source_config.Sources = get_test_source() | ||
source_config.HasMasks = "segmentation_mask" | ||
source_config.HasLatents = "image_crop,masked_image_crop" | ||
source_config.HasAttributes = "caption_coca,caption_cogvlm" | ||
dataset = get_dataset(client_config) | ||
|
||
for sample in dataset: | ||
assert len(sample["caption_coca"]) != len( | ||
sample["caption_cogvlm"] | ||
), "Caption lengths should not be equal" | ||
|
||
latent_chw_masked = sample["latent_masked_image_crop_SDv2_896"].shape | ||
assert ( | ||
len(latent_chw_masked) > 1 | ||
), "Latent of cropped masked image should be multi-dimensional" | ||
if len(latent_chw_masked) > 1: | ||
w, h = sample["masked_image_crop"].size | ||
assert (w, h) == ( | ||
latent_chw_masked[2] * 8, | ||
latent_chw_masked[1] * 8, | ||
), "Masked Image latent / masked image size mismatch" | ||
|
||
latent_chw = sample["latent_image_crop_SDv2_896"].shape | ||
assert len(latent_chw) > 1, "Latent cropped image should be multi-dimensional" | ||
if len(latent_chw) > 1: | ||
w, h = sample["image_crop"].size | ||
assert (w, h) == ( | ||
latent_chw[2] * 8, | ||
latent_chw[1] * 8, | ||
), "Image latent / image size mismatch" | ||
|
||
assert ( | ||
len(sample["latent_image_crop_SDv2_896"]) > 1 | ||
), "Latent image should be multi-dimensional" | ||
assert sample["segmentation_mask"].mode == "L", "Mask should be single channel" | ||
|
||
|
||
def no_test_caption_and_image(): | ||
dataset = get_dataset() | ||
dataset = DataRoomReader( | ||
num_samples=N_SAMPLES, | ||
sources="COYO", | ||
lacks_attributes=[], | ||
has_attributes=["caption_coca", "caption_cogvlm"], | ||
has_latents=["latent_caption_cogvlm_t5xxl"], | ||
has_masks=["segmentation_mask"], | ||
) | ||
|
||
def check_key_and_value(elem, key): | ||
assert key in elem, f"{key} not found in sample" | ||
if key in elem: | ||
assert elem[key] is not None, f"{key} is None" | ||
|
||
for sample in dataset: | ||
check_key_and_value(sample, "image") | ||
check_key_and_value(sample, "caption_cogvlm") | ||
check_key_and_value(sample, "latent_caption_cogvlm_t5xxl") | ||
del sample | ||
|
||
|
||
def no_test_image_resize(): | ||
dataset = DataRoomReader( | ||
num_samples=N_SAMPLES, | ||
sources="COYO", | ||
has_masks=["segmentation_mask"], | ||
crop_and_resize=crop_and_resize, | ||
) | ||
sample = next(iter(dataset)) | ||
|
||
# Now check that all the images have the same size | ||
ref_image = sample["image"] | ||
for k, v in sample.items(): | ||
if isinstance(v, PIL.Image.Image): | ||
assert ( | ||
v.size == ref_image.size | ||
), f"Image edges should be {ref_image.size} - {k}" | ||
|
||
|
||
def test_filtering(): | ||
# Check that asking for this or that latent works | ||
dataset = DataRoomReader( | ||
num_samples=N_SAMPLES, | ||
sources="COYO", | ||
has_latents=["latent_caption_cogvlm_t5xxl"], | ||
) | ||
|
||
sample = next(iter(dataset)) | ||
assert "latent_caption_cogvlm_t5xxl" in sample.keys(), "Latent not found" | ||
assert "caption_coca" in sample.keys(), "attribute not found" | ||
assert "caption_cogvlm" in sample.keys(), "attribute not found" | ||
|
||
# Masks | ||
dataset = DataRoomReader( | ||
num_samples=N_SAMPLES, | ||
sources="COYO", | ||
has_masks=["segmentation_mask"], | ||
) | ||
sample = next(iter(dataset)) | ||
|
||
assert "segmentation_mask" in sample.keys(), "Mask not found" | ||
assert isinstance( | ||
sample["segmentation_mask"], PIL.Image.Image | ||
), "Mask not correctly decoded" | ||
|
||
|
||
def test_multiple_ranks(): | ||
# Check that there's no collision in the samples that the DataroomReader serves, if rank aware | ||
world_size = 2 | ||
sources = "FREEPIK" | ||
num_samples = 2048 # Make sure we go over the whole source | ||
|
||
# Pull all the samples, build a set per reader | ||
def get_samples_set(rank): | ||
reader = DataRoomReader( | ||
num_samples=num_samples, | ||
sources=sources, | ||
requires_image=False, | ||
rank=rank, | ||
world_size=world_size, | ||
) | ||
|
||
sample_ids = [sample["dataroom_id"] for sample in reader] | ||
del reader | ||
return sample_ids | ||
|
||
samples = [get_samples_set(rank) for rank in range(world_size)] | ||
|
||
# Check that the set intersection is null | ||
for i in range(world_size - 1): | ||
assert ( | ||
len(set(samples[i]).intersection(set(samples[i + 1]))) == 0 | ||
), "There are collisions in the samples" | ||
|
||
|
||
def test_jpg_compression(): | ||
# Check that the images are compressed as expected | ||
dataset = DataRoomReader( | ||
num_samples=1, | ||
sources="COYO", | ||
has_masks=["segmentation_mask"], | ||
has_latents=["latent_caption_cogvlm_t5xxl", "masked_image"], | ||
crop_and_resize=crop_and_resize, | ||
pre_encode_images=True, | ||
) | ||
|
||
sample = next(iter(dataset)) | ||
|
||
def decode_image(image_bytes): | ||
return PIL.Image.open(image_bytes) | ||
|
||
assert decode_image(sample["image"]).format == "JPEG", "Image should be JPEG" | ||
assert decode_image(sample["masked_image"]).format == "JPEG", "Image should be JPEG" | ||
assert ( | ||
decode_image(sample["segmentation_mask"]).format == "PNG" | ||
), "Image should be PNG" | ||
|
||
# decode_image(sample["image"]).save("test_image.jpg") | ||
# decode_image(sample["masked_image"]).save("test_masked_image.jpg") | ||
# decode_image(sample["segmentation_mask"]).save("test_segmentation_mask.png") | ||
|
||
|
||
def test_has_tags(): | ||
dataset = DataRoomReader( | ||
num_samples=1, | ||
sources="COYO", | ||
tags=["v4_trainset_hq"], | ||
) | ||
|
||
sample = next(iter(dataset)) | ||
|
||
assert "v4_trainset_hq" in sample["tags"] | ||
|
||
|
||
def test_duplicate_state(): | ||
dataset = DataRoomReader( | ||
num_samples=1, | ||
fields=["duplicate_state"], | ||
sources="COYO", | ||
tags=["v4_trainset_hq"], | ||
) | ||
sample = next(iter(dataset)) | ||
assert "duplicate_state" in list(sample.keys()), "Duplicate state not found" | ||
assert sample["duplicate_state"] in [ | ||
None, | ||
1, | ||
2, | ||
], "Duplicate state should be None, 1 or 2" | ||
|
||
|
||
def test_any_latents(): | ||
dataset = DataRoomReader( | ||
num_samples=1, | ||
sources="DataRoomWriter_tests", | ||
has_latents="any", | ||
max_short_edge=640, | ||
) | ||
sample = next(iter(dataset)) | ||
print( | ||
"", | ||
) | ||
assert ( | ||
sample["dataroom_id"] | ||
== "29ffdfc1c18cfb46f323a6226cdece573a14be8e16d4e6e87cc39e44161756da" | ||
), "This test needs to be updated if the sample changes" | ||
assert "latent_caption_coca_t5xxl" in sample.keys(), "Latent not found" | ||
assert "latent_caption_cogvlm_t5xxl" in sample.keys(), "Latent not found" | ||
|
||
|
||
def test_random_sampling(): | ||
dataset_1, dataset_2 = ( | ||
DataRoomReader( | ||
num_samples=10, | ||
sources="COYO", | ||
random_sampling=True, | ||
), | ||
DataRoomReader( | ||
num_samples=10, | ||
sources="COYO", | ||
random_sampling=True, | ||
), | ||
) | ||
sample_ids_1 = [sample["dataroom_id"] for sample in dataset_1] | ||
sample_ids_2 = [sample["dataroom_id"] for sample in dataset_2] | ||
assert set(sample_ids_1) != set( | ||
sample_ids_2 | ||
), "Two random samples should not be identical" | ||
|
||
|
||
def test_pixel_count_filtering(): | ||
min_pixel_count = 1800000 | ||
max_pixel_count = 2000000 | ||
dataset = DataRoomReader( | ||
num_samples=10, | ||
sources="SOCIAL_ADS", | ||
min_pixel_count=min_pixel_count, | ||
max_pixel_count=max_pixel_count, | ||
) | ||
samples = [sample for sample in dataset] | ||
pixel_counts = [ | ||
sample["image"].size[0] * sample["image"].size[1] for sample in samples | ||
] | ||
print(pixel_counts) | ||
assert all( | ||
[ | ||
min_pixel_count <= pixel_count <= max_pixel_count | ||
for pixel_count in pixel_counts | ||
] | ||
), "All samples should have pixel count between min and max" | ||
|
||
|
||
def test_sources__ne_filtering(): | ||
dataset = DataRoomReader( | ||
num_samples=10, | ||
sources=["COYO", "LAION_AESTHETICS"], | ||
sources__ne="COYO", | ||
) | ||
samples = [sample for sample in dataset] | ||
assert all( | ||
[sample["source"] != "COYO" for sample in samples] | ||
), "All samples should have pixel count between min and max" | ||
|
||
|
||
def test_multiple_sources(): | ||
dataset = DataRoomReader( | ||
num_samples=100, | ||
sources=["COYO", "LAION_AESTHETICS"], | ||
requires_image=False, | ||
random_sampling=True, | ||
) | ||
sample_sources = set([sample["source"] for sample in dataset]) | ||
assert sample_sources == { | ||
"COYO", | ||
"LAION_AESTHETICS", | ||
}, "Sources should be COYO and LAION_AESTHETICS" | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main(["-v", __file__]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
from PIL import Image | ||
from datago import datago | ||
|
||
|
||
# FIXME: Would need to generate more fake data to test this | ||
def no_test_get_sample_filesystem(): | ||
cwd = os.getcwd() | ||
|
||
try: | ||
# Dump a sample image to the filesystem | ||
img = Image.new("RGB", (100, 100)) | ||
img.save(cwd + "/test.png") | ||
|
||
# Check that we can instantiate a client and get a sample, nothing more | ||
client_config = datago.GetDatagoConfig() | ||
client_config.SourceType = "filesystem" | ||
client_config.SamplesBufferSize = 1 | ||
|
||
source_config = datago.SourceFileSystemConfig() | ||
source_config.RootPath = cwd | ||
source_config.PageSize = 1 | ||
|
||
client = datago.GetClient(client_config, source_config) | ||
data = client.GetSample() | ||
assert data.ID != "" | ||
finally: | ||
os.remove(cwd + "/test.png") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters