diff --git a/client_test.go b/client_test.go index 5a791f5e..806d2a40 100644 --- a/client_test.go +++ b/client_test.go @@ -1298,9 +1298,13 @@ func startContainer(t *testing.T, containerName string) { } } -func tryToExecute(session *Session, query string) (resp *ResultSet, err error) { +type executer interface { + Execute(query string) (*ResultSet, error) +} + +func tryToExecute(e executer, query string) (resp *ResultSet, err error) { for i := 3; i > 0; i-- { - resp, err = session.Execute(query) + resp, err = e.Execute(query) if err == nil && resp.IsSucceed() { return } @@ -1321,7 +1325,7 @@ func tryToExecuteWithParameter(session *Session, query string, params map[string } // creates schema -func createTestDataSchema(t *testing.T, session *Session) { +func createTestDataSchema(t *testing.T, executor executer) { createSchema := "CREATE SPACE IF NOT EXISTS test_data(vid_type = FIXED_STRING(30));" + "USE test_data; " + "CREATE TAG IF NOT EXISTS person(name string, age int8, grade int16, " + @@ -1333,7 +1337,7 @@ func createTestDataSchema(t *testing.T, session *Session) { "CREATE EDGE IF NOT EXISTS like(likeness double); " + "CREATE EDGE IF NOT EXISTS friend(start_Datetime datetime, end_Datetime datetime); " + "CREATE TAG INDEX IF NOT EXISTS person_name_index ON person(name(8));" - resultSet, err := tryToExecute(session, createSchema) + resultSet, err := tryToExecute(executor, createSchema) if err != nil { t.Fatalf(err.Error()) return @@ -1344,7 +1348,7 @@ func createTestDataSchema(t *testing.T, session *Session) { } // inserts data that used in tests -func loadTestData(t *testing.T, session *Session) { +func loadTestData(t *testing.T, e executer) { query := "INSERT VERTEX person(name, age, grade, friends, book_num," + "birthday, start_school, morning, property," + "is_girl, child_name, expend, first_out_city) VALUES" + @@ -1363,7 +1367,7 @@ func loadTestData(t *testing.T, session *Session) { "'John':('John', 10, 3, 10, 100, datetime('2010-09-10T10:08:02'), " + "date('2017-09-10'), time('07:10:00'), " + "1000.0, false, \"Hello World!\", 100.0, 1111)" - resultSet, err := tryToExecute(session, query) + resultSet, err := tryToExecute(e, query) if err != nil { t.Fatalf(err.Error()) return @@ -1374,7 +1378,7 @@ func loadTestData(t *testing.T, session *Session) { "INSERT VERTEX student(name, interval) VALUES " + "'Bob':('Bob', duration({months:1, seconds:100, microseconds:20})), 'Lily':('Lily', duration({years: 1, seconds: 0})), " + "'Tom':('Tom', duration({years: 1, seconds: 0})), 'Jerry':('Jerry', duration({years: 1, seconds: 0})), 'John':('John', duration({years: 1, seconds: 0}))" - resultSet, err = tryToExecute(session, query) + resultSet, err = tryToExecute(e, query) if err != nil { t.Fatalf(err.Error()) return @@ -1387,8 +1391,9 @@ func loadTestData(t *testing.T, session *Session) { "'Bob'->'Tom':(70.0), " + "'Jerry'->'Lily':(84.0)," + "'Tom'->'Jerry':(68.3), " + - "'Bob'->'John':(97.2)" - resultSet, err = tryToExecute(session, query) + "'Bob'->'John':(97.2), " + + "'Lily'->'Tom':(80.0)" + resultSet, err = tryToExecute(e, query) if err != nil { t.Fatalf(err.Error()) return @@ -1402,7 +1407,7 @@ func loadTestData(t *testing.T, session *Session) { "'Jerry'->'Lily':(datetime('2008-09-10T10:08:02'), datetime('2010-09-10T10:08:02')), " + "'Tom'->'Jerry':(datetime('2008-09-10T10:08:02'), datetime('2010-09-10T10:08:02')), " + "'Bob'->'John':(datetime('2008-09-10T10:08:02'), datetime('2010-09-10T10:08:02'))" - resultSet, err = tryToExecute(session, query) + resultSet, err = tryToExecute(e, query) if err != nil { t.Fatalf(err.Error()) return diff --git a/session_pool.go b/session_pool.go index 5acd6f82..7121b348 100644 --- a/session_pool.go +++ b/session_pool.go @@ -249,31 +249,12 @@ func (pool *SessionPool) Close() { // iterate all sessions for i := 0; i < idleLen; i++ { session := pool.idleSessions.Front().Value.(*pureSession) - if session.connection == nil { - pool.log.Warn("Session has been released") - pool.idleSessions.Remove(pool.idleSessions.Front()) - continue - } - - if err := session.connection.signOut(session.sessionID); err != nil { - pool.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) - } - // close connection - session.connection.close() + session.close() pool.idleSessions.Remove(pool.idleSessions.Front()) } for i := 0; i < activeLen; i++ { session := pool.activeSessions.Front().Value.(*pureSession) - if session.connection == nil { - pool.log.Warn("Session has been released") - pool.activeSessions.Remove(pool.activeSessions.Front()) - continue - } - if err := session.connection.signOut(session.sessionID); err != nil { - pool.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) - } - // close connection - session.connection.close() + session.close() pool.activeSessions.Remove(pool.activeSessions.Front()) } @@ -448,57 +429,46 @@ func (pool *SessionPool) sessionCleaner() { case <-pool.cleanerChan: // pool was closed. } - pool.rwLock.Lock() - if pool.closed { pool.cleanerChan = nil - pool.rwLock.Unlock() return } closing := pool.timeoutSessionList() - //release expired session from the pool for _, session := range closing { - if session.connection == nil { - pool.log.Warn("Session has been released") - pool.rwLock.Unlock() - return - } - if err := session.connection.signOut(session.sessionID); err != nil { - pool.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) - } - // close connection - session.connection.close() + session.close() } - pool.rwLock.Unlock() t.Reset(d) } } // timeoutSessionList returns a list of sessions that have been idle for longer than the idle time. func (pool *SessionPool) timeoutSessionList() (closing []*pureSession) { - if pool.conf.idleTime > 0 { - expiredSince := time.Now().Add(-pool.conf.idleTime) - var newEle *list.Element = nil - - maxCleanSize := pool.idleSessions.Len() + pool.activeSessions.Len() - pool.conf.minSize - - for ele := pool.idleSessions.Front(); ele != nil; { - if maxCleanSize == 0 { - return - } - - newEle = ele.Next() - // Check Session is expired - if !ele.Value.(*pureSession).returnedAt.Before(expiredSince) { - return - } - closing = append(closing, ele.Value.(*pureSession)) - pool.idleSessions.Remove(ele) - ele = newEle - maxCleanSize-- + if pool.conf.idleTime == 0 { + return + } + pool.rwLock.Lock() + defer pool.rwLock.Unlock() + expiredSince := time.Now().Add(-pool.conf.idleTime) + var newEle *list.Element = nil + + maxCleanSize := pool.idleSessions.Len() + pool.activeSessions.Len() - pool.conf.minSize + + for ele := pool.idleSessions.Front(); ele != nil; { + if maxCleanSize == 0 { + return + } + + newEle = ele.Next() + // Check Session is expired + if !ele.Value.(*pureSession).returnedAt.Before(expiredSince) { + return } + closing = append(closing, ele.Value.(*pureSession)) + pool.idleSessions.Remove(ele) + ele = newEle + maxCleanSize-- } return } @@ -612,12 +582,17 @@ func (session *pureSession) executeWithParameter(stmt string, params map[string] } func (session *pureSession) close() { + defer func() { + if err := recover(); err != nil { + return + } + }() if session.connection != nil { // ignore signout error _ = session.connection.signOut(session.sessionID) session.connection.close() + session.connection = nil } - session.connection = nil } // Ping checks if the session is valid diff --git a/session_pool_test.go b/session_pool_test.go index d0e10cb6..197e6a17 100644 --- a/session_pool_test.go +++ b/session_pool_test.go @@ -12,6 +12,7 @@ package nebula_go import ( "fmt" + "strings" "sync" "testing" "time" @@ -19,6 +20,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/vesoft-inc/nebula-go/v3/nebula" "github.com/vesoft-inc/nebula-go/v3/nebula/graph" + "golang.org/x/net/context" ) func TestSessionPoolInvalidConfig(t *testing.T) { @@ -53,6 +55,76 @@ func TestSessionPoolInvalidConfig(t *testing.T) { "error message should contain Service address is empty") } +func TestSessionPoolServerCheck(t *testing.T) { + prepareSpace("client_test") + defer dropSpace("client_test") + hostAddress := HostAddress{Host: address, Port: port} + testcases := []struct { + conf *SessionPoolConf + errMsg string + }{ + { + conf: &SessionPoolConf{ + username: "root", + password: "nebula", + serviceAddrs: []HostAddress{hostAddress}, + spaceName: "invalid_space", + minSize: 1, + }, + errMsg: "failed to create a new session pool, " + + "failed to initialize the session pool, " + + "failed to use space invalid_space: SpaceNotFound: SpaceName `invalid_space`", + }, + { + conf: &SessionPoolConf{ + username: "root1", + password: "nebula", + serviceAddrs: []HostAddress{hostAddress}, + spaceName: "client_test", + minSize: 1, + }, + errMsg: "failed to create a new session pool, " + + "failed to initialize the session pool, " + + "failed to authenticate the user, error code: -1001, " + + "error message: User not exist, the pool has been closed", + }, + { + conf: &SessionPoolConf{ + username: "root", + password: "nebula1", + serviceAddrs: []HostAddress{hostAddress}, + spaceName: "client_test", + minSize: 1, + }, + errMsg: "failed to create a new session pool, " + + "failed to initialize the session pool, " + + "failed to authenticate the user, error code: -1001, " + + "error message: Invalid password, the pool has been closed", + }, + { + conf: &SessionPoolConf{ + username: "root", + password: "nebula1", + serviceAddrs: []HostAddress{{"127.0.0.1", 1234}}, + spaceName: "client_test", + minSize: 1, + }, + errMsg: "failed to create a new session pool, " + + "failed to initialize the session pool, " + + "failed to open transport, " + + "error: dial tcp 127.0.0.1:1234: connect: connection refused", + }, + } + for _, tc := range testcases { + _, err := NewSessionPool(*tc.conf, DefaultLogger{}) + if err == nil { + t.Fatal("should return error") + } + assert.Equal(t, err.Error(), tc.errMsg, + fmt.Sprintf("expected error: %s, but actual error: %s", tc.errMsg, err.Error())) + } +} + func TestSessionPoolBasic(t *testing.T) { prepareSpace("client_test") defer dropSpace("client_test") @@ -493,13 +565,14 @@ func TestSessionPoolRetry(t *testing.T) { if err != nil { t.Fatal(err) } - original := *session + original := session.sessionID + conn := session.connection _, _ = sessionPool.executeWithRetry(session, tc.retryFn, 2) if tc.retry { - assert.NotEqual(t, original, *session, fmt.Sprintf("test case: %s", tc.name)) - assert.NotEqual(t, original.connection, nil, fmt.Sprintf("test case: %s", tc.name)) + assert.NotEqual(t, original, session.sessionID, fmt.Sprintf("test case: %s", tc.name)) + assert.NotEqual(t, conn, nil, fmt.Sprintf("test case: %s", tc.name)) } else { - assert.Equal(t, original, *session, fmt.Sprintf("test case: %s", tc.name)) + assert.Equal(t, original, session.sessionID, fmt.Sprintf("test case: %s", tc.name)) } } } @@ -536,3 +609,62 @@ func TestSessionPoolClose(t *testing.T) { _, err = sessionPool.Execute("SHOW HOSTS;") assert.Equal(t, err.Error(), "failed to execute: Session pool has been closed", "session pool should be closed") } + +// TestSessionPoolGetSessionTimeout tests the scenario that if all requests are timeout, +// the session pool should return timeout error, not reach the pool size limit. +func TestQueryTimeout(t *testing.T) { + hostAddress := HostAddress{Host: address, Port: port} + config, err := NewSessionPoolConf( + "root", + "nebula", + []HostAddress{hostAddress}, + "test_data") + if err != nil { + t.Errorf("failed to create session pool config, %s", err.Error()) + } + config.minSize = 0 + config.maxSize = 10 + config.retryGetSessionTimes = 1 + config.timeOut = 100 * time.Millisecond + // create session pool + sessionPool, err := NewSessionPool(*config, DefaultLogger{}) + if err != nil { + t.Fatal(err) + } + defer sessionPool.Close() + createTestDataSchema(t, sessionPool) + loadTestData(t, sessionPool) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + errCh := make(chan error, 1) + defer cancel() + var wg sync.WaitGroup + for i := 0; i < config.maxSize; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + select { + case <-ctx.Done(): + return + default: + _, err := sessionPool.Execute(`go 2000 step from "Bob" over like yield tags($$)`) + if err == nil { + errCh <- fmt.Errorf("should return error") + return + } + errMsg := "i/o timeout" + if !strings.Contains(err.Error(), errMsg) { + errCh <- fmt.Errorf("expect error contains: %s, but actual: %s", errMsg, err.Error()) + return + } + } + } + }() + } + wg.Wait() + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} diff --git a/session_test.go b/session_test.go index 4f73e34b..925463ef 100644 --- a/session_test.go +++ b/session_test.go @@ -23,17 +23,13 @@ func TestSession_Execute(t *testing.T) { if err != nil { t.Fatal(err) } - - sess, err := pool.GetSession("root", "nebula") - if err != nil { - t.Fatal(err) - } + errCh := make(chan error, 1) f := func(s *Session) { time.Sleep(10 * time.Microsecond) reps, err := s.Execute("yield 1") if err != nil { - t.Fatal(err) + errCh <- err } if !reps.IsSucceed() { t.Fatal(reps.resp.ErrorMsg) @@ -42,12 +38,16 @@ func TestSession_Execute(t *testing.T) { // test Ping() err = s.Ping() if err != nil { - t.Fatal(err) + errCh <- err } } - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithTimeout(context.TODO(), 300*time.Millisecond) defer cancel() go func(ctx context.Context) { + sess, err := pool.GetSession("root", "nebula") + if err != nil { + errCh <- err + } for { select { case <-ctx.Done(): @@ -58,16 +58,27 @@ func TestSession_Execute(t *testing.T) { } }(ctx) go func(ctx context.Context) { + sess, err := pool.GetSession("root", "nebula") + if err != nil { + errCh <- err + } for { select { case <-ctx.Done(): - break default: f(sess) } } }(ctx) - time.Sleep(300 * time.Millisecond) + + for { + select { + case err := <-errCh: + t.Fatal(err) + case <-ctx.Done(): + return + } + } }