Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wip dynamic worker pool, not super elegant but actually faster
Browse files Browse the repository at this point in the history
blefaudeux committed Jan 14, 2025
1 parent 3bc2636 commit 0894432
Showing 14 changed files with 236 additions and 122 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -50,7 +50,6 @@ client_config = {
# some performance options, best settings will depend on your machine
"prefetch_buffer_size": 64,
"samples_buffer_size": 128,
"concurrency": concurrency,
}

client = datago.GetClientFromJSON(json.dumps(config)) # Will return None if something goes wrong, check the logs
6 changes: 2 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
@@ -13,7 +13,6 @@ import (
func main() {

cropAndResize := flag.Bool("crop_and_resize", false, "Whether to crop and resize the images and masks")
concurrency := flag.Int("concurrency", 64, "The number of concurrent http requests to make")
itemFetchBuffer := flag.Int("item_fetch_buffer", 256, "The number of items to pre-load")
itemReadyBuffer := flag.Int("item_ready_buffer", 128, "The number of items ready to be served")
limit := flag.Int("limit", 2000, "The number of items to fetch")
@@ -36,9 +35,8 @@ func main() {
CropAndResize: *cropAndResize,
}
config.SourceConfig = sourceConfig
config.Concurrency = *concurrency
config.PrefetchBufferSize = *itemFetchBuffer
config.SamplesBufferSize = *itemReadyBuffer
config.PrefetchBufferSize = int32(*itemFetchBuffer)
config.SamplesBufferSize = int32(*itemReadyBuffer)
config.Limit = *limit

dataroom_client := datago.GetClient(config)
57 changes: 27 additions & 30 deletions pkg/client.go
Original file line number Diff line number Diff line change
@@ -56,9 +56,8 @@ type DatagoConfig struct {
SourceType DatagoSourceType `json:"source_type"`
SourceConfig interface{} `json:"source_config"`
ImageConfig ImageTransformConfig `json:"image_config"`
PrefetchBufferSize int `json:"prefetch_buffer_size"`
SamplesBufferSize int `json:"samples_buffer_size"`
Concurrency int `json:"concurrency"`
PrefetchBufferSize int32 `json:"prefetch_buffer_size"`
SamplesBufferSize int32 `json:"samples_buffer_size"`
Limit int `json:"limit"`
}

@@ -70,7 +69,6 @@ func (c *DatagoConfig) setDefaults() {
c.ImageConfig.setDefaults()
c.PrefetchBufferSize = 64
c.SamplesBufferSize = 32
c.Concurrency = 64
c.Limit = 0
}

@@ -131,9 +129,8 @@ func DatagoConfigFromJSON(jsonString string) DatagoConfig {
log.Panicf("Error unmarshalling Image config: %v", err)
}

config.PrefetchBufferSize = int(tempConfig["prefetch_buffer_size"].(float64))
config.SamplesBufferSize = int(tempConfig["samples_buffer_size"].(float64))
config.Concurrency = int(tempConfig["concurrency"].(float64))
config.PrefetchBufferSize = int32(tempConfig["prefetch_buffer_size"].(float64))
config.SamplesBufferSize = int32(tempConfig["samples_buffer_size"].(float64))
if err != nil {
log.Panicf("Error unmarshalling JSON: %v", err)
}
@@ -155,9 +152,9 @@ type DatagoClient struct {
backend Backend

// Channels - these will be used to communicate between the background goroutines
chanPages chan Pages
chanSampleMetadata chan SampleDataPointers
chanSamples chan Sample
chanPages BufferedChan[Pages]
chanSampleMetadata BufferedChan[SampleDataPointers]
chanSamples BufferedChan[Sample]
}

// -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
@@ -187,23 +184,23 @@ func GetClient(config DatagoConfig) *DatagoClient {
if err != nil {
return nil
} else {
backend = BackendHTTP{config: &dbConfig, concurrency: config.Concurrency}
backend = BackendHTTP{config: &dbConfig}
}
case SourceFileSystemConfig:
fmt.Println("Creating a FileSystem-backed dataloader.", config.Limit, " max samples")
fsConfig := config.SourceConfig.(SourceFileSystemConfig)
generator = newDatagoGeneratorFileSystem(fsConfig)
backend = BackendFileSystem{config: &config, concurrency: config.Concurrency}
backend = BackendFileSystem{config: &config}
default:
fmt.Println("Unsupported source type")
log.Panic("Unsupported source type")
}

// Create the client
client := &DatagoClient{
chanPages: make(chan Pages, 2),
chanSampleMetadata: make(chan SampleDataPointers, config.PrefetchBufferSize),
chanSamples: make(chan Sample, config.SamplesBufferSize),
chanPages: NewBufferedChan[Pages](2),
chanSampleMetadata: NewBufferedChan[SampleDataPointers](config.PrefetchBufferSize),
chanSamples: NewBufferedChan[Sample](config.SamplesBufferSize),
imageConfig: config.ImageConfig,
servedSamples: 0,
limit: config.Limit,
@@ -264,7 +261,7 @@ func (c *DatagoClient) Start() {
wg.Add(1)
go func() {
defer wg.Done()
c.generator.generatePages(c.context, c.chanPages) // Collect the root data source pages
c.generator.generatePages(c.context, &c.chanPages) // Collect the root data source pages
}()

wg.Add(1)
@@ -276,7 +273,7 @@ func (c *DatagoClient) Start() {
wg.Add(1)
go func() {
defer wg.Done()
c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform, c.imageConfig.PreEncodeImages) // Fetch the payloads and and deserialize them
c.backend.collectSamples(&c.chanSampleMetadata, &c.chanSamples, arAwareTransform, c.imageConfig.PreEncodeImages) // Fetch the payloads and and deserialize them
}()

c.waitGroup = &wg
@@ -296,7 +293,8 @@ func (c *DatagoClient) GetSample() Sample {
return Sample{}
}

if sample, ok := <-c.chanSamples; ok {
sample, ok := c.chanSamples.Receive()
if ok {
c.servedSamples++
return sample
}
@@ -307,18 +305,16 @@ func (c *DatagoClient) GetSample() Sample {

// Stop the background downloads, will clear the memory and CPU footprint
func (c *DatagoClient) Stop() {
fmt.Println("Stopping the datago client")

// Signal the coroutines that next round should be a stop
if c.cancel == nil {
return // Already stopped
}
fmt.Println("Stopping the datago client")
c.cancel()

// Clear the channels, in case a commit is blocking
go consumeChannel(c.chanPages)
go consumeChannel(c.chanSampleMetadata)
go consumeChannel(c.chanSamples)
c.chanPages.Empty()
c.chanSampleMetadata.Empty()
c.chanSamples.Empty()

// Wait for all goroutines to finish
if c.waitGroup != nil {
@@ -339,22 +335,23 @@ func (c *DatagoClient) asyncDispatch() {
for {
select {
case <-c.context.Done():
close(c.chanSampleMetadata)
c.chanSampleMetadata.Close()
return
case page, open := <-c.chanPages:
default:
page, open := c.chanPages.Receive()
if !open {
fmt.Println("No more metadata to fetch, wrapping up")
close(c.chanSampleMetadata)
c.chanSampleMetadata.Close()
return
}

for _, item := range page.samplesDataPointers {
select {
case <-c.context.Done():
close(c.chanSampleMetadata)
c.chanSampleMetadata.Close()
return
case c.chanSampleMetadata <- item:
// Item sent to the channel
default:
c.chanSampleMetadata.Send(item)
}
}
}
50 changes: 47 additions & 3 deletions pkg/core.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,50 @@
package datago

import "context"
import (
"context"
"sync/atomic"
)

type BufferedChan[T any] struct {
_channel chan T
current_items int32
max_items int32
open bool
}

// --- Simple buffered channel implementation ---------------------------------------------------------------------------------------------------------------------------------------------------------------
// Make it possible to track the current channel status from the outside

func NewBufferedChan[T any](max_items int32) BufferedChan[T] {
return BufferedChan[T]{_channel: make(chan T, max_items), current_items: 0, max_items: max_items, open: true}
}

func (b *BufferedChan[T]) Send(item T) {
b._channel <- item

// Small perf hit, not sure it's worth it
atomic.AddInt32(&b.current_items, 1)
}

func (b *BufferedChan[T]) Receive() (T, bool) {
item, open := <-b._channel
if !open {
return item, false
}

// Small perf hit, not sure it's worth it
atomic.AddInt32(&b.current_items, -1)
return item, true
}

func (b *BufferedChan[T]) Empty() {
consumeChannel(b._channel)
}

func (b *BufferedChan[T]) Close() {
close(b._channel)
b.open = false
}

// --- Sample data structures - these will be exposed to the Python world ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type LatentPayload struct {
@@ -43,10 +87,10 @@ type Pages struct {
}

type Generator interface {
generatePages(ctx context.Context, chanPages chan Pages)
generatePages(ctx context.Context, chanPages *BufferedChan[Pages])
}

// The backend will be responsible for fetching the payloads and deserializing them
type Backend interface {
collectSamples(chanSampleMetadata chan SampleDataPointers, chanSamples chan Sample, transform *ARAwareTransform, pre_encode_images bool)
collectSamples(inputSampleMetadata *BufferedChan[SampleDataPointers], outputSamples *BufferedChan[Sample], transform *ARAwareTransform, encodeImages bool)
}
89 changes: 89 additions & 0 deletions pkg/dynamic_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package datago

import (
"fmt"
"runtime"
"time"
)

// Define an enum which will be used to track the state of the worker
type worker_state int

const (
worker_idle worker_state = iota
worker_running
worker_done
worker_stopping
)

// Define a stateful worker struct which will be spawned by the worker pool
type worker struct {
state worker_state
}

// Manage a pool of workers to fetch the samples
// We'll initially spawn half the machine capacity in terms of workers,
// and then we'll dynamically adjust the number of workers based on the
// work backlog and idle time

func run_worker_pool(sampleWorker func(*worker), chanInputs *BufferedChan[SampleDataPointers], chanOutputs *BufferedChan[Sample]) {

// Get the number of CPUs on the machine
numCPUs := runtime.NumCPU() // We suppose that this doesn´t change during the execution
worker_pool_size := numCPUs / 2

// Start the workers and work on the metadata channel
var workers []*worker

for i := 0; i < worker_pool_size; i++ {
new_worker := worker{state: worker_idle}
workers = append(workers, &new_worker)
go sampleWorker(&new_worker)
}

// Every second, check the state of the workers and adjust the pool size
// based on the work backlog and idle time
for {
// FIXME: Logic is super crude here, although ballpark correct
if !idle_workers(workers) && chanInputs.current_items > int32(len(workers)) && len(workers) < numCPUs {
fmt.Println("Increasing the worker pool size. Now ", len(workers))
new_worker := worker{state: worker_idle}
workers = append(workers, &new_worker)
go sampleWorker(&new_worker)
}

if idle_workers(workers) && chanInputs.current_items < 10 && len(workers) > 1 {
fmt.Println("Decreasing the worker pool size. Now ", len(workers))
workers[len(workers)-1].state = worker_stopping
workers = workers[:len(workers)-1]
}

if done_workers(workers) {
fmt.Println("All workers are done")
break
}
time.Sleep(1 * time.Second)
fmt.Println("Samples in the input queue", chanInputs.current_items, " Output queue: ", chanOutputs.current_items)
}
}

func idle_workers(workers []*worker) bool {
// There will be some noise measuring this, it's ok, we're only interested in a big picture
idle := 0
for _, w := range workers {
if w.state == worker_idle {
idle += 1
}
}
return (float64(idle) / float64(len(workers))) > 0.5
}

func done_workers(workers []*worker) bool {
done := 0
for _, w := range workers {
if w.state == worker_done {
done += 1
}
}
return done == len(workers)
}
8 changes: 4 additions & 4 deletions pkg/generator_db.go
Original file line number Diff line number Diff line change
@@ -242,7 +242,7 @@ func newDatagoGeneratorDB(config SourceDBConfig) (datagoGeneratorDB, error) {
return datagoGeneratorDB{baseRequest: *getHTTPRequest(api_url, api_key, request), config: config}, nil
}

func (f datagoGeneratorDB) generatePages(ctx context.Context, chanPages chan Pages) {
func (f datagoGeneratorDB) generatePages(ctx context.Context, chanPages *BufferedChan[Pages]) {
// Fetch pages from the API, and feed the results to the items channel
// This is meant to be run in a goroutine
http_client := http.Client{Timeout: 30 * time.Second}
@@ -304,13 +304,13 @@ func (f datagoGeneratorDB) generatePages(ctx context.Context, chanPages chan Pag
samplesDataPointers[i] = sample
}

chanPages <- Pages{samplesDataPointers}
chanPages.Send(Pages{samplesDataPointers})
}

// Check if there are more pages to fetch
if data.Next == "" {
fmt.Println("No more pages to fetch, wrapping up")
close(chanPages)
chanPages.Close()
return
}

@@ -328,7 +328,7 @@ func (f datagoGeneratorDB) generatePages(ctx context.Context, chanPages chan Pag
// Check if we consumed all the retries
if !valid_page {
fmt.Println("Too many errors fetching new pages, wrapping up")
close(chanPages)
chanPages.Close()
return
}
}
Loading

0 comments on commit 0894432

Please sign in to comment.