Skip to content

Commit

Permalink
support to config the version white list (#300)
Browse files Browse the repository at this point in the history
* support to config the version white list

* update version

* add more note for version

* go fmt
  • Loading branch information
Nicole00 authored Jan 3, 2024
1 parent fde0af0 commit ae7b5d8
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 17 deletions.
26 changes: 21 additions & 5 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func logoutAndClose(conn *connection, sessionID int64) {
func TestConnection(t *testing.T) {
hostAddress := HostAddress{Host: address, Port: port}
conn := newConnection(hostAddress)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "")
if err != nil {
t.Fatalf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error())
}
Expand Down Expand Up @@ -122,7 +122,7 @@ func TestConnection(t *testing.T) {
func TestConnectionIPv6(t *testing.T) {
hostAddress := HostAddress{Host: addressIPv6, Port: port}
conn := newConnection(hostAddress)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "")
if err != nil {
t.Fatalf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error())
}
Expand Down Expand Up @@ -245,6 +245,22 @@ func TestConfigs(t *testing.T) {
}
}

func TestVersionVerify(t *testing.T) {
const (
username = "root"
password = "nebula"
)

hostAddress := HostAddress{Host: address, Port: port}

conn := newConnection(hostAddress)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "INVALID_VERSION")
if err != nil {
assert.Contains(t, err.Error(), "incompatible version between client and server")
}
defer conn.close()
}

func TestAuthentication(t *testing.T) {
const (
username = "dummy"
Expand All @@ -254,7 +270,7 @@ func TestAuthentication(t *testing.T) {
hostAddress := HostAddress{Host: address, Port: port}

conn := newConnection(hostAddress)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "")
if err != nil {
t.Fatalf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error())
}
Expand Down Expand Up @@ -1421,7 +1437,7 @@ func prepareSpace(spaceName string) error {
conn := newConnection(hostAddress)
testPoolConfig := GetDefaultConf()

err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "")
if err != nil {
return fmt.Errorf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error())
}
Expand Down Expand Up @@ -1458,7 +1474,7 @@ func dropSpace(spaceName string) error {
conn := newConnection(hostAddress)
testPoolConfig := GetDefaultConf()

err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "")
if err != nil {
return fmt.Errorf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error())
}
Expand Down
11 changes: 11 additions & 0 deletions configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type PoolConfig struct {
UseHTTP2 bool
// HttpHeader is the http headers for the connection when using HTTP2
HttpHeader http.Header
// client version, make sure the client version is in the white list of NebulaGraph server 'client_white_list'
Version string
}

// validateConf validates config
Expand Down Expand Up @@ -64,6 +66,7 @@ func GetDefaultConf() PoolConfig {
MaxConnPoolSize: 10,
MinConnPoolSize: 0,
UseHTTP2: false,
Version: "",
}
}

Expand Down Expand Up @@ -138,6 +141,8 @@ type SessionPoolConf struct {
useHTTP2 bool
// httpHeader is the http headers for the connection
httpHeader http.Header
// client version, make sure the client version is in the white list of NebulaGraph server 'client_white_list'
version string
}

type SessionPoolConfOption func(*SessionPoolConf)
Expand Down Expand Up @@ -214,6 +219,12 @@ func WithHttpHeader(header http.Header) SessionPoolConfOption {
}
}

func WithVersion(version string) SessionPoolConfOption {
return func(conf *SessionPoolConf) {
conf.version = version
}
}

func (conf *SessionPoolConf) checkMandatoryFields() error {
// Check mandatory fields
if conf.username == "" {
Expand Down
10 changes: 8 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type connection struct {
sslConfig *tls.Config
useHTTP2 bool
httpHeader http.Header
version string
graph *graph.GraphServiceClient
}

Expand All @@ -39,19 +40,21 @@ func newConnection(severAddress HostAddress) *connection {
timeout: 0 * time.Millisecond,
returnedAt: time.Now(),
sslConfig: nil,
version: "",
graph: nil,
}
}

// open opens a transport for the connection
// if sslConfig is not nil, an SSL transport will be created
func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslConfig *tls.Config,
useHTTP2 bool, httpHeader http.Header) error {
useHTTP2 bool, httpHeader http.Header, version string) error {
ip := hostAddress.Host
port := hostAddress.Port
newAdd := net.JoinHostPort(ip, strconv.Itoa(port))
cn.timeout = timeout
cn.useHTTP2 = useHTTP2
cn.version = version

var (
err error
Expand Down Expand Up @@ -133,6 +136,9 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo

func (cn *connection) verifyClientVersion() error {
req := graph.NewVerifyClientVersionReq()
if cn.version != "" {
req.SetVersion([]byte(cn.version))
}
resp, err := cn.graph.VerifyClientVersion(req)
if err != nil {
cn.close()
Expand All @@ -150,7 +156,7 @@ func (cn *connection) verifyClientVersion() error {
// When the timeout occurs, the connection will be reopened to avoid the impact of the message.
func (cn *connection) reopen() error {
cn.close()
return cn.open(cn.severAddress, cn.timeout, cn.sslConfig, cn.useHTTP2, cn.httpHeader)
return cn.open(cn.severAddress, cn.timeout, cn.sslConfig, cn.useHTTP2, cn.httpHeader, cn.version)
}

// Authenticate
Expand Down
16 changes: 8 additions & 8 deletions connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func NewSslConnectionPool(addresses []HostAddress, conf PoolConfig, sslConfig *t
// initPool initializes the connection pool
func (pool *ConnectionPool) initPool() error {
if err := checkAddresses(pool.conf.TimeOut, pool.addresses, pool.sslConfig,
pool.conf.UseHTTP2, pool.conf.HttpHeader); err != nil {
pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version); err != nil {
return fmt.Errorf("failed to open connection, error: %s ", err.Error())
}

Expand All @@ -76,7 +76,7 @@ func (pool *ConnectionPool) initPool() error {

// Open connection to host
if err := newConn.open(newConn.severAddress, pool.conf.TimeOut, pool.sslConfig,
pool.conf.UseHTTP2, pool.conf.HttpHeader); err != nil {
pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version); err != nil {
// If initialization failed, clean idle queue
idleLen := pool.idleConnectionQueue.Len()
for i := 0; i < idleLen; i++ {
Expand Down Expand Up @@ -194,7 +194,7 @@ func (pool *ConnectionPool) releaseAndBack(conn *connection, pushBack bool) {

// Ping checks availability of host
func (pool *ConnectionPool) Ping(host HostAddress, timeout time.Duration) error {
return pingAddress(host, timeout, pool.sslConfig, pool.conf.UseHTTP2, pool.conf.HttpHeader)
return pingAddress(host, timeout, pool.sslConfig, pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version)
}

// Close closes all connection
Expand Down Expand Up @@ -246,7 +246,7 @@ func (pool *ConnectionPool) newConnToHost() (*connection, error) {
newConn := newConnection(host)
// Open connection to host
if err := newConn.open(newConn.severAddress, pool.conf.TimeOut, pool.sslConfig,
pool.conf.UseHTTP2, pool.conf.HttpHeader); err != nil {
pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version); err != nil {
return nil, err
}
// Add connection to active queue
Expand Down Expand Up @@ -354,24 +354,24 @@ func (pool *ConnectionPool) timeoutConnectionList() (closing []*connection) {
// It opens a temporary connection to each address and closes it immediately.
// If no error is returned, the addresses are available.
func checkAddresses(confTimeout time.Duration, addresses []HostAddress, sslConfig *tls.Config,
useHTTP2 bool, httpHeader http.Header) error {
useHTTP2 bool, httpHeader http.Header, version string) error {
var timeout = 3 * time.Second
if confTimeout != 0 && confTimeout < timeout {
timeout = confTimeout
}
for _, address := range addresses {
if err := pingAddress(address, timeout, sslConfig, useHTTP2, httpHeader); err != nil {
if err := pingAddress(address, timeout, sslConfig, useHTTP2, httpHeader, version); err != nil {
return err
}
}
return nil
}

func pingAddress(address HostAddress, timeout time.Duration, sslConfig *tls.Config,
useHTTP2 bool, httpHeader http.Header) error {
useHTTP2 bool, httpHeader http.Header, version string) error {
newConn := newConnection(address)
// Open connection to host
if err := newConn.open(newConn.severAddress, timeout, sslConfig, useHTTP2, httpHeader); err != nil {
if err := newConn.open(newConn.severAddress, timeout, sslConfig, useHTTP2, httpHeader, version); err != nil {
return err
}
defer newConn.close()
Expand Down
1 change: 1 addition & 0 deletions examples/basic_example/graph_client_basic_example.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func main() {
// Create configs for connection pool using default values
testPoolConfig := nebula.GetDefaultConf()
testPoolConfig.UseHTTP2 = useHTTP2
testPoolConfig.Version = "3.0.0"

// Initialize connection pool
pool, err := nebula.NewConnectionPool(hostList, testPoolConfig, log)
Expand Down
6 changes: 6 additions & 0 deletions nebula-docker-compose/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ services:
- --system_memory_high_watermark_ratio=0.99
- --heartbeat_interval_secs=1
- --max_sessions_per_ip_per_user=1200
- --enable_client_white_list=true
- --client_white_list=3.0.0:test
# ssl
- --ca_path=${ca_path}
- --cert_path=${cert_path}
Expand Down Expand Up @@ -328,6 +330,8 @@ services:
- --system_memory_high_watermark_ratio=0.99
- --heartbeat_interval_secs=1
- --max_sessions_per_ip_per_user=1200
- --enable_client_white_list=true
- --client_white_list=3.0.0:test
# ssl
- --ca_path=${ca_path}
- --cert_path=${cert_path}
Expand Down Expand Up @@ -374,6 +378,8 @@ services:
- --system_memory_high_watermark_ratio=0.99
- --heartbeat_interval_secs=1
- --max_sessions_per_ip_per_user=1200
- --enable_client_white_list=true
- --client_white_list=3.0.0:test
# ssl
- --ca_path=${ca_path}
- --cert_path=${cert_path}
Expand Down
4 changes: 2 additions & 2 deletions session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func NewSessionPool(conf SessionPoolConf, log Logger) (*SessionPool, error) {
func (pool *SessionPool) init() error {
// check the hosts status
if err := checkAddresses(pool.conf.timeOut, pool.conf.serviceAddrs, pool.conf.sslConfig,
pool.conf.useHTTP2, pool.conf.httpHeader); err != nil {
pool.conf.useHTTP2, pool.conf.httpHeader, pool.conf.version); err != nil {
return fmt.Errorf("failed to initialize the session pool, %s", err.Error())
}

Expand Down Expand Up @@ -287,7 +287,7 @@ func (pool *SessionPool) newSession() (*pureSession, error) {

// open a new connection
if err := cn.open(cn.severAddress, pool.conf.timeOut, pool.conf.sslConfig,
pool.conf.useHTTP2, pool.conf.httpHeader); err != nil {
pool.conf.useHTTP2, pool.conf.httpHeader, pool.conf.version); err != nil {
return nil, fmt.Errorf("failed to create a net.Conn-backed Transport,: %s", err.Error())
}

Expand Down
22 changes: 22 additions & 0 deletions session_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,28 @@ func TestSessionPoolServerCheck(t *testing.T) {
}
}

func TestSessionPoolInvalidVersion(t *testing.T) {
prepareSpace("client_test")
defer dropSpace("client_test")
hostAddress := HostAddress{Host: address, Port: port}

// wrong version info
versionConfig, err := NewSessionPoolConf(
"root",
"nebula",
[]HostAddress{hostAddress},
"client_test",
)
versionConfig.version = "INVALID_VERSION"
versionConfig.minSize = 1

// create session pool
_, err = NewSessionPool(*versionConfig, DefaultLogger{})
if err != nil {
assert.Contains(t, err.Error(), "incompatible version between client and server")
}
}

func TestSessionPoolBasic(t *testing.T) {
prepareSpace("client_test")
defer dropSpace("client_test")
Expand Down

0 comments on commit ae7b5d8

Please sign in to comment.