Skip to content

Commit

Permalink
Another round of cleanup, getting more presentable / could be open fo…
Browse files Browse the repository at this point in the history
…r review
  • Loading branch information
blefaudeux committed Nov 2, 2024
1 parent 779f86e commit e05a2c5
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 126 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ go.work.sum

# env file
.env

.vscode
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ Datago is rank and world-size aware, in which case the samples are dispatched de
```python
from datago import datago

config = datago.GetDefaultConfig()
config = datago.DatagoConfig()
config.SetDefaults()

# Check out the config fields, plenty of option to specify your DB query and optimize performance

client = datago.GetClient(config)
Expand Down
10 changes: 6 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ func main() {
client_config := datago.GetDefaultConfig()
client_config.SourceType = datago.SourceTypeFileSystem
client_config.SourceConfig = datago.GeneratorFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM"), PageSize: 10}
client_config.DefaultImageSize = 1024
client_config.DownsamplingRatio = 32
client_config.ImageConfig = datago.ImageTransformConfig{
DefaultImageSize: 1024,
DownsamplingRatio: 32,
CropAndResize: *flag.Bool("crop_and_resize", false, "Whether to crop and resize the images and masks"),
}

client_config.CropAndResize = *flag.Bool("crop_and_resize", false, "Whether to crop and resize the images and masks")
client_config.ConcurrentDownloads = *flag.Int("concurrency", 64, "The number of concurrent http requests to make")
client_config.Concurrency = *flag.Int("concurrency", 64, "The number of concurrent http requests to make")
client_config.PrefetchBufferSize = *flag.Int("item_fetch_buffer", 256, "The number of items to pre-load")
client_config.SamplesBufferSize = *flag.Int("item_ready_buffer", 128, "The number of items ready to be served")

Expand Down
2 changes: 1 addition & 1 deletion generate_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ python_version=$(python3 --version 2>&1 | awk '{print $2}' | cut -d. -f1,2)
echo "Building package for python" $python_version

# Setup where the python package will be copied
DESTINATION="../../../python_$python_version"
DESTINATION="../build/python_$python_version"
rm -rf $DESTINATION

# Build the python package via the gopy toolchain
Expand Down
100 changes: 45 additions & 55 deletions pkg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,51 @@ const (

type DataSourceConfig interface{}

type ImageTransformConfig struct {
CropAndResize bool
DefaultImageSize int
DownsamplingRatio int
MinAspectRatio float64
MaxAspectRatio float64
PreEncodeImages bool
}

func (c *ImageTransformConfig) SetDefaults() {
c.DefaultImageSize = 512
c.DownsamplingRatio = 16
c.MinAspectRatio = 0.5
c.MaxAspectRatio = 2.0
c.PreEncodeImages = false
}

type DatagoConfig struct {
SourceType DatagoSourceType
SourceConfig DataSourceConfig
CropAndResize bool
DefaultImageSize int
DownsamplingRatio int
MinAspectRatio float64
MaxAspectRatio float64
PreEncodeImages bool
PrefetchBufferSize int
SamplesBufferSize int
ConcurrentDownloads int
PageSize int
SourceType DatagoSourceType `default:"DB"`
SourceConfig DataSourceConfig
ImageConfig ImageTransformConfig
PrefetchBufferSize int
SamplesBufferSize int
Concurrency int
}

type DatagoClient struct {
concurrency int
func (c *DatagoConfig) SetDefaults() {
c.SourceType = SourceTypeDB

dbConfig := GeneratorDBConfig{}
dbConfig.SetDefaults()
c.SourceConfig = dbConfig

c.ImageConfig.SetDefaults()
c.PrefetchBufferSize = 64
c.SamplesBufferSize = 32
c.Concurrency = 64
}

type DatagoClient struct {
context context.Context
waitGroup *sync.WaitGroup
cancel context.CancelFunc

// Online transform parameters
crop_and_resize bool
default_image_size int
downsampling_ratio int
min_aspect_ratio float64
max_aspect_ratio float64
pre_encode_images bool
ImageConfig ImageTransformConfig

// Flexible generator, backend and dispatch goroutines
generator Generator
Expand All @@ -67,27 +83,8 @@ type DatagoClient struct {

// -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

func GetDefaultConfig() DatagoConfig {
dbConfig := GetDefaultDBConfig()

return DatagoConfig{
SourceType: SourceTypeDB,
SourceConfig: dbConfig,
CropAndResize: false,
DefaultImageSize: 512,
DownsamplingRatio: 16,
MinAspectRatio: 0.5,
MaxAspectRatio: 2.0,
PreEncodeImages: false,
PrefetchBufferSize: 8,
SamplesBufferSize: 8,
PageSize: 20, // 1000 for a vectorDB, make this a default which depends on the source type
}
}

// Create a new Dataroom Client
func GetClient(config DatagoConfig) *DatagoClient {

// Create the generator and backend
var generator Generator
var backend Backend
Expand All @@ -96,30 +93,23 @@ func GetClient(config DatagoConfig) *DatagoClient {
fmt.Println("Creating a DB-backed dataloader")
db_config := config.SourceConfig.(GeneratorDBConfig)
generator = newDatagoGeneratorDB(db_config)
backend = BackendHTTP{config: &db_config}
backend = BackendHTTP{config: &db_config, concurrency: config.Concurrency}
} else if config.SourceType == SourceTypeFileSystem {
fmt.Println("Creating a FileSystem-backed dataloader")
fs_config := config.SourceConfig.(GeneratorFileSystemConfig)
generator = newDatagoGeneratorFileSystem(fs_config)
backend = BackendFileSystem{config: &config}
backend = BackendFileSystem{config: &config, concurrency: config.Concurrency}
} else {
// TODO: Handle other sources
log.Panic("Unsupported source type at the moment")
}

// Create the client
client := &DatagoClient{
concurrency: config.ConcurrentDownloads,
chanPages: make(chan Pages, 2),
chanSampleMetadata: make(chan SampleDataPointers, config.PrefetchBufferSize),
chanSamples: make(chan Sample, config.SamplesBufferSize),

crop_and_resize: config.CropAndResize,
default_image_size: config.DefaultImageSize,
downsampling_ratio: config.DownsamplingRatio,
min_aspect_ratio: config.MinAspectRatio,
max_aspect_ratio: config.MaxAspectRatio,
pre_encode_images: config.PreEncodeImages,
ImageConfig: config.ImageConfig,
context: nil,
cancel: nil,
waitGroup: nil,
Expand Down Expand Up @@ -162,13 +152,13 @@ func (c *DatagoClient) Start() {
// Optionally crop and resize the images and masks on the fly
var arAwareTransform *ARAwareTransform = nil

if c.crop_and_resize {
if c.ImageConfig.CropAndResize {
fmt.Println("Cropping and resizing images")
fmt.Println("Base image size | downsampling ratio | min | max:", c.default_image_size, c.downsampling_ratio, c.min_aspect_ratio, c.max_aspect_ratio)
arAwareTransform = newARAwareTransform(c.default_image_size, c.downsampling_ratio, c.min_aspect_ratio, c.max_aspect_ratio)
fmt.Println("Base image size | downsampling ratio | min | max:", c.ImageConfig.DefaultImageSize, c.ImageConfig.DownsamplingRatio, c.ImageConfig.MinAspectRatio, c.ImageConfig.MaxAspectRatio)
arAwareTransform = newARAwareTransform(c.ImageConfig)
}

if c.pre_encode_images {
if c.ImageConfig.PreEncodeImages {
fmt.Println("Pre-encoding images, we'll return serialized JPG and PNG bytes")
}

Expand All @@ -191,7 +181,7 @@ func (c *DatagoClient) Start() {
wg.Add(1)
go func() {
defer wg.Done()
c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform, c.pre_encode_images) // Fetch the payloads and and deserialize them
c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform, c.ImageConfig.PreEncodeImages) // Fetch the payloads and and deserialize them
}()

c.waitGroup = &wg
Expand Down
67 changes: 31 additions & 36 deletions pkg/generator_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,37 @@ type dbRequest struct {

// -- Define the front end goroutine ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type GeneratorDBConfig struct {
// Request parameters
Sources string
RequireImages bool
RequireEmbeddings bool
Tags []string
TagsNE []string
HasAttributes []string
LacksAttributes []string
HasMasks []string
LacksMasks []string
HasLatents []string
LacksLatents []string
ConcurrentDownloads int
PageSize int
Rank uint32
WorldSize uint32
Sources string
RequireImages bool
RequireEmbeddings bool
Tags []string
TagsNE []string
HasAttributes []string
LacksAttributes []string
HasMasks []string
LacksMasks []string
HasLatents []string
LacksLatents []string
Rank uint32
WorldSize uint32
PageSize uint32
}

func (c *GeneratorDBConfig) SetDefaults() {
c.Sources = ""
c.RequireImages = true
c.RequireEmbeddings = false
c.Tags = []string{}
c.TagsNE = []string{}
c.HasAttributes = []string{}
c.LacksAttributes = []string{}
c.HasMasks = []string{}
c.LacksMasks = []string{}
c.HasLatents = []string{}
c.LacksLatents = []string{}
c.Rank = 0
c.WorldSize = 0
c.PageSize = 512
}

func (c *GeneratorDBConfig) getDbRequest() dbRequest {
Expand Down Expand Up @@ -122,26 +137,6 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest {
}
}

func GetDefaultDBConfig() GeneratorDBConfig {
return GeneratorDBConfig{
Sources: "",
RequireImages: true,
RequireEmbeddings: false,
Tags: []string{},
TagsNE: []string{},
HasAttributes: []string{},
LacksAttributes: []string{},
HasMasks: []string{},
LacksMasks: []string{},
HasLatents: []string{},
LacksLatents: []string{},
Rank: 0,
WorldSize: 0,
ConcurrentDownloads: 1,
PageSize: 1000,
}
}

type datagoGeneratorDB struct {
baseRequest http.Request
config GeneratorDBConfig
Expand Down
20 changes: 10 additions & 10 deletions pkg/generator_filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@ type fsSampleMetadata struct {
}

// -- Define the front end goroutine ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type datagoGeneratorFileSystem struct {
root_directory string
extensions set
page_size int
}

type GeneratorFileSystemConfig struct {
RootPath string
PageSize int
}

type datagoGeneratorFileSystem struct {
extensions set
config GeneratorFileSystemConfig
}

func newDatagoGeneratorFileSystem(config GeneratorFileSystemConfig) datagoGeneratorFileSystem {
supported_img_extensions := []string{".jpg", ".jpeg", ".png", ".JPEG", ".JPG", ".PNG"}
var extensionsMap = make(set)
Expand All @@ -37,7 +36,7 @@ func newDatagoGeneratorFileSystem(config GeneratorFileSystemConfig) datagoGenera
fmt.Println("File system root directory", config.RootPath)
fmt.Println("Supported image extensions", supported_img_extensions)

return datagoGeneratorFileSystem{root_directory: config.RootPath, extensions: extensionsMap, page_size: config.PageSize}
return datagoGeneratorFileSystem{config: config, extensions: extensionsMap}
}

func (f datagoGeneratorFileSystem) generatePages(ctx context.Context, chanPages chan Pages) {
Expand All @@ -46,7 +45,7 @@ func (f datagoGeneratorFileSystem) generatePages(ctx context.Context, chanPages

var samples []SampleDataPointers

err := filepath.Walk(f.root_directory, func(path string, info os.FileInfo, err error) error {
err := filepath.Walk(f.config.RootPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
Expand All @@ -56,15 +55,16 @@ func (f datagoGeneratorFileSystem) generatePages(ctx context.Context, chanPages
}

// Check if we have enough files to send a page
if len(samples) >= f.page_size {
if len(samples) >= f.config.PageSize {
chanPages <- Pages{samples}
samples = nil
}
return nil
})

if err != nil {
fmt.Println("Error walking the path", f.root_directory)
fmt.Println("Error walking the path", f.config.RootPath)
panic(err)
} else {
// Send the last page
if len(samples) > 0 {
Expand Down
12 changes: 6 additions & 6 deletions pkg/transforms.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ func buildImageSizeList(defaultImageSize int, downsamplingRatio int, minAspectRa
return image_list
}

func newARAwareTransform(defaultImageSize int, downsamplingRatio int, minAspectRatio, maxAspectRatio float64) *ARAwareTransform {
func newARAwareTransform(imageConfig ImageTransformConfig) *ARAwareTransform {
// Build the image size list
image_list := buildImageSizeList(defaultImageSize, downsamplingRatio, minAspectRatio, maxAspectRatio)
image_list := buildImageSizeList(imageConfig.DefaultImageSize, imageConfig.DownsamplingRatio, imageConfig.MinAspectRatio, imageConfig.MaxAspectRatio)

// Fill in the map table to match aspect ratios and image sizes
aspectRatioToSize := make(map[float64]ImageSize)
Expand All @@ -64,10 +64,10 @@ func newARAwareTransform(defaultImageSize int, downsamplingRatio int, minAspectR

//
return &ARAwareTransform{
defaultImageSize: defaultImageSize,
downsamplingRatio: downsamplingRatio,
minAspectRatio: minAspectRatio,
maxAspectRatio: maxAspectRatio,
defaultImageSize: imageConfig.DefaultImageSize,
downsamplingRatio: imageConfig.DownsamplingRatio,
minAspectRatio: imageConfig.MinAspectRatio,
maxAspectRatio: imageConfig.MaxAspectRatio,
targetImageSizes: image_list,
aspectRatioToSize: aspectRatioToSize,
}
Expand Down
Loading

0 comments on commit e05a2c5

Please sign in to comment.