Skip to content

Commit

Permalink
Add support for url, json number, and decimal
Browse files Browse the repository at this point in the history
GODRIVER-363
GODRIVER-343

Change-Id: I3a7e4198beb878b7f38f0a296b3be7fab604148f
  • Loading branch information
skriptble committed May 2, 2018
1 parent 5c4209e commit ff5ad99
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 55 deletions.
22 changes: 19 additions & 3 deletions bson/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"fmt"
"io"
"math"
"net/url"
"reflect"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -355,16 +357,26 @@ func (d *Decoder) getReflectValue(v *Value, containerType reflect.Type, outer re

case tFloat64, tEmpty:
val = reflect.ValueOf(f)
case tJSONNumber:
val = reflect.ValueOf(strconv.FormatFloat(f, 'f', -1, 64)).Convert(tJSONNumber)
default:
return val, nil
}

case 0x2:
if containerType != tString && containerType != tEmpty {
str := v.StringValue()
switch containerType {
case tString, tEmpty:
val = reflect.ValueOf(str)
case tURL:
u, err := url.Parse(str)
if err != nil {
return val, err
}
val = reflect.ValueOf(u).Elem()
default:
return val, nil
}

val = reflect.ValueOf(v.StringValue())
case 0x4:
if containerType == tEmpty {
d := NewDecoder(bytes.NewBuffer(v.ReaderArray()))
Expand Down Expand Up @@ -547,6 +559,8 @@ func (d *Decoder) getReflectValue(v *Value, containerType reflect.Type, outer re

case tEmpty, tInt32, tInt64, tInt, tFloat32, tFloat64:
val = reflect.ValueOf(i).Convert(containerType)
case tJSONNumber:
val = reflect.ValueOf(strconv.FormatInt(int64(i), 10)).Convert(tJSONNumber)
default:
return val, nil
}
Expand Down Expand Up @@ -609,6 +623,8 @@ func (d *Decoder) getReflectValue(v *Value, containerType reflect.Type, outer re
val = reflect.ValueOf(float32(i))
case tFloat64:
val = reflect.ValueOf(float64(i))
case tJSONNumber:
val = reflect.ValueOf(strconv.FormatInt(i, 10)).Convert(tJSONNumber)
}

case 0x13:
Expand Down
101 changes: 101 additions & 0 deletions bson/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ package bson

import (
"bytes"
"encoding/json"
"net/url"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/mongodb/mongo-go-driver/bson/decimal"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -2007,6 +2010,47 @@ func TestDecoder(t *testing.T) {
return
}

require.True(t, reflect.DeepEqual(tc.expected, tc.actual))
})
}
})
t.Run("decimal128", func(t *testing.T) {
decimal128, err := decimal.ParseDecimal128("1.5e10")
if err != nil {
t.Errorf("Error parsing decimal128: %v", err)
t.FailNow()
}
testCases := []struct {
name string
reader []byte
expected interface{}
actual interface{}
err error
}{
{
"decimal128",
docToBytes(NewDocument(EC.Decimal128("a", decimal128))),
&struct {
A decimal.Decimal128
}{
A: decimal128,
},
&struct {
A decimal.Decimal128
}{},
nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
d := NewDecoder(bytes.NewBuffer(tc.reader))

err := d.Decode(tc.actual)
requireErrEqual(t, tc.err, err)
if err != nil {
return
}

require.True(t, reflect.DeepEqual(tc.expected, tc.actual))
})
}
Expand Down Expand Up @@ -2382,6 +2426,63 @@ func TestDecoder(t *testing.T) {
})
}
})
t.Run("pluggable types", func(t *testing.T) {
murl, err := url.Parse("https://mongodb.com/random-url?hello=world")
if err != nil {
t.Errorf("Error parsing URL: %v", err)
t.FailNow()
}
testCases := []struct {
name string
reader []byte
expected interface{}
actual interface{}
err error
}{
{
"*url.URL",
docToBytes(NewDocument(EC.String("a", murl.String()))),
&struct {
A *url.URL
}{
A: murl,
},
&struct {
A *url.URL
}{},
nil,
},
{
"json.Number",
docToBytes(NewDocument(EC.Int64("a", 5), EC.Double("b", 10.10))),
&struct {
A json.Number
B json.Number
}{
A: json.Number("5"),
B: json.Number("10.1"),
},
&struct {
A json.Number
B json.Number
}{},
nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
d := NewDecoder(bytes.NewBuffer(tc.reader))

err := d.Decode(tc.actual)
requireErrEqual(t, tc.err, err)
if err != nil {
return
}

require.True(t, reflect.DeepEqual(tc.expected, tc.actual))
})
}
})
}

func elementSliceEqual(t *testing.T, e1 []*Element, e2 []*Element) {
Expand Down
88 changes: 86 additions & 2 deletions bson/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
package bson

import (
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/url"
"reflect"
"strconv"
"strings"
"time"

"github.com/mongodb/mongo-go-driver/bson/decimal"
"github.com/mongodb/mongo-go-driver/bson/objectid"
)

Expand All @@ -25,6 +28,8 @@ var ErrEncoderNilWriter = errors.New("encoder.Encode called on Encoder with nil
var tByteSlice = reflect.TypeOf(([]byte)(nil))
var tByte = reflect.TypeOf(byte(0x00))
var tElement = reflect.TypeOf((*Element)(nil))
var tURL = reflect.TypeOf(url.URL{})
var tJSONNumber = reflect.TypeOf(json.Number(""))

// Marshaler describes a type that can marshal a BSON representation of itself into bytes.
type Marshaler interface {
Expand Down Expand Up @@ -278,6 +283,7 @@ func (e *encoder) encodeMap(val reflect.Value) ([]*Element, error) {
mapkeys := val.MapKeys()
elems := make([]*Element, 0, val.Len())
for _, rkey := range mapkeys {
orig := rkey
rkey = e.underlyingVal(rkey)

var key string
Expand All @@ -297,9 +303,15 @@ func (e *encoder) encodeMap(val reflect.Value) ([]*Element, error) {
case reflect.String:
key = rkey.String()
default:
if rkey.Type() == tOID {
switch rkey.Type() {
case tOID:
key = fmt.Sprintf("%s", rkey.Interface())
} else {
case tURL:
rkey = orig
key = fmt.Sprintf("%s", rkey.Interface())
case tDecimal:
key = fmt.Sprintf("%s", rkey.Interface())
default:
return nil, fmt.Errorf("Unsupported map key type %s", rkey.Kind())
}
}
Expand All @@ -316,6 +328,24 @@ func (e *encoder) encodeMap(val reflect.Value) ([]*Element, error) {
case Reader:
elems = append(elems, EC.SubDocumentFromReader(key, t))
continue
case json.Number:
// We try to do an int first
if i64, err := t.Int64(); err == nil {
elems = append(elems, EC.Int64(key, i64))
continue
}
f64, err := t.Float64()
if err != nil {
return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err)
}
elems = append(elems, EC.Double(key, f64))
continue
case *url.URL:
elems = append(elems, EC.String(key, t.String()))
continue
case decimal.Decimal128:
elems = append(elems, EC.Decimal128(key, t))
continue
}
rval = e.underlyingVal(rval)

Expand Down Expand Up @@ -343,6 +373,24 @@ func (e *encoder) encodeSlice(val reflect.Value) ([]*Element, error) {
case Reader:
elems = append(elems, EC.SubDocumentFromReader(key, t))
continue
case json.Number:
// We try to do an int first
if i64, err := t.Int64(); err == nil {
elems = append(elems, EC.Int64(key, i64))
continue
}
f64, err := t.Float64()
if err != nil {
return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err)
}
elems = append(elems, EC.Double(key, f64))
continue
case *url.URL:
elems = append(elems, EC.String(key, t.String()))
continue
case decimal.Decimal128:
elems = append(elems, EC.Decimal128(key, t))
continue
}
sval = e.underlyingVal(sval)
elem, err := e.elemFromValue(key, sval, false)
Expand Down Expand Up @@ -371,6 +419,24 @@ func (e *encoder) encodeSliceAsArray(rval reflect.Value, minsize bool) ([]*Value
case Reader:
vals = append(vals, VC.DocumentFromReader(t))
continue
case json.Number:
// We try to do an int first
if i64, err := t.Int64(); err == nil {
vals = append(vals, VC.Int64(i64))
continue
}
f64, err := t.Float64()
if err != nil {
return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err)
}
vals = append(vals, VC.Double(f64))
continue
case *url.URL:
vals = append(vals, VC.String(t.String()))
continue
case decimal.Decimal128:
vals = append(vals, VC.Decimal128(t))
continue
}

sval = e.underlyingVal(sval)
Expand Down Expand Up @@ -429,6 +495,24 @@ func (e *encoder) encodeStruct(val reflect.Value) ([]*Element, error) {
case Reader:
elems = append(elems, EC.SubDocumentFromReader(key, t))
continue
case json.Number:
// We try to do an int first
if i64, err := t.Int64(); err == nil {
elems = append(elems, EC.Int64(key, i64))
continue
}
f64, err := t.Float64()
if err != nil {
return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err)
}
elems = append(elems, EC.Double(key, f64))
continue
case *url.URL:
elems = append(elems, EC.String(key, t.String()))
continue
case decimal.Decimal128:
elems = append(elems, EC.Decimal128(key, t))
continue
}
field = e.underlyingVal(field)

Expand Down
Loading

0 comments on commit ff5ad99

Please sign in to comment.