Skip to content

Commit

Permalink
add unit tests for rows.go and statement.go
Browse files Browse the repository at this point in the history
  • Loading branch information
mazchew committed Aug 8, 2024
1 parent ff9c489 commit 9b921e1
Show file tree
Hide file tree
Showing 5 changed files with 916 additions and 85 deletions.
134 changes: 85 additions & 49 deletions client/gosqldriver/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,35 @@ import (

var corrIDUnsetCmd = netstring.NewNetstringFrom(common.CmdClientCalCorrelationID, []byte("CorrId=NotSet"))

type heraConnectionInterface interface {
Prepare(query string) (driver.Stmt, error)
Close() error
Begin() (driver.Tx, error)
exec(cmd int, payload []byte) error
execNs(ns *netstring.Netstring) error
getResponse() (*netstring.Netstring, error)
SetShardID(shard int) error
ResetShardID() error
GetNumShards() (int, error)
SetShardKeyPayload(payload string)
ResetShardKeyPayload()
SetCalCorrID(corrID string)
SetClientInfo(poolName string, host string) error
SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error
getID() string
getCorrID() *netstring.Netstring
getShardKeyPayload() []byte
setCorrID(*netstring.Netstring)
}

type heraConnection struct {
id string // used for logging
conn net.Conn
reader *netstring.Reader
// for the sharding extension
shardKeyPayload []byte
// correlation id
corrID *netstring.Netstring
corrID *netstring.Netstring
clientinfo *netstring.Netstring
}

Expand All @@ -52,7 +73,6 @@ func NewHeraConnection(conn net.Conn) driver.Conn {
return hera
}


// Prepare returns a prepared statement, bound to this connection.
func (c *heraConnection) Prepare(query string) (driver.Stmt, error) {
if logger.GetLogger().V(logger.Debug) {
Expand Down Expand Up @@ -177,66 +197,82 @@ func (c *heraConnection) SetCalCorrID(corrID string) {
}

// SetClientInfo actually sends it over to Hera server
func (c *heraConnection) SetClientInfo(poolName string, host string)(error){
func (c *heraConnection) SetClientInfo(poolName string, host string) error {
if len(poolName) <= 0 && len(host) <= 0 {
return nil
}

pid := os.Getpid()
data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, Command: SetClientInfo,", pid, host, poolName)
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
return nil
}

func (c *heraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string)(error){
func (c *heraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error {
if len(poolName) <= 0 && len(host) <= 0 && len(poolStack) <= 0 {
return nil
}

pid := os.Getpid()
data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, PoolStack: %s, Command: SetClientInfo,", pid, host, poolName, poolStack)
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
return nil
}
}

func (c *heraConnection) getID() string {
return c.id
}

func (c *heraConnection) getCorrID() *netstring.Netstring {
return c.corrID
}

func (c *heraConnection) getShardKeyPayload() []byte {
return c.shardKeyPayload
}

func (c *heraConnection) setCorrID(newCorrID *netstring.Netstring) {
c.corrID = newCorrID
}
6 changes: 3 additions & 3 deletions client/gosqldriver/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
// similar to JDBC's result set
// Rows is an iterator over an executed query's results.
type rows struct {
hera *heraConnection
hera heraConnectionInterface
vals []driver.Value
cols int
currentRow int
Expand All @@ -41,7 +41,7 @@ type rows struct {
}

// TODO: fetch chunk size
func newRows(hera *heraConnection, cols int, fetchChunkSize []byte) (*rows, error) {
func newRows(hera heraConnectionInterface, cols int, fetchChunkSize []byte) (*rows, error) {
rs := &rows{hera: hera, cols: cols, currentRow: 0, fetchChunkSize: fetchChunkSize}
err := rs.fetchResults()
if err != nil {
Expand All @@ -63,7 +63,7 @@ func (r *rows) fetchResults() error {
return nil
case common.RcNoMoreData:
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, r.hera.id, "Rows: cols = ", r.cols, ", numValues =", len(r.vals))
logger.GetLogger().Log(logger.Verbose, r.hera.getID(), "Rows: cols = ", r.cols, ", numValues =", len(r.vals))
}
r.completed = true
return nil
Expand Down
169 changes: 169 additions & 0 deletions client/gosqldriver/rows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package gosqldriver

import (
"database/sql/driver"
"errors"
"io"
"testing"

"github.com/paypal/hera/common"
"github.com/paypal/hera/utility/encoding/netstring"
)

type mockHeraConnection struct {
heraConnection
responses []netstring.Netstring
execErr error
}

func (m *mockHeraConnection) Prepare(query string) (driver.Stmt, error) {
return nil, nil
}

func (m *mockHeraConnection) Close() error {
return nil
}

func (m *mockHeraConnection) Begin() (driver.Tx, error) {
return nil, nil
}

func (m *mockHeraConnection) exec(cmd int, payload []byte) error {
return m.execErr
}

func (m *mockHeraConnection) execNs(ns *netstring.Netstring) error {
return m.execErr
}

func (m *mockHeraConnection) getResponse() (*netstring.Netstring, error) {
if len(m.responses) == 0 {
return &netstring.Netstring{}, io.EOF
}
response := m.responses[0]
m.responses = m.responses[1:]
return &response, nil
}

func (m *mockHeraConnection) SetShardID(shard int) error {
return nil
}

func (m *mockHeraConnection) ResetShardID() error {
return nil
}

func (m *mockHeraConnection) GetNumShards() (int, error) {
return 0, nil
}

func (m *mockHeraConnection) SetShardKeyPayload(payload string) {
}

func (m *mockHeraConnection) ResetShardKeyPayload() {
}

func (m *mockHeraConnection) SetCalCorrID(corrID string) {
}

func (m *mockHeraConnection) SetClientInfo(poolName string, host string) error {
return nil
}

func (m *mockHeraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error {
return nil
}

func (m *mockHeraConnection) getID() string {
return "mockID"
}

func TestNewRows(t *testing.T) {
mockHera := &mockHeraConnection{
responses: []netstring.Netstring{
{Cmd: common.RcValue, Payload: []byte("value1")},
{Cmd: common.RcValue, Payload: []byte("value2")},
{Cmd: common.RcOK},
},
}

cols := 2
fetchChunkSize := []byte("10")
rows, err := newRows(mockHera, cols, fetchChunkSize)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if rows == nil {
t.Fatalf("expected rows to be non-nil")
}
if rows.cols != cols {
t.Errorf("expected cols to be %d, got %d", cols, rows.cols)
}
if rows.currentRow != 0 {
t.Errorf("expected currentRow to be 0, got %d", rows.currentRow)
}
if string(rows.fetchChunkSize) != string(fetchChunkSize) {
t.Errorf("expected fetchChunkSize to be %s, got %s", fetchChunkSize, rows.fetchChunkSize)
}
}

func TestColumns(t *testing.T) {
mockHera := &mockHeraConnection{}
rows := &rows{hera: mockHera, cols: 3}

columns := rows.Columns()
if len(columns) != 3 {
t.Fatalf("expected 3 columns, got %d", len(columns))
}
for _, col := range columns {
if col != "" {
t.Errorf("expected column name to be empty, got %s", col)
}
}
}

// TODO: Change unit test for Close() once it has been implemented
func TestClose(t *testing.T) {
mockHera := &mockHeraConnection{}
rows := &rows{hera: mockHera}

err := rows.Close()
if err == nil {
t.Fatalf("expected an error, got nil")
}
expectedErr := "Rows.Close() not yet implemented"
if err.Error() != expectedErr {
t.Errorf("expected error to be %s, got %s", expectedErr, err.Error())
}
}

func TestNext(t *testing.T) {
mockHera := &mockHeraConnection{
responses: []netstring.Netstring{
{Cmd: common.RcValue, Payload: []byte("value1")},
{Cmd: common.RcValue, Payload: []byte("value2")},
{Cmd: common.RcNoMoreData},
},
}
rows, err := newRows(mockHera, 2, []byte("10"))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

dest := make([]driver.Value, 2)
err = rows.Next(dest)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
expectedDest := []driver.Value{[]byte("value1"), []byte("value2")}
for i := range dest {
if string(dest[i].([]byte)) != string(expectedDest[i].([]byte)) {
t.Errorf("expected dest[%d] to be %s, got %s", i, expectedDest[i], dest[i])
}
}

err = rows.Next(dest)
if !errors.Is(err, io.EOF) {
t.Fatalf("expected io.EOF error, got %v", err)
}
}
Loading

0 comments on commit 9b921e1

Please sign in to comment.