Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP
Browse files Browse the repository at this point in the history
Benjamin Lefaudeux committed Nov 12, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 99d5af7 commit 76b0913
Showing 11 changed files with 405 additions and 93 deletions.
5 changes: 2 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
@@ -12,10 +12,9 @@ import (

func main() {
// Define flags
config := datago.DatagoConfig{}
config.SetDefaults()
config := datago.GetDatagoConfig()

sourceConfig := datago.GeneratorFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")}
sourceConfig := datago.SourceFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")}
sourceConfig.PageSize = 10
config.ImageConfig = datago.ImageTransformConfig{
DefaultImageSize: 1024,
28 changes: 17 additions & 11 deletions pkg/client.go
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ type ImageTransformConfig struct {
PreEncodeImages bool `json:"pre_encode_images"`
}

func (c *ImageTransformConfig) SetDefaults() {
func (c *ImageTransformConfig) setDefaults() {
c.DefaultImageSize = 512
c.DownsamplingRatio = 16
c.MinAspectRatio = 0.5
@@ -59,17 +59,23 @@ type DatagoConfig struct {
Concurrency int `json:"concurrency"`
}

func (c *DatagoConfig) SetDefaults() {
dbConfig := GeneratorDBConfig{}
dbConfig.SetDefaults()
func (c *DatagoConfig) setDefaults() {
dbConfig := SourceDBConfig{}
dbConfig.setDefaults()
c.SourceConfig = dbConfig

c.ImageConfig.SetDefaults()
c.ImageConfig.setDefaults()
c.PrefetchBufferSize = 64
c.SamplesBufferSize = 32
c.Concurrency = 64
}

func GetDatagoConfig() DatagoConfig {
config := DatagoConfig{}
config.setDefaults()
return config
}

func DatagoConfigFromJSON(jsonString string) DatagoConfig {
config := DatagoConfig{}
var tempConfig map[string]interface{}
@@ -85,14 +91,14 @@ func DatagoConfigFromJSON(jsonString string) DatagoConfig {

switch tempConfig["source_type"] {
case string(SourceTypeDB):
var dbConfig GeneratorDBConfig
var dbConfig SourceDBConfig
err = json.Unmarshal(sourceConfig, &dbConfig)
if err != nil {
log.Panicf("Error unmarshalling DB config: %v", err)
}
config.SourceConfig = dbConfig
case string(SourceTypeFileSystem):
var fsConfig GeneratorFileSystemConfig
var fsConfig SourceFileSystemConfig
err = json.Unmarshal(sourceConfig, &fsConfig)
if err != nil {
log.Panicf("Error unmarshalling FileSystem config: %v", err)
@@ -157,14 +163,14 @@ func GetClient(config DatagoConfig) *DatagoClient {
fmt.Println(reflect.TypeOf(config.SourceConfig))

switch config.SourceConfig.(type) {
case GeneratorDBConfig:
case SourceDBConfig:
fmt.Println("Creating a DB-backed dataloader")
dbConfig := config.SourceConfig.(GeneratorDBConfig)
dbConfig := config.SourceConfig.(SourceDBConfig)
generator = newDatagoGeneratorDB(dbConfig)
backend = BackendHTTP{config: &dbConfig, concurrency: config.Concurrency}
case GeneratorFileSystemConfig:
case SourceFileSystemConfig:
fmt.Println("Creating a FileSystem-backed dataloader")
fsConfig := config.SourceConfig.(GeneratorFileSystemConfig)
fsConfig := config.SourceConfig.(SourceFileSystemConfig)
generator = newDatagoGeneratorFileSystem(fsConfig)
backend = BackendFileSystem{config: &config, concurrency: config.Concurrency}
default:
16 changes: 11 additions & 5 deletions pkg/generator_db.go
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ type dbRequest struct {
}

// -- Define the front end goroutine ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type GeneratorDBConfig struct {
type SourceDBConfig struct {
DataSourceConfig
Sources string `json:"sources"`
RequireImages bool `json:"require_images"`
@@ -95,7 +95,7 @@ type GeneratorDBConfig struct {
RandomSampling bool `json:"random_sampling"`
}

func (c *GeneratorDBConfig) SetDefaults() {
func (c *SourceDBConfig) setDefaults() {
c.PageSize = 512
c.Rank = -1
c.WorldSize = -1
@@ -120,7 +120,7 @@ func (c *GeneratorDBConfig) SetDefaults() {
c.RandomSampling = false
}

func (c *GeneratorDBConfig) getDbRequest() dbRequest {
func (c *SourceDBConfig) getDbRequest() dbRequest {

fields := "attributes,image_direct_url"
if len(c.HasLatents) > 0 || len(c.HasMasks) > 0 {
@@ -176,12 +176,18 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest {
}
}

func GetSourceDBConfig() SourceDBConfig {
config := SourceDBConfig{}
config.setDefaults()
return config
}

type datagoGeneratorDB struct {
baseRequest http.Request
config GeneratorDBConfig
config SourceDBConfig
}

func newDatagoGeneratorDB(config GeneratorDBConfig) datagoGeneratorDB {
func newDatagoGeneratorDB(config SourceDBConfig) datagoGeneratorDB {
request := config.getDbRequest()

api_key := os.Getenv("DATAROOM_API_KEY")
8 changes: 4 additions & 4 deletions pkg/generator_filesystem.go
Original file line number Diff line number Diff line change
@@ -18,12 +18,12 @@ type fsSampleMetadata struct {
}

// -- Define the front end goroutine ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type GeneratorFileSystemConfig struct {
type SourceFileSystemConfig struct {
DataSourceConfig
RootPath string `json:"root_path"`
}

func (c *GeneratorFileSystemConfig) SetDefaults() {
func (c *SourceFileSystemConfig) setDefaults() {
c.PageSize = 512
c.Rank = 0
c.WorldSize = 1
@@ -33,10 +33,10 @@ func (c *GeneratorFileSystemConfig) SetDefaults() {

type datagoGeneratorFileSystem struct {
extensions set
config GeneratorFileSystemConfig
config SourceFileSystemConfig
}

func newDatagoGeneratorFileSystem(config GeneratorFileSystemConfig) datagoGeneratorFileSystem {
func newDatagoGeneratorFileSystem(config SourceFileSystemConfig) datagoGeneratorFileSystem {
supported_img_extensions := []string{".jpg", ".jpeg", ".png", ".JPEG", ".JPG", ".PNG"}
var extensionsMap = make(set)
for _, ext := range supported_img_extensions {
2 changes: 1 addition & 1 deletion pkg/serdes.go
Original file line number Diff line number Diff line change
@@ -185,7 +185,7 @@ func fetchImage(client *http.Client, url string, retries int, transform *ARAware
return nil, -1., err_report
}

func fetchSample(config *GeneratorDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample {
func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample {
// Per sample work:
// - fetch the raw payloads
// - deserialize / decode, depending on the types
2 changes: 1 addition & 1 deletion pkg/worker_http.go
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ import (
)

type BackendHTTP struct {
config *GeneratorDBConfig
config *SourceDBConfig
concurrency int
}

15 changes: 9 additions & 6 deletions python/benchmark_db.py
Original file line number Diff line number Diff line change
@@ -18,21 +18,24 @@ def benchmark(
test_latents: bool = typer.Option(True, help="Test latents"),
):
print(f"Running benchmark for {source} - {limit} samples")
client_config = datago.DatagoConfig()
client_config.SetDefaults()

# Get a generic client config
client_config = datago.GetDatagoConfig()
client_config.ImageConfig.CropAndResize = crop_and_resize

source_config = datago.GeneratorDBConfig()
source_config.SetDefaults()
# Specify the source parameters as you see fit
source_config = datago.GetSourceDBConfig()
source_config.Sources = source
source_config.RequireImages = require_images
source_config.RequireEmbeddings = require_embeddings
source_config.HasMasks = "segmentation_mask" if test_masks else ""
source_config.HasLatents = "caption_latent_t5xxl" if test_latents else ""
client_config.SourceConfig = source_config

# Get a new client instance, happy benchmarking
client_config.SourceConfig = source_config
client = datago.GetClient(client_config)
client.Start()

client.Start() # Optional, but good practice to start the client to reduce latency to first sample (while you're instantiating models for instance)
start = time.time()

# Make sure in the following that we compare apples to apples, meaning in that case
54 changes: 0 additions & 54 deletions python/tests/test_datago.py

This file was deleted.

326 changes: 326 additions & 0 deletions python/tests/test_datago_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
from datago import datago
import pytest
import os
import PIL


def get_test_source():
return os.getenv("DATAROOM_TEST_SOURCE")


def get_dataset(client_config):
client = datago.GetClient(client_config)

class Dataset:
def __init__(self, client):
self.client = client

def __iter__(self):
return self

def __next__(self):
return self.client.GetSample()

return Dataset(client)


def test_get_sample_db():
# Check that we can instantiate a client and get a sample, nothing more
client_config = datago.GetDatagoConfig()
client_config.SamplesBufferSize = 10

source_config = datago.GetSourceDBConfig()
source_config.Sources = get_test_source()
client_config.SourceConfig = source_config

client = datago.GetClient(client_config)
data = client.GetSample()
assert data.ID != ""


N_SAMPLES = 3


def test_sample_shapes():
client_config = datago.GetDatagoConfig()
client_config.SamplesBufferSize = 10

source_config = datago.GetSourceDBConfig()
source_config.Sources = get_test_source()
source_config.HasMasks = "segmentation_mask"
source_config.HasLatents = "image_crop,masked_image_crop"
source_config.HasAttributes = "caption_coca,caption_cogvlm"
dataset = get_dataset(client_config)

for sample in dataset:
assert len(sample["caption_coca"]) != len(
sample["caption_cogvlm"]
), "Caption lengths should not be equal"

latent_chw_masked = sample["latent_masked_image_crop_SDv2_896"].shape
assert (
len(latent_chw_masked) > 1
), "Latent of cropped masked image should be multi-dimensional"
if len(latent_chw_masked) > 1:
w, h = sample["masked_image_crop"].size
assert (w, h) == (
latent_chw_masked[2] * 8,
latent_chw_masked[1] * 8,
), "Masked Image latent / masked image size mismatch"

latent_chw = sample["latent_image_crop_SDv2_896"].shape
assert len(latent_chw) > 1, "Latent cropped image should be multi-dimensional"
if len(latent_chw) > 1:
w, h = sample["image_crop"].size
assert (w, h) == (
latent_chw[2] * 8,
latent_chw[1] * 8,
), "Image latent / image size mismatch"

assert (
len(sample["latent_image_crop_SDv2_896"]) > 1
), "Latent image should be multi-dimensional"
assert sample["segmentation_mask"].mode == "L", "Mask should be single channel"


def no_test_caption_and_image():
dataset = get_dataset()
dataset = DataRoomReader(
num_samples=N_SAMPLES,
sources="COYO",
lacks_attributes=[],
has_attributes=["caption_coca", "caption_cogvlm"],
has_latents=["latent_caption_cogvlm_t5xxl"],
has_masks=["segmentation_mask"],
)

def check_key_and_value(elem, key):
assert key in elem, f"{key} not found in sample"
if key in elem:
assert elem[key] is not None, f"{key} is None"

for sample in dataset:
check_key_and_value(sample, "image")
check_key_and_value(sample, "caption_cogvlm")
check_key_and_value(sample, "latent_caption_cogvlm_t5xxl")
del sample


def no_test_image_resize():
dataset = DataRoomReader(
num_samples=N_SAMPLES,
sources="COYO",
has_masks=["segmentation_mask"],
crop_and_resize=crop_and_resize,
)
sample = next(iter(dataset))

# Now check that all the images have the same size
ref_image = sample["image"]
for k, v in sample.items():
if isinstance(v, PIL.Image.Image):
assert (
v.size == ref_image.size
), f"Image edges should be {ref_image.size} - {k}"


def test_filtering():
# Check that asking for this or that latent works
dataset = DataRoomReader(
num_samples=N_SAMPLES,
sources="COYO",
has_latents=["latent_caption_cogvlm_t5xxl"],
)

sample = next(iter(dataset))
assert "latent_caption_cogvlm_t5xxl" in sample.keys(), "Latent not found"
assert "caption_coca" in sample.keys(), "attribute not found"
assert "caption_cogvlm" in sample.keys(), "attribute not found"

# Masks
dataset = DataRoomReader(
num_samples=N_SAMPLES,
sources="COYO",
has_masks=["segmentation_mask"],
)
sample = next(iter(dataset))

assert "segmentation_mask" in sample.keys(), "Mask not found"
assert isinstance(
sample["segmentation_mask"], PIL.Image.Image
), "Mask not correctly decoded"


def test_multiple_ranks():
# Check that there's no collision in the samples that the DataroomReader serves, if rank aware
world_size = 2
sources = "FREEPIK"
num_samples = 2048 # Make sure we go over the whole source

# Pull all the samples, build a set per reader
def get_samples_set(rank):
reader = DataRoomReader(
num_samples=num_samples,
sources=sources,
requires_image=False,
rank=rank,
world_size=world_size,
)

sample_ids = [sample["dataroom_id"] for sample in reader]
del reader
return sample_ids

samples = [get_samples_set(rank) for rank in range(world_size)]

# Check that the set intersection is null
for i in range(world_size - 1):
assert (
len(set(samples[i]).intersection(set(samples[i + 1]))) == 0
), "There are collisions in the samples"


def test_jpg_compression():
# Check that the images are compressed as expected
dataset = DataRoomReader(
num_samples=1,
sources="COYO",
has_masks=["segmentation_mask"],
has_latents=["latent_caption_cogvlm_t5xxl", "masked_image"],
crop_and_resize=crop_and_resize,
pre_encode_images=True,
)

sample = next(iter(dataset))

def decode_image(image_bytes):
return PIL.Image.open(image_bytes)

assert decode_image(sample["image"]).format == "JPEG", "Image should be JPEG"
assert decode_image(sample["masked_image"]).format == "JPEG", "Image should be JPEG"
assert (
decode_image(sample["segmentation_mask"]).format == "PNG"
), "Image should be PNG"

# decode_image(sample["image"]).save("test_image.jpg")
# decode_image(sample["masked_image"]).save("test_masked_image.jpg")
# decode_image(sample["segmentation_mask"]).save("test_segmentation_mask.png")


def test_has_tags():
dataset = DataRoomReader(
num_samples=1,
sources="COYO",
tags=["v4_trainset_hq"],
)

sample = next(iter(dataset))

assert "v4_trainset_hq" in sample["tags"]


def test_duplicate_state():
dataset = DataRoomReader(
num_samples=1,
fields=["duplicate_state"],
sources="COYO",
tags=["v4_trainset_hq"],
)
sample = next(iter(dataset))
assert "duplicate_state" in list(sample.keys()), "Duplicate state not found"
assert sample["duplicate_state"] in [
None,
1,
2,
], "Duplicate state should be None, 1 or 2"


def test_any_latents():
dataset = DataRoomReader(
num_samples=1,
sources="DataRoomWriter_tests",
has_latents="any",
max_short_edge=640,
)
sample = next(iter(dataset))
print(
"",
)
assert (
sample["dataroom_id"]
== "29ffdfc1c18cfb46f323a6226cdece573a14be8e16d4e6e87cc39e44161756da"
), "This test needs to be updated if the sample changes"
assert "latent_caption_coca_t5xxl" in sample.keys(), "Latent not found"
assert "latent_caption_cogvlm_t5xxl" in sample.keys(), "Latent not found"


def test_random_sampling():
dataset_1, dataset_2 = (
DataRoomReader(
num_samples=10,
sources="COYO",
random_sampling=True,
),
DataRoomReader(
num_samples=10,
sources="COYO",
random_sampling=True,
),
)
sample_ids_1 = [sample["dataroom_id"] for sample in dataset_1]
sample_ids_2 = [sample["dataroom_id"] for sample in dataset_2]
assert set(sample_ids_1) != set(
sample_ids_2
), "Two random samples should not be identical"


def test_pixel_count_filtering():
min_pixel_count = 1800000
max_pixel_count = 2000000
dataset = DataRoomReader(
num_samples=10,
sources="SOCIAL_ADS",
min_pixel_count=min_pixel_count,
max_pixel_count=max_pixel_count,
)
samples = [sample for sample in dataset]
pixel_counts = [
sample["image"].size[0] * sample["image"].size[1] for sample in samples
]
print(pixel_counts)
assert all(
[
min_pixel_count <= pixel_count <= max_pixel_count
for pixel_count in pixel_counts
]
), "All samples should have pixel count between min and max"


def test_sources__ne_filtering():
dataset = DataRoomReader(
num_samples=10,
sources=["COYO", "LAION_AESTHETICS"],
sources__ne="COYO",
)
samples = [sample for sample in dataset]
assert all(
[sample["source"] != "COYO" for sample in samples]
), "All samples should have pixel count between min and max"


def test_multiple_sources():
dataset = DataRoomReader(
num_samples=100,
sources=["COYO", "LAION_AESTHETICS"],
requires_image=False,
random_sampling=True,
)
sample_sources = set([sample["source"] for sample in dataset])
assert sample_sources == {
"COYO",
"LAION_AESTHETICS",
}, "Sources should be COYO and LAION_AESTHETICS"


if __name__ == "__main__":
pytest.main(["-v", __file__])
28 changes: 28 additions & 0 deletions python/tests/test_datago_filesystem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
from PIL import Image
from datago import datago


# FIXME: Would need to generate more fake data to test this
def no_test_get_sample_filesystem():
cwd = os.getcwd()

try:
# Dump a sample image to the filesystem
img = Image.new("RGB", (100, 100))
img.save(cwd + "/test.png")

# Check that we can instantiate a client and get a sample, nothing more
client_config = datago.GetDatagoConfig()
client_config.SourceType = "filesystem"
client_config.SamplesBufferSize = 1

source_config = datago.SourceFileSystemConfig()
source_config.RootPath = cwd
source_config.PageSize = 1

client = datago.GetClient(client_config, source_config)
data = client.GetSample()
assert data.ID != ""
finally:
os.remove(cwd + "/test.png")
14 changes: 6 additions & 8 deletions tests/client_test.go
Original file line number Diff line number Diff line change
@@ -15,15 +15,13 @@ func get_test_source() string {
}

func get_default_test_config() datago.DatagoConfig {
config := datago.DatagoConfig{}
config.SetDefaults()
config := datago.GetDatagoConfig()

db_config := datago.GeneratorDBConfig{}
db_config.SetDefaults()
db_config := datago.GetSourceDBConfig()
db_config.Sources = get_test_source()
db_config.PageSize = 32
config.SourceConfig = db_config

config.SourceConfig = db_config
return config
}

@@ -104,7 +102,7 @@ func TestExtraFields(t *testing.T) {
clientConfig := get_default_test_config()
clientConfig.SamplesBufferSize = 1

dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig)
dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.HasLatents = "masked_image"
dbConfig.HasMasks = "segmentation_mask"
clientConfig.SourceConfig = dbConfig
@@ -174,7 +172,7 @@ func TestImageBufferCompression(t *testing.T) {
clientConfig.SamplesBufferSize = 1
clientConfig.ImageConfig.PreEncodeImages = true

dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig)
dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.HasLatents = "masked_image"
dbConfig.HasMasks = "segmentation_mask"
clientConfig.SourceConfig = dbConfig
@@ -250,7 +248,7 @@ func TestRanks(t *testing.T) {
clientConfig := get_default_test_config()
clientConfig.SamplesBufferSize = 1

dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig)
dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.WorldSize = 2
dbConfig.Rank = 0
clientConfig.SourceConfig = dbConfig

0 comments on commit 76b0913

Please sign in to comment.