diff --git a/null.go b/null.go index ab143b8..cc603e6 100644 --- a/null.go +++ b/null.go @@ -4,6 +4,7 @@ import ( "database/sql" "database/sql/driver" "encoding/json" + "fmt" ) // Int is an int that will write as null when it is zero both to databases and json @@ -136,3 +137,76 @@ func (s String) Value() (driver.Value, error) { } return string(s), nil } + +// StringMap is a one level deep dictionary that is represented as JSON text in the database. +// Empty maps will be written as null to the database and to JSON. +type StringMap struct { + m map[string]string +} + +// NewStringMap creates a new StringMap +func NewStringMap(m map[string]string) StringMap { + return StringMap{m: m} +} + +// Map returns our underlying map +func (m *StringMap) Map() map[string]string { + if m.m == nil { + m.m = make(map[string]string) + } + return m.m +} + +// Scan implements the Scanner interface for decoding from a database +func (m *StringMap) Scan(src interface{}) error { + m.m = make(map[string]string) + if src == nil { + return nil + } + + var source []byte + switch src.(type) { + case string: + source = []byte(src.(string)) + case []byte: + source = src.([]byte) + default: + return fmt.Errorf("incompatible type for map") + } + + // 0 length string is same as nil + if len(source) == 0 { + return nil + } + + err := json.Unmarshal(source, &m.m) + if err != nil { + return err + } + return nil +} + +// Value implements the driver Valuer interface +func (m StringMap) Value() (driver.Value, error) { + if m.m == nil || len(m.m) == 0 { + return nil, nil + } + return json.Marshal(m.m) +} + +// MarshalJSON encodes our map to JSON +func (m StringMap) MarshalJSON() ([]byte, error) { + if m.m == nil || len(m.m) == 0 { + return json.Marshal(nil) + } + return json.Marshal(m.m) +} + +// UnmarshalJSON sets our map from the passed in JSON +func (m *StringMap) UnmarshalJSON(data []byte) error { + m.m = make(map[string]string) + if len(data) == 0 { + return nil + } + return json.Unmarshal(data, &m.m) +} diff --git a/null_test.go b/null_test.go index db81013..ca846db 100644 --- a/null_test.go +++ b/null_test.go @@ -297,3 +297,69 @@ func TestString(t *testing.T) { assert.True(t, tc.Value == str) } } + +func TestMap(t *testing.T) { + db, err := sql.Open("postgres", "postgres://localhost/null_test?sslmode=disable") + assert.NoError(t, err) + + _, err = db.Exec(`DROP TABLE IF EXISTS map; CREATE TABLE map(value varchar(255) null);`) + assert.NoError(t, err) + + sp := func(s string) *string { + return &s + } + + tcs := []struct { + Value StringMap + JSON string + DB *string + }{ + {NewStringMap(map[string]string{"foo": "bar"}), `{"foo":"bar"}`, sp(`{"foo": "bar"}`)}, + {NewStringMap(map[string]string{}), "null", nil}, + {NewStringMap(nil), "null", nil}, + {NewStringMap(nil), "null", sp("")}, + } + + for i, tc := range tcs { + _, err = db.Exec(`DELETE FROM map;`) + assert.NoError(t, err) + + b, err := json.Marshal(tc.Value) + assert.NoError(t, err) + assert.Equal(t, tc.JSON, string(b), "%d: %s not equal to %s", i, tc.JSON, string(b)) + + m := StringMap{} + err = json.Unmarshal(b, &m) + assert.NoError(t, err) + assert.Equal(t, tc.Value.Map(), m.Map(), "%d: %s not equal to %s", i, tc.Value, m) + + _, err = db.Exec(`INSERT INTO map(value) VALUES($1)`, tc.Value) + assert.NoError(t, err) + + rows, err := db.Query(`SELECT value FROM map;`) + assert.NoError(t, err) + + m2 := StringMap{} + assert.True(t, rows.Next()) + err = rows.Scan(&m2) + assert.NoError(t, err) + + assert.Equal(t, tc.Value.Map(), m2.Map()) + + _, err = db.Exec(`DELETE FROM map;`) + assert.NoError(t, err) + + _, err = db.Exec(`INSERT INTO map(value) VALUES($1)`, tc.DB) + assert.NoError(t, err) + + rows, err = db.Query(`SELECT value FROM map;`) + assert.NoError(t, err) + + m2 = StringMap{} + assert.True(t, rows.Next()) + err = rows.Scan(&m2) + assert.NoError(t, err) + + assert.Equal(t, tc.Value.Map(), m2.Map()) + } +}