Skip to content

Commit

Permalink
feat: date & timestamp type (#7)
Browse files Browse the repository at this point in the history
1. support date & timestamp type
   you can use `time.Time` to represent SQL timestamp type, and `openmldb.NullDate` to represent date type
2. simply `mode` in DSN
    only `online` & `offline`, and later option `request`, rm `offsync` or `offasync`
3. deps: upgrade testify to v1.9.0
  • Loading branch information
aceforeverd authored Apr 26, 2024
1 parent c046517 commit 49f51dd
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 77 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ For example, to open a database to `test_db` by api server at `127.0.0.1:8080`:
db, err := sql.Open("openmldb", "openmldb://127.0.0.1:8080/test_db")
```

`<DB_NAME>` is mandatory in DSN, and at this time (version 0.2.0), you must ensure the database `<DB_NAME>` created before open go connection.

## Getting Start

```go
Expand Down
100 changes: 60 additions & 40 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,42 @@ package openmldb
import (
"bytes"
"context"
interfaces "database/sql/driver"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
)

// compile time validation that our types implements the expected interfaces
var (
_ interfaces.Conn = (*conn)(nil)
_ driver.Conn = (*conn)(nil)

// All Conn implementations should implement the following interfaces:
// Pinger, SessionResetter, and Validator.

_ interfaces.Pinger = (*conn)(nil)
_ interfaces.SessionResetter = (*conn)(nil)
_ interfaces.Validator = (*conn)(nil)
_ driver.Pinger = (*conn)(nil)
_ driver.SessionResetter = (*conn)(nil)
_ driver.Validator = (*conn)(nil)

// If named parameters or context are supported, the driver's Conn should implement:
// ExecerContext, QueryerContext, ConnPrepareContext, and ConnBeginTx.

_ interfaces.ExecerContext = (*conn)(nil)
_ interfaces.QueryerContext = (*conn)(nil)
_ driver.ExecerContext = (*conn)(nil)
_ driver.QueryerContext = (*conn)(nil)

_ interfaces.Rows = (*respDataRows)(nil)
_ driver.Rows = (*respDataRows)(nil)
)

type queryMode string

func (m queryMode) String() string {
switch m {
case ModeOffsync:
return "offsync"
case ModeOffasync:
return "offasync"
case ModeOffline:
return "offline"
case ModeOnline:
return "online"
default:
Expand All @@ -47,15 +47,14 @@ func (m queryMode) String() string {
}

const (
ModeOffsync queryMode = "offsync"
ModeOffasync queryMode = "offasync"
ModeOnline queryMode = "online"
ModeOffline queryMode = "offline"
ModeOnline queryMode = "online"
// TODO(someone): "request"
)

var allQueryMode = map[string]queryMode{
"offsync": ModeOffsync,
"offasync": ModeOffasync,
"online": ModeOnline,
"offline": ModeOffline,
"online": ModeOnline,
}

type conn struct {
Expand All @@ -73,29 +72,36 @@ type queryResp struct {

type respData struct {
Schema []string `json:"schema"`
Data [][]interfaces.Value `json:"data"`
Data [][]driver.Value `json:"data"`
}

type respDataRows struct {
respData
i int
}

// Columns returns the names of the columns. The number of
// Columns implements driver.Rows.
//
// Returns the names of the columns. The number of
// columns of the result is inferred from the length of the
// slice. If a particular column name isn't known, an empty
// string should be returned for that entry.
func (r respDataRows) Columns() []string {
// FIXME(someone): current impl returns schema list, not name of columns
return make([]string, len(r.Schema))
}

// Close closes the rows iterator.
// Close implements driver.Rows.
//
// closes the rows iterator.
func (r *respDataRows) Close() error {
r.i = len(r.Data)
return nil
}

// Next is called to populate the next row of data into
// Next implements driver.Rows.
//
// called to populate the next row of data into
// the provided slice. The provided slice will be the same
// size as the Columns() are wide.
//
Expand All @@ -104,7 +110,7 @@ func (r *respDataRows) Close() error {
// The dest should not be written to outside of Next. Care
// should be taken when closing Rows not to modify
// a buffer held in dest.
func (r *respDataRows) Next(dest []interfaces.Value) error {
func (r *respDataRows) Next(dest []driver.Value) error {
if r.i >= len(r.Data) {
return io.EOF
}
Expand All @@ -122,10 +128,10 @@ type queryReq struct {

type queryInput struct {
Schema []string `json:"schema"`
Data []interfaces.Value `json:"data"`
Data []driver.Value `json:"data"`
}

func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error) {
func marshalQueryRequest(mode, sql string, input ...driver.Value) ([]byte, error) {
req := queryReq{
Mode: mode,
SQL: sql,
Expand All @@ -149,6 +155,10 @@ func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error)
schema[i] = "double"
case string:
schema[i] = "string"
case time.Time:
schema[i] = "timestamp"
case NullDate:
schema[i] = "date"
default:
return nil, fmt.Errorf("unknown type at index %d", i)
}
Expand All @@ -162,7 +172,7 @@ func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error)
return json.Marshal(req)
}

func parseRespFromJson(respBody io.Reader) (*queryResp, error) {
func unmarshalQueryResponse(respBody io.Reader) (*queryResp, error) {
var r queryResp
if err := json.NewDecoder(respBody).Decode(&r); err != nil {
return nil, err
Expand All @@ -186,6 +196,14 @@ func parseRespFromJson(respBody io.Reader) (*queryResp, error) {
row[i] = float64(col.(float64))
case "string":
row[i] = col.(string)
case "timestamp":
// timestamp value returned as int64 millisecond unix epoch time
row[i] = time.UnixMilli(int64(col.(float64)))
case "date":
// date values returned as "YYYY-mm-dd" formated string
var nullDate NullDate
nullDate.Scan(col.(string))
row[i] = nullDate
default:
return nil, fmt.Errorf("unknown type %s at index %d", r.Data.Schema[i], i)
}
Expand All @@ -196,16 +214,18 @@ func parseRespFromJson(respBody io.Reader) (*queryResp, error) {
return &r, nil
}

func (c *conn) query(ctx context.Context, sql string, parameters ...interfaces.Value) (rows interfaces.Rows, err error) {
func (c *conn) execute(ctx context.Context, sql string, parameters ...driver.Value) (rows driver.Rows, err error) {
if c.closed {
return nil, interfaces.ErrBadConn
return nil, driver.ErrBadConn
}

reqBody, err := parseReqToJson(string(c.mode), sql, parameters...)
reqBody, err := marshalQueryRequest(string(c.mode), sql, parameters...)
if err != nil {
return nil, err
}

// POST endpoint/dbs/<db_name> is capable of all SQL, though it looks like
// a query API returns rows
req, err := http.NewRequestWithContext(
ctx,
"POST",
Expand All @@ -221,7 +241,7 @@ func (c *conn) query(ctx context.Context, sql string, parameters ...interfaces.V
return nil, err
}

if r, err := parseRespFromJson(resp.Body); err != nil {
if r, err := unmarshalQueryResponse(resp.Body); err != nil {
return nil, err
} else if r.Code != 0 {
return nil, fmt.Errorf("conn error: %s", r.Msg)
Expand All @@ -233,7 +253,7 @@ func (c *conn) query(ctx context.Context, sql string, parameters ...interfaces.V
}

// Prepare implements driver.Conn.
func (c *conn) Prepare(query string) (interfaces.Stmt, error) {
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return nil, errors.New("Prepare is not implemented, use QueryContext instead")
}

Expand All @@ -244,13 +264,13 @@ func (c *conn) Close() error {
}

// Begin implements driver.Conn.
func (c *conn) Begin() (interfaces.Tx, error) {
func (c *conn) Begin() (driver.Tx, error) {
return nil, errors.New("begin not implemented")
}

// Ping implements driver.Pinger.
func (c *conn) Ping(ctx context.Context) error {
_, err := c.query(ctx, "SELECT 1")
_, err := c.execute(ctx, "SELECT 1")
return err
}

Expand All @@ -269,22 +289,22 @@ func (c *conn) IsValid() bool {
}

// ExecContext implements driver.ExecerContext.
func (c *conn) ExecContext(ctx context.Context, query string, args []interfaces.NamedValue) (interfaces.Result, error) {
parameters := make([]interfaces.Value, len(args))
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
parameters := make([]driver.Value, len(args))
for i, arg := range args {
parameters[i] = arg.Value
}
if _, err := c.query(ctx, query, parameters...); err != nil {
if _, err := c.execute(ctx, query, parameters...); err != nil {
return nil, err
}
return interfaces.ResultNoRows, nil
return driver.ResultNoRows, nil
}

// QueryContext implements driver.QueryerContext.
func (c *conn) QueryContext(ctx context.Context, query string, args []interfaces.NamedValue) (interfaces.Rows, error) {
parameters := make([]interfaces.Value, len(args))
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
parameters := make([]driver.Value, len(args))
for i, arg := range args {
parameters[i] = arg.Value
}
return c.query(ctx, query, parameters...)
return c.execute(ctx, query, parameters...)
}
4 changes: 2 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestParseReqToJson(t *testing.T) {
}`,
},
} {
actual, err := parseReqToJson(tc.mode, tc.sql, tc.input...)
actual, err := marshalQueryRequest(tc.mode, tc.sql, tc.input...)
assert.NoError(t, err)
assert.JSONEq(t, tc.expect, string(actual))
}
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestParseRespFromJson(t *testing.T) {
},
},
} {
actual, err := parseRespFromJson(strings.NewReader(tc.resp))
actual, err := unmarshalQueryResponse(strings.NewReader(tc.resp))
assert.NoError(t, err)
assert.Equal(t, &tc.expect, actual)
}
Expand Down
41 changes: 21 additions & 20 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,25 @@ package openmldb
import (
"context"
"database/sql"
interfaces "database/sql/driver"
"database/sql/driver"
"fmt"
"net/url"
"strings"
)

func init() {
sql.Register("openmldb", &driver{})
sql.Register("openmldb", &openmldbDriver{})
}

// compile time validation that our types implements the expected interfaces
var (
_ interfaces.Driver = (*driver)(nil)
_ interfaces.DriverContext = (*driver)(nil)
_ driver.Driver = openmldbDriver{}
_ driver.DriverContext = openmldbDriver{}

_ interfaces.Connector = (*connecter)(nil)
_ driver.Connector = connecter{}
)

type driver struct{}
type openmldbDriver struct{}

func parseDsn(dsn string) (host string, db string, mode queryMode, err error) {
u, err := url.Parse(dsn)
Expand Down Expand Up @@ -51,7 +52,7 @@ func parseDsn(dsn string) (host string, db string, mode queryMode, err error) {
}

// Open implements driver.Driver.
func (driver) Open(name string) (interfaces.Conn, error) {
func (openmldbDriver) Open(name string) (driver.Conn, error) {
// name should be the URL of the api server, e.g. openmldb://localhost:6543/db
host, db, mode, err := parseDsn(name)
if err != nil {
Expand All @@ -61,14 +62,24 @@ func (driver) Open(name string) (interfaces.Conn, error) {
return &conn{host: host, db: db, mode: mode, closed: false}, nil
}

// OpenConnector implements driver.DriverContext.
func (openmldbDriver) OpenConnector(name string) (driver.Connector, error) {
host, db, mode, err := parseDsn(name)
if err != nil {
return nil, err
}

return &connecter{host, db, mode}, nil
}

type connecter struct {
host string
db string
mode queryMode
}

// Connect implements driver.Connector.
func (c connecter) Connect(ctx context.Context) (interfaces.Conn, error) {
func (c connecter) Connect(ctx context.Context) (driver.Conn, error) {
conn := &conn{host: c.host, db: c.db, mode: c.mode, closed: false}
if err := conn.Ping(ctx); err != nil {
return nil, err
Expand All @@ -77,16 +88,6 @@ func (c connecter) Connect(ctx context.Context) (interfaces.Conn, error) {
}

// Driver implements driver.Connector.
func (connecter) Driver() interfaces.Driver {
return &driver{}
}

// OpenConnector implements driver.DriverContext.
func (driver) OpenConnector(name string) (interfaces.Connector, error) {
host, db, mode, err := parseDsn(name)
if err != nil {
return nil, err
}

return &connecter{host, db, mode}, nil
func (connecter) Driver() driver.Driver {
return &openmldbDriver{}
}
2 changes: 1 addition & 1 deletion driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func Test_parseDsn(t *testing.T) {
}{
{"openmldb://127.0.0.1:8080/test_db", "127.0.0.1:8080", "test_db", ModeOnline, nil},
{"openmldb://127.0.0.1:8080/test_db?mode=online", "127.0.0.1:8080", "test_db", ModeOnline, nil},
{"openmldb://127.0.0.1:8080/test_db?mode=offasync", "127.0.0.1:8080", "test_db", ModeOffasync, nil},
{"openmldb://127.0.0.1:8080/test_db?mode=offline", "127.0.0.1:8080", "test_db", ModeOffline, nil},
{"openmldb://127.0.0.1:8080/test_db?mode=unknown", "127.0.0.1:8080", "test_db", "", errors.New("")},
} {
host, db, mode, err := parseDsn(tc.dsn)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/4paradigm/openmldb-go-sdk

go 1.18

require github.com/stretchr/testify v1.8.0
require github.com/stretchr/testify v1.9.0

require (
github.com/davecgh/go-spew v1.1.1 // indirect
Expand Down
Loading

0 comments on commit 49f51dd

Please sign in to comment.