Skip to content

Commit

Permalink
Merge pull request cedar-policy#64 from strongdm/json-encoding-cleanup
Browse files Browse the repository at this point in the history
types: clean up JSON unmarshaling story for extension types
  • Loading branch information
patjakdev authored Nov 12, 2024
2 parents 96c854e + 3b282ac commit c7cc1e3
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 543 deletions.
44 changes: 6 additions & 38 deletions types/datetime.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package types

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
Expand Down Expand Up @@ -247,46 +245,16 @@ func (a Datetime) String() string {

// UnmarshalJSON implements encoding/json.Unmarshaler for Datetime
//
// It is capable of unmarshaling 4 different representations supported by Cedar
// - { "__extn": { "fn": "datetime", "arg": "1970-01-01" }}
// - { "fn": "datetime", "arg": "1970-01-01" }
// - "datetime(\"1970-01-01\")"
// - "1970-01-01"
// It is capable of unmarshaling 3 different representations supported by Cedar
// - { "__extn": { "fn": "datetime", "arg": "1970-01-01" }}
// - { "fn": "datetime", "arg": "1970-01-01" }
// - "1970-01-01"
func (a *Datetime) UnmarshalJSON(b []byte) error {
var arg string
if bytes.HasPrefix(b, []byte(`"datetime(\"`)) && bytes.HasSuffix(b, []byte(`\")"`)) {
arg = string(b[12 : len(b)-4])
} else if len(b) > 0 && b[0] == '"' {
if err := json.Unmarshal(b, &arg); err != nil {
return errors.Join(errJSONDecode, err)
}
} else {
var res extValueJSON
if err := json.Unmarshal(b, &res); err != nil {
return errors.Join(errJSONDecode, err)
}
if res.Extn == nil {
// If we didn't find an Extn, maybe it's just an extn.
var res2 extn
json.Unmarshal(b, &res2)
// We've tried Ext.Fn and Fn, so no good.
if res2.Fn == "" {
return errJSONExtNotFound
}
if res2.Fn != "datetime" {
return errJSONExtFnMatch
}
arg = res2.Arg
} else if res.Extn.Fn != "datetime" {
return errJSONExtFnMatch
} else {
arg = res.Extn.Arg
}
}
aa, err := ParseDatetime(arg)
aa, err := unmarshalExtensionValue(b, "datetime", ParseDatetime)
if err != nil {
return err
}

*a = aa
return nil
}
Expand Down
22 changes: 20 additions & 2 deletions types/datetime_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package types_test

import (
"encoding/json"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -197,8 +198,25 @@ func TestDatetime(t *testing.T) {

t.Run("MarshalJSON", func(t *testing.T) {
t.Parallel()
bs, err := types.NewDatetime(time.UnixMilli(42)).MarshalJSON()
expected := `{
"__extn": {
"fn": "datetime",
"arg": "1970-01-01T00:00:00.042Z"
}
}`
dt1 := types.NewDatetime(time.UnixMilli(42))
testutil.JSONMarshalsTo(t, dt1, expected)

var dt2 types.Datetime
err := json.Unmarshal([]byte(expected), &dt2)
testutil.OK(t, err)
testutil.Equals(t, string(bs), `{"__extn":{"fn":"datetime","arg":"1970-01-01T00:00:00.042Z"}}`)
testutil.Equals(t, dt1, dt2)
})

t.Run("UnmarshalJSON/error", func(t *testing.T) {
t.Parallel()
var dt2 types.Datetime
err := json.Unmarshal([]byte("{}"), &dt2)
testutil.Error(t, err)
})
}
33 changes: 9 additions & 24 deletions types/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,34 +160,19 @@ func (d Decimal) String() string {
return res[:right]
}

// UnmarshalJSON implements encoding/json.Unmarshaler for Decimal
//
// It is capable of unmarshaling 3 different representations supported by Cedar
// - { "__extn": { "fn": "decimal", "arg": "1234.5678" }}
// - { "fn": "decimal", "arg": "1234.5678" }
// - "1234.5678"
func (d *Decimal) UnmarshalJSON(b []byte) error {
var arg string
if len(b) > 0 && b[0] == '"' {
if err := json.Unmarshal(b, &arg); err != nil {
return errors.Join(errJSONDecode, err)
}
} else {
// NOTE: cedar supports two other forms, for now we're only supporting the smallest implicit and explicit form.
// The following are not supported:
// "decimal(\"1234.5678\")"
// {"fn":"decimal","arg":"1234.5678"}
var res extValueJSON
if err := json.Unmarshal(b, &res); err != nil {
return errors.Join(errJSONDecode, err)
}
if res.Extn == nil {
return errJSONExtNotFound
}
if res.Extn.Fn != "decimal" {
return errJSONExtFnMatch
}
arg = res.Extn.Arg
}
vv, err := ParseDecimal(arg)
dd, err := unmarshalExtensionValue(b, "decimal", ParseDecimal)
if err != nil {
return err
}
*d = vv

*d = dd
return nil
}

Expand Down
25 changes: 25 additions & 0 deletions types/decimal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package types_test

import (
"encoding/json"
"fmt"
"testing"

Expand Down Expand Up @@ -287,4 +288,28 @@ func TestDecimal(t *testing.T) {
string(testutil.Must(types.NewDecimal(42, 0)).MarshalCedar()),
`decimal("42.0")`)
})

t.Run("MarshalJSON", func(t *testing.T) {
t.Parallel()
expected := `{
"__extn": {
"fn": "decimal",
"arg": "1234.5678"
}
}`
d1 := testutil.Must(types.NewDecimal(12345678, -4))
testutil.JSONMarshalsTo(t, d1, expected)

var d2 types.Decimal
err := json.Unmarshal([]byte(expected), &d2)
testutil.OK(t, err)
testutil.Equals(t, d1, d2)
})

t.Run("UnmarshalJSON/error", func(t *testing.T) {
t.Parallel()
var dt2 types.Decimal
err := json.Unmarshal([]byte("{}"), &dt2)
testutil.Error(t, err)
})
}
43 changes: 6 additions & 37 deletions types/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package types
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"math"
"strconv"
Expand Down Expand Up @@ -214,46 +213,16 @@ func (v Duration) String() string {

// UnmarshalJSON implements encoding/json.Unmarshaler for Duration
//
// It is capable of unmarshaling 4 different representations supported by Cedar
// - { "__extn": { "fn": "duration", "arg": "1h10m" }}
// - { "fn": "duration", "arg": "1h10m" }
// - "duration(\"1h10m\")"
// - "1h10m"
// It is capable of unmarshaling 3 different representations supported by Cedar
// - { "__extn": { "fn": "duration", "arg": "1h10m" }}
// - { "fn": "duration", "arg": "1h10m" }
// - "1h10m"
func (v *Duration) UnmarshalJSON(b []byte) error {
var arg string
if bytes.HasPrefix(b, []byte(`"duration(\"`)) && bytes.HasSuffix(b, []byte(`\")"`)) {
arg = string(b[12 : len(b)-4])
} else if len(b) > 0 && b[0] == '"' {
if err := json.Unmarshal(b, &arg); err != nil {
return errors.Join(errJSONDecode, err)
}
} else {
var res extValueJSON
if err := json.Unmarshal(b, &res); err != nil {
return errors.Join(errJSONDecode, err)
}
if res.Extn == nil {
// If we didn't find an Extn, maybe it's just an extn.
var res2 extn
json.Unmarshal(b, &res2)
// We've tried Ext.Fn and Fn, so no good.
if res2.Fn == "" {
return errJSONExtNotFound
}
if res2.Fn != "duration" {
return errJSONExtFnMatch
}
arg = res2.Arg
} else if res.Extn.Fn != "duration" {
return errJSONExtFnMatch
} else {
arg = res.Extn.Arg
}
}
vv, err := ParseDuration(arg)
vv, err := unmarshalExtensionValue(b, "duration", ParseDuration)
if err != nil {
return err
}

*v = vv
return nil
}
Expand Down
22 changes: 20 additions & 2 deletions types/duration_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package types_test

import (
"encoding/json"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -176,8 +177,25 @@ func TestDuration(t *testing.T) {

t.Run("MarshalJSON", func(t *testing.T) {
t.Parallel()
bs, err := types.NewDuration(42 * time.Millisecond).MarshalJSON()
expected := `{
"__extn": {
"fn": "duration",
"arg": "42ms"
}
}`
d1 := types.NewDuration(42 * time.Millisecond)
testutil.JSONMarshalsTo(t, d1, expected)

var d2 types.Duration
err := json.Unmarshal([]byte(expected), &d2)
testutil.OK(t, err)
testutil.Equals(t, string(bs), `{"__extn":{"fn":"duration","arg":"42ms"}}`)
testutil.Equals(t, d1, d2)
})

t.Run("UnmarshalJSON/error", func(t *testing.T) {
t.Parallel()
var dt2 types.Duration
err := json.Unmarshal([]byte("{}"), &dt2)
testutil.Error(t, err)
})
}
32 changes: 8 additions & 24 deletions types/ipaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package types

import (
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"net/netip"
Expand Down Expand Up @@ -104,33 +103,18 @@ func (c IPAddr) Contains(o IPAddr) bool {
return c.Prefix().Contains(o.Addr()) && c.Prefix().Bits() <= o.Prefix().Bits()
}

// UnmarshalJSON implements encoding/json.Unmarshaler for IPAddr
//
// It is capable of unmarshaling 3 different representations supported by Cedar
// - { "__extn": { "fn": "ip", "arg": "12.34.56.78" }}
// - { "fn": "ip", "arg": "12.34.56.78" }
// - "12.34.56.78"
func (v *IPAddr) UnmarshalJSON(b []byte) error {
var arg string
if len(b) > 0 && b[0] == '"' {
if err := json.Unmarshal(b, &arg); err != nil {
return errors.Join(errJSONDecode, err)
}
} else {
// NOTE: cedar supports two other forms, for now we're only supporting the smallest implicit explicit form.
// The following are not supported:
// "ip(\"192.168.0.42\")"
// {"fn":"ip","arg":"192.168.0.42"}
var res extValueJSON
if err := json.Unmarshal(b, &res); err != nil {
return errors.Join(errJSONDecode, err)
}
if res.Extn == nil {
return errJSONExtNotFound
}
if res.Extn.Fn != "ip" {
return errJSONExtFnMatch
}
arg = res.Extn.Arg
}
vv, err := ParseIPAddr(arg)
vv, err := unmarshalExtensionValue(b, "ip", ParseIPAddr)
if err != nil {
return err
}

*v = vv
return nil
}
Expand Down
24 changes: 24 additions & 0 deletions types/ipaddr_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package types_test

import (
"encoding/json"
"fmt"
"testing"

Expand Down Expand Up @@ -277,4 +278,27 @@ func TestIP(t *testing.T) {
`ip("10.0.0.42")`)
})

t.Run("MarshalJSON", func(t *testing.T) {
t.Parallel()
expected := `{
"__extn": {
"fn": "ip",
"arg": "12.34.56.78"
}
}`
i1 := testutil.Must(types.ParseIPAddr("12.34.56.78"))
testutil.JSONMarshalsTo(t, i1, expected)

var i2 types.IPAddr
err := json.Unmarshal([]byte(expected), &i2)
testutil.OK(t, err)
testutil.Equals(t, i1, i2)
})

t.Run("UnmarshalJSON/error", func(t *testing.T) {
t.Parallel()
var dt2 types.IPAddr
err := json.Unmarshal([]byte("{}"), &dt2)
testutil.Error(t, err)
})
}
Loading

0 comments on commit c7cc1e3

Please sign in to comment.