Skip to content

Commit

Permalink
add session pool test (#286)
Browse files Browse the repository at this point in the history
* add session pool test

* fix

* fix
  • Loading branch information
HarrisChu authored Aug 31, 2023
1 parent 7f5a200 commit c36464d
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 81 deletions.
25 changes: 15 additions & 10 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1298,9 +1298,13 @@ func startContainer(t *testing.T, containerName string) {
}
}

func tryToExecute(session *Session, query string) (resp *ResultSet, err error) {
type executer interface {
Execute(query string) (*ResultSet, error)
}

func tryToExecute(e executer, query string) (resp *ResultSet, err error) {
for i := 3; i > 0; i-- {
resp, err = session.Execute(query)
resp, err = e.Execute(query)
if err == nil && resp.IsSucceed() {
return
}
Expand All @@ -1321,7 +1325,7 @@ func tryToExecuteWithParameter(session *Session, query string, params map[string
}

// creates schema
func createTestDataSchema(t *testing.T, session *Session) {
func createTestDataSchema(t *testing.T, executor executer) {
createSchema := "CREATE SPACE IF NOT EXISTS test_data(vid_type = FIXED_STRING(30));" +
"USE test_data; " +
"CREATE TAG IF NOT EXISTS person(name string, age int8, grade int16, " +
Expand All @@ -1333,7 +1337,7 @@ func createTestDataSchema(t *testing.T, session *Session) {
"CREATE EDGE IF NOT EXISTS like(likeness double); " +
"CREATE EDGE IF NOT EXISTS friend(start_Datetime datetime, end_Datetime datetime); " +
"CREATE TAG INDEX IF NOT EXISTS person_name_index ON person(name(8));"
resultSet, err := tryToExecute(session, createSchema)
resultSet, err := tryToExecute(executor, createSchema)
if err != nil {
t.Fatalf(err.Error())
return
Expand All @@ -1344,7 +1348,7 @@ func createTestDataSchema(t *testing.T, session *Session) {
}

// inserts data that used in tests
func loadTestData(t *testing.T, session *Session) {
func loadTestData(t *testing.T, e executer) {
query := "INSERT VERTEX person(name, age, grade, friends, book_num," +
"birthday, start_school, morning, property," +
"is_girl, child_name, expend, first_out_city) VALUES" +
Expand All @@ -1363,7 +1367,7 @@ func loadTestData(t *testing.T, session *Session) {
"'John':('John', 10, 3, 10, 100, datetime('2010-09-10T10:08:02'), " +
"date('2017-09-10'), time('07:10:00'), " +
"1000.0, false, \"Hello World!\", 100.0, 1111)"
resultSet, err := tryToExecute(session, query)
resultSet, err := tryToExecute(e, query)
if err != nil {
t.Fatalf(err.Error())
return
Expand All @@ -1374,7 +1378,7 @@ func loadTestData(t *testing.T, session *Session) {
"INSERT VERTEX student(name, interval) VALUES " +
"'Bob':('Bob', duration({months:1, seconds:100, microseconds:20})), 'Lily':('Lily', duration({years: 1, seconds: 0})), " +
"'Tom':('Tom', duration({years: 1, seconds: 0})), 'Jerry':('Jerry', duration({years: 1, seconds: 0})), 'John':('John', duration({years: 1, seconds: 0}))"
resultSet, err = tryToExecute(session, query)
resultSet, err = tryToExecute(e, query)
if err != nil {
t.Fatalf(err.Error())
return
Expand All @@ -1387,8 +1391,9 @@ func loadTestData(t *testing.T, session *Session) {
"'Bob'->'Tom':(70.0), " +
"'Jerry'->'Lily':(84.0)," +
"'Tom'->'Jerry':(68.3), " +
"'Bob'->'John':(97.2)"
resultSet, err = tryToExecute(session, query)
"'Bob'->'John':(97.2), " +
"'Lily'->'Tom':(80.0)"
resultSet, err = tryToExecute(e, query)
if err != nil {
t.Fatalf(err.Error())
return
Expand All @@ -1402,7 +1407,7 @@ func loadTestData(t *testing.T, session *Session) {
"'Jerry'->'Lily':(datetime('2008-09-10T10:08:02'), datetime('2010-09-10T10:08:02')), " +
"'Tom'->'Jerry':(datetime('2008-09-10T10:08:02'), datetime('2010-09-10T10:08:02')), " +
"'Bob'->'John':(datetime('2008-09-10T10:08:02'), datetime('2010-09-10T10:08:02'))"
resultSet, err = tryToExecute(session, query)
resultSet, err = tryToExecute(e, query)
if err != nil {
t.Fatalf(err.Error())
return
Expand Down
89 changes: 32 additions & 57 deletions session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,31 +249,12 @@ func (pool *SessionPool) Close() {
// iterate all sessions
for i := 0; i < idleLen; i++ {
session := pool.idleSessions.Front().Value.(*pureSession)
if session.connection == nil {
pool.log.Warn("Session has been released")
pool.idleSessions.Remove(pool.idleSessions.Front())
continue
}

if err := session.connection.signOut(session.sessionID); err != nil {
pool.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error()))
}
// close connection
session.connection.close()
session.close()
pool.idleSessions.Remove(pool.idleSessions.Front())
}
for i := 0; i < activeLen; i++ {
session := pool.activeSessions.Front().Value.(*pureSession)
if session.connection == nil {
pool.log.Warn("Session has been released")
pool.activeSessions.Remove(pool.activeSessions.Front())
continue
}
if err := session.connection.signOut(session.sessionID); err != nil {
pool.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error()))
}
// close connection
session.connection.close()
session.close()
pool.activeSessions.Remove(pool.activeSessions.Front())
}

Expand Down Expand Up @@ -448,57 +429,46 @@ func (pool *SessionPool) sessionCleaner() {
case <-pool.cleanerChan: // pool was closed.
}

pool.rwLock.Lock()

if pool.closed {
pool.cleanerChan = nil
pool.rwLock.Unlock()
return
}

closing := pool.timeoutSessionList()

//release expired session from the pool
for _, session := range closing {
if session.connection == nil {
pool.log.Warn("Session has been released")
pool.rwLock.Unlock()
return
}
if err := session.connection.signOut(session.sessionID); err != nil {
pool.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error()))
}
// close connection
session.connection.close()
session.close()
}
pool.rwLock.Unlock()
t.Reset(d)
}
}

// timeoutSessionList returns a list of sessions that have been idle for longer than the idle time.
func (pool *SessionPool) timeoutSessionList() (closing []*pureSession) {
if pool.conf.idleTime > 0 {
expiredSince := time.Now().Add(-pool.conf.idleTime)
var newEle *list.Element = nil

maxCleanSize := pool.idleSessions.Len() + pool.activeSessions.Len() - pool.conf.minSize

for ele := pool.idleSessions.Front(); ele != nil; {
if maxCleanSize == 0 {
return
}

newEle = ele.Next()
// Check Session is expired
if !ele.Value.(*pureSession).returnedAt.Before(expiredSince) {
return
}
closing = append(closing, ele.Value.(*pureSession))
pool.idleSessions.Remove(ele)
ele = newEle
maxCleanSize--
if pool.conf.idleTime == 0 {
return
}
pool.rwLock.Lock()
defer pool.rwLock.Unlock()
expiredSince := time.Now().Add(-pool.conf.idleTime)
var newEle *list.Element = nil

maxCleanSize := pool.idleSessions.Len() + pool.activeSessions.Len() - pool.conf.minSize

for ele := pool.idleSessions.Front(); ele != nil; {
if maxCleanSize == 0 {
return
}

newEle = ele.Next()
// Check Session is expired
if !ele.Value.(*pureSession).returnedAt.Before(expiredSince) {
return
}
closing = append(closing, ele.Value.(*pureSession))
pool.idleSessions.Remove(ele)
ele = newEle
maxCleanSize--
}
return
}
Expand Down Expand Up @@ -612,12 +582,17 @@ func (session *pureSession) executeWithParameter(stmt string, params map[string]
}

func (session *pureSession) close() {
defer func() {
if err := recover(); err != nil {
return
}
}()
if session.connection != nil {
// ignore signout error
_ = session.connection.signOut(session.sessionID)
session.connection.close()
session.connection = nil
}
session.connection = nil
}

// Ping checks if the session is valid
Expand Down
140 changes: 136 additions & 4 deletions session_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ package nebula_go

import (
"fmt"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/vesoft-inc/nebula-go/v3/nebula"
"github.com/vesoft-inc/nebula-go/v3/nebula/graph"
"golang.org/x/net/context"
)

func TestSessionPoolInvalidConfig(t *testing.T) {
Expand Down Expand Up @@ -53,6 +55,76 @@ func TestSessionPoolInvalidConfig(t *testing.T) {
"error message should contain Service address is empty")
}

func TestSessionPoolServerCheck(t *testing.T) {
prepareSpace("client_test")
defer dropSpace("client_test")
hostAddress := HostAddress{Host: address, Port: port}
testcases := []struct {
conf *SessionPoolConf
errMsg string
}{
{
conf: &SessionPoolConf{
username: "root",
password: "nebula",
serviceAddrs: []HostAddress{hostAddress},
spaceName: "invalid_space",
minSize: 1,
},
errMsg: "failed to create a new session pool, " +
"failed to initialize the session pool, " +
"failed to use space invalid_space: SpaceNotFound: SpaceName `invalid_space`",
},
{
conf: &SessionPoolConf{
username: "root1",
password: "nebula",
serviceAddrs: []HostAddress{hostAddress},
spaceName: "client_test",
minSize: 1,
},
errMsg: "failed to create a new session pool, " +
"failed to initialize the session pool, " +
"failed to authenticate the user, error code: -1001, " +
"error message: User not exist, the pool has been closed",
},
{
conf: &SessionPoolConf{
username: "root",
password: "nebula1",
serviceAddrs: []HostAddress{hostAddress},
spaceName: "client_test",
minSize: 1,
},
errMsg: "failed to create a new session pool, " +
"failed to initialize the session pool, " +
"failed to authenticate the user, error code: -1001, " +
"error message: Invalid password, the pool has been closed",
},
{
conf: &SessionPoolConf{
username: "root",
password: "nebula1",
serviceAddrs: []HostAddress{{"127.0.0.1", 1234}},
spaceName: "client_test",
minSize: 1,
},
errMsg: "failed to create a new session pool, " +
"failed to initialize the session pool, " +
"failed to open transport, " +
"error: dial tcp 127.0.0.1:1234: connect: connection refused",
},
}
for _, tc := range testcases {
_, err := NewSessionPool(*tc.conf, DefaultLogger{})
if err == nil {
t.Fatal("should return error")
}
assert.Equal(t, err.Error(), tc.errMsg,
fmt.Sprintf("expected error: %s, but actual error: %s", tc.errMsg, err.Error()))
}
}

func TestSessionPoolBasic(t *testing.T) {
prepareSpace("client_test")
defer dropSpace("client_test")
Expand Down Expand Up @@ -493,13 +565,14 @@ func TestSessionPoolRetry(t *testing.T) {
if err != nil {
t.Fatal(err)
}
original := *session
original := session.sessionID
conn := session.connection
_, _ = sessionPool.executeWithRetry(session, tc.retryFn, 2)
if tc.retry {
assert.NotEqual(t, original, *session, fmt.Sprintf("test case: %s", tc.name))
assert.NotEqual(t, original.connection, nil, fmt.Sprintf("test case: %s", tc.name))
assert.NotEqual(t, original, session.sessionID, fmt.Sprintf("test case: %s", tc.name))
assert.NotEqual(t, conn, nil, fmt.Sprintf("test case: %s", tc.name))
} else {
assert.Equal(t, original, *session, fmt.Sprintf("test case: %s", tc.name))
assert.Equal(t, original, session.sessionID, fmt.Sprintf("test case: %s", tc.name))
}
}
}
Expand Down Expand Up @@ -536,3 +609,62 @@ func TestSessionPoolClose(t *testing.T) {
_, err = sessionPool.Execute("SHOW HOSTS;")
assert.Equal(t, err.Error(), "failed to execute: Session pool has been closed", "session pool should be closed")
}

// TestSessionPoolGetSessionTimeout tests the scenario that if all requests are timeout,
// the session pool should return timeout error, not reach the pool size limit.
func TestQueryTimeout(t *testing.T) {
hostAddress := HostAddress{Host: address, Port: port}
config, err := NewSessionPoolConf(
"root",
"nebula",
[]HostAddress{hostAddress},
"test_data")
if err != nil {
t.Errorf("failed to create session pool config, %s", err.Error())
}
config.minSize = 0
config.maxSize = 10
config.retryGetSessionTimes = 1
config.timeOut = 100 * time.Millisecond
// create session pool
sessionPool, err := NewSessionPool(*config, DefaultLogger{})
if err != nil {
t.Fatal(err)
}
defer sessionPool.Close()
createTestDataSchema(t, sessionPool)
loadTestData(t, sessionPool)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
errCh := make(chan error, 1)
defer cancel()
var wg sync.WaitGroup
for i := 0; i < config.maxSize; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
select {
case <-ctx.Done():
return
default:
_, err := sessionPool.Execute(`go 2000 step from "Bob" over like yield tags($$)`)
if err == nil {
errCh <- fmt.Errorf("should return error")
return
}
errMsg := "i/o timeout"
if !strings.Contains(err.Error(), errMsg) {
errCh <- fmt.Errorf("expect error contains: %s, but actual: %s", errMsg, err.Error())
return
}
}
}
}()
}
wg.Wait()
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
Loading

0 comments on commit c36464d

Please sign in to comment.