-
Notifications
You must be signed in to change notification settings - Fork 2
/
objectid.go
183 lines (156 loc) · 5.5 KB
/
objectid.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// Package oid provides an easy to use/integrate abstraction layer between your code and the primitive package of the official mongo driver
//
// helps to bridge the gap between the new mongo-go-driver and previous community run drivers such as the mgo package.
// For people not migrating, this also helps cover some of the pitfalls and frustrations of mongo-go-driver's primitive package which is extremely brittle, leaks too much of the driver to the clients of the API, and can even cause unexpected panics when unmarshalling
//
// This package follows the community run driver standard of using strings to represent objectIDs instead of [12]bytes, allowing for a much smoother development experience.
//
// Features
//
// 1. This package automatically unmarshalls all objectId strings in a JSON payload into oid.ObjectID types including support for mongos EXTJSON. And Un/Marshalls the oid.ObjectID types into primitive.ObjectIDs when interacting with bson
//
// 2. no panics on JSON unmarshalling
//
// 3. uses string types to avoid the driver bleeding to the API and give more control to the dev
//
// 4. fixes vet errors that was rampant in the community drivers
//
// 5. combines the best features of community drivers and the primitive package
//
// 6. makes migrating significantly easier
package oid
import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx"
)
// ObjectID is a unique ID identifying a BSON value. It must be exactly 12 bytes
// long.
//
// it does not support mongo's extJSON spec
// http://www.mongodb.org/display/DOCS/Object+Ids
type ObjectID string
// ObjectIDHex returns an ObjectID from the provided hex representation.
// Calling this function with an invalid hex representation will
// return an error. See the IsObjectIDHex function.
func ObjectIDHex(s string) (ObjectID, error) {
d, err := hex.DecodeString(s)
if err != nil || len(d) != 12 {
return ObjectID(d), fmt.Errorf("invalid input to ObjectIDHex: %q", s)
}
return ObjectID(d), nil
}
// IsObjectIDHex returns whether s is a valid hex representation of
// an ObjectID. See the ObjectIDHex function.
func IsObjectIDHex(s string) bool {
_, err := primitive.ObjectIDFromHex(s)
return err == nil
}
// NewObjectID returns a new unique ObjectID.
func NewObjectID() ObjectID {
id, _ := ObjectIDHex(primitive.NewObjectID().Hex())
return id
}
// String returns a hex string representation of the id.
// Example: ObjectIdHex("4d88e15b60f486e428412dc9").
func (id ObjectID) String() string {
return fmt.Sprintf(`ObjectID("%x")`, string(id))
}
// Hex returns a hex representation of the ObjectID.
func (id ObjectID) Hex() string {
return hex.EncodeToString([]byte(id))
}
// Valid confirms that the objectID is valid
func (id ObjectID) Valid() bool {
_, err := primitive.ObjectIDFromHex(id.Hex())
return err == nil
}
// MarshalBSONValue satisfies the decoding interface for the mongo driver
func (id ObjectID) MarshalBSONValue() (bsontype.Type, []byte, error) {
objID, err := primitive.ObjectIDFromHex(id.Hex())
if err != nil {
return bsontype.ObjectID, []byte{}, fmt.Errorf("%s is not an ObjectID", id.String())
}
val := bsonx.ObjectID(objID)
return val.MarshalBSONValue()
}
// UnmarshalBSONValue satisfies the decoding interface for the mongo driver
func (id *ObjectID) UnmarshalBSONValue(t bsontype.Type, b []byte) error {
if t != bsontype.ObjectID && t != bsontype.String {
return fmt.Errorf("type %s cannot be converted to %s", t, bsontype.ObjectID)
}
val := bsonx.Undefined()
if err := val.UnmarshalBSONValue(t, b); err != nil {
return fmt.Errorf("invalid objectID from source: %v", err)
}
var oid ObjectID
var err error
if t == bsontype.ObjectID {
oid, err = ObjectIDHex(val.ObjectID().Hex())
} else {
oid, err = ObjectIDHex(val.String())
}
if nil != err {
return fmt.Errorf("error occurred while trying to convert, reason: %s", err)
}
*id = oid
return nil
}
// MarshalJSON turns a bson.ObjectID into a json.Marshaller.
func (id ObjectID) MarshalJSON() ([]byte, error) {
return []byte("\"" + id.Hex() + "\""), nil
}
var nullBytes = []byte("null")
// UnmarshalJSON populates the byte slice with the ObjectID. If the byte slice is 64 bytes long, it
// will be populated with the hex representation of the ObjectID. If the byte slice is twelve bytes
// long, it will be populated with the BSON representation of the ObjectID. Otherwise, it will
// return an error.
func (id *ObjectID) UnmarshalJSON(b []byte) error {
var buf [12]byte
switch len(b) {
case 12:
_, err := hex.Decode(buf[:], b)
if err != nil {
return fmt.Errorf("invalid ObjectID in JSON: %s (%s)", string(b), err)
}
default:
// Extended JSON
var res interface{}
if err := json.Unmarshal(b, &res); err != nil {
return err
}
str, ok := res.(string)
if !ok {
m, ok := res.(map[string]interface{})
if !ok {
return errors.New("not an extended JSON ObjectID")
}
oid, ok := m["$oid"]
if !ok {
return errors.New("not an extended JSON ObjectID")
}
str, ok = oid.(string)
if !ok {
return errors.New("not an extended JSON ObjectID")
}
}
if len(b) == 2 && b[0] == '"' && b[1] == '"' || bytes.Equal(b, nullBytes) {
*id = ""
return nil
}
if len(str) != 24 {
return fmt.Errorf("invalid ObjectID in JSON: %s", str)
}
_, err := hex.Decode(buf[:], []byte(str))
if err != nil {
return fmt.Errorf("invalid ObjectID in JSON: %s (%s)", string(b), err)
}
*id = ObjectID(string(buf[:]))
}
return nil
}