diff --git a/client_test.go b/client_test.go index 3bfc7c4d..cae58a84 100644 --- a/client_test.go +++ b/client_test.go @@ -65,7 +65,7 @@ func logoutAndClose(conn *connection, sessionID int64) { func TestConnection(t *testing.T) { hostAddress := HostAddress{Host: address, Port: port} conn := newConnection(hostAddress) - err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil) + err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "") if err != nil { t.Fatalf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error()) } @@ -122,7 +122,7 @@ func TestConnection(t *testing.T) { func TestConnectionIPv6(t *testing.T) { hostAddress := HostAddress{Host: addressIPv6, Port: port} conn := newConnection(hostAddress) - err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil) + err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "") if err != nil { t.Fatalf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error()) } @@ -245,6 +245,22 @@ func TestConfigs(t *testing.T) { } } +func TestVersionVerify(t *testing.T) { + const ( + username = "root" + password = "nebula" + ) + + hostAddress := HostAddress{Host: address, Port: port} + + conn := newConnection(hostAddress) + err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "INVALID_VERSION") + if err != nil { + assert.Contains(t, err.Error(), "incompatible version between client and server") + } + defer conn.close() +} + func TestAuthentication(t *testing.T) { const ( username = "dummy" @@ -254,7 +270,7 @@ func TestAuthentication(t *testing.T) { hostAddress := HostAddress{Host: address, Port: port} conn := newConnection(hostAddress) - err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil) + err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "") if err != nil { t.Fatalf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error()) } @@ -1421,7 +1437,7 @@ func prepareSpace(spaceName string) error { conn := newConnection(hostAddress) testPoolConfig := GetDefaultConf() - err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil) + err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "") if err != nil { return fmt.Errorf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error()) } @@ -1458,7 +1474,7 @@ func dropSpace(spaceName string) error { conn := newConnection(hostAddress) testPoolConfig := GetDefaultConf() - err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil) + err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "") if err != nil { return fmt.Errorf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error()) } diff --git a/configs.go b/configs.go index fec4c3fb..2a4d60ef 100644 --- a/configs.go +++ b/configs.go @@ -34,6 +34,8 @@ type PoolConfig struct { UseHTTP2 bool // HttpHeader is the http headers for the connection when using HTTP2 HttpHeader http.Header + // client version, make sure the client version is in the white list of NebulaGraph server 'client_white_list' + Version string } // validateConf validates config @@ -64,6 +66,7 @@ func GetDefaultConf() PoolConfig { MaxConnPoolSize: 10, MinConnPoolSize: 0, UseHTTP2: false, + Version: "", } } @@ -138,6 +141,8 @@ type SessionPoolConf struct { useHTTP2 bool // httpHeader is the http headers for the connection httpHeader http.Header + // client version, make sure the client version is in the white list of NebulaGraph server 'client_white_list' + version string } type SessionPoolConfOption func(*SessionPoolConf) @@ -214,6 +219,12 @@ func WithHttpHeader(header http.Header) SessionPoolConfOption { } } +func WithVersion(version string) SessionPoolConfOption { + return func(conf *SessionPoolConf) { + conf.version = version + } +} + func (conf *SessionPoolConf) checkMandatoryFields() error { // Check mandatory fields if conf.username == "" { diff --git a/connection.go b/connection.go index fdfe188e..c6ee5bf2 100644 --- a/connection.go +++ b/connection.go @@ -30,6 +30,7 @@ type connection struct { sslConfig *tls.Config useHTTP2 bool httpHeader http.Header + version string graph *graph.GraphServiceClient } @@ -39,6 +40,7 @@ func newConnection(severAddress HostAddress) *connection { timeout: 0 * time.Millisecond, returnedAt: time.Now(), sslConfig: nil, + version: "", graph: nil, } } @@ -46,12 +48,13 @@ func newConnection(severAddress HostAddress) *connection { // open opens a transport for the connection // if sslConfig is not nil, an SSL transport will be created func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslConfig *tls.Config, - useHTTP2 bool, httpHeader http.Header) error { + useHTTP2 bool, httpHeader http.Header, version string) error { ip := hostAddress.Host port := hostAddress.Port newAdd := net.JoinHostPort(ip, strconv.Itoa(port)) cn.timeout = timeout cn.useHTTP2 = useHTTP2 + cn.version = version var ( err error @@ -133,6 +136,9 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo func (cn *connection) verifyClientVersion() error { req := graph.NewVerifyClientVersionReq() + if cn.version != "" { + req.SetVersion([]byte(cn.version)) + } resp, err := cn.graph.VerifyClientVersion(req) if err != nil { cn.close() @@ -150,7 +156,7 @@ func (cn *connection) verifyClientVersion() error { // When the timeout occurs, the connection will be reopened to avoid the impact of the message. func (cn *connection) reopen() error { cn.close() - return cn.open(cn.severAddress, cn.timeout, cn.sslConfig, cn.useHTTP2, cn.httpHeader) + return cn.open(cn.severAddress, cn.timeout, cn.sslConfig, cn.useHTTP2, cn.httpHeader, cn.version) } // Authenticate diff --git a/connection_pool.go b/connection_pool.go index bb29efd3..50bce81b 100644 --- a/connection_pool.go +++ b/connection_pool.go @@ -66,7 +66,7 @@ func NewSslConnectionPool(addresses []HostAddress, conf PoolConfig, sslConfig *t // initPool initializes the connection pool func (pool *ConnectionPool) initPool() error { if err := checkAddresses(pool.conf.TimeOut, pool.addresses, pool.sslConfig, - pool.conf.UseHTTP2, pool.conf.HttpHeader); err != nil { + pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version); err != nil { return fmt.Errorf("failed to open connection, error: %s ", err.Error()) } @@ -76,7 +76,7 @@ func (pool *ConnectionPool) initPool() error { // Open connection to host if err := newConn.open(newConn.severAddress, pool.conf.TimeOut, pool.sslConfig, - pool.conf.UseHTTP2, pool.conf.HttpHeader); err != nil { + pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version); err != nil { // If initialization failed, clean idle queue idleLen := pool.idleConnectionQueue.Len() for i := 0; i < idleLen; i++ { @@ -194,7 +194,7 @@ func (pool *ConnectionPool) releaseAndBack(conn *connection, pushBack bool) { // Ping checks availability of host func (pool *ConnectionPool) Ping(host HostAddress, timeout time.Duration) error { - return pingAddress(host, timeout, pool.sslConfig, pool.conf.UseHTTP2, pool.conf.HttpHeader) + return pingAddress(host, timeout, pool.sslConfig, pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version) } // Close closes all connection @@ -246,7 +246,7 @@ func (pool *ConnectionPool) newConnToHost() (*connection, error) { newConn := newConnection(host) // Open connection to host if err := newConn.open(newConn.severAddress, pool.conf.TimeOut, pool.sslConfig, - pool.conf.UseHTTP2, pool.conf.HttpHeader); err != nil { + pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version); err != nil { return nil, err } // Add connection to active queue @@ -354,13 +354,13 @@ func (pool *ConnectionPool) timeoutConnectionList() (closing []*connection) { // It opens a temporary connection to each address and closes it immediately. // If no error is returned, the addresses are available. func checkAddresses(confTimeout time.Duration, addresses []HostAddress, sslConfig *tls.Config, - useHTTP2 bool, httpHeader http.Header) error { + useHTTP2 bool, httpHeader http.Header, version string) error { var timeout = 3 * time.Second if confTimeout != 0 && confTimeout < timeout { timeout = confTimeout } for _, address := range addresses { - if err := pingAddress(address, timeout, sslConfig, useHTTP2, httpHeader); err != nil { + if err := pingAddress(address, timeout, sslConfig, useHTTP2, httpHeader, version); err != nil { return err } } @@ -368,10 +368,10 @@ func checkAddresses(confTimeout time.Duration, addresses []HostAddress, sslConfi } func pingAddress(address HostAddress, timeout time.Duration, sslConfig *tls.Config, - useHTTP2 bool, httpHeader http.Header) error { + useHTTP2 bool, httpHeader http.Header, version string) error { newConn := newConnection(address) // Open connection to host - if err := newConn.open(newConn.severAddress, timeout, sslConfig, useHTTP2, httpHeader); err != nil { + if err := newConn.open(newConn.severAddress, timeout, sslConfig, useHTTP2, httpHeader, version); err != nil { return err } defer newConn.close() diff --git a/examples/basic_example/graph_client_basic_example.go b/examples/basic_example/graph_client_basic_example.go index 976aefc3..3ecebbfd 100644 --- a/examples/basic_example/graph_client_basic_example.go +++ b/examples/basic_example/graph_client_basic_example.go @@ -33,6 +33,7 @@ func main() { // Create configs for connection pool using default values testPoolConfig := nebula.GetDefaultConf() testPoolConfig.UseHTTP2 = useHTTP2 + testPoolConfig.Version = "3.0.0" // Initialize connection pool pool, err := nebula.NewConnectionPool(hostList, testPoolConfig, log) diff --git a/nebula-docker-compose/docker-compose.yaml b/nebula-docker-compose/docker-compose.yaml index 820ac387..bf2e8f04 100644 --- a/nebula-docker-compose/docker-compose.yaml +++ b/nebula-docker-compose/docker-compose.yaml @@ -282,6 +282,8 @@ services: - --system_memory_high_watermark_ratio=0.99 - --heartbeat_interval_secs=1 - --max_sessions_per_ip_per_user=1200 + - --enable_client_white_list=true + - --client_white_list=3.0.0:test # ssl - --ca_path=${ca_path} - --cert_path=${cert_path} @@ -328,6 +330,8 @@ services: - --system_memory_high_watermark_ratio=0.99 - --heartbeat_interval_secs=1 - --max_sessions_per_ip_per_user=1200 + - --enable_client_white_list=true + - --client_white_list=3.0.0:test # ssl - --ca_path=${ca_path} - --cert_path=${cert_path} @@ -374,6 +378,8 @@ services: - --system_memory_high_watermark_ratio=0.99 - --heartbeat_interval_secs=1 - --max_sessions_per_ip_per_user=1200 + - --enable_client_white_list=true + - --client_white_list=3.0.0:test # ssl - --ca_path=${ca_path} - --cert_path=${cert_path} diff --git a/session_pool.go b/session_pool.go index 877dcab6..e25312ac 100644 --- a/session_pool.go +++ b/session_pool.go @@ -79,7 +79,7 @@ func NewSessionPool(conf SessionPoolConf, log Logger) (*SessionPool, error) { func (pool *SessionPool) init() error { // check the hosts status if err := checkAddresses(pool.conf.timeOut, pool.conf.serviceAddrs, pool.conf.sslConfig, - pool.conf.useHTTP2, pool.conf.httpHeader); err != nil { + pool.conf.useHTTP2, pool.conf.httpHeader, pool.conf.version); err != nil { return fmt.Errorf("failed to initialize the session pool, %s", err.Error()) } @@ -287,7 +287,7 @@ func (pool *SessionPool) newSession() (*pureSession, error) { // open a new connection if err := cn.open(cn.severAddress, pool.conf.timeOut, pool.conf.sslConfig, - pool.conf.useHTTP2, pool.conf.httpHeader); err != nil { + pool.conf.useHTTP2, pool.conf.httpHeader, pool.conf.version); err != nil { return nil, fmt.Errorf("failed to create a net.Conn-backed Transport,: %s", err.Error()) } diff --git a/session_pool_test.go b/session_pool_test.go index 197e6a17..aa03afb4 100644 --- a/session_pool_test.go +++ b/session_pool_test.go @@ -125,6 +125,28 @@ func TestSessionPoolServerCheck(t *testing.T) { } } +func TestSessionPoolInvalidVersion(t *testing.T) { + prepareSpace("client_test") + defer dropSpace("client_test") + hostAddress := HostAddress{Host: address, Port: port} + + // wrong version info + versionConfig, err := NewSessionPoolConf( + "root", + "nebula", + []HostAddress{hostAddress}, + "client_test", + ) + versionConfig.version = "INVALID_VERSION" + versionConfig.minSize = 1 + + // create session pool + _, err = NewSessionPool(*versionConfig, DefaultLogger{}) + if err != nil { + assert.Contains(t, err.Error(), "incompatible version between client and server") + } +} + func TestSessionPoolBasic(t *testing.T) { prepareSpace("client_test") defer dropSpace("client_test")