Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Jul 5, 2024
1 parent 6d76528 commit 96d0d2c
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 22 deletions.
22 changes: 7 additions & 15 deletions bson/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,6 @@ func Marshal(val interface{}) ([]byte, error) {
// MarshalValue will use bson.DefaultRegistry to transform val into a BSON value. If val is a struct, this function will
// inspect struct tags and alter the marshalling process accordingly.
func MarshalValue(val interface{}) (Type, []byte, error) {
return MarshalValueWithRegistry(DefaultRegistry, val)
}

// MarshalValueWithRegistry returns the BSON encoding of val using Registry r.
//
// Deprecated: Using a custom registry to marshal individual BSON values will not be supported in Go
// Driver 2.0.
func MarshalValueWithRegistry(r *Registry, val interface{}) (Type, []byte, error) {
sw := bufPool.Get().(*bytes.Buffer)
defer func() {
// Proper usage of a sync.Pool requires each entry to have approximately
Expand All @@ -115,29 +107,29 @@ func MarshalValueWithRegistry(r *Registry, val interface{}) (Type, []byte, error
}
}()
sw.Reset()
vw := NewValueWriter(sw).(*valueWriter)
vwFlusher, err := vw.WriteDocumentElement("")
vwFlusher := NewValueWriter(sw).(*valueWriter)
vw, err := vwFlusher.WriteDocumentElement("")
if err != nil {
return 0, nil, err
}

// get an Encoder and encode the value
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(vwFlusher)
enc.SetRegistry(r)
enc.Reset(vw)
enc.SetRegistry(DefaultRegistry)
if err := enc.Encode(val); err != nil {
return 0, nil, err
}

// flush the bytes written because we cannot guarantee that a full document has been written
// after the flush, *sw will be in the format
// [value type, 0 (null byte to indicate end of empty element name), value bytes..]
if err := vw.Flush(); err != nil {
if err := vwFlusher.Flush(); err != nil {
return 0, nil, err
}
buf := sw.Bytes()
return Type(buf[0]), buf[2:], nil
typ := sw.Next(2)
return Type(typ[0]), sw.Bytes(), nil
}

// MarshalExtJSON returns the extended JSON encoding of val.
Expand Down
1 change: 0 additions & 1 deletion bson/truncation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ func unmarshalWithContext(t *testing.T, dc DecodeContext, data []byte, val inter
}

func TestTruncation(t *testing.T) {

t.Run("truncation", func(t *testing.T) {
inputName := "truncation"
inputVal := 4.7892
Expand Down
6 changes: 2 additions & 4 deletions mongo/change_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,12 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in
// any errors from Marshaling.
customOptions := make(map[string]bsoncore.Value)
for optionName, optionValue := range cs.options.Custom {
bsonType, bsonData, err := bson.MarshalValueWithRegistry(cs.registry, optionValue)
optionValueBSON, err := marshalValueWithRegistry(cs.registry, optionValue)
if err != nil {
cs.err = err
closeImplicitSession(cs.sess)
return nil, cs.Err()
}
optionValueBSON := bsoncore.Value{Type: bsoncore.Type(bsonType), Data: bsonData}
customOptions[optionName] = optionValueBSON
}
cs.aggregate.CustomOptions(customOptions)
Expand All @@ -260,13 +259,12 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in
// any errors from Marshaling.
cs.pipelineOptions = make(map[string]bsoncore.Value)
for optionName, optionValue := range cs.options.CustomPipeline {
bsonType, bsonData, err := bson.MarshalValueWithRegistry(cs.registry, optionValue)
optionValueBSON, err := marshalValueWithRegistry(cs.registry, optionValue)
if err != nil {
cs.err = err
closeImplicitSession(cs.sess)
return nil, cs.Err()
}
optionValueBSON := bsoncore.Value{Type: bsoncore.Type(bsonType), Data: bsonData}
cs.pipelineOptions[optionName] = optionValueBSON
}
}
Expand Down
3 changes: 1 addition & 2 deletions mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1079,11 +1079,10 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) {
// any errors from Marshaling.
customOptions := make(map[string]bsoncore.Value)
for optionName, optionValue := range ao.Custom {
bsonType, bsonData, err := bson.MarshalValueWithRegistry(a.registry, optionValue)
optionValueBSON, err := marshalValueWithRegistry(a.registry, optionValue)
if err != nil {
return nil, err
}
optionValueBSON := bsoncore.Value{Type: bsoncore.Type(bsonType), Data: bsonData}
customOptions[optionName] = optionValueBSON
}
op.CustomOptions(customOptions)
Expand Down
70 changes: 70 additions & 0 deletions mongo/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,73 @@
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package mongo

import (
"bytes"
"sync"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)

// Pool of buffers for marshalling BSON.
var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}

// Pool of bson.Encoder.
var encPool = sync.Pool{
New: func() interface{} {
return new(bson.Encoder)
},
}

func marshalValueWithRegistry(r *bson.Registry, val interface{}) (bsoncore.Value, error) {
sw := bufPool.Get().(*bytes.Buffer)
defer func() {
// Proper usage of a sync.Pool requires each entry to have approximately
// the same memory cost. To obtain this property when the stored type
// contains a variably-sized buffer, we add a hard limit on the maximum
// buffer to place back in the pool. We limit the size to 16MiB because
// that's the maximum wire message size supported by any current MongoDB
// server.
//
// Comment based on
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147
//
// Recycle byte slices that are smaller than 16MiB and at least half
// occupied.
if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() {
bufPool.Put(sw)
}
}()
sw.Reset()
vwFlusher := bson.NewValueWriter(sw).(interface {
// The returned instance should be a *bson.valueWriter.
WriteDocumentElement(string) (bson.ValueWriter, error)
Flush() error
})
vw, err := vwFlusher.WriteDocumentElement("")
if err != nil {
return bsoncore.Value{}, err
}

enc := encPool.Get().(*bson.Encoder)
defer encPool.Put(enc)
enc.Reset(vw)
enc.SetRegistry(r)
if err := enc.Encode(val); err != nil {
return bsoncore.Value{}, err
}

// flush the bytes written because we cannot guarantee that a full document has been written
// after the flush, *sw will be in the format
// [value type, 0 (null byte to indicate end of empty element name), value bytes..]
if err := vwFlusher.Flush(); err != nil {
return bsoncore.Value{}, err
}
typ := sw.Next(2)
return bsoncore.Value{Type: bsoncore.Type(typ[0]), Data: sw.Bytes()}, nil
}

0 comments on commit 96d0d2c

Please sign in to comment.