Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP
Browse files Browse the repository at this point in the history
Benjamin Lefaudeux committed Nov 12, 2024
1 parent 99d5af7 commit 76b0913
Showing 11 changed files with 405 additions and 93 deletions.
5 changes: 2 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
@@ -12,10 +12,9 @@ import (

func main() {
// Define flags
config := datago.DatagoConfig{}
config.SetDefaults()
config := datago.GetDatagoConfig()

sourceConfig := datago.GeneratorFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")}
sourceConfig := datago.SourceFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")}
sourceConfig.PageSize = 10
config.ImageConfig = datago.ImageTransformConfig{
DefaultImageSize: 1024,
28 changes: 17 additions & 11 deletions pkg/client.go
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ type ImageTransformConfig struct {
PreEncodeImages bool `json:"pre_encode_images"`
}

func (c *ImageTransformConfig) SetDefaults() {
func (c *ImageTransformConfig) setDefaults() {
c.DefaultImageSize = 512
c.DownsamplingRatio = 16
c.MinAspectRatio = 0.5
@@ -59,17 +59,23 @@ type DatagoConfig struct {
Concurrency int `json:"concurrency"`
}

func (c *DatagoConfig) SetDefaults() {
dbConfig := GeneratorDBConfig{}
dbConfig.SetDefaults()
func (c *DatagoConfig) setDefaults() {
dbConfig := SourceDBConfig{}
dbConfig.setDefaults()
c.SourceConfig = dbConfig

c.ImageConfig.SetDefaults()
c.ImageConfig.setDefaults()
c.PrefetchBufferSize = 64
c.SamplesBufferSize = 32
c.Concurrency = 64
}

func GetDatagoConfig() DatagoConfig {
config := DatagoConfig{}
config.setDefaults()
return config
}

func DatagoConfigFromJSON(jsonString string) DatagoConfig {
config := DatagoConfig{}
var tempConfig map[string]interface{}
@@ -85,14 +91,14 @@ func DatagoConfigFromJSON(jsonString string) DatagoConfig {

switch tempConfig["source_type"] {
case string(SourceTypeDB):
var dbConfig GeneratorDBConfig
var dbConfig SourceDBConfig
err = json.Unmarshal(sourceConfig, &dbConfig)
if err != nil {
log.Panicf("Error unmarshalling DB config: %v", err)
}
config.SourceConfig = dbConfig
case string(SourceTypeFileSystem):
var fsConfig GeneratorFileSystemConfig
var fsConfig SourceFileSystemConfig
err = json.Unmarshal(sourceConfig, &fsConfig)
if err != nil {
log.Panicf("Error unmarshalling FileSystem config: %v", err)
@@ -157,14 +163,14 @@ func GetClient(config DatagoConfig) *DatagoClient {
fmt.Println(reflect.TypeOf(config.SourceConfig))

switch config.SourceConfig.(type) {
case GeneratorDBConfig:
case SourceDBConfig:
fmt.Println("Creating a DB-backed dataloader")
dbConfig := config.SourceConfig.(GeneratorDBConfig)
dbConfig := config.SourceConfig.(SourceDBConfig)
generator = newDatagoGeneratorDB(dbConfig)
backend = BackendHTTP{config: &dbConfig, concurrency: config.Concurrency}
case GeneratorFileSystemConfig:
case SourceFileSystemConfig:
fmt.Println("Creating a FileSystem-backed dataloader")
fsConfig := config.SourceConfig.(GeneratorFileSystemConfig)
fsConfig := config.SourceConfig.(SourceFileSystemConfig)
generator = newDatagoGeneratorFileSystem(fsConfig)
backend = BackendFileSystem{config: &config, concurrency: config.Concurrency}
default:
16 changes: 11 additions & 5 deletions pkg/generator_db.go
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ type dbRequest struct {
}

// -- Define the front end goroutine ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type GeneratorDBConfig struct {
type SourceDBConfig struct {
DataSourceConfig
Sources string `json:"sources"`
RequireImages bool `json:"require_images"`
@@ -95,7 +95,7 @@ type GeneratorDBConfig struct {
RandomSampling bool `json:"random_sampling"`
}

func (c *GeneratorDBConfig) SetDefaults() {
func (c *SourceDBConfig) setDefaults() {
c.PageSize = 512
c.Rank = -1
c.WorldSize = -1
@@ -120,7 +120,7 @@ func (c *GeneratorDBConfig) SetDefaults() {
c.RandomSampling = false
}

func (c *GeneratorDBConfig) getDbRequest() dbRequest {
func (c *SourceDBConfig) getDbRequest() dbRequest {

fields := "attributes,image_direct_url"
if len(c.HasLatents) > 0 || len(c.HasMasks) > 0 {
@@ -176,12 +176,18 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest {
}
}

func GetSourceDBConfig() SourceDBConfig {
config := SourceDBConfig{}
config.setDefaults()
return config
}

type datagoGeneratorDB struct {
baseRequest http.Request
config GeneratorDBConfig
config SourceDBConfig
}

func newDatagoGeneratorDB(config GeneratorDBConfig) datagoGeneratorDB {
func newDatagoGeneratorDB(config SourceDBConfig) datagoGeneratorDB {
request := config.getDbRequest()

api_key := os.Getenv("DATAROOM_API_KEY")
8 changes: 4 additions & 4 deletions pkg/generator_filesystem.go
Original file line number Diff line number Diff line change
@@ -18,12 +18,12 @@ type fsSampleMetadata struct {
}

// -- Define the front end goroutine ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type GeneratorFileSystemConfig struct {
type SourceFileSystemConfig struct {
DataSourceConfig
RootPath string `json:"root_path"`
}

func (c *GeneratorFileSystemConfig) SetDefaults() {
func (c *SourceFileSystemConfig) setDefaults() {
c.PageSize = 512
c.Rank = 0
c.WorldSize = 1
@@ -33,10 +33,10 @@ func (c *GeneratorFileSystemConfig) SetDefaults() {

type datagoGeneratorFileSystem struct {
extensions set
config GeneratorFileSystemConfig
config SourceFileSystemConfig
}

func newDatagoGeneratorFileSystem(config GeneratorFileSystemConfig) datagoGeneratorFileSystem {
func newDatagoGeneratorFileSystem(config SourceFileSystemConfig) datagoGeneratorFileSystem {
supported_img_extensions := []string{".jpg", ".jpeg", ".png", ".JPEG", ".JPG", ".PNG"}
var extensionsMap = make(set)
for _, ext := range supported_img_extensions {
2 changes: 1 addition & 1 deletion pkg/serdes.go
Original file line number Diff line number Diff line change
@@ -185,7 +185,7 @@ func fetchImage(client *http.Client, url string, retries int, transform *ARAware
return nil, -1., err_report
}

func fetchSample(config *GeneratorDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample {
func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample {
// Per sample work:
// - fetch the raw payloads
// - deserialize / decode, depending on the types
2 changes: 1 addition & 1 deletion pkg/worker_http.go
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ import (
)

type BackendHTTP struct {
config *GeneratorDBConfig
config *SourceDBConfig
concurrency int
}

15 changes: 9 additions & 6 deletions python/benchmark_db.py
Original file line number Diff line number Diff line change
@@ -18,21 +18,24 @@ def benchmark(
test_latents: bool = typer.Option(True, help="Test latents"),
):
print(f"Running benchmark for {source} - {limit} samples")
client_config = datago.DatagoConfig()
client_config.SetDefaults()

# Get a generic client config
client_config = datago.GetDatagoConfig()
client_config.ImageConfig.CropAndResize = crop_and_resize

source_config = datago.GeneratorDBConfig()
source_config.SetDefaults()
# Specify the source parameters as you see fit
source_config = datago.GetSourceDBConfig()
source_config.Sources = source
source_config.RequireImages = require_images
source_config.RequireEmbeddings = require_embeddings
source_config.HasMasks = "segmentation_mask" if test_masks else ""
source_config.HasLatents = "caption_latent_t5xxl" if test_latents else ""
client_config.SourceConfig = source_config

# Get a new client instance, happy benchmarking
client_config.SourceConfig = source_config
client = datago.GetClient(client_config)
client.Start()

client.Start() # Optional, but good practice to start the client to reduce latency to first sample (while you're instantiating models for instance)
start = time.time()

# Make sure in the following that we compare apples to apples, meaning in that case
54 changes: 0 additions & 54 deletions python/tests/test_datago.py

This file was deleted.

326 changes: 326 additions & 0 deletions python/tests/test_datago_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
from datago import datago
import pytest
import os
import PIL


def get_test_source():
return os.getenv("DATAROOM_TEST_SOURCE")


def get_dataset(client_config):
client = datago.GetClient(client_config)

class Dataset:
def __init__(self, client):
self.client = client

def __iter__(self):
return self

def __next__(self):
return self.client.GetSample()

return Dataset(client)


def test_get_sample_db():
# Check that we can instantiate a client and get a sample, nothing more
client_config = datago.GetDatagoConfig()
client_config.SamplesBufferSize = 10

source_config = datago.GetSourceDBConfig()
source_config.Sources = get_test_source()
client_config.SourceConfig = source_config

client = datago.GetClient(client_config)
data = client.GetSample()
assert data.ID != ""


N_SAMPLES = 3


def test_sample_shapes():
client_config = datago.GetDatagoConfig()
client_config.SamplesBufferSize = 10

source_config = datago.GetSourceDBConfig()
source_config.Sources = get_test_source()
source_config.HasMasks = "segmentation_mask"
source_config.HasLatents = "image_crop,masked_image_crop"
source_config.HasAttributes = "caption_coca,caption_cogvlm"
dataset = get_dataset(client_config)

for sample in dataset:
assert len(sample["caption_coca"]) != len(
sample["caption_cogvlm"]
), "Caption lengths should not be equal"

latent_chw_masked = sample["latent_masked_image_crop_SDv2_896"].shape
assert (
len(latent_chw_masked) > 1
), "Latent of cropped masked image should be multi-dimensional"
if len(latent_chw_masked) > 1:
w, h = sample["masked_image_crop"].size
assert (w, h) == (
latent_chw_masked[2] * 8,
latent_chw_masked[1] * 8,
), "Masked Image latent / masked image size mismatch"

latent_chw = sample["latent_image_crop_SDv2_896"].shape
assert len(latent_chw) > 1, "Latent cropped image should be multi-dimensional"
if len(latent_chw) > 1:
w, h = sample["image_crop"].size
assert (w, h) == (
latent_chw[2] * 8,
latent_chw[1] * 8,
), "Image latent / image size mismatch"

assert (
len(sample["latent_image_crop_SDv2_896"]) > 1
), "Latent image should be multi-dimensional"
assert sample["segmentation_mask"].mode == "L", "Mask should be single channel"


def no_test_caption_and_image():
dataset = get_dataset()
dataset = DataRoomReader(
num_samples=N_SAMPLES,
sources="COYO",
lacks_attributes=[],
has_attributes=["caption_coca", "caption_cogvlm"],
has_latents=["latent_caption_cogvlm_t5xxl"],
has_masks=["segmentation_mask"],
)

def check_key_and_value(elem, key):
assert key in elem, f"{key} not found in sample"
if key in elem:
assert elem[key] is not None, f"{key} is None"

for sample in dataset:
check_key_and_value(sample, "image")
check_key_and_value(sample, "caption_cogvlm")
check_key_and_value(sample, "latent_caption_cogvlm_t5xxl")
del sample


def no_test_image_resize():
dataset = DataRoomReader(
num_samples=N_SAMPLES,
sources="COYO",
has_masks=["segmentation_mask"],
crop_and_resize=crop_and_resize,
)
sample = next(iter(dataset))

# Now check that all the images have the same size
ref_image = sample["image"]
for k, v in sample.items():
if isinstance(v, PIL.Image.Image):
assert (
v.size == ref_image.size
), f"Image edges should be {ref_image.size} - {k}"


def test_filtering():
# Check that asking for this or that latent works
dataset = DataRoomReader(
num_samples=N_SAMPLES,
sources="COYO",
has_latents=["latent_caption_cogvlm_t5xxl"],
)

sample = next(iter(dataset))
assert "latent_caption_cogvlm_t5xxl" in sample.keys(), "Latent not found"
assert "caption_coca" in sample.keys(), "attribute not found"
assert "caption_cogvlm" in sample.keys(), "attribute not found"

# Masks
dataset = DataRoomReader(
num_samples=N_SAMPLES,
sources="COYO",
has_masks=["segmentation_mask"],
)
sample = next(iter(dataset))

assert "segmentation_mask" in sample.keys(), "Mask not found"
assert isinstance(
sample["segmentation_mask"], PIL.Image.Image
), "Mask not correctly decoded"


def test_multiple_ranks():
# Check that there's no collision in the samples that the DataroomReader serves, if rank aware
world_size = 2
sources = "FREEPIK"
num_samples = 2048 # Make sure we go over the whole source

# Pull all the samples, build a set per reader
def get_samples_set(rank):
reader = DataRoomReader(
num_samples=num_samples,
sources=sources,
requires_image=False,
rank=rank,
world_size=world_size,
)

sample_ids = [sample["dataroom_id"] for sample in reader]
del reader
return sample_ids

samples = [get_samples_set(rank) for rank in range(world_size)]

# Check that the set intersection is null
for i in range(world_size - 1):
assert (
len(set(samples[i]).intersection(set(samples[i + 1]))) == 0
), "There are collisions in the samples"


def test_jpg_compression():
# Check that the images are compressed as expected
dataset = DataRoomReader(
num_samples=1,
sources="COYO",
has_masks=["segmentation_mask"],
has_latents=["latent_caption_cogvlm_t5xxl", "masked_image"],
crop_and_resize=crop_and_resize,
pre_encode_images=True,
)

sample = next(iter(dataset))

def decode_image(image_bytes):
return PIL.Image.open(image_bytes)

assert decode_image(sample["image"]).format == "JPEG", "Image should be JPEG"
assert decode_image(sample["masked_image"]).format == "JPEG", "Image should be JPEG"
assert (
decode_image(sample["segmentation_mask"]).format == "PNG"
), "Image should be PNG"

# decode_image(sample["image"]).save("test_image.jpg")
# decode_image(sample["masked_image"]).save("test_masked_image.jpg")
# decode_image(sample["segmentation_mask"]).save("test_segmentation_mask.png")


def test_has_tags():
dataset = DataRoomReader(
num_samples=1,
sources="COYO",
tags=["v4_trainset_hq"],
)

sample = next(iter(dataset))

assert "v4_trainset_hq" in sample["tags"]


def test_duplicate_state():
dataset = DataRoomReader(
num_samples=1,
fields=["duplicate_state"],
sources="COYO",
tags=["v4_trainset_hq"],
)
sample = next(iter(dataset))
assert "duplicate_state" in list(sample.keys()), "Duplicate state not found"
assert sample["duplicate_state"] in [
None,
1,
2,
], "Duplicate state should be None, 1 or 2"


def test_any_latents():
dataset = DataRoomReader(
num_samples=1,
sources="DataRoomWriter_tests",
has_latents="any",
max_short_edge=640,
)
sample = next(iter(dataset))
print(
"",
)
assert (
sample["dataroom_id"]
== "29ffdfc1c18cfb46f323a6226cdece573a14be8e16d4e6e87cc39e44161756da"
), "This test needs to be updated if the sample changes"
assert "latent_caption_coca_t5xxl" in sample.keys(), "Latent not found"
assert "latent_caption_cogvlm_t5xxl" in sample.keys(), "Latent not found"


def test_random_sampling():
dataset_1, dataset_2 = (
DataRoomReader(
num_samples=10,
sources="COYO",
random_sampling=True,
),
DataRoomReader(
num_samples=10,
sources="COYO",
random_sampling=True,
),
)
sample_ids_1 = [sample["dataroom_id"] for sample in dataset_1]
sample_ids_2 = [sample["dataroom_id"] for sample in dataset_2]
assert set(sample_ids_1) != set(
sample_ids_2
), "Two random samples should not be identical"


def test_pixel_count_filtering():
min_pixel_count = 1800000
max_pixel_count = 2000000
dataset = DataRoomReader(
num_samples=10,
sources="SOCIAL_ADS",
min_pixel_count=min_pixel_count,
max_pixel_count=max_pixel_count,
)
samples = [sample for sample in dataset]
pixel_counts = [
sample["image"].size[0] * sample["image"].size[1] for sample in samples
]
print(pixel_counts)
assert all(
[
min_pixel_count <= pixel_count <= max_pixel_count
for pixel_count in pixel_counts
]
), "All samples should have pixel count between min and max"


def test_sources__ne_filtering():
dataset = DataRoomReader(
num_samples=10,
sources=["COYO", "LAION_AESTHETICS"],
sources__ne="COYO",
)
samples = [sample for sample in dataset]
assert all(
[sample["source"] != "COYO" for sample in samples]
), "All samples should have pixel count between min and max"


def test_multiple_sources():
dataset = DataRoomReader(
num_samples=100,
sources=["COYO", "LAION_AESTHETICS"],
requires_image=False,
random_sampling=True,
)
sample_sources = set([sample["source"] for sample in dataset])
assert sample_sources == {
"COYO",
"LAION_AESTHETICS",
}, "Sources should be COYO and LAION_AESTHETICS"


if __name__ == "__main__":
pytest.main(["-v", __file__])
28 changes: 28 additions & 0 deletions python/tests/test_datago_filesystem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
from PIL import Image
from datago import datago


# FIXME: Would need to generate more fake data to test this
def no_test_get_sample_filesystem():
cwd = os.getcwd()

try:
# Dump a sample image to the filesystem
img = Image.new("RGB", (100, 100))
img.save(cwd + "/test.png")

# Check that we can instantiate a client and get a sample, nothing more
client_config = datago.GetDatagoConfig()
client_config.SourceType = "filesystem"
client_config.SamplesBufferSize = 1

source_config = datago.SourceFileSystemConfig()
source_config.RootPath = cwd
source_config.PageSize = 1

client = datago.GetClient(client_config, source_config)
data = client.GetSample()
assert data.ID != ""
finally:
os.remove(cwd + "/test.png")
14 changes: 6 additions & 8 deletions tests/client_test.go
Original file line number Diff line number Diff line change
@@ -15,15 +15,13 @@ func get_test_source() string {
}

func get_default_test_config() datago.DatagoConfig {
config := datago.DatagoConfig{}
config.SetDefaults()
config := datago.GetDatagoConfig()

db_config := datago.GeneratorDBConfig{}
db_config.SetDefaults()
db_config := datago.GetSourceDBConfig()
db_config.Sources = get_test_source()
db_config.PageSize = 32
config.SourceConfig = db_config

config.SourceConfig = db_config
return config
}

@@ -104,7 +102,7 @@ func TestExtraFields(t *testing.T) {
clientConfig := get_default_test_config()
clientConfig.SamplesBufferSize = 1

dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig)
dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.HasLatents = "masked_image"
dbConfig.HasMasks = "segmentation_mask"
clientConfig.SourceConfig = dbConfig
@@ -174,7 +172,7 @@ func TestImageBufferCompression(t *testing.T) {
clientConfig.SamplesBufferSize = 1
clientConfig.ImageConfig.PreEncodeImages = true

dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig)
dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.HasLatents = "masked_image"
dbConfig.HasMasks = "segmentation_mask"
clientConfig.SourceConfig = dbConfig
@@ -250,7 +248,7 @@ func TestRanks(t *testing.T) {
clientConfig := get_default_test_config()
clientConfig.SamplesBufferSize = 1

dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig)
dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.WorldSize = 2
dbConfig.Rank = 0
clientConfig.SourceConfig = dbConfig

0 comments on commit 76b0913

Please sign in to comment.