diff --git a/client/gosqldriver/connection.go b/client/gosqldriver/connection.go index e2186b22..848dfdfa 100644 --- a/client/gosqldriver/connection.go +++ b/client/gosqldriver/connection.go @@ -32,6 +32,27 @@ import ( var corrIDUnsetCmd = netstring.NewNetstringFrom(common.CmdClientCalCorrelationID, []byte("CorrId=NotSet")) +type heraConnectionInterface interface { + Prepare(query string) (driver.Stmt, error) + Close() error + Begin() (driver.Tx, error) + exec(cmd int, payload []byte) error + execNs(ns *netstring.Netstring) error + getResponse() (*netstring.Netstring, error) + SetShardID(shard int) error + ResetShardID() error + GetNumShards() (int, error) + SetShardKeyPayload(payload string) + ResetShardKeyPayload() + SetCalCorrID(corrID string) + SetClientInfo(poolName string, host string) error + SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error + getID() string + getCorrID() *netstring.Netstring + getShardKeyPayload() []byte + setCorrID(*netstring.Netstring) +} + type heraConnection struct { id string // used for logging conn net.Conn @@ -39,7 +60,7 @@ type heraConnection struct { // for the sharding extension shardKeyPayload []byte // correlation id - corrID *netstring.Netstring + corrID *netstring.Netstring clientinfo *netstring.Netstring } @@ -52,7 +73,6 @@ func NewHeraConnection(conn net.Conn) driver.Conn { return hera } - // Prepare returns a prepared statement, bound to this connection. func (c *heraConnection) Prepare(query string) (driver.Stmt, error) { if logger.GetLogger().V(logger.Debug) { @@ -177,66 +197,82 @@ func (c *heraConnection) SetCalCorrID(corrID string) { } // SetClientInfo actually sends it over to Hera server -func (c *heraConnection) SetClientInfo(poolName string, host string)(error){ +func (c *heraConnection) SetClientInfo(poolName string, host string) error { if len(poolName) <= 0 && len(host) <= 0 { return nil } pid := os.Getpid() data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, Command: SetClientInfo,", pid, host, poolName) - c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data))) - if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized) - } - - _, err := c.conn.Write(c.clientinfo.Serialized) - if err != nil { - if logger.GetLogger().V(logger.Warning) { - logger.GetLogger().Log(logger.Warning, "Failed to send client info") - } - return errors.New("Failed custom auth, failed to send client info") - } - ns, err := c.reader.ReadNext() - if err != nil { - if logger.GetLogger().V(logger.Warning) { - logger.GetLogger().Log(logger.Warning, "Failed to read server info") - } - return errors.New("Failed to read server info") - } - if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload)) - } + c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data))) + if logger.GetLogger().V(logger.Verbose) { + logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized) + } + + _, err := c.conn.Write(c.clientinfo.Serialized) + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "Failed to send client info") + } + return errors.New("Failed custom auth, failed to send client info") + } + ns, err := c.reader.ReadNext() + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "Failed to read server info") + } + return errors.New("Failed to read server info") + } + if logger.GetLogger().V(logger.Debug) { + logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload)) + } return nil } -func (c *heraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string)(error){ +func (c *heraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error { if len(poolName) <= 0 && len(host) <= 0 && len(poolStack) <= 0 { return nil } pid := os.Getpid() data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, PoolStack: %s, Command: SetClientInfo,", pid, host, poolName, poolStack) - c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data))) - if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized) - } - - _, err := c.conn.Write(c.clientinfo.Serialized) - if err != nil { - if logger.GetLogger().V(logger.Warning) { - logger.GetLogger().Log(logger.Warning, "Failed to send client info") - } - return errors.New("Failed custom auth, failed to send client info") - } - ns, err := c.reader.ReadNext() - if err != nil { - if logger.GetLogger().V(logger.Warning) { - logger.GetLogger().Log(logger.Warning, "Failed to read server info") - } - return errors.New("Failed to read server info") - } - if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload)) - } + c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data))) + if logger.GetLogger().V(logger.Verbose) { + logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized) + } + + _, err := c.conn.Write(c.clientinfo.Serialized) + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "Failed to send client info") + } + return errors.New("Failed custom auth, failed to send client info") + } + ns, err := c.reader.ReadNext() + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "Failed to read server info") + } + return errors.New("Failed to read server info") + } + if logger.GetLogger().V(logger.Debug) { + logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload)) + } return nil -} \ No newline at end of file +} + +func (c *heraConnection) getID() string { + return c.id +} + +func (c *heraConnection) getCorrID() *netstring.Netstring { + return c.corrID +} + +func (c *heraConnection) getShardKeyPayload() []byte { + return c.shardKeyPayload +} + +func (c *heraConnection) setCorrID(newCorrID *netstring.Netstring) { + c.corrID = newCorrID +} diff --git a/client/gosqldriver/rows.go b/client/gosqldriver/rows.go index 06af1161..b41778fe 100644 --- a/client/gosqldriver/rows.go +++ b/client/gosqldriver/rows.go @@ -32,7 +32,7 @@ import ( // similar to JDBC's result set // Rows is an iterator over an executed query's results. type rows struct { - hera *heraConnection + hera heraConnectionInterface vals []driver.Value cols int currentRow int @@ -41,7 +41,7 @@ type rows struct { } // TODO: fetch chunk size -func newRows(hera *heraConnection, cols int, fetchChunkSize []byte) (*rows, error) { +func newRows(hera heraConnectionInterface, cols int, fetchChunkSize []byte) (*rows, error) { rs := &rows{hera: hera, cols: cols, currentRow: 0, fetchChunkSize: fetchChunkSize} err := rs.fetchResults() if err != nil { @@ -63,7 +63,7 @@ func (r *rows) fetchResults() error { return nil case common.RcNoMoreData: if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, r.hera.id, "Rows: cols = ", r.cols, ", numValues =", len(r.vals)) + logger.GetLogger().Log(logger.Verbose, r.hera.getID(), "Rows: cols = ", r.cols, ", numValues =", len(r.vals)) } r.completed = true return nil diff --git a/client/gosqldriver/rows_test.go b/client/gosqldriver/rows_test.go new file mode 100644 index 00000000..79c6a7b4 --- /dev/null +++ b/client/gosqldriver/rows_test.go @@ -0,0 +1,169 @@ +package gosqldriver + +import ( + "database/sql/driver" + "errors" + "io" + "testing" + + "github.com/paypal/hera/common" + "github.com/paypal/hera/utility/encoding/netstring" +) + +type mockHeraConnection struct { + heraConnection + responses []netstring.Netstring + execErr error +} + +func (m *mockHeraConnection) Prepare(query string) (driver.Stmt, error) { + return nil, nil +} + +func (m *mockHeraConnection) Close() error { + return nil +} + +func (m *mockHeraConnection) Begin() (driver.Tx, error) { + return nil, nil +} + +func (m *mockHeraConnection) exec(cmd int, payload []byte) error { + return m.execErr +} + +func (m *mockHeraConnection) execNs(ns *netstring.Netstring) error { + return m.execErr +} + +func (m *mockHeraConnection) getResponse() (*netstring.Netstring, error) { + if len(m.responses) == 0 { + return &netstring.Netstring{}, io.EOF + } + response := m.responses[0] + m.responses = m.responses[1:] + return &response, nil +} + +func (m *mockHeraConnection) SetShardID(shard int) error { + return nil +} + +func (m *mockHeraConnection) ResetShardID() error { + return nil +} + +func (m *mockHeraConnection) GetNumShards() (int, error) { + return 0, nil +} + +func (m *mockHeraConnection) SetShardKeyPayload(payload string) { +} + +func (m *mockHeraConnection) ResetShardKeyPayload() { +} + +func (m *mockHeraConnection) SetCalCorrID(corrID string) { +} + +func (m *mockHeraConnection) SetClientInfo(poolName string, host string) error { + return nil +} + +func (m *mockHeraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error { + return nil +} + +func (m *mockHeraConnection) getID() string { + return "mockID" +} + +func TestNewRows(t *testing.T) { + mockHera := &mockHeraConnection{ + responses: []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("value1")}, + {Cmd: common.RcValue, Payload: []byte("value2")}, + {Cmd: common.RcOK}, + }, + } + + cols := 2 + fetchChunkSize := []byte("10") + rows, err := newRows(mockHera, cols, fetchChunkSize) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if rows == nil { + t.Fatalf("expected rows to be non-nil") + } + if rows.cols != cols { + t.Errorf("expected cols to be %d, got %d", cols, rows.cols) + } + if rows.currentRow != 0 { + t.Errorf("expected currentRow to be 0, got %d", rows.currentRow) + } + if string(rows.fetchChunkSize) != string(fetchChunkSize) { + t.Errorf("expected fetchChunkSize to be %s, got %s", fetchChunkSize, rows.fetchChunkSize) + } +} + +func TestColumns(t *testing.T) { + mockHera := &mockHeraConnection{} + rows := &rows{hera: mockHera, cols: 3} + + columns := rows.Columns() + if len(columns) != 3 { + t.Fatalf("expected 3 columns, got %d", len(columns)) + } + for _, col := range columns { + if col != "" { + t.Errorf("expected column name to be empty, got %s", col) + } + } +} + +// TODO: Change unit test for Close() once it has been implemented +func TestClose(t *testing.T) { + mockHera := &mockHeraConnection{} + rows := &rows{hera: mockHera} + + err := rows.Close() + if err == nil { + t.Fatalf("expected an error, got nil") + } + expectedErr := "Rows.Close() not yet implemented" + if err.Error() != expectedErr { + t.Errorf("expected error to be %s, got %s", expectedErr, err.Error()) + } +} + +func TestNext(t *testing.T) { + mockHera := &mockHeraConnection{ + responses: []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("value1")}, + {Cmd: common.RcValue, Payload: []byte("value2")}, + {Cmd: common.RcNoMoreData}, + }, + } + rows, err := newRows(mockHera, 2, []byte("10")) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + dest := make([]driver.Value, 2) + err = rows.Next(dest) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + expectedDest := []driver.Value{[]byte("value1"), []byte("value2")} + for i := range dest { + if string(dest[i].([]byte)) != string(expectedDest[i].([]byte)) { + t.Errorf("expected dest[%d] to be %s, got %s", i, expectedDest[i], dest[i]) + } + } + + err = rows.Next(dest) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF error, got %v", err) + } +} diff --git a/client/gosqldriver/statement.go b/client/gosqldriver/statement.go index 90217b60..e63aec15 100644 --- a/client/gosqldriver/statement.go +++ b/client/gosqldriver/statement.go @@ -33,12 +33,12 @@ import ( // implements sql/driver Stmt interface and the newer StmtQueryContext and StmtExecContext interfaces type stmt struct { - hera *heraConnection + hera heraConnectionInterface sql string fetchChunkSize []byte } -func newStmt(hera *heraConnection, sql string) *stmt { +func newStmt(hera heraConnectionInterface, sql string) *stmt { st := &stmt{hera: hera, fetchChunkSize: []byte("0")} // replace '?' with named parameters p1, p2, ... var bf bytes.Buffer @@ -57,7 +57,7 @@ func newStmt(hera *heraConnection, sql string) *stmt { } st.sql = bf.String() if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, hera.id, "final SQL:", st.sql) + logger.GetLogger().Log(logger.Debug, hera.getID(), "final SQL:", st.sql) } return st } @@ -83,19 +83,19 @@ func (st *stmt) NumInput() int { // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE. func (st *stmt) Exec(args []driver.Value) (driver.Result, error) { sk := 0 - if len(st.hera.shardKeyPayload) > 0 { + if len(st.hera.getShardKeyPayload()) > 0 { sk = 1 } crid := 0 - if st.hera.corrID != nil { + if st.hera.getCorrID() != nil { crid = 1 } binds := len(args) nss := make([]*netstring.Netstring, crid /*CmdClientCorrelationID*/ +1 /*CmdPrepare*/ +2*binds /* CmdBindName and CmdBindValue */ +sk /*CmdShardKey*/ +1 /*CmdExecute*/) idx := 0 if crid == 1 { - nss[0] = st.hera.corrID - st.hera.corrID = nil + nss[0] = st.hera.getCorrID() + st.hera.setCorrID(nil) idx++ } nss[idx] = netstring.NewNetstringFrom(common.CmdPrepareV2, []byte(st.sql)) @@ -116,12 +116,12 @@ func (st *stmt) Exec(args []driver.Value) (driver.Result, error) { nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, []byte(val)) } if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, st.hera.id, "Bind name =", string(nss[idx-1].Payload), ", value=", string(nss[idx].Payload)) + logger.GetLogger().Log(logger.Verbose, st.hera.getID(), "Bind name =", string(nss[idx-1].Payload), ", value=", string(nss[idx].Payload)) } idx++ } if sk == 1 { - nss[idx] = netstring.NewNetstringFrom(common.CmdShardKey, st.hera.shardKeyPayload) + nss[idx] = netstring.NewNetstringFrom(common.CmdShardKey, st.hera.getShardKeyPayload()) idx++ } nss[idx] = netstring.NewNetstringFrom(common.CmdExecute, nil) @@ -158,7 +158,7 @@ func (st *stmt) Exec(args []driver.Value) (driver.Result, error) { return nil, err } if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, st.hera.id, "DML successfull, rows affected:", res.nRows) + logger.GetLogger().Log(logger.Debug, st.hera.getID(), "DML successfull, rows affected:", res.nRows) } return res, nil } @@ -168,19 +168,19 @@ func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driv //TODO: refactor ExecContext / Exec to reuse code //TODO: honor the context timeout and return when it is canceled sk := 0 - if len(st.hera.shardKeyPayload) > 0 { + if len(st.hera.getShardKeyPayload()) > 0 { sk = 1 } crid := 0 - if st.hera.corrID != nil { + if st.hera.getCorrID() != nil { crid = 1 } binds := len(args) nss := make([]*netstring.Netstring, crid /*CmdClientCalCorrelationID*/ +1 /*CmdPrepare*/ +2*binds /* CmdBindName and BindValue */ +sk /*CmdShardKey*/ +1 /*CmdExecute*/) idx := 0 if crid == 1 { - nss[0] = st.hera.corrID - st.hera.corrID = nil + nss[0] = st.hera.getCorrID() + st.hera.setCorrID(nil) idx++ } nss[idx] = netstring.NewNetstringFrom(common.CmdPrepareV2, []byte(st.sql)) @@ -205,12 +205,12 @@ func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driv nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, []byte(val)) } if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, st.hera.id, "Bind name =", string(nss[idx-1].Payload), ", value=", string(nss[idx].Payload)) + logger.GetLogger().Log(logger.Verbose, st.hera.getID(), "Bind name =", string(nss[idx-1].Payload), ", value=", string(nss[idx].Payload)) } idx++ } if sk == 1 { - nss[idx] = netstring.NewNetstringFrom(common.CmdShardKey, st.hera.shardKeyPayload) + nss[idx] = netstring.NewNetstringFrom(common.CmdShardKey, st.hera.getShardKeyPayload()) idx++ } nss[idx] = netstring.NewNetstringFrom(common.CmdExecute, nil) @@ -247,7 +247,7 @@ func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driv return nil, err } if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, st.hera.id, "DML successfull, rows affected:", res.nRows) + logger.GetLogger().Log(logger.Debug, st.hera.getID(), "DML successfull, rows affected:", res.nRows) } return res, nil } @@ -256,19 +256,19 @@ func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driv // Query executes a query that may return rows, such as a SELECT. func (st *stmt) Query(args []driver.Value) (driver.Rows, error) { sk := 0 - if len(st.hera.shardKeyPayload) > 0 { + if len(st.hera.getShardKeyPayload()) > 0 { sk = 1 } crid := 0 - if st.hera.corrID != nil { + if st.hera.getCorrID() != nil { crid = 1 } binds := len(args) nss := make([]*netstring.Netstring, crid /*CmdClientCorrelationID*/ +1 /*CmdPrepare*/ +2*binds /* CmdBindName and BindValue */ +sk /*CmdShardKey*/ +1 /*CmdExecute*/ +1 /* CmdFetch */) idx := 0 if crid == 1 { - nss[0] = st.hera.corrID - st.hera.corrID = nil + nss[0] = st.hera.getCorrID() + st.hera.setCorrID(nil) idx++ } nss[idx] = netstring.NewNetstringFrom(common.CmdPrepareV2, []byte(st.sql)) @@ -289,12 +289,12 @@ func (st *stmt) Query(args []driver.Value) (driver.Rows, error) { nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, []byte(val)) } if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, st.hera.id, "Bind name =", string(nss[idx-1].Payload), ", value=", string(nss[idx].Payload)) + logger.GetLogger().Log(logger.Verbose, st.hera.getID(), "Bind name =", string(nss[idx-1].Payload), ", value=", string(nss[idx].Payload)) } idx++ } if sk == 1 { - nss[idx] = netstring.NewNetstringFrom(common.CmdShardKey, st.hera.shardKeyPayload) + nss[idx] = netstring.NewNetstringFrom(common.CmdShardKey, st.hera.getShardKeyPayload()) idx++ } nss[idx] = netstring.NewNetstringFrom(common.CmdExecute, nil) @@ -317,7 +317,7 @@ Loop: switch ns.Cmd { case common.RcStillExecuting: if logger.GetLogger().V(logger.Info) { - logger.GetLogger().Log(logger.Info, st.hera.id, " Still executing ...") + logger.GetLogger().Log(logger.Info, st.hera.getID(), " Still executing ...") } // continues the loop case common.RcSQLError: @@ -349,7 +349,7 @@ Loop: return nil, err } if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, st.hera.id, "Query successfull, num columns:", cols) + logger.GetLogger().Log(logger.Debug, st.hera.getID(), "Query successfull, num columns:", cols) } return newRows(st.hera, cols, st.fetchChunkSize) } @@ -360,19 +360,19 @@ func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (dri // TODO: refactor Query/QueryContext to reuse code // TODO: honor the context timeout and return when it is canceled sk := 0 - if len(st.hera.shardKeyPayload) > 0 { + if len(st.hera.getShardKeyPayload()) > 0 { sk = 1 } crid := 0 - if st.hera.corrID != nil { + if st.hera.getCorrID() != nil { crid = 1 } binds := len(args) nss := make([]*netstring.Netstring, crid /*ClientCalCorrelationID*/ +1 /*CmdPrepare*/ +2*binds /* CmdBindName and BindValue */ +sk /*ShardKey*/ +1 /*Execute*/ +1 /* Fetch */) idx := 0 if crid == 1 { - nss[0] = st.hera.corrID - st.hera.corrID = nil + nss[0] = st.hera.getCorrID() + st.hera.setCorrID(nil) idx++ } nss[idx] = netstring.NewNetstringFrom(common.CmdPrepareV2, []byte(st.sql)) @@ -397,12 +397,12 @@ func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (dri nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, []byte(val)) } if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, st.hera.id, "Bind name =", string(nss[idx-1].Payload), ", value=", string(nss[idx].Payload)) + logger.GetLogger().Log(logger.Verbose, st.hera.getID(), "Bind name =", string(nss[idx-1].Payload), ", value=", string(nss[idx].Payload)) } idx++ } if sk == 1 { - nss[idx] = netstring.NewNetstringFrom(common.CmdShardKey, st.hera.shardKeyPayload) + nss[idx] = netstring.NewNetstringFrom(common.CmdShardKey, st.hera.getShardKeyPayload()) idx++ } nss[idx] = netstring.NewNetstringFrom(common.CmdExecute, nil) @@ -425,7 +425,7 @@ Loop: switch ns.Cmd { case common.RcStillExecuting: if logger.GetLogger().V(logger.Info) { - logger.GetLogger().Log(logger.Info, st.hera.id, " Still executing ...") + logger.GetLogger().Log(logger.Info, st.hera.getID(), " Still executing ...") } // continues the loop case common.RcSQLError: @@ -457,7 +457,7 @@ Loop: return nil, err } if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, st.hera.id, "Query successfull, num columns:", cols) + logger.GetLogger().Log(logger.Debug, st.hera.getID(), "Query successfull, num columns:", cols) } return newRows(st.hera, cols, st.fetchChunkSize) } diff --git a/client/gosqldriver/statement_test.go b/client/gosqldriver/statement_test.go new file mode 100644 index 00000000..0c9aec08 --- /dev/null +++ b/client/gosqldriver/statement_test.go @@ -0,0 +1,626 @@ +package gosqldriver + +import ( + "context" + "database/sql/driver" + "testing" + + "github.com/paypal/hera/common" + "github.com/paypal/hera/utility/encoding/netstring" +) + +func (m *mockHeraConnection) getCorrID() *netstring.Netstring { + return m.corrID +} + +func (m *mockHeraConnection) setCorrID(corrID *netstring.Netstring) { + m.corrID = corrID +} + +func (m *mockHeraConnection) getShardKeyPayload() []byte { + return make([]byte, 1) +} + +func TestNewStmt(t *testing.T) { + mockHera := &mockHeraConnection{} + + tests := []struct { + sql string + expectedSQL string + }{ + {"SELECT * FROM table WHERE col1 = ? AND col2 = ?", "SELECT * FROM table WHERE col1 = :p1 AND col2 = :p2"}, + {"INSERT INTO table (col1, col2) VALUES (?, ?)", "INSERT INTO table (col1, col2) VALUES (:p1, :p2)"}, + {"UPDATE table SET col1 = ? WHERE col2 = ?", "UPDATE table SET col1 = :p1 WHERE col2 = :p2"}, + } + + for _, test := range tests { + st := newStmt(mockHera, test.sql) + if st.sql != test.expectedSQL { + t.Errorf("expected SQL to be %s, got %s", test.expectedSQL, st.sql) + } + if string(st.fetchChunkSize) != "0" { + t.Errorf("expected fetchChunkSize to be '0', got %s", st.fetchChunkSize) + } + } +} + +func TestStmtClose(t *testing.T) { + mockHera := &mockHeraConnection{} + st := newStmt(mockHera, "SELECT * FROM table") + + err := st.Close() + if err == nil { + t.Fatalf("expected an error, got nil") + } + expectedErr := "stmt.Close() not implemented" + if err.Error() != expectedErr { + t.Errorf("expected error to be %s, got %s", expectedErr, err.Error()) + } +} + +func TestStmtNumInput(t *testing.T) { + mockHera := &mockHeraConnection{} + st := newStmt(mockHera, "SELECT * FROM table") + + numInput := st.NumInput() + expectedNumInput := -1 + if numInput != expectedNumInput { + t.Errorf("expected NumInput to be %d, got %d", expectedNumInput, numInput) + } +} + +func TestStmtExec(t *testing.T) { + tests := []struct { + name string + args []driver.Value + setupMock func(*mockHeraConnection) + expectedError string + expectedRows int + }{ + { + name: "Exec with no args", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("1")}, + {Cmd: common.RcValue, Payload: []byte("1")}, + } + }, + expectedError: "", + expectedRows: 1, + }, + { + name: "Exec with int args", + args: []driver.Value{1}, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("1")}, + {Cmd: common.RcValue, Payload: []byte("1")}, + } + }, + expectedError: "", + expectedRows: 1, + }, + { + name: "Exec with string args", + args: []driver.Value{"test"}, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("1")}, + {Cmd: common.RcValue, Payload: []byte("1")}, + } + }, + expectedError: "", + expectedRows: 1, + }, + { + name: "Exec with shard key", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.shardKeyPayload = []byte("shard_key") + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("1")}, + {Cmd: common.RcValue, Payload: []byte("1")}, + } + }, + expectedError: "", + expectedRows: 1, + }, + { + name: "Exec with SQL error", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcSQLError, Payload: []byte("SQL error")}, + } + }, + expectedError: "SQL error: SQL error", + expectedRows: 0, + }, + { + name: "Exec with internal error", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcError, Payload: []byte("Internal error")}, + } + }, + expectedError: "Internal hera error: Internal error", + expectedRows: 0, + }, + { + name: "Exec with unknown code", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: 999, Payload: []byte("Unknown code")}, + } + }, + expectedError: "Unknown code: 999, data: Unknown code", + expectedRows: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockHera := &mockHeraConnection{} + if tt.setupMock != nil { + tt.setupMock(mockHera) + } + + st := &stmt{ + hera: mockHera, + sql: "INSERT INTO table (col1) VALUES (?)", + } + + result, err := st.Exec(tt.args) + if tt.expectedError != "" { + if err == nil || err.Error() != tt.expectedError { + t.Fatalf("expected error %v, got %v", tt.expectedError, err) + } + return + } + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if result == nil { + t.Fatalf("expected result, got nil") + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + t.Fatalf("expected no error on RowsAffected, got %v", err) + } + if rowsAffected != int64(tt.expectedRows) { + t.Errorf("expected %d rows affected, got %d", tt.expectedRows, rowsAffected) + } + }) + } +} + +func TestStmtExecContext(t *testing.T) { + tests := []struct { + name string + args []driver.NamedValue + setupMock func(*mockHeraConnection) + expectedError string + expectedRows int + cancelCtx bool + }{ + { + name: "ExecContext with no args", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("1")}, + {Cmd: common.RcValue, Payload: []byte("1")}, + } + }, + expectedError: "", + expectedRows: 1, + }, + { + name: "ExecContext with int args", + args: []driver.NamedValue{{Ordinal: 1, Value: 1}}, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("1")}, + {Cmd: common.RcValue, Payload: []byte("1")}, + } + }, + expectedError: "", + expectedRows: 1, + }, + { + name: "ExecContext with string args", + args: []driver.NamedValue{{Ordinal: 1, Value: "test"}}, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("1")}, + {Cmd: common.RcValue, Payload: []byte("1")}, + } + }, + expectedError: "", + expectedRows: 1, + }, + { + name: "ExecContext with shard key", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.shardKeyPayload = []byte("shard_key") + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("1")}, + {Cmd: common.RcValue, Payload: []byte("1")}, + } + }, + expectedError: "", + expectedRows: 1, + }, + { + name: "ExecContext with SQL error", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcSQLError, Payload: []byte("SQL error")}, + } + }, + expectedError: "SQL error: SQL error", + expectedRows: 0, + }, + { + name: "ExecContext with internal error", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcError, Payload: []byte("Internal error")}, + } + }, + expectedError: "Internal hera error: Internal error", + expectedRows: 0, + }, + { + name: "ExecContext with unknown code", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: 999, Payload: []byte("Unknown code")}, + } + }, + expectedError: "Unknown code: 999, data: Unknown code", + expectedRows: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockHera := &mockHeraConnection{} + if tt.setupMock != nil { + tt.setupMock(mockHera) + } + + st := &stmt{ + hera: mockHera, + sql: "INSERT INTO table (col1) VALUES (?)", + } + + ctx := context.Background() + if tt.cancelCtx { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + cancel() + } + + result, err := st.ExecContext(ctx, tt.args) + if tt.expectedError != "" { + if err == nil || err.Error() != tt.expectedError { + t.Fatalf("expected error %v, got %v", tt.expectedError, err) + } + return + } + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if result == nil { + t.Fatalf("expected result, got nil") + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + t.Fatalf("expected no error on RowsAffected, got %v", err) + } + if rowsAffected != int64(tt.expectedRows) { + t.Errorf("expected %d rows affected, got %d", tt.expectedRows, rowsAffected) + } + }) + } +} + +func TestStmtQuery(t *testing.T) { + tests := []struct { + name string + args []driver.Value + setupMock func(*mockHeraConnection) + expectedError string + expectedCols int + }{ + { + name: "Query with no args", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("2")}, // number of columns + {Cmd: common.RcValue, Payload: []byte("1")}, // number of row + {Cmd: common.RcValue, Payload: []byte("row1")}, + {Cmd: common.RcNoMoreData, Payload: []byte("")}, + } + }, + expectedError: "", + expectedCols: 2, + }, + { + name: "Query with int args", + args: []driver.Value{1}, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("2")}, // number of columns + {Cmd: common.RcValue, Payload: []byte("1")}, // number of rows + {Cmd: common.RcValue, Payload: []byte("row1")}, + {Cmd: common.RcNoMoreData, Payload: []byte("")}, + } + }, + expectedError: "", + expectedCols: 2, + }, + { + name: "Query with string args", + args: []driver.Value{"test"}, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("2")}, // number of columns + {Cmd: common.RcValue, Payload: []byte("1")}, // number of rows + {Cmd: common.RcValue, Payload: []byte("row1")}, + {Cmd: common.RcNoMoreData, Payload: []byte("")}, + } + }, + expectedError: "", + expectedCols: 2, + }, + { + name: "Query with shard key", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.shardKeyPayload = []byte("shard_key") + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("2")}, // number of columns + {Cmd: common.RcValue, Payload: []byte("1")}, // number of rows + {Cmd: common.RcValue, Payload: []byte("row1")}, + {Cmd: common.RcNoMoreData, Payload: []byte("")}, + } + }, + expectedError: "", + expectedCols: 2, + }, + { + name: "Query with SQL error", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcSQLError, Payload: []byte("SQL error")}, + } + }, + expectedError: "SQL error: SQL error", + expectedCols: 0, + }, + { + name: "Query with internal error", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcError, Payload: []byte("Internal error")}, + } + }, + expectedError: "Internal hera error: Internal error", + expectedCols: 0, + }, + { + name: "Query with unknown code", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: 999, Payload: []byte("Unknown code")}, + } + }, + expectedError: "Unknown code: 999, data: Unknown code", + expectedCols: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockHera := &mockHeraConnection{} + if tt.setupMock != nil { + tt.setupMock(mockHera) + } + + st := &stmt{ + hera: mockHera, + sql: "SELECT * FROM table WHERE col1 = ?", + } + + rows, err := st.Query(tt.args) + if tt.expectedError != "" { + if err == nil || err.Error() != tt.expectedError { + t.Fatalf("expected error %v, got %v", tt.expectedError, err) + } + return + } + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if rows == nil { + t.Fatalf("expected rows, got nil") + } + + columns := rows.Columns() + if len(columns) != tt.expectedCols { + t.Errorf("expected %d columns, got %d", tt.expectedCols, len(columns)) + } + }) + } +} + +func TestStmtQueryContext(t *testing.T) { + tests := []struct { + name string + args []driver.NamedValue + setupMock func(*mockHeraConnection) + expectedError string + expectedCols int + cancelCtx bool + }{ + { + name: "QueryContext with no args", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("2")}, // number of columns + {Cmd: common.RcValue, Payload: []byte("1")}, // number of rows + {Cmd: common.RcValue, Payload: []byte("row1")}, + {Cmd: common.RcNoMoreData, Payload: []byte("")}, + } + }, + expectedError: "", + expectedCols: 2, + }, + { + name: "QueryContext with int args", + args: []driver.NamedValue{{Ordinal: 1, Value: 1}}, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("2")}, // number of columns + {Cmd: common.RcValue, Payload: []byte("1")}, // number of rows + {Cmd: common.RcValue, Payload: []byte("row1")}, + {Cmd: common.RcNoMoreData, Payload: []byte("")}, + } + }, + expectedError: "", + expectedCols: 2, + }, + { + name: "QueryContext with string args", + args: []driver.NamedValue{{Ordinal: 1, Value: "test"}}, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("2")}, // number of columns + {Cmd: common.RcValue, Payload: []byte("1")}, // number of rows + {Cmd: common.RcValue, Payload: []byte("row1")}, + {Cmd: common.RcNoMoreData, Payload: []byte("")}, + } + }, + expectedError: "", + expectedCols: 2, + }, + { + name: "QueryContext with shard key", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.shardKeyPayload = []byte("shard_key") + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("2")}, // number of columns + {Cmd: common.RcValue, Payload: []byte("1")}, // number of rows + {Cmd: common.RcValue, Payload: []byte("row1")}, + {Cmd: common.RcNoMoreData, Payload: []byte("")}, + } + }, + expectedError: "", + expectedCols: 2, + }, + { + name: "QueryContext with SQL error", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcSQLError, Payload: []byte("SQL error")}, + } + }, + expectedError: "SQL error: SQL error", + expectedCols: 0, + }, + { + name: "QueryContext with internal error", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcError, Payload: []byte("Internal error")}, + } + }, + expectedError: "Internal hera error: Internal error", + expectedCols: 0, + }, + { + name: "QueryContext with unknown code", + args: nil, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: 999, Payload: []byte("Unknown code")}, + } + }, + expectedError: "Unknown code: 999, data: Unknown code", + expectedCols: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockHera := &mockHeraConnection{} + if tt.setupMock != nil { + tt.setupMock(mockHera) + } + + st := &stmt{ + hera: mockHera, + sql: "SELECT * FROM table WHERE col1 = ?", + } + + ctx := context.Background() + if tt.cancelCtx { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + cancel() + } + + rows, err := st.QueryContext(ctx, tt.args) + if tt.expectedError != "" { + if err == nil || err.Error() != tt.expectedError { + t.Fatalf("expected error %v, got %v", tt.expectedError, err) + } + return + } + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if rows == nil { + t.Fatalf("expected rows, got nil") + } + + columns := rows.Columns() + if len(columns) != tt.expectedCols { + t.Errorf("expected %d columns, got %d", tt.expectedCols, len(columns)) + } + }) + } +} + +func TestStmtSetFetchSize(t *testing.T) { + st := &stmt{} + st.SetFetchSize(10) + + expected := "10" + if string(st.fetchChunkSize) != expected { + t.Errorf("expected fetchChunkSize to be %s, got %s", expected, st.fetchChunkSize) + } +}