From 2c5a00d91619af9451f5e75a78b9b6532815d3ba Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Sat, 27 Apr 2024 00:43:16 +0800 Subject: [PATCH] feat: timestamp & date (#8) - fix timestamp or date type as go query parameters. - basic facility to support SQL Null as input or output. - more tests --- .github/workflows/go.yml | 4 +- conn.go | 56 +++++++--- conn_test.go | 50 +++++++-- encode.go | 15 +++ go.mod | 2 +- go_sdk_test.go | 233 ++++++++++++++++++++++++++++----------- types.go | 64 ++++++----- 7 files changed, 299 insertions(+), 125 deletions(-) create mode 100644 encode.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index bbb8cca..daf06c5 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -15,7 +15,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '^1.18' + go-version: '^1.22' - name: OpenMLDB cluster run: | @@ -33,7 +33,7 @@ jobs: docker compose -f docker-compose.yml exec openmldb-ns1 /opt/openmldb/bin/openmldb --zk_cluster=openmldb-zk:2181 --zk_root_path=/openmldb --role=sql_client --cmd 'SET GLOBAL execute_mode = "online"' - name: go test - run: go test ./... -race -covermode=atomic -coverprofile=coverage.out + run: go test ./... -race -covermode=atomic -coverprofile=coverage.out -v - name: Coverage uses: codecov/codecov-action@v4 diff --git a/conn.go b/conn.go index a2e75c9..7beccb0 100644 --- a/conn.go +++ b/conn.go @@ -3,6 +3,7 @@ package openmldb import ( "bytes" "context" + "database/sql" "database/sql/driver" "encoding/json" "errors" @@ -71,7 +72,7 @@ type queryResp struct { } type respData struct { - Schema []string `json:"schema"` + Schema []string `json:"schema"` Data [][]driver.Value `json:"data"` } @@ -127,36 +128,48 @@ type queryReq struct { } type queryInput struct { - Schema []string `json:"schema"` + Schema []string `json:"schema"` Data []driver.Value `json:"data"` } -func marshalQueryRequest(mode, sql string, input ...driver.Value) ([]byte, error) { +func marshalQueryRequest(mode string, sqlStr string, input ...driver.Value) ([]byte, error) { req := queryReq{ Mode: mode, - SQL: sql, + SQL: sqlStr, } + // TODO(someone): Type infer from input slice does not work always. Consider those cases: + // 1. a int type can be a int32 or int64, depends on value size. + // 2. we're not covering more input types like uint. + // 3. For a int16 or int32 input from DB.Query(...), it always convert to int64 because driver.Value + // only expect int64 from primitive types. + // + // A better approach is to ask the schema types from api server, which in turn ask types info to SQL compiler. + if len(input) > 0 { schema := make([]string, len(input)) + // TODO(someone): support value as nil, at current time it is not possible to infer SQL type from a nil for i, v := range input { - switch v.(type) { - case bool: + switch vv := v.(type) { + case bool, Null[bool]: schema[i] = "bool" - case int16: + case int16, Null[int16]: schema[i] = "int16" - case int32: + case int32, Null[int32]: schema[i] = "int32" - case int64: + case int64, Null[int64]: schema[i] = "int64" - case float32: + case float32, Null[float32]: schema[i] = "float" - case float64: + case float64, Null[float64]: schema[i] = "double" - case string: + case string, Null[string]: schema[i] = "string" case time.Time: schema[i] = "timestamp" + input[i] = Null[time.Time]{Null: sql.Null[time.Time]{V: vv, Valid: true}} + case Null[time.Time]: + schema[i] = "timestamp" case NullDate: schema[i] = "date" default: @@ -179,8 +192,14 @@ func unmarshalQueryResponse(respBody io.Reader) (*queryResp, error) { } if r.Data != nil { + // queryResp.Data may nil for DDL for _, row := range r.Data.Data { for i, col := range row { + if col == nil { + row[i] = nil + continue + } + switch strings.ToLower(r.Data.Schema[i]) { case "bool": row[i] = col.(bool) @@ -196,14 +215,17 @@ func unmarshalQueryResponse(respBody io.Reader) (*queryResp, error) { row[i] = float64(col.(float64)) case "string": row[i] = col.(string) + // date and timestamp values saved internally as time.Time case "timestamp": // timestamp value returned as int64 millisecond unix epoch time row[i] = time.UnixMilli(int64(col.(float64))) case "date": - // date values returned as "YYYY-mm-dd" formated string - var nullDate NullDate - nullDate.Scan(col.(string)) - row[i] = nullDate + t, err := parseDateStr(col.(string)) + if err != nil { + row[i] = nil + } + + row[i] = t default: return nil, fmt.Errorf("unknown type %s at index %d", r.Data.Schema[i], i) } @@ -244,7 +266,7 @@ func (c *conn) execute(ctx context.Context, sql string, parameters ...driver.Val if r, err := unmarshalQueryResponse(resp.Body); err != nil { return nil, err } else if r.Code != 0 { - return nil, fmt.Errorf("conn error: %s", r.Msg) + return nil, fmt.Errorf("execute error: %s", r.Msg) } else if r.Data != nil { return &respDataRows{*r.Data, 0}, nil } diff --git a/conn_test.go b/conn_test.go index 560c7a5..2c1c2be 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,9 +1,11 @@ package openmldb import ( - interfaces "database/sql/driver" + "database/sql" + "database/sql/driver" "strings" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -12,28 +14,36 @@ func TestParseReqToJson(t *testing.T) { for _, tc := range []struct { mode string sql string - input []interfaces.Value + input []driver.Value expect string }{ { - "offsync", + "offline", "SELECT 1;", nil, `{ - "mode": "offsync", + "mode": "offline", "sql": "SELECT 1;" }`, }, { - "offsync", + "online", "SELECT c1, c2 FROM demo WHERE c1 = ? AND c2 = ?;", - []interfaces.Value{int32(1), "bb"}, + []driver.Value{ + int16(2), // int16 + int32(1), // int32 + "bb", // string + Null[string]{Null: sql.Null[string]{V: "foo", Valid: true}}, // string + time.UnixMilli(8000), // timestamp + Null[time.Time]{Null: sql.Null[time.Time]{V: time.UnixMilli(4000), Valid: true}}, // timestamp + Null[time.Time]{Null: sql.Null[time.Time]{V: time.UnixMilli(4000), Valid: false}}, // timestamp + NullDate{Null: sql.Null[time.Time]{V: time.Date(2022, time.October, 10, 0, 0, 0, 0, time.UTC), Valid: true}}}, // date `{ - "mode": "offsync", + "mode": "online", "sql": "SELECT c1, c2 FROM demo WHERE c1 = ? AND c2 = ?;", "input": { - "schema": ["int32", "string"], - "data": [1, "bb"] + "schema": ["int16", "int32", "string", "string", "timestamp", "timestamp", "timestamp", "date"], + "data": [2, 1, "bb", "foo", 8000, 4000, null, "2022-10-10"] } }`, }, @@ -60,6 +70,24 @@ func TestParseRespFromJson(t *testing.T) { Data: nil, }, }, + { + `{ + "code": 0, + "msg": "ok", + "data": { + "schema": ["date", "string"], + "data": [] + } + }`, + queryResp{ + Code: 0, + Msg: "ok", + Data: &respData{ + Schema: []string{"date", "string"}, + Data: [][]driver.Value{}, + }, + }, + }, { `{ "code": 0, @@ -74,7 +102,7 @@ func TestParseRespFromJson(t *testing.T) { Msg: "ok", Data: &respData{ Schema: []string{"Int32", "String"}, - Data: [][]interfaces.Value{ + Data: [][]driver.Value{ {int32(1), "bb"}, {int32(2), "bb"}, }, @@ -95,7 +123,7 @@ func TestParseRespFromJson(t *testing.T) { Msg: "ok", Data: &respData{ Schema: []string{"Bool", "Int16", "Int32", "Int64", "Float", "Double", "String"}, - Data: [][]interfaces.Value{ + Data: [][]driver.Value{ {true, int16(1), int32(1), int64(1), float32(1), float64(1), "bb"}, }, }, diff --git a/encode.go b/encode.go new file mode 100644 index 0000000..c1d7547 --- /dev/null +++ b/encode.go @@ -0,0 +1,15 @@ +package openmldb + +import ( + "time" +) + +func parseDateStr(src string) (time.Time, error) { + // api server returns date type as string formatted 'yyyy-mm-dd' + dval, err := time.Parse(time.DateOnly, src) + if err != nil { + return time.Time{}, err + } + + return dval, nil +} diff --git a/go.mod b/go.mod index e57fa16..bffc635 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/4paradigm/openmldb-go-sdk -go 1.18 +go 1.22 require github.com/stretchr/testify v1.9.0 diff --git a/go_sdk_test.go b/go_sdk_test.go index 4dff6fa..0fa8401 100644 --- a/go_sdk_test.go +++ b/go_sdk_test.go @@ -18,95 +18,196 @@ import ( var apiServer string -// 1. NullTime + NullDate -// 2. Time + Time +var db *sql.DB +var ctx context.Context -func Test_driver(t *testing.T) { - db, err := sql.Open("openmldb", fmt.Sprintf("openmldb://%s/test_db", apiServer)) - if err != nil { - t.Errorf("fail to open connect: %s", err) - } +// user may use sql.NullXXX types to represent SQL values that may be null - defer func() { - if err := db.Close(); err != nil { - t.Errorf("fail to close connection: %s", err) - } - }() +type demoStruct1 struct { + c1 int32 + c2 string + ts time.Time + dt time.Time +} +type demoStruct2 struct { + c1 sql.NullInt32 + c2 sql.NullString + ts sql.NullTime + dt sql.NullTime +} +type demoStruct3 struct { + c1 openmldb.Null[int32] + c2 openmldb.Null[string] + ts openmldb.Null[time.Time] + dt openmldb.NullDate +} - ctx := context.Background() +func TestPingCtx(t *testing.T) { assert.NoError(t, db.PingContext(ctx), "fail to ping connect") +} + +func TestQuery1(t *testing.T) { + // use time.Time to represent both timestamp and date + queryStmt := `SELECT * FROM demo` + rows, err := db.QueryContext(ctx, queryStmt) + assert.NoError(t, err, "fail to query %s", queryStmt) + + var demo demoStruct1 + { + assert.True(t, rows.Next()) + assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt)) + assert.Equal(t, demoStruct1{1, "bb", time.UnixMilli(3000), time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC)}, demo) + } + // { + // assert.True(t, rows.Next()) + // assert.NoError(t, rows.Scan(&demo.c1, &demo.c2)) + // assert.Equal(t, struct { + // c1 int32 + // c2 string + // }{2, "bb"}, demo) + // } +} + +func TestQuery2(t *testing.T) { + // use sql.NullTime to represent both timestamp and date + queryStmt := `SELECT * FROM demo` + rows, err := db.QueryContext(ctx, queryStmt) + assert.NoError(t, err, "fail to query %s", queryStmt) + + var demo demoStruct2 + assert.True(t, rows.Next()) + assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt)) + assert.Equal(t, sql.NullInt32{Int32: 1, Valid: true}, demo.c1) + assert.Equal(t, sql.NullString{String: "bb", Valid: true}, demo.c2) + assert.Equal(t, sql.NullTime{Time: time.UnixMilli(3000), Valid: true}, demo.ts) + assert.Equal(t, sql.NullTime{Time: time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC), Valid: true}, demo.dt) +} + +func TestQuery3(t *testing.T) { + // use openmldb.Null[T] and openmldb.NullDate to represent timestamp and date + queryStmt := `SELECT * FROM demo` + rows, err := db.QueryContext(ctx, queryStmt) + assert.NoError(t, err, "fail to query %s", queryStmt) + + var demo demoStruct3 + assert.True(t, rows.Next()) + assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt)) + assert.Equal(t, openmldb.Null[int32]{Null: sql.Null[int32]{V: 1, Valid: true}}, demo.c1) + assert.Equal(t, openmldb.Null[string]{Null: sql.Null[string]{V: "bb", Valid: true}}, demo.c2) + assert.Equal(t, openmldb.Null[time.Time]{Null: sql.Null[time.Time]{V: time.UnixMilli(3000), Valid: true}}, demo.ts) + assert.Equal(t, openmldb.NullDate{Null: sql.Null[time.Time]{V: time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, demo.dt) +} + +func TestQueryWithParams(t *testing.T) { + parameterQueryStmt := `SELECT * FROM demo WHERE c2 = ? AND c1 = ? AND ts = ?;` + rows, err := db.QueryContext(ctx, parameterQueryStmt, "bb", 1, time.UnixMilli(3000)) + assert.NoError(t, err, "fail to query %s", parameterQueryStmt) + + var demo demoStruct1 + { + assert.True(t, rows.Next()) + assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt)) + assert.Equal(t, demoStruct1{1, "bb", time.UnixMilli(3000), time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC)}, demo) + } +} + +func TestQueryWithParamsExpectsNull(t *testing.T) { + _, err := db.ExecContext(ctx, "create table test2 (id int16, val int64, dt date)") + assert.NoError(t, err) + t.Cleanup(func() { + _, err := db.ExecContext(ctx, "drop table test2") + assert.NoError(t, err) + }) + + { + _, err := db.ExecContext(ctx, "insert into test2 values (1, NULL, NULL)") + assert.NoError(t, err) + } + + rows, err := db.QueryContext(ctx, "select * from test2 where id = ?", 1) + assert.NoError(t, err) + var demo struct { + id sql.NullInt16 + val sql.NullInt64 + dt sql.NullTime + } + { + assert.True(t, rows.Next()) + assert.NoError(t, rows.Scan(&demo.id, &demo.val, &demo.dt)) + assert.Equal(t, sql.NullInt16{Int16: 1, Valid: true}, demo.id) + assert.Equal(t, sql.NullInt64{Int64: 0, Valid: false}, demo.val) + assert.Equal(t, sql.NullTime{Time: time.Time{}, Valid: false}, demo.dt) + } +} + +func TestQueryWithParamsResultsEmpty(t *testing.T) { + _, err := db.ExecContext(ctx, "create table test3 (id int16, val int64, dt date)") + assert.NoError(t, err) + t.Cleanup(func() { + _, err := db.ExecContext(ctx, "drop table test3") + assert.NoError(t, err) + }) + + { + _, err := db.ExecContext(ctx, "insert into test3 values (1, NULL, NULL)") + assert.NoError(t, err) + } + + { + rows, err := db.QueryContext(ctx, "select * from test3 where id = ?", int16(10)) + assert.NoError(t, err) + assert.False(t, rows.Next()) + } + + { + // disabled since https://github.com/4paradigm/OpenMLDB/issues/3902 + // _, err := db.QueryContext(ctx, "select * from test3 where id = ?", + // openmldb.Null[int16]{Null: sql.Null[int16]{V: 0, Valid: false}}) + // assert.NoError(t, err) + // assert.False(t, rows.Next()) + } +} + +func PrepareAndRun(m *testing.M) int { + var err error + db, err = sql.Open("openmldb", fmt.Sprintf("openmldb://%s/test_db", apiServer)) + if err != nil { + fmt.Fprintf(os.Stderr, "fail to open connect: %s", err) + os.Exit(1) + } + + ctx = context.Background() { createTableStmt := "CREATE TABLE demo(c1 int, c2 string, ts timestamp, dt date);" _, err := db.ExecContext(ctx, createTableStmt) - assert.NoError(t, err, "fail to exec %s", createTableStmt) + if err != nil { + fmt.Fprintf(os.Stderr, "fail to exec %s", createTableStmt) + os.Exit(1) + } } defer func() { dropTableStmt := "DROP TABLE demo;" _, err := db.ExecContext(ctx, dropTableStmt) if err != nil { - t.Errorf("fail to drop table: %s", err) + fmt.Fprintf(os.Stderr, "fail to drop table: %s", err) + os.Exit(1) } }() - { // FIXME: ordering issue insertValueStmt := `INSERT INTO demo VALUES (1, "bb", 3000, "2022-12-12");` // insertValueStmt := `INSERT INTO demo VALUES (1, "bb"), (2, "bb");` _, err := db.ExecContext(ctx, insertValueStmt) - assert.NoError(t, err, "fail to exec %s", insertValueStmt) - } - - t.Run("query", func(t *testing.T) { - queryStmt := `SELECT * FROM demo` - rows, err := db.QueryContext(ctx, queryStmt) - assert.NoError(t, err, "fail to query %s", queryStmt) - - var demo struct { - c1 int32 - c2 string - ts time.Time - dt openmldb.NullDate - } - { - assert.True(t, rows.Next()) - assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt)) - assert.Equal(t, struct { - c1 int32 - c2 string - ts time.Time - dt openmldb.NullDate - }{1, "bb", time.UnixMilli(3000), openmldb.NullDate{Time: time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, demo) + if err != nil { + fmt.Fprintf(os.Stderr, "fail to exec: %s", insertValueStmt) + os.Exit(1) } - // { - // assert.True(t, rows.Next()) - // assert.NoError(t, rows.Scan(&demo.c1, &demo.c2)) - // assert.Equal(t, struct { - // c1 int32 - // c2 string - // }{2, "bb"}, demo) - // } - }) + } - t.Run("query with parameter", func(t *testing.T) { - parameterQueryStmt := `SELECT c1, c2 FROM demo WHERE c2 = ? AND c1 = ?;` - rows, err := db.QueryContext(ctx, parameterQueryStmt, "bb", 1) - assert.NoError(t, err, "fail to query %s", parameterQueryStmt) + return m.Run() - var demo struct { - c1 int32 - c2 string - } - { - assert.True(t, rows.Next()) - assert.NoError(t, rows.Scan(&demo.c1, &demo.c2)) - assert.Equal(t, struct { - c1 int32 - c2 string - }{1, "bb"}, demo) - } - }) } func TestMain(m *testing.M) { @@ -117,5 +218,5 @@ func TestMain(m *testing.M) { log.Fatalf("non-empty api server address required") } - os.Exit(m.Run()) + os.Exit(PrepareAndRun(m)) } diff --git a/types.go b/types.go index bb009cf..f4ad1d1 100644 --- a/types.go +++ b/types.go @@ -1,49 +1,57 @@ package openmldb +// TODO(someone): support go < 1.22 + import ( "database/sql" "database/sql/driver" - "errors" + "encoding/json" "time" ) var ( - _ sql.Scanner = (*NullDate)(nil) + _ sql.Scanner = (*NullDate)(nil) _ driver.Valuer = NullDate{} ) +// Null represents a value that may be null. +// +// declare type embedded sql.Null so we still able to +// utilize sql.Scanner and driver.Valuer in go standard, +// and customize marshal logic for api requests +type Null[T any] struct { + sql.Null[T] +} + +// NullDate represents nullable SQL date in go +// +// embedded sql.Null[time.Time] so it by default +// implements sql.Scanner and driver.Valuer, but +// distinct timestamp representation in sdk. type NullDate struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL + sql.Null[time.Time] } -// Scan implements sql.Scanner for NullDate -func (dt *NullDate) Scan(src any) error { - switch val := src.(type) { - case string: - dval, err := time.Parse(time.DateOnly, val) - if err != nil { - dt.Valid = false - return err - } else { - dt.Time = dval - dt.Valid = true - return nil - } - case NullDate: - *dt = val - return nil - default: - return errors.New("scan NullDate from unsupported type") +// MarshalJSON implements json.Marshaler +func (src NullDate) MarshalJSON() ([]byte, error) { + if !src.Valid { + return json.Marshal(nil) } - + return json.Marshal(src.V.Format(time.DateOnly)) } -// Value implements driver.Valuer for NullDate -func (dt NullDate) Value() (driver.Value, error) { - if !dt.Valid { - return nil, nil +// MarshalJSON implements json.Marshaler for Null[T] +func (src Null[T]) MarshalJSON() ([]byte, error) { + if !src.Valid { + return json.Marshal(nil) } - return dt.Time, nil + var v any = src.V + switch val := v.(type) { + case time.Time: + // timestamp, marshal to int64 unix epoch time in millisecond + return json.Marshal(val.UnixMilli()) + default: + return json.Marshal(src.V) + } }