Skip to content

Commit

Permalink
Add context.Context support for operations
Browse files Browse the repository at this point in the history
  • Loading branch information
egasimov committed Sep 9, 2024
1 parent d12bd94 commit 90ac1a7
Show file tree
Hide file tree
Showing 55 changed files with 108,932 additions and 113,378 deletions.
259 changes: 144 additions & 115 deletions client_test.go

Large diffs are not rendered by default.

121 changes: 70 additions & 51 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"strconv"
"time"

"github.com/vesoft-inc/fbthrift/thrift/lib/go/thrift"
"github.com/apache/thrift/lib/go/thrift"
"github.com/vesoft-inc/nebula-go/v3/nebula"
"github.com/vesoft-inc/nebula-go/v3/nebula/graph"
"golang.org/x/net/http2"
Expand All @@ -32,6 +32,7 @@ type connection struct {
httpHeader http.Header
handshakeKey string
graph *graph.GraphServiceClient
transport thrift.TTransport
}

func newConnection(severAddress HostAddress) *connection {
Expand All @@ -42,12 +43,13 @@ func newConnection(severAddress HostAddress) *connection {
sslConfig: nil,
handshakeKey: "",
graph: nil,
transport: nil,
}
}

// open opens a transport for the connection
// open opens transport for the connection
// if sslConfig is not nil, an SSL transport will be created
func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslConfig *tls.Config,
func (cn *connection) open(ctx context.Context, hostAddress HostAddress, timeout time.Duration, sslConfig *tls.Config,
useHTTP2 bool, httpHeader http.Header, handshakeKey string) error {
ip := hostAddress.Host
port := hostAddress.Port
Expand All @@ -58,20 +60,21 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo

var (
err error
transport thrift.Transport
pf thrift.ProtocolFactory
transport thrift.TTransport
pf thrift.TProtocolFactory
)
if useHTTP2 {
if sslConfig != nil {
transport, err = thrift.NewHTTPPostClientWithOptions("https://"+newAdd, thrift.HTTPClientOptions{
Client: &http.Client{
Transport: &http2.Transport{
TLSClientConfig: sslConfig,
transport, err = thrift.NewTHttpClientWithOptions("https://"+newAdd,
thrift.THttpClientOptions{
Client: &http.Client{
Transport: &http2.Transport{
TLSClientConfig: sslConfig,
},
},
},
})
})
} else {
transport, err = thrift.NewHTTPPostClientWithOptions("http://"+newAdd, thrift.HTTPClientOptions{
transport, err = thrift.NewTHttpClientWithOptions("https://"+newAdd, thrift.THttpClientOptions{
Client: &http.Client{
Transport: &http2.Transport{
// So http2.Transport doesn't complain the URL scheme isn't 'https'
Expand All @@ -85,13 +88,15 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo
},
},
})

}
if err != nil {
return fmt.Errorf("failed to create a net.Conn-backed Transport,: %s", err.Error())
}
pf = thrift.NewBinaryProtocolFactoryDefault()

pf = thrift.NewTBinaryProtocolFactoryDefault()
if httpHeader != nil {
client, ok := transport.(*thrift.HTTPClient)
client, ok := transport.(*thrift.THttpClient)
if !ok {
return fmt.Errorf("failed to get thrift http client")
}
Expand All @@ -109,37 +114,43 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo
} else {
bufferSize := 128 << 10

var sock thrift.Transport
var sock thrift.TTransport
if sslConfig != nil {
sock, err = thrift.NewSSLSocketTimeout(newAdd, sslConfig, timeout)
//TODO replace the deprecated thrift structures later
sock, err = thrift.NewTSSLSocketTimeout(newAdd, sslConfig, timeout, timeout)
} else {
sock, err = thrift.NewSocket(thrift.SocketAddr(newAdd), thrift.SocketTimeout(timeout))
sock, err = thrift.NewTSocketTimeout(newAdd, timeout, timeout)
}
if err != nil {
return fmt.Errorf("failed to create a net.Conn-backed Transport,: %s", err.Error())
}
// Set transport
bufferedTranFactory := thrift.NewBufferedTransportFactory(bufferSize)
transport = thrift.NewHeaderTransport(bufferedTranFactory.GetTransport(sock))
pf = thrift.NewHeaderProtocolFactory()
bufferedTransFactory := thrift.NewTBufferedTransportFactory(bufferSize)
buffTransport, _ := bufferedTransFactory.GetTransport(sock)
transport = thrift.NewTHeaderTransport(buffTransport)
pf = thrift.NewTHeaderProtocolFactory()
}

cn.transport = transport
cn.graph = graph.NewGraphServiceClientFactory(transport, pf)
if err = cn.graph.Open(); err != nil {

if err = cn.transport.Open(); err != nil {
return fmt.Errorf("failed to open transport, error: %s", err.Error())
}
if !cn.graph.IsOpen() {

if !cn.transport.IsOpen() {
return fmt.Errorf("transport is off")
}
return cn.verifyClientVersion()

return cn.verifyClientVersion(ctx)
}

func (cn *connection) verifyClientVersion() error {
func (cn *connection) verifyClientVersion(ctx context.Context) error {
req := graph.NewVerifyClientVersionReq()
if cn.handshakeKey != "" {
req.SetVersion([]byte(cn.handshakeKey))
req.Version = []byte(cn.handshakeKey)
}
resp, err := cn.graph.VerifyClientVersion(req)
resp, err := cn.graph.VerifyClientVersion(ctx, req)
if err != nil {
cn.close()
return fmt.Errorf("failed to verify client handshakeKey: %s", err.Error())
Expand All @@ -154,17 +165,18 @@ func (cn *connection) verifyClientVersion() error {
// Because the code generated by Fbthrift does not handle the seqID,
// the message will be dislocated when the timeout occurs, resulting in unexpected response.
// When the timeout occurs, the connection will be reopened to avoid the impact of the message.
func (cn *connection) reopen() error {
func (cn *connection) reopen(ctx context.Context) error {
cn.close()
return cn.open(cn.severAddress, cn.timeout, cn.sslConfig, cn.useHTTP2, cn.httpHeader, cn.handshakeKey)
return cn.open(ctx, cn.severAddress, cn.timeout, cn.sslConfig, cn.useHTTP2, cn.httpHeader, cn.handshakeKey)
}

// Authenticate
func (cn *connection) authenticate(username, password string) (*graph.AuthResponse, error) {
resp, err := cn.graph.Authenticate([]byte(username), []byte(password))
func (cn *connection) authenticate(ctx context.Context, username, password string) (*graph.AuthResponse, error) {
resp, err := cn.graph.Authenticate(ctx, []byte(username), []byte(password))
if err != nil {
err = fmt.Errorf("authentication fails, %s", err.Error())
if e := cn.graph.Close(); e != nil {

if e := cn.transport.Close(); e != nil {
err = fmt.Errorf("fail to close transport, error: %s", e.Error())
}
return nil, err
Expand All @@ -173,39 +185,41 @@ func (cn *connection) authenticate(username, password string) (*graph.AuthRespon
return resp, nil
}

func (cn *connection) execute(sessionID int64, stmt string) (*graph.ExecutionResponse, error) {
return cn.executeWithParameter(sessionID, stmt, map[string]*nebula.Value{})
func (cn *connection) execute(ctx context.Context, sessionID int64, stmt string) (*graph.ExecutionResponse, error) {
return cn.executeWithParameter(ctx, sessionID, stmt, map[string]*nebula.Value{})
}

func (cn *connection) executeWithParameter(sessionID int64, stmt string,
func (cn *connection) executeWithParameter(ctx context.Context, sessionID int64, stmt string,
params map[string]*nebula.Value) (*graph.ExecutionResponse, error) {
resp, err := cn.graph.ExecuteWithParameter(sessionID, []byte(stmt), params)
resp, err := cn.graph.ExecuteWithParameter(ctx, sessionID, []byte(stmt), params)
if err != nil {
return nil, err
}

return resp, nil
}

func (cn *connection) executeWithParameterTimeout(sessionID int64, stmt string, params map[string]*nebula.Value, timeoutMs int64) (*graph.ExecutionResponse, error) {
return cn.graph.ExecuteWithTimeout(sessionID, []byte(stmt), params, timeoutMs)
func (cn *connection) executeWithParameterTimeout(ctx context.Context, sessionID int64, stmt string, params map[string]*nebula.Value, timeoutMs int64) (*graph.ExecutionResponse, error) {
//TODO handle timeout value later
return cn.graph.ExecuteWithParameter(ctx, sessionID, []byte(stmt), params)
//return cn.graph.ExecuteWithTimeout(ctx, sessionID, []byte(stmt), params, timeoutMs)
}

func (cn *connection) executeJson(sessionID int64, stmt string) ([]byte, error) {
return cn.ExecuteJsonWithParameter(sessionID, stmt, map[string]*nebula.Value{})
func (cn *connection) executeJson(ctx context.Context, sessionID int64, stmt string) ([]byte, error) {
return cn.ExecuteJsonWithParameter(ctx, sessionID, stmt, map[string]*nebula.Value{})
}

func (cn *connection) ExecuteJsonWithParameter(sessionID int64, stmt string, params map[string]*nebula.Value) ([]byte, error) {
jsonResp, err := cn.graph.ExecuteJsonWithParameter(sessionID, []byte(stmt), params)
func (cn *connection) ExecuteJsonWithParameter(ctx context.Context, sessionID int64, stmt string, params map[string]*nebula.Value) ([]byte, error) {
jsonResp, err := cn.graph.ExecuteJsonWithParameter(ctx, sessionID, []byte(stmt), params)
if err != nil {
// reopen the connection if timeout
if _, ok := err.(thrift.TransportException); ok {
if err.(thrift.TransportException).TypeID() == thrift.TIMED_OUT {
reopenErr := cn.reopen()
if _, ok := err.(thrift.TTransportException); ok {
if err.(thrift.TTransportException).TypeId() == thrift.TIMED_OUT {
reopenErr := cn.reopen(ctx)
if reopenErr != nil {
return nil, reopenErr
}
return cn.graph.ExecuteJsonWithParameter(sessionID, []byte(stmt), params)
return cn.graph.ExecuteJsonWithParameter(ctx, sessionID, []byte(stmt), params)
}
}
}
Expand All @@ -214,15 +228,15 @@ func (cn *connection) ExecuteJsonWithParameter(sessionID int64, stmt string, par
}

// Check connection to host address
func (cn *connection) ping() bool {
_, err := cn.execute(0, "YIELD 1")
func (cn *connection) ping(ctx context.Context) bool {
_, err := cn.execute(ctx, 0, "YIELD 1")
return err == nil
}

// Sign out and release session ID
func (cn *connection) signOut(sessionID int64) error {
func (cn *connection) signOut(ctx context.Context, sessionID int64) error {
// Release session ID to graphd
return cn.graph.Signout(sessionID)
return cn.graph.Signout(ctx, sessionID)
}

// Update returnedAt for cleaner
Expand All @@ -231,6 +245,11 @@ func (cn *connection) release() {
}

// Close transport
func (cn *connection) close() {
cn.graph.Close()
func (cn *connection) close() error {
if e := cn.transport.Close(); e != nil {
err := fmt.Errorf("fail to close transport, error: %s", e.Error())

return err
}
return nil
}
Loading

0 comments on commit 90ac1a7

Please sign in to comment.