diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 3d7adb6a6..9f0138115 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -199,11 +199,6 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head } if !enableTrace { - if response.ContentLength > 0 { - out.Grow(int(response.ContentLength)) - } else { - out.Grow(1024 * 4) - } _, err = out.ReadFrom(respReader) return } diff --git a/v2/pkg/engine/resolve/fetch.go b/v2/pkg/engine/resolve/fetch.go index 9f5d68fc9..eca88e703 100644 --- a/v2/pkg/engine/resolve/fetch.go +++ b/v2/pkg/engine/resolve/fetch.go @@ -1,8 +1,10 @@ package resolve import ( + "bytes" "encoding/json" "slices" + "sync" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -20,6 +22,42 @@ type Fetch interface { FetchKind() FetchKind Dependencies() FetchDependencies DataSourceInfo() DataSourceInfo + + GetBuffer() *bytes.Buffer + ReportResponseSize(out *bytes.Buffer) +} + +// FetchBufferSizeCalculator calculates the right size for a buffer based on the previous 64 fetches +// Instead of using a buffer with a random default size and growing it to the right cap +// FetchBufferSizeCalculator uses information about previous fetches to suggest a reasonable size +// Overall, this has shown to reduce bytes.growSlice operations to almost zero in hot paths +type FetchBufferSizeCalculator struct { + mux sync.RWMutex + count int + total int +} + +func (f *FetchBufferSizeCalculator) GetBuffer() *bytes.Buffer { + f.mux.RLock() + defer f.mux.RUnlock() + if f.count == 0 { + return bytes.NewBuffer(make([]byte, 0, 1024*4)) + } + size := f.total / f.count + return bytes.NewBuffer(make([]byte, 0, size)) +} + +func (f *FetchBufferSizeCalculator) ReportResponseSize(out *bytes.Buffer) { + f.mux.Lock() + defer f.mux.Unlock() + inc := out.Cap() + if f.count > 64 { // reset after 64 fetches + f.total = inc + f.count = 1 + } else { + f.count++ + f.total += inc + } } type FetchItem struct { @@ -71,6 +109,7 @@ const ( ) type SingleFetch struct { + FetchBufferSizeCalculator FetchConfiguration FetchDependencies InputTemplate InputTemplate @@ -140,6 +179,7 @@ func (_ *SingleFetch) FetchKind() FetchKind { // allows to join nested fetches to the same subgraph into a single fetch // representations variable will contain multiple items according to amount of entities matching this query type BatchEntityFetch struct { + FetchBufferSizeCalculator FetchDependencies Input BatchInput DataSource DataSource @@ -182,6 +222,7 @@ func (_ *BatchEntityFetch) FetchKind() FetchKind { // EntityFetch - represents nested entity fetch on object field // representations variable will contain single item type EntityFetch struct { + FetchBufferSizeCalculator FetchDependencies Input EntityInput DataSource DataSource @@ -217,6 +258,7 @@ func (_ *EntityFetch) FetchKind() FetchKind { // Usually, you want to batch fetches within a list, which is the default behavior of SingleFetch // However, if the data source does not support batching, you can use this fetch to make parallel fetches within a list type ParallelListItemFetch struct { + FetchBufferSizeCalculator Fetch *SingleFetch Traces []*SingleFetch Trace *DataSourceLoadTrace diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 18788ee02..2947fb326 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -155,12 +155,13 @@ func (l *Loader) resolveSingle(item *FetchItem) error { switch f := item.Fetch.(type) { case *SingleFetch: res := &result{ - out: &bytes.Buffer{}, + out: f.GetBuffer(), } err := l.loadSingleFetch(l.ctx.ctx, f, item, items, res) if err != nil { return err } + f.ReportResponseSize(res.out) err = l.mergeResult(item, res, items) if l.ctx.LoaderHooks != nil && res.loaderHookContext != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.statusCode, res.ds, goerrors.Join(res.err, l.ctx.subgraphErrors)) @@ -168,12 +169,13 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return err case *BatchEntityFetch: res := &result{ - out: &bytes.Buffer{}, + out: f.GetBuffer(), } err := l.loadBatchEntityFetch(l.ctx.ctx, item, f, items, res) if err != nil { return errors.WithStack(err) } + f.ReportResponseSize(res.out) err = l.mergeResult(item, res, items) if l.ctx.LoaderHooks != nil && res.loaderHookContext != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.statusCode, res.ds, goerrors.Join(res.err, l.ctx.subgraphErrors)) @@ -181,12 +183,13 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return err case *EntityFetch: res := &result{ - out: &bytes.Buffer{}, + out: f.GetBuffer(), } err := l.loadEntityFetch(l.ctx.ctx, item, f, items, res) if err != nil { return errors.WithStack(err) } + f.ReportResponseSize(res.out) err = l.mergeResult(item, res, items) if l.ctx.LoaderHooks != nil && res.loaderHookContext != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.statusCode, res.ds, goerrors.Join(res.err, l.ctx.subgraphErrors)) @@ -201,11 +204,10 @@ func (l *Loader) resolveSingle(item *FetchItem) error { for i := range items { i := i results[i] = &result{ - out: &bytes.Buffer{}, + out: f.GetBuffer(), } if l.ctx.TracingOptions.Enable { - f.Traces[i] = new(SingleFetch) - *f.Traces[i] = *f.Fetch + f.Traces[i] = f.Fetch g.Go(func() error { return l.loadFetch(ctx, f.Traces[i], item, items[i:i+1], results[i]) }) @@ -220,6 +222,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return errors.WithStack(err) } for i := range results { + f.ReportResponseSize(results[i].out) err = l.mergeResult(item, results[i], items[i:i+1]) if l.ctx.LoaderHooks != nil && results[i].loaderHookContext != nil { l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].statusCode, results[i].ds, goerrors.Join(results[i].err, l.ctx.subgraphErrors)) @@ -369,7 +372,8 @@ func (l *Loader) itemsData(items []*astjson.Value) *astjson.Value { func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { switch f := fetch.(type) { case *SingleFetch: - res.out = &bytes.Buffer{} + res.out = fetch.GetBuffer() + defer fetch.ReportResponseSize(res.out) return l.loadSingleFetch(ctx, f, fetchItem, items, res) case *ParallelListItemFetch: results := make([]*result, len(items)) @@ -380,11 +384,10 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte for i := range items { i := i results[i] = &result{ - out: &bytes.Buffer{}, + out: fetch.GetBuffer(), } if l.ctx.TracingOptions.Enable { - f.Traces[i] = new(SingleFetch) - *f.Traces[i] = *f.Fetch + f.Traces[i] = f.Fetch g.Go(func() error { return l.loadFetch(ctx, f.Traces[i], fetchItem, items[i:i+1], results[i]) }) @@ -399,12 +402,17 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte return errors.WithStack(err) } res.nestedMergeItems = results + for i := range results { + fetch.ReportResponseSize(results[i].out) + } return nil case *EntityFetch: - res.out = &bytes.Buffer{} + res.out = fetch.GetBuffer() + defer fetch.ReportResponseSize(res.out) return l.loadEntityFetch(ctx, fetchItem, f, items, res) case *BatchEntityFetch: - res.out = &bytes.Buffer{} + res.out = fetch.GetBuffer() + defer fetch.ReportResponseSize(res.out) return l.loadBatchEntityFetch(ctx, fetchItem, f, items, res) } return nil diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index da73aef23..808d13e5a 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -159,7 +159,11 @@ func (r *Resolvable) ResolveNode(node Node, data *astjson.Value, out io.Writer) r.printErr = nil r.authorizationError = nil r.errors = r.astjsonArena.NewArray() - + defer func() { + // remove references to buffers when no longer needed + r.out = nil + r.errors = nil + }() hasErrors := r.walkNode(node, data) if hasErrors { return fmt.Errorf("error resolving node") diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 57207633b..f0d2927a0 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "io" + "runtime" "sync" "time" @@ -67,6 +68,8 @@ type Resolver struct { propagateSubgraphErrors bool propagateSubgraphStatusCodes bool + + resolvableBufferPool *pool.LimitBufferPool } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { @@ -142,6 +145,8 @@ type ResolverOptions struct { ResolvableOptions ResolvableOptions // AllowedCustomSubgraphErrorFields defines which fields are allowed in the subgraph error when in passthrough mode AllowedSubgraphErrorFields []string + // BufferPoolOptions defines the size & limits of the resolvable buffer pool + BufferPoolOptions pool.LimitBufferPoolOptions } // New returns a new Resolver, ctx.Done() is used to cancel all active subscriptions & streams @@ -175,6 +180,19 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { allowedErrorFields[field] = struct{}{} } + if options.BufferPoolOptions.MaxBuffers == 0 { + options.BufferPoolOptions.MaxBuffers = runtime.GOMAXPROCS(-1) + } + if options.BufferPoolOptions.MaxBuffers < 8 { + options.BufferPoolOptions.MaxBuffers = 8 + } + if options.BufferPoolOptions.MaxBufferSize == 0 { + options.BufferPoolOptions.MaxBufferSize = 1024 * 1024 * 10 // 10MB + } + if options.BufferPoolOptions.DefaultBufferSize < 1024*8 { + options.BufferPoolOptions.DefaultBufferSize = 1024 * 8 // 8KB + } + resolver := &Resolver{ ctx: ctx, options: options, @@ -188,6 +206,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), allowedErrorExtensionFields: allowedExtensionFields, allowedErrorFields: allowedErrorFields, + resolvableBufferPool: pool.NewLimitBufferPool(ctx, options.BufferPoolOptions), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) for i := 0; i < options.MaxConcurrency; i++ { @@ -242,6 +261,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons }() t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + defer t.resolvable.Reset() // set all references to nil, e.g. pointers to buffers err := t.resolvable.Init(ctx, data, response.Info.OperationType) if err != nil { @@ -254,15 +274,17 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons return nil, err } } - - buf := &bytes.Buffer{} - err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) + buf := r.resolvableBufferPool.Get() + defer r.resolvableBufferPool.Put(buf) + err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf.Buf) if err != nil { return nil, err } - - _, err = buf.WriteTo(writer) - return resp, err + _, err = buf.Buf.WriteTo(writer) + if err != nil { + return nil, err + } + return resp, nil } type trigger struct { @@ -287,6 +309,7 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput fmt.Printf("resolver:trigger:subscription:update:%d\n", sub.id.SubscriptionID) } t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + defer t.resolvable.Reset() // reset all references input := make([]byte, len(sharedInput)) copy(input, sharedInput) diff --git a/v2/pkg/pool/limitbufferpool.go b/v2/pkg/pool/limitbufferpool.go new file mode 100644 index 000000000..7b77a7e6e --- /dev/null +++ b/v2/pkg/pool/limitbufferpool.go @@ -0,0 +1,98 @@ +package pool + +import ( + "bytes" + "context" + "time" +) + +// LimitBufferPool is a pool of buffers that is limited in size and is limiting the max size of buffers that should be recycled +// In addition, it runs a GC runtime that randomly resets a buffer every second to keep the memory usage low when usage is low +// This is an alternative to sync.Pool, which can grow unbounded rather quickly +type LimitBufferPool struct { + buffers []*ResolvableBuffer + index chan int + options LimitBufferPoolOptions +} + +type ResolvableBuffer struct { + Buf *bytes.Buffer + idx int +} + +type LimitBufferPoolOptions struct { + // MaxBuffers limits the total amount of buffers that can be allocated for printing the response + // It's recommended to set this to the number of CPU cores available, or a multiple of it + // If set to 0, the number of CPU cores is used + MaxBuffers int + // MaxBufferSize limits the size of the buffer that can be recycled back into the pool + // If set to 0, a limit of 10MB is applied + // If the buffer cap exceeds this limit, a new buffer with the default size is created + MaxBufferSize int + // DefaultBufferSize is used to initialize the buffer with a default size + // If set to 0, a default size of 8KB is used + DefaultBufferSize int + // GCTime is the time interval to run the GC to reset the buffer + GCTime time.Duration +} + +func NewLimitBufferPool(ctx context.Context, options LimitBufferPoolOptions) *LimitBufferPool { + if options.MaxBufferSize == 0 { + options.MaxBufferSize = 1024 * 1024 * 10 // 10MB + } + if options.DefaultBufferSize == 0 { + options.DefaultBufferSize = 1024 * 8 // 8KB + } + if options.MaxBuffers == 0 { + options.MaxBuffers = 8 + } + if options.GCTime == 0 { + options.GCTime = time.Second + } + pool := &LimitBufferPool{ + buffers: make([]*ResolvableBuffer, options.MaxBuffers), + index: make(chan int, options.MaxBuffers), + options: options, + } + for i := range pool.buffers { + pool.buffers[i] = &ResolvableBuffer{ + idx: i, + } + pool.index <- i + } + go pool.runGC(ctx) + return pool +} + +func (p *LimitBufferPool) runGC(ctx context.Context) { + ticker := time.NewTicker(p.options.GCTime) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + b := p.Get() + b.Buf = nil + p.Put(b) + } + } +} + +func (p *LimitBufferPool) Get() *ResolvableBuffer { + i := <-p.index + if p.buffers[i].Buf == nil { + p.buffers[i].Buf = bytes.NewBuffer(make([]byte, 0, p.options.DefaultBufferSize)) + } + return p.buffers[i] +} + +func (p *LimitBufferPool) Put(buf *ResolvableBuffer) { + if buf.Buf != nil { + buf.Buf.Reset() + if buf.Buf.Cap() > p.options.MaxBufferSize { + buf.Buf = bytes.NewBuffer(make([]byte, 0, p.options.DefaultBufferSize)) + } + } + p.index <- buf.idx +} diff --git a/v2/pkg/pool/limitbufferpool_test.go b/v2/pkg/pool/limitbufferpool_test.go new file mode 100644 index 000000000..c74650d22 --- /dev/null +++ b/v2/pkg/pool/limitbufferpool_test.go @@ -0,0 +1,75 @@ +package pool + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" +) + +func TestLimitBufferPool(t *testing.T) { + defer goleak.VerifyNone(t, + goleak.IgnoreCurrent(), // ignore the test itself + ) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + p := NewLimitBufferPool(ctx, LimitBufferPoolOptions{ + MaxBuffers: 4, + DefaultBufferSize: 1024, + MaxBufferSize: 1024 * 8, + GCTime: time.Millisecond * 10, + }) + + buffers := make([]*ResolvableBuffer, 4) + + for i := 0; i < 4; i++ { + buf := p.Get() + _, err := buf.Buf.Write(bytes.Repeat([]byte("a"), 64)) + assert.NoError(t, err) + buffers[i] = buf + } + + select { + case <-p.index: + t.Fatal("should not be able to get more buffers") + default: + } + + for i := 0; i < 4; i++ { + p.Put(buffers[i]) + } + + b := p.Get() + assert.NotNil(t, b) + assert.Equal(t, 1024, b.Buf.Cap()) + assert.Equal(t, 0, b.Buf.Len()) + + _, err := b.Buf.Write(bytes.Repeat([]byte("a"), 1024*9)) + assert.NoError(t, err) + assert.Equal(t, 1024*9, b.Buf.Len()) // write over the limit + assert.Equal(t, 9472, b.Buf.Cap()) // should have doubled the initial size + p.Put(b) // should reset the buffer + + for i := 0; i < 4; i++ { + buf := p.Get() + assert.NotNil(t, buf.Buf) + assert.Equal(t, 1024, buf.Buf.Cap()) // default size + _, err = buf.Buf.Write(bytes.Repeat([]byte("a"), 2048)) + assert.NoError(t, err) + p.Put(buf) + } + + time.Sleep(time.Millisecond * 100) // wait for GC to run + + for i := 0; i < 4; i++ { + buf := p.Get() + assert.NotNil(t, buf.Buf) + assert.Equal(t, 1024, buf.Buf.Cap()) // default size after GC + p.Put(buf) + } +}