From 3a285fae2c5f0f2c83a76546317793eb47db6551 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Thu, 25 Apr 2024 16:11:35 +0800 Subject: [PATCH] feat: date & timestamp type support Timestamp is time.Time in go; and date is NullDate, a defined struct in openmldb go sdk. --- conn.go | 28 ++++++++++++++++++++++++---- conn_test.go | 4 ++-- driver.go | 37 +++++++++++++++++++------------------ go_sdk_test.go | 18 ++++++++++++------ types.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 105 insertions(+), 30 deletions(-) create mode 100644 types.go diff --git a/conn.go b/conn.go index 22a97d5..8ec75a8 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "strings" + "time" ) // compile time validation that our types implements the expected interfaces @@ -86,6 +87,7 @@ type respDataRows struct { // slice. If a particular column name isn't known, an empty // string should be returned for that entry. func (r respDataRows) Columns() []string { + // FIXME(someone): current impl returns schema list, not name of columns return make([]string, len(r.Schema)) } @@ -129,7 +131,7 @@ type queryInput struct { Data []interfaces.Value `json:"data"` } -func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error) { +func marshalQueryRequest(mode, sql string, input ...interfaces.Value) ([]byte, error) { req := queryReq{ Mode: mode, SQL: sql, @@ -153,6 +155,10 @@ func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error) schema[i] = "double" case string: schema[i] = "string" + case time.Time: + schema[i] = "timestamp" + case NullDate: + schema[i] = "date" default: return nil, fmt.Errorf("unknown type at index %d", i) } @@ -166,7 +172,7 @@ func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error) return json.Marshal(req) } -func parseRespFromJson(respBody io.Reader) (*queryResp, error) { +func unmarshalQueryResponse(respBody io.Reader) (*queryResp, error) { var r queryResp if err := json.NewDecoder(respBody).Decode(&r); err != nil { return nil, err @@ -190,6 +196,20 @@ func parseRespFromJson(respBody io.Reader) (*queryResp, error) { row[i] = float64(col.(float64)) case "string": row[i] = col.(string) + case "timestamp": + // timestamp value returned as int64 millisecond unix epoch time + row[i] = time.UnixMilli(int64(col.(float64))) + case "date": + // date values returned as "YYYY-mm-dd" formated string + var nullDate NullDate + dt, err := time.Parse(time.DateOnly, col.(string)) + if err != nil { + nullDate.Valid = false + } else { + nullDate.Time = dt + nullDate.Valid = true + } + row[i] = nullDate default: return nil, fmt.Errorf("unknown type %s at index %d", r.Data.Schema[i], i) } @@ -205,7 +225,7 @@ func (c *conn) execute(ctx context.Context, sql string, parameters ...interfaces return nil, interfaces.ErrBadConn } - reqBody, err := parseReqToJson(string(c.mode), sql, parameters...) + reqBody, err := marshalQueryRequest(string(c.mode), sql, parameters...) if err != nil { return nil, err } @@ -227,7 +247,7 @@ func (c *conn) execute(ctx context.Context, sql string, parameters ...interfaces return nil, err } - if r, err := parseRespFromJson(resp.Body); err != nil { + if r, err := unmarshalQueryResponse(resp.Body); err != nil { return nil, err } else if r.Code != 0 { return nil, fmt.Errorf("conn error: %s", r.Msg) diff --git a/conn_test.go b/conn_test.go index b250882..560c7a5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -38,7 +38,7 @@ func TestParseReqToJson(t *testing.T) { }`, }, } { - actual, err := parseReqToJson(tc.mode, tc.sql, tc.input...) + actual, err := marshalQueryRequest(tc.mode, tc.sql, tc.input...) assert.NoError(t, err) assert.JSONEq(t, tc.expect, string(actual)) } @@ -102,7 +102,7 @@ func TestParseRespFromJson(t *testing.T) { }, }, } { - actual, err := parseRespFromJson(strings.NewReader(tc.resp)) + actual, err := unmarshalQueryResponse(strings.NewReader(tc.resp)) assert.NoError(t, err) assert.Equal(t, &tc.expect, actual) } diff --git a/driver.go b/driver.go index 396cf6b..b3eb8e1 100644 --- a/driver.go +++ b/driver.go @@ -9,18 +9,19 @@ import ( "strings" ) - func init() { - sql.Register("openmldb", &driver{}) + sql.Register("openmldb", &openmldbDriver{}) } + +// compile time validation that our types implements the expected interfaces var ( - _ interfaces.Driver = (*driver)(nil) - _ interfaces.DriverContext = (*driver)(nil) + _ interfaces.Driver = openmldbDriver{} + _ interfaces.DriverContext = openmldbDriver{} - _ interfaces.Connector = (*connecter)(nil) + _ interfaces.Connector = connecter{} ) -type driver struct{} +type openmldbDriver struct{} func parseDsn(dsn string) (host string, db string, mode queryMode, err error) { u, err := url.Parse(dsn) @@ -51,7 +52,7 @@ func parseDsn(dsn string) (host string, db string, mode queryMode, err error) { } // Open implements driver.Driver. -func (driver) Open(name string) (interfaces.Conn, error) { +func (openmldbDriver) Open(name string) (interfaces.Conn, error) { // name should be the URL of the api server, e.g. openmldb://localhost:6543/db host, db, mode, err := parseDsn(name) if err != nil { @@ -61,6 +62,16 @@ func (driver) Open(name string) (interfaces.Conn, error) { return &conn{host: host, db: db, mode: mode, closed: false}, nil } +// OpenConnector implements driver.DriverContext. +func (openmldbDriver) OpenConnector(name string) (interfaces.Connector, error) { + host, db, mode, err := parseDsn(name) + if err != nil { + return nil, err + } + + return &connecter{host, db, mode}, nil +} + type connecter struct { host string db string @@ -78,15 +89,5 @@ func (c connecter) Connect(ctx context.Context) (interfaces.Conn, error) { // Driver implements driver.Connector. func (connecter) Driver() interfaces.Driver { - return &driver{} -} - -// OpenConnector implements driver.DriverContext. -func (driver) OpenConnector(name string) (interfaces.Connector, error) { - host, db, mode, err := parseDsn(name) - if err != nil { - return nil, err - } - - return &connecter{host, db, mode}, nil + return &openmldbDriver{} } diff --git a/go_sdk_test.go b/go_sdk_test.go index cccd7db..7484a4a 100644 --- a/go_sdk_test.go +++ b/go_sdk_test.go @@ -8,10 +8,12 @@ import ( "log" "os" "testing" + "time" // register openmldb driver - _ "github.com/4paradigm/openmldb-go-sdk" "github.com/stretchr/testify/assert" + + openmldb "github.com/4paradigm/openmldb-go-sdk" ) var apiServer string @@ -32,7 +34,7 @@ func Test_driver(t *testing.T) { assert.NoError(t, db.PingContext(ctx), "fail to ping connect") { - createTableStmt := "CREATE TABLE demo(c1 int, c2 string);" + createTableStmt := "CREATE TABLE demo(c1 int, c2 string, ts timestamp, dt date);" _, err := db.ExecContext(ctx, createTableStmt) assert.NoError(t, err, "fail to exec %s", createTableStmt) } @@ -47,28 +49,32 @@ func Test_driver(t *testing.T) { { // FIXME: ordering issue - insertValueStmt := `INSERT INTO demo VALUES (1, "bb");` + insertValueStmt := `INSERT INTO demo VALUES (1, "bb", 3000, "2022-12-12");` // insertValueStmt := `INSERT INTO demo VALUES (1, "bb"), (2, "bb");` _, err := db.ExecContext(ctx, insertValueStmt) assert.NoError(t, err, "fail to exec %s", insertValueStmt) } t.Run("query", func(t *testing.T) { - queryStmt := `SELECT c1, c2 FROM demo` + queryStmt := `SELECT * FROM demo` rows, err := db.QueryContext(ctx, queryStmt) assert.NoError(t, err, "fail to query %s", queryStmt) var demo struct { c1 int32 c2 string + ts time.Time + dt openmldb.NullDate } { assert.True(t, rows.Next()) - assert.NoError(t, rows.Scan(&demo.c1, &demo.c2)) + assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt)) assert.Equal(t, struct { c1 int32 c2 string - }{1, "bb"}, demo) + ts time.Time + dt openmldb.NullDate + }{1, "bb", time.UnixMilli(3000), openmldb.NullDate{Time: time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, demo) } // { // assert.True(t, rows.Next()) diff --git a/types.go b/types.go new file mode 100644 index 0000000..a574fa8 --- /dev/null +++ b/types.go @@ -0,0 +1,48 @@ +package openmldb + +import ( + "database/sql" + "database/sql/driver" + "errors" + "time" +) + +var ( + _ sql.Scanner = (*NullDate)(nil) +) + +type NullDate struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements sql.Scanner for NullDate +func (dt *NullDate) Scan(src any) error { + switch val := src.(type) { + case string: + dval, err := time.Parse(time.DateOnly, val) + if err != nil { + dt.Valid = false + return err + } else { + dt.Time = dval + dt.Valid = true + return nil + } + case NullDate: + *dt = val + return nil + default: + return errors.New("scan NullDate from unsupported type") + } + +} + +// Value implements driver.Value for NullDate +func (dt NullDate) Value() (driver.Value, error) { + if !dt.Valid { + return nil, nil + } + return dt.Time, nil + +}