Skip to content

Commit

Permalink
simplify boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
nicpottier committed Feb 12, 2019
1 parent 58eab92 commit 05165d3
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 58 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*.dll
*.so
*.dylib
.vscode

# Test binary, build with `go test -c`
*.test
Expand Down
16 changes: 4 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,15 @@ func (i CustomID) MarshalJSON() ([]byte, error) {
}

func (i *CustomID) UnmarshalJSON(b []byte) error {
val, err := null.UnmarshalInt(b)
*i = CustomID(val)
return err
return null.UnmarshalInt(b, (*null.Int)(i))
}

func (i CustomID) Value() (driver.Value, error) {
return null.Int(i).Value()
}

func (i *CustomID) Scan(value interface{}) error {
val, err := null.ScanInt(value)
*i = CustomID(val)
return err
return null.ScanInt(value, (*null.Int)(i))
}
```

Expand All @@ -54,19 +50,15 @@ func (s CustomString) MarshalJSON() ([]byte, error) {
}

func (s *CustomString) UnmarshalJSON(b []byte) error {
val, err := null.UnmarshalString(b)
*s = CustomString(val)
return err
return null.UnmarshalString(b, (*null.String)(s))
}

func (s CustomString) Value() (driver.Value, error) {
return null.String(s).Value()
}

func (s *CustomString) Scan(value interface{}) error {
val, err := null.ScanString(value)
*s = CustomString(val)
return err
return null.ScanString(value, (*null.String)(s))
}
```

Expand Down
57 changes: 28 additions & 29 deletions null.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,32 @@ const NullInt = Int(0)

// UnmarshalInt is a utility method that can be used to unmarshal a json value to an Int and error
// In the case of an error, null or "" values, NullInt is returned
func UnmarshalInt(b []byte) (Int, error) {
var val *int64
func UnmarshalInt(b []byte, v *Int) error {
val := int64(0)
err := json.Unmarshal(b, &val)
if err != nil || val == nil {
return NullInt, err
if err != nil {
return err
}
return Int(*val), nil
*v = Int(val)
return nil
}

// ScanInt is a utility method that can be used to scan a db value and return an Int and error
// In the case of an error, null or "" values, NullInt is returned
func ScanInt(value interface{}) (Int, error) {
func ScanInt(value interface{}, v *Int) error {
val := &sql.NullInt64{}
err := val.Scan(value)
if err != nil {
return NullInt, err
return err
}

if !val.Valid {
return NullInt, nil
*v = NullInt
return nil
}

return Int(val.Int64), nil
*v = Int(val.Int64)
return nil
}

// MarshalJSON marshals our int to JSON. 0 values will be marshalled as null
Expand All @@ -50,16 +53,12 @@ func (i Int) MarshalJSON() ([]byte, error) {

// UnmarshalJSON unmarshals our JSON to int. null values will be marshalled to 0
func (i *Int) UnmarshalJSON(b []byte) error {
val, err := UnmarshalInt(b)
*i = val
return err
return UnmarshalInt(b, i)
}

// Scan implements the Scanner interface for Int
func (i *Int) Scan(value interface{}) error {
val, err := ScanInt(value)
*i = val
return err
return ScanInt(value, i)
}

// Value implements the driver Valuer interface for Int
Expand All @@ -79,33 +78,37 @@ const NullString = String("")

// UnmarshalString is a utility method that can be used to unmarshal a json value to a String and error
// In the case of an error, null or "" values, NullString is returned
func UnmarshalString(b []byte) (String, error) {
func UnmarshalString(b []byte, v *String) error {
var val *string
err := json.Unmarshal(b, &val)
if err != nil {
return NullString, err
return err
}
if val == nil {
return NullString, nil
*v = NullString
return nil
}

return String(*val), nil
*v = String(*val)
return nil
}

// ScanString is a utility method that can be used to scan a db value and return a String and error
// In the case of an error, null or "" values, NullString is returned
func ScanString(value interface{}) (String, error) {
func ScanString(value interface{}, v *String) error {
val := &sql.NullString{}
err := val.Scan(value)
if err != nil {
return NullString, err
return err
}

if !val.Valid {
return NullString, nil
*v = NullString
return nil
}

return String(val.String), nil
*v = String(val.String)
return nil
}

// MarshalJSON marshals our string to JSON. "" values will be marshalled as null
Expand All @@ -118,16 +121,12 @@ func (s String) MarshalJSON() ([]byte, error) {

// UnmarshalJSON unmarshals our json to a string. null values will be marshalled to ""
func (s *String) UnmarshalJSON(b []byte) error {
val, err := UnmarshalString(b)
*s = val
return err
return UnmarshalString(b, s)
}

// Scan implements the Scanner interface for String
func (s *String) Scan(value interface{}) error {
val, err := ScanString(value)
*s = val
return err
return ScanString(value, s)
}

// Value implements the driver Valuer interface for String
Expand Down
26 changes: 9 additions & 17 deletions null_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@ func (i CustomID) MarshalJSON() ([]byte, error) {
}

func (i *CustomID) UnmarshalJSON(b []byte) error {
val, err := UnmarshalInt(b)
*i = CustomID(val)
return err
return UnmarshalInt(b, (*Int)(i))
}

func (i CustomID) Value() (driver.Value, error) {
return Int(i).Value()
}

func (i *CustomID) Scan(value interface{}) error {
val, err := ScanInt(value)
*i = CustomID(val)
return err
return ScanInt(value, (*Int)(i))
}

type OtherCustom = Int
Expand Down Expand Up @@ -70,8 +66,8 @@ func TestCustomInt(t *testing.T) {
id := CustomID(10)
err = json.Unmarshal(b, &id)
assert.NoError(t, err)
assert.True(t, tc.Value == id)
assert.True(t, tc.Test == id)
assert.True(t, tc.Value == id, "%d: %s not equal to %s", i, tc.Value, id)
assert.True(t, tc.Test == id, "%d: %s not equal to %s", i, tc.Test, id)

_, err = db.Exec(`INSERT INTO custom_id(id) VALUES($1)`, tc.Value)
assert.NoError(t, err)
Expand Down Expand Up @@ -168,19 +164,15 @@ func (s CustomString) MarshalJSON() ([]byte, error) {
}

func (s *CustomString) UnmarshalJSON(b []byte) error {
val, err := UnmarshalString(b)
*s = CustomString(val)
return err
return UnmarshalString(b, (*String)(s))
}

func (s CustomString) Value() (driver.Value, error) {
return String(s).Value()
}

func (s *CustomString) Scan(value interface{}) error {
val, err := ScanString(value)
*s = CustomString(val)
return err
return ScanString(value, (*String)(s))
}

const NullCustomString = CustomString("")
Expand Down Expand Up @@ -266,18 +258,18 @@ func TestString(t *testing.T) {
{NullString, "null", nil},
}

for _, tc := range tcs {
for i, tc := range tcs {
_, err = db.Exec(`DELETE FROM custom_string;`)
assert.NoError(t, err)

b, err := json.Marshal(tc.Value)
assert.NoError(t, err)
assert.True(t, tc.JSON == string(b), "%s not equal to %s", tc.JSON, string(b))
assert.True(t, tc.JSON == string(b), "%d: %s not equal to %s", i, tc.JSON, string(b))

str := String("blah")
err = json.Unmarshal(b, &str)
assert.NoError(t, err)
assert.True(t, tc.Value == str)
assert.True(t, tc.Value == str, "%d: %s not equal to %s", i, tc.Value, str)

_, err = db.Exec(`INSERT INTO custom_string(string) VALUES($1)`, tc.Value)
assert.NoError(t, err)
Expand Down

0 comments on commit 05165d3

Please sign in to comment.