diff --git a/README.md b/README.md index de761a6..b116e11 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ For example, to open a database to `test_db` by api server at `127.0.0.1:8080`: db, err := sql.Open("openmldb", "openmldb://127.0.0.1:8080/test_db") ``` +`` is mandatory in DSN, and at this time (version 0.2.0), you must ensure the database `` created before open go connection. + ## Getting Start ```go diff --git a/conn.go b/conn.go index 13550a5..a2e75c9 100644 --- a/conn.go +++ b/conn.go @@ -3,42 +3,42 @@ package openmldb import ( "bytes" "context" - interfaces "database/sql/driver" + "database/sql/driver" "encoding/json" "errors" "fmt" "io" "net/http" "strings" + "time" ) +// compile time validation that our types implements the expected interfaces var ( - _ interfaces.Conn = (*conn)(nil) + _ driver.Conn = (*conn)(nil) // All Conn implementations should implement the following interfaces: // Pinger, SessionResetter, and Validator. - _ interfaces.Pinger = (*conn)(nil) - _ interfaces.SessionResetter = (*conn)(nil) - _ interfaces.Validator = (*conn)(nil) + _ driver.Pinger = (*conn)(nil) + _ driver.SessionResetter = (*conn)(nil) + _ driver.Validator = (*conn)(nil) // If named parameters or context are supported, the driver's Conn should implement: // ExecerContext, QueryerContext, ConnPrepareContext, and ConnBeginTx. - _ interfaces.ExecerContext = (*conn)(nil) - _ interfaces.QueryerContext = (*conn)(nil) + _ driver.ExecerContext = (*conn)(nil) + _ driver.QueryerContext = (*conn)(nil) - _ interfaces.Rows = (*respDataRows)(nil) + _ driver.Rows = (*respDataRows)(nil) ) type queryMode string func (m queryMode) String() string { switch m { - case ModeOffsync: - return "offsync" - case ModeOffasync: - return "offasync" + case ModeOffline: + return "offline" case ModeOnline: return "online" default: @@ -47,15 +47,14 @@ func (m queryMode) String() string { } const ( - ModeOffsync queryMode = "offsync" - ModeOffasync queryMode = "offasync" - ModeOnline queryMode = "online" + ModeOffline queryMode = "offline" + ModeOnline queryMode = "online" + // TODO(someone): "request" ) var allQueryMode = map[string]queryMode{ - "offsync": ModeOffsync, - "offasync": ModeOffasync, - "online": ModeOnline, + "offline": ModeOffline, + "online": ModeOnline, } type conn struct { @@ -73,7 +72,7 @@ type queryResp struct { type respData struct { Schema []string `json:"schema"` - Data [][]interfaces.Value `json:"data"` + Data [][]driver.Value `json:"data"` } type respDataRows struct { @@ -81,21 +80,28 @@ type respDataRows struct { i int } -// Columns returns the names of the columns. The number of +// Columns implements driver.Rows. +// +// Returns the names of the columns. The number of // columns of the result is inferred from the length of the // slice. If a particular column name isn't known, an empty // string should be returned for that entry. func (r respDataRows) Columns() []string { + // FIXME(someone): current impl returns schema list, not name of columns return make([]string, len(r.Schema)) } -// Close closes the rows iterator. +// Close implements driver.Rows. +// +// closes the rows iterator. func (r *respDataRows) Close() error { r.i = len(r.Data) return nil } -// Next is called to populate the next row of data into +// Next implements driver.Rows. +// +// called to populate the next row of data into // the provided slice. The provided slice will be the same // size as the Columns() are wide. // @@ -104,7 +110,7 @@ func (r *respDataRows) Close() error { // The dest should not be written to outside of Next. Care // should be taken when closing Rows not to modify // a buffer held in dest. -func (r *respDataRows) Next(dest []interfaces.Value) error { +func (r *respDataRows) Next(dest []driver.Value) error { if r.i >= len(r.Data) { return io.EOF } @@ -122,10 +128,10 @@ type queryReq struct { type queryInput struct { Schema []string `json:"schema"` - Data []interfaces.Value `json:"data"` + Data []driver.Value `json:"data"` } -func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error) { +func marshalQueryRequest(mode, sql string, input ...driver.Value) ([]byte, error) { req := queryReq{ Mode: mode, SQL: sql, @@ -149,6 +155,10 @@ func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error) schema[i] = "double" case string: schema[i] = "string" + case time.Time: + schema[i] = "timestamp" + case NullDate: + schema[i] = "date" default: return nil, fmt.Errorf("unknown type at index %d", i) } @@ -162,7 +172,7 @@ func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error) return json.Marshal(req) } -func parseRespFromJson(respBody io.Reader) (*queryResp, error) { +func unmarshalQueryResponse(respBody io.Reader) (*queryResp, error) { var r queryResp if err := json.NewDecoder(respBody).Decode(&r); err != nil { return nil, err @@ -186,6 +196,14 @@ func parseRespFromJson(respBody io.Reader) (*queryResp, error) { row[i] = float64(col.(float64)) case "string": row[i] = col.(string) + case "timestamp": + // timestamp value returned as int64 millisecond unix epoch time + row[i] = time.UnixMilli(int64(col.(float64))) + case "date": + // date values returned as "YYYY-mm-dd" formated string + var nullDate NullDate + nullDate.Scan(col.(string)) + row[i] = nullDate default: return nil, fmt.Errorf("unknown type %s at index %d", r.Data.Schema[i], i) } @@ -196,16 +214,18 @@ func parseRespFromJson(respBody io.Reader) (*queryResp, error) { return &r, nil } -func (c *conn) query(ctx context.Context, sql string, parameters ...interfaces.Value) (rows interfaces.Rows, err error) { +func (c *conn) execute(ctx context.Context, sql string, parameters ...driver.Value) (rows driver.Rows, err error) { if c.closed { - return nil, interfaces.ErrBadConn + return nil, driver.ErrBadConn } - reqBody, err := parseReqToJson(string(c.mode), sql, parameters...) + reqBody, err := marshalQueryRequest(string(c.mode), sql, parameters...) if err != nil { return nil, err } + // POST endpoint/dbs/ is capable of all SQL, though it looks like + // a query API returns rows req, err := http.NewRequestWithContext( ctx, "POST", @@ -221,7 +241,7 @@ func (c *conn) query(ctx context.Context, sql string, parameters ...interfaces.V return nil, err } - if r, err := parseRespFromJson(resp.Body); err != nil { + if r, err := unmarshalQueryResponse(resp.Body); err != nil { return nil, err } else if r.Code != 0 { return nil, fmt.Errorf("conn error: %s", r.Msg) @@ -233,7 +253,7 @@ func (c *conn) query(ctx context.Context, sql string, parameters ...interfaces.V } // Prepare implements driver.Conn. -func (c *conn) Prepare(query string) (interfaces.Stmt, error) { +func (c *conn) Prepare(query string) (driver.Stmt, error) { return nil, errors.New("Prepare is not implemented, use QueryContext instead") } @@ -244,13 +264,13 @@ func (c *conn) Close() error { } // Begin implements driver.Conn. -func (c *conn) Begin() (interfaces.Tx, error) { +func (c *conn) Begin() (driver.Tx, error) { return nil, errors.New("begin not implemented") } // Ping implements driver.Pinger. func (c *conn) Ping(ctx context.Context) error { - _, err := c.query(ctx, "SELECT 1") + _, err := c.execute(ctx, "SELECT 1") return err } @@ -269,22 +289,22 @@ func (c *conn) IsValid() bool { } // ExecContext implements driver.ExecerContext. -func (c *conn) ExecContext(ctx context.Context, query string, args []interfaces.NamedValue) (interfaces.Result, error) { - parameters := make([]interfaces.Value, len(args)) +func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + parameters := make([]driver.Value, len(args)) for i, arg := range args { parameters[i] = arg.Value } - if _, err := c.query(ctx, query, parameters...); err != nil { + if _, err := c.execute(ctx, query, parameters...); err != nil { return nil, err } - return interfaces.ResultNoRows, nil + return driver.ResultNoRows, nil } // QueryContext implements driver.QueryerContext. -func (c *conn) QueryContext(ctx context.Context, query string, args []interfaces.NamedValue) (interfaces.Rows, error) { - parameters := make([]interfaces.Value, len(args)) +func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + parameters := make([]driver.Value, len(args)) for i, arg := range args { parameters[i] = arg.Value } - return c.query(ctx, query, parameters...) + return c.execute(ctx, query, parameters...) } diff --git a/conn_test.go b/conn_test.go index b250882..560c7a5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -38,7 +38,7 @@ func TestParseReqToJson(t *testing.T) { }`, }, } { - actual, err := parseReqToJson(tc.mode, tc.sql, tc.input...) + actual, err := marshalQueryRequest(tc.mode, tc.sql, tc.input...) assert.NoError(t, err) assert.JSONEq(t, tc.expect, string(actual)) } @@ -102,7 +102,7 @@ func TestParseRespFromJson(t *testing.T) { }, }, } { - actual, err := parseRespFromJson(strings.NewReader(tc.resp)) + actual, err := unmarshalQueryResponse(strings.NewReader(tc.resp)) assert.NoError(t, err) assert.Equal(t, &tc.expect, actual) } diff --git a/driver.go b/driver.go index 5c76269..4f5a12e 100644 --- a/driver.go +++ b/driver.go @@ -3,24 +3,25 @@ package openmldb import ( "context" "database/sql" - interfaces "database/sql/driver" + "database/sql/driver" "fmt" "net/url" "strings" ) func init() { - sql.Register("openmldb", &driver{}) + sql.Register("openmldb", &openmldbDriver{}) } +// compile time validation that our types implements the expected interfaces var ( - _ interfaces.Driver = (*driver)(nil) - _ interfaces.DriverContext = (*driver)(nil) + _ driver.Driver = openmldbDriver{} + _ driver.DriverContext = openmldbDriver{} - _ interfaces.Connector = (*connecter)(nil) + _ driver.Connector = connecter{} ) -type driver struct{} +type openmldbDriver struct{} func parseDsn(dsn string) (host string, db string, mode queryMode, err error) { u, err := url.Parse(dsn) @@ -51,7 +52,7 @@ func parseDsn(dsn string) (host string, db string, mode queryMode, err error) { } // Open implements driver.Driver. -func (driver) Open(name string) (interfaces.Conn, error) { +func (openmldbDriver) Open(name string) (driver.Conn, error) { // name should be the URL of the api server, e.g. openmldb://localhost:6543/db host, db, mode, err := parseDsn(name) if err != nil { @@ -61,6 +62,16 @@ func (driver) Open(name string) (interfaces.Conn, error) { return &conn{host: host, db: db, mode: mode, closed: false}, nil } +// OpenConnector implements driver.DriverContext. +func (openmldbDriver) OpenConnector(name string) (driver.Connector, error) { + host, db, mode, err := parseDsn(name) + if err != nil { + return nil, err + } + + return &connecter{host, db, mode}, nil +} + type connecter struct { host string db string @@ -68,7 +79,7 @@ type connecter struct { } // Connect implements driver.Connector. -func (c connecter) Connect(ctx context.Context) (interfaces.Conn, error) { +func (c connecter) Connect(ctx context.Context) (driver.Conn, error) { conn := &conn{host: c.host, db: c.db, mode: c.mode, closed: false} if err := conn.Ping(ctx); err != nil { return nil, err @@ -77,16 +88,6 @@ func (c connecter) Connect(ctx context.Context) (interfaces.Conn, error) { } // Driver implements driver.Connector. -func (connecter) Driver() interfaces.Driver { - return &driver{} -} - -// OpenConnector implements driver.DriverContext. -func (driver) OpenConnector(name string) (interfaces.Connector, error) { - host, db, mode, err := parseDsn(name) - if err != nil { - return nil, err - } - - return &connecter{host, db, mode}, nil +func (connecter) Driver() driver.Driver { + return &openmldbDriver{} } diff --git a/driver_test.go b/driver_test.go index 8ced88c..56c426b 100644 --- a/driver_test.go +++ b/driver_test.go @@ -17,7 +17,7 @@ func Test_parseDsn(t *testing.T) { }{ {"openmldb://127.0.0.1:8080/test_db", "127.0.0.1:8080", "test_db", ModeOnline, nil}, {"openmldb://127.0.0.1:8080/test_db?mode=online", "127.0.0.1:8080", "test_db", ModeOnline, nil}, - {"openmldb://127.0.0.1:8080/test_db?mode=offasync", "127.0.0.1:8080", "test_db", ModeOffasync, nil}, + {"openmldb://127.0.0.1:8080/test_db?mode=offline", "127.0.0.1:8080", "test_db", ModeOffline, nil}, {"openmldb://127.0.0.1:8080/test_db?mode=unknown", "127.0.0.1:8080", "test_db", "", errors.New("")}, } { host, db, mode, err := parseDsn(tc.dsn) diff --git a/go.mod b/go.mod index 9ac3cbd..e57fa16 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/4paradigm/openmldb-go-sdk go 1.18 -require github.com/stretchr/testify v1.8.0 +require github.com/stretchr/testify v1.9.0 require ( github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 5164829..60ce688 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,10 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go_sdk_test.go b/go_sdk_test.go index cccd7db..4dff6fa 100644 --- a/go_sdk_test.go +++ b/go_sdk_test.go @@ -8,14 +8,19 @@ import ( "log" "os" "testing" + "time" // register openmldb driver - _ "github.com/4paradigm/openmldb-go-sdk" "github.com/stretchr/testify/assert" + + openmldb "github.com/4paradigm/openmldb-go-sdk" ) var apiServer string +// 1. NullTime + NullDate +// 2. Time + Time + func Test_driver(t *testing.T) { db, err := sql.Open("openmldb", fmt.Sprintf("openmldb://%s/test_db", apiServer)) if err != nil { @@ -32,7 +37,7 @@ func Test_driver(t *testing.T) { assert.NoError(t, db.PingContext(ctx), "fail to ping connect") { - createTableStmt := "CREATE TABLE demo(c1 int, c2 string);" + createTableStmt := "CREATE TABLE demo(c1 int, c2 string, ts timestamp, dt date);" _, err := db.ExecContext(ctx, createTableStmt) assert.NoError(t, err, "fail to exec %s", createTableStmt) } @@ -47,28 +52,32 @@ func Test_driver(t *testing.T) { { // FIXME: ordering issue - insertValueStmt := `INSERT INTO demo VALUES (1, "bb");` + insertValueStmt := `INSERT INTO demo VALUES (1, "bb", 3000, "2022-12-12");` // insertValueStmt := `INSERT INTO demo VALUES (1, "bb"), (2, "bb");` _, err := db.ExecContext(ctx, insertValueStmt) assert.NoError(t, err, "fail to exec %s", insertValueStmt) } t.Run("query", func(t *testing.T) { - queryStmt := `SELECT c1, c2 FROM demo` + queryStmt := `SELECT * FROM demo` rows, err := db.QueryContext(ctx, queryStmt) assert.NoError(t, err, "fail to query %s", queryStmt) var demo struct { c1 int32 c2 string + ts time.Time + dt openmldb.NullDate } { assert.True(t, rows.Next()) - assert.NoError(t, rows.Scan(&demo.c1, &demo.c2)) + assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt)) assert.Equal(t, struct { c1 int32 c2 string - }{1, "bb"}, demo) + ts time.Time + dt openmldb.NullDate + }{1, "bb", time.UnixMilli(3000), openmldb.NullDate{Time: time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, demo) } // { // assert.True(t, rows.Next()) diff --git a/types.go b/types.go new file mode 100644 index 0000000..bb009cf --- /dev/null +++ b/types.go @@ -0,0 +1,49 @@ +package openmldb + +import ( + "database/sql" + "database/sql/driver" + "errors" + "time" +) + +var ( + _ sql.Scanner = (*NullDate)(nil) + _ driver.Valuer = NullDate{} +) + +type NullDate struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements sql.Scanner for NullDate +func (dt *NullDate) Scan(src any) error { + switch val := src.(type) { + case string: + dval, err := time.Parse(time.DateOnly, val) + if err != nil { + dt.Valid = false + return err + } else { + dt.Time = dval + dt.Valid = true + return nil + } + case NullDate: + *dt = val + return nil + default: + return errors.New("scan NullDate from unsupported type") + } + +} + +// Value implements driver.Valuer for NullDate +func (dt NullDate) Value() (driver.Value, error) { + if !dt.Valid { + return nil, nil + } + return dt.Time, nil + +}