Skip to content

Commit

Permalink
feat: date & timestamp type support
Browse files Browse the repository at this point in the history
Timestamp is time.Time in go; and date is NullDate, a defined struct in
openmldb go sdk.
  • Loading branch information
aceforeverd committed Apr 25, 2024
1 parent 01e2025 commit 3a285fa
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 30 deletions.
28 changes: 24 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"net/http"
"strings"
"time"
)

// compile time validation that our types implements the expected interfaces
Expand Down Expand Up @@ -86,6 +87,7 @@ type respDataRows struct {
// slice. If a particular column name isn't known, an empty
// string should be returned for that entry.
func (r respDataRows) Columns() []string {
// FIXME(someone): current impl returns schema list, not name of columns
return make([]string, len(r.Schema))
}

Expand Down Expand Up @@ -129,7 +131,7 @@ type queryInput struct {
Data []interfaces.Value `json:"data"`
}

func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error) {
func marshalQueryRequest(mode, sql string, input ...interfaces.Value) ([]byte, error) {
req := queryReq{
Mode: mode,
SQL: sql,
Expand All @@ -153,6 +155,10 @@ func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error)
schema[i] = "double"
case string:
schema[i] = "string"
case time.Time:
schema[i] = "timestamp"
case NullDate:
schema[i] = "date"
default:
return nil, fmt.Errorf("unknown type at index %d", i)
}
Expand All @@ -166,7 +172,7 @@ func parseReqToJson(mode, sql string, input ...interfaces.Value) ([]byte, error)
return json.Marshal(req)
}

func parseRespFromJson(respBody io.Reader) (*queryResp, error) {
func unmarshalQueryResponse(respBody io.Reader) (*queryResp, error) {
var r queryResp
if err := json.NewDecoder(respBody).Decode(&r); err != nil {
return nil, err
Expand All @@ -190,6 +196,20 @@ func parseRespFromJson(respBody io.Reader) (*queryResp, error) {
row[i] = float64(col.(float64))
case "string":
row[i] = col.(string)
case "timestamp":
// timestamp value returned as int64 millisecond unix epoch time
row[i] = time.UnixMilli(int64(col.(float64)))
case "date":
// date values returned as "YYYY-mm-dd" formated string
var nullDate NullDate
dt, err := time.Parse(time.DateOnly, col.(string))
if err != nil {
nullDate.Valid = false
} else {
nullDate.Time = dt
nullDate.Valid = true
}
row[i] = nullDate
default:
return nil, fmt.Errorf("unknown type %s at index %d", r.Data.Schema[i], i)
}
Expand All @@ -205,7 +225,7 @@ func (c *conn) execute(ctx context.Context, sql string, parameters ...interfaces
return nil, interfaces.ErrBadConn
}

reqBody, err := parseReqToJson(string(c.mode), sql, parameters...)
reqBody, err := marshalQueryRequest(string(c.mode), sql, parameters...)
if err != nil {
return nil, err
}
Expand All @@ -227,7 +247,7 @@ func (c *conn) execute(ctx context.Context, sql string, parameters ...interfaces
return nil, err
}

if r, err := parseRespFromJson(resp.Body); err != nil {
if r, err := unmarshalQueryResponse(resp.Body); err != nil {
return nil, err
} else if r.Code != 0 {
return nil, fmt.Errorf("conn error: %s", r.Msg)
Expand Down
4 changes: 2 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestParseReqToJson(t *testing.T) {
}`,
},
} {
actual, err := parseReqToJson(tc.mode, tc.sql, tc.input...)
actual, err := marshalQueryRequest(tc.mode, tc.sql, tc.input...)
assert.NoError(t, err)
assert.JSONEq(t, tc.expect, string(actual))
}
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestParseRespFromJson(t *testing.T) {
},
},
} {
actual, err := parseRespFromJson(strings.NewReader(tc.resp))
actual, err := unmarshalQueryResponse(strings.NewReader(tc.resp))
assert.NoError(t, err)
assert.Equal(t, &tc.expect, actual)
}
Expand Down
37 changes: 19 additions & 18 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@ import (
"strings"
)


func init() {
sql.Register("openmldb", &driver{})
sql.Register("openmldb", &openmldbDriver{})
}

// compile time validation that our types implements the expected interfaces
var (
_ interfaces.Driver = (*driver)(nil)
_ interfaces.DriverContext = (*driver)(nil)
_ interfaces.Driver = openmldbDriver{}
_ interfaces.DriverContext = openmldbDriver{}

_ interfaces.Connector = (*connecter)(nil)
_ interfaces.Connector = connecter{}
)

type driver struct{}
type openmldbDriver struct{}

func parseDsn(dsn string) (host string, db string, mode queryMode, err error) {
u, err := url.Parse(dsn)
Expand Down Expand Up @@ -51,7 +52,7 @@ func parseDsn(dsn string) (host string, db string, mode queryMode, err error) {
}

// Open implements driver.Driver.
func (driver) Open(name string) (interfaces.Conn, error) {
func (openmldbDriver) Open(name string) (interfaces.Conn, error) {
// name should be the URL of the api server, e.g. openmldb://localhost:6543/db
host, db, mode, err := parseDsn(name)
if err != nil {
Expand All @@ -61,6 +62,16 @@ func (driver) Open(name string) (interfaces.Conn, error) {
return &conn{host: host, db: db, mode: mode, closed: false}, nil
}

// OpenConnector implements driver.DriverContext.
func (openmldbDriver) OpenConnector(name string) (interfaces.Connector, error) {
host, db, mode, err := parseDsn(name)
if err != nil {
return nil, err
}

return &connecter{host, db, mode}, nil
}

type connecter struct {
host string
db string
Expand All @@ -78,15 +89,5 @@ func (c connecter) Connect(ctx context.Context) (interfaces.Conn, error) {

// Driver implements driver.Connector.
func (connecter) Driver() interfaces.Driver {
return &driver{}
}

// OpenConnector implements driver.DriverContext.
func (driver) OpenConnector(name string) (interfaces.Connector, error) {
host, db, mode, err := parseDsn(name)
if err != nil {
return nil, err
}

return &connecter{host, db, mode}, nil
return &openmldbDriver{}
}
18 changes: 12 additions & 6 deletions go_sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import (
"log"
"os"
"testing"
"time"

// register openmldb driver
_ "github.com/4paradigm/openmldb-go-sdk"
"github.com/stretchr/testify/assert"

openmldb "github.com/4paradigm/openmldb-go-sdk"
)

var apiServer string
Expand All @@ -32,7 +34,7 @@ func Test_driver(t *testing.T) {
assert.NoError(t, db.PingContext(ctx), "fail to ping connect")

{
createTableStmt := "CREATE TABLE demo(c1 int, c2 string);"
createTableStmt := "CREATE TABLE demo(c1 int, c2 string, ts timestamp, dt date);"
_, err := db.ExecContext(ctx, createTableStmt)
assert.NoError(t, err, "fail to exec %s", createTableStmt)
}
Expand All @@ -47,28 +49,32 @@ func Test_driver(t *testing.T) {

{
// FIXME: ordering issue
insertValueStmt := `INSERT INTO demo VALUES (1, "bb");`
insertValueStmt := `INSERT INTO demo VALUES (1, "bb", 3000, "2022-12-12");`
// insertValueStmt := `INSERT INTO demo VALUES (1, "bb"), (2, "bb");`
_, err := db.ExecContext(ctx, insertValueStmt)
assert.NoError(t, err, "fail to exec %s", insertValueStmt)
}

t.Run("query", func(t *testing.T) {
queryStmt := `SELECT c1, c2 FROM demo`
queryStmt := `SELECT * FROM demo`
rows, err := db.QueryContext(ctx, queryStmt)
assert.NoError(t, err, "fail to query %s", queryStmt)

var demo struct {
c1 int32
c2 string
ts time.Time
dt openmldb.NullDate
}
{
assert.True(t, rows.Next())
assert.NoError(t, rows.Scan(&demo.c1, &demo.c2))
assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt))
assert.Equal(t, struct {
c1 int32
c2 string
}{1, "bb"}, demo)
ts time.Time
dt openmldb.NullDate
}{1, "bb", time.UnixMilli(3000), openmldb.NullDate{Time: time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, demo)
}
// {
// assert.True(t, rows.Next())
Expand Down
48 changes: 48 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package openmldb

import (
"database/sql"
"database/sql/driver"
"errors"
"time"
)

var (
_ sql.Scanner = (*NullDate)(nil)
)

type NullDate struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}

// Scan implements sql.Scanner for NullDate
func (dt *NullDate) Scan(src any) error {
switch val := src.(type) {
case string:
dval, err := time.Parse(time.DateOnly, val)
if err != nil {
dt.Valid = false
return err
} else {
dt.Time = dval
dt.Valid = true
return nil
}
case NullDate:
*dt = val
return nil
default:
return errors.New("scan NullDate from unsupported type")
}

}

// Value implements driver.Value for NullDate
func (dt NullDate) Value() (driver.Value, error) {
if !dt.Valid {
return nil, nil
}
return dt.Time, nil

}

0 comments on commit 3a285fa

Please sign in to comment.