From 96d0d2c4d5f02cdf3f717ff3a31abcba962a4b66 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 5 Jul 2024 12:01:15 -0400 Subject: [PATCH] WIP --- bson/marshal.go | 22 +++++-------- bson/truncation_test.go | 1 - mongo/change_stream.go | 6 ++-- mongo/collection.go | 3 +- mongo/util.go | 70 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 80 insertions(+), 22 deletions(-) diff --git a/bson/marshal.go b/bson/marshal.go index 5d3407f162c..82c2d3ffc62 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -88,14 +88,6 @@ func Marshal(val interface{}) ([]byte, error) { // MarshalValue will use bson.DefaultRegistry to transform val into a BSON value. If val is a struct, this function will // inspect struct tags and alter the marshalling process accordingly. func MarshalValue(val interface{}) (Type, []byte, error) { - return MarshalValueWithRegistry(DefaultRegistry, val) -} - -// MarshalValueWithRegistry returns the BSON encoding of val using Registry r. -// -// Deprecated: Using a custom registry to marshal individual BSON values will not be supported in Go -// Driver 2.0. -func MarshalValueWithRegistry(r *Registry, val interface{}) (Type, []byte, error) { sw := bufPool.Get().(*bytes.Buffer) defer func() { // Proper usage of a sync.Pool requires each entry to have approximately @@ -115,8 +107,8 @@ func MarshalValueWithRegistry(r *Registry, val interface{}) (Type, []byte, error } }() sw.Reset() - vw := NewValueWriter(sw).(*valueWriter) - vwFlusher, err := vw.WriteDocumentElement("") + vwFlusher := NewValueWriter(sw).(*valueWriter) + vw, err := vwFlusher.WriteDocumentElement("") if err != nil { return 0, nil, err } @@ -124,8 +116,8 @@ func MarshalValueWithRegistry(r *Registry, val interface{}) (Type, []byte, error // get an Encoder and encode the value enc := encPool.Get().(*Encoder) defer encPool.Put(enc) - enc.Reset(vwFlusher) - enc.SetRegistry(r) + enc.Reset(vw) + enc.SetRegistry(DefaultRegistry) if err := enc.Encode(val); err != nil { return 0, nil, err } @@ -133,11 +125,11 @@ func MarshalValueWithRegistry(r *Registry, val interface{}) (Type, []byte, error // flush the bytes written because we cannot guarantee that a full document has been written // after the flush, *sw will be in the format // [value type, 0 (null byte to indicate end of empty element name), value bytes..] - if err := vw.Flush(); err != nil { + if err := vwFlusher.Flush(); err != nil { return 0, nil, err } - buf := sw.Bytes() - return Type(buf[0]), buf[2:], nil + typ := sw.Next(2) + return Type(typ[0]), sw.Bytes(), nil } // MarshalExtJSON returns the extended JSON encoding of val. diff --git a/bson/truncation_test.go b/bson/truncation_test.go index 019bc37b069..958a9b1915a 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -31,7 +31,6 @@ func unmarshalWithContext(t *testing.T, dc DecodeContext, data []byte, val inter } func TestTruncation(t *testing.T) { - t.Run("truncation", func(t *testing.T) { inputName := "truncation" inputVal := 4.7892 diff --git a/mongo/change_stream.go b/mongo/change_stream.go index f02010f53f9..5eb48bb3918 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -244,13 +244,12 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in // any errors from Marshaling. customOptions := make(map[string]bsoncore.Value) for optionName, optionValue := range cs.options.Custom { - bsonType, bsonData, err := bson.MarshalValueWithRegistry(cs.registry, optionValue) + optionValueBSON, err := marshalValueWithRegistry(cs.registry, optionValue) if err != nil { cs.err = err closeImplicitSession(cs.sess) return nil, cs.Err() } - optionValueBSON := bsoncore.Value{Type: bsoncore.Type(bsonType), Data: bsonData} customOptions[optionName] = optionValueBSON } cs.aggregate.CustomOptions(customOptions) @@ -260,13 +259,12 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in // any errors from Marshaling. cs.pipelineOptions = make(map[string]bsoncore.Value) for optionName, optionValue := range cs.options.CustomPipeline { - bsonType, bsonData, err := bson.MarshalValueWithRegistry(cs.registry, optionValue) + optionValueBSON, err := marshalValueWithRegistry(cs.registry, optionValue) if err != nil { cs.err = err closeImplicitSession(cs.sess) return nil, cs.Err() } - optionValueBSON := bsoncore.Value{Type: bsoncore.Type(bsonType), Data: bsonData} cs.pipelineOptions[optionName] = optionValueBSON } } diff --git a/mongo/collection.go b/mongo/collection.go index ea75a4a0cd2..8f90d28c59f 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -1079,11 +1079,10 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { // any errors from Marshaling. customOptions := make(map[string]bsoncore.Value) for optionName, optionValue := range ao.Custom { - bsonType, bsonData, err := bson.MarshalValueWithRegistry(a.registry, optionValue) + optionValueBSON, err := marshalValueWithRegistry(a.registry, optionValue) if err != nil { return nil, err } - optionValueBSON := bsoncore.Value{Type: bsoncore.Type(bsonType), Data: bsonData} customOptions[optionName] = optionValueBSON } op.CustomOptions(customOptions) diff --git a/mongo/util.go b/mongo/util.go index 270fa24a255..7b6942fc8bd 100644 --- a/mongo/util.go +++ b/mongo/util.go @@ -5,3 +5,73 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package mongo + +import ( + "bytes" + "sync" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +// Pool of buffers for marshalling BSON. +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +// Pool of bson.Encoder. +var encPool = sync.Pool{ + New: func() interface{} { + return new(bson.Encoder) + }, +} + +func marshalValueWithRegistry(r *bson.Registry, val interface{}) (bsoncore.Value, error) { + sw := bufPool.Get().(*bytes.Buffer) + defer func() { + // Proper usage of a sync.Pool requires each entry to have approximately + // the same memory cost. To obtain this property when the stored type + // contains a variably-sized buffer, we add a hard limit on the maximum + // buffer to place back in the pool. We limit the size to 16MiB because + // that's the maximum wire message size supported by any current MongoDB + // server. + // + // Comment based on + // https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147 + // + // Recycle byte slices that are smaller than 16MiB and at least half + // occupied. + if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() { + bufPool.Put(sw) + } + }() + sw.Reset() + vwFlusher := bson.NewValueWriter(sw).(interface { + // The returned instance should be a *bson.valueWriter. + WriteDocumentElement(string) (bson.ValueWriter, error) + Flush() error + }) + vw, err := vwFlusher.WriteDocumentElement("") + if err != nil { + return bsoncore.Value{}, err + } + + enc := encPool.Get().(*bson.Encoder) + defer encPool.Put(enc) + enc.Reset(vw) + enc.SetRegistry(r) + if err := enc.Encode(val); err != nil { + return bsoncore.Value{}, err + } + + // flush the bytes written because we cannot guarantee that a full document has been written + // after the flush, *sw will be in the format + // [value type, 0 (null byte to indicate end of empty element name), value bytes..] + if err := vwFlusher.Flush(); err != nil { + return bsoncore.Value{}, err + } + typ := sw.Next(2) + return bsoncore.Value{Type: bsoncore.Type(typ[0]), Data: sw.Bytes()}, nil +}