Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Lack of context.Context support #350 #351

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 144 additions & 115 deletions client_test.go

Large diffs are not rendered by default.

121 changes: 70 additions & 51 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"strconv"
"time"

"github.com/vesoft-inc/fbthrift/thrift/lib/go/thrift"
"github.com/apache/thrift/lib/go/thrift"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry that we cannot replace the thrift with apache thrift, we need to keep the rpc protocol the same with the server to avoid some incompatibility issues.

"github.com/vesoft-inc/nebula-go/v3/nebula"
"github.com/vesoft-inc/nebula-go/v3/nebula/graph"
"golang.org/x/net/http2"
Expand All @@ -32,6 +32,7 @@ type connection struct {
httpHeader http.Header
handshakeKey string
graph *graph.GraphServiceClient
transport thrift.TTransport
}

func newConnection(severAddress HostAddress) *connection {
Expand All @@ -42,12 +43,13 @@ func newConnection(severAddress HostAddress) *connection {
sslConfig: nil,
handshakeKey: "",
graph: nil,
transport: nil,
}
}

// open opens a transport for the connection
// open opens transport for the connection
// if sslConfig is not nil, an SSL transport will be created
func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslConfig *tls.Config,
func (cn *connection) open(ctx context.Context, hostAddress HostAddress, timeout time.Duration, sslConfig *tls.Config,
useHTTP2 bool, httpHeader http.Header, handshakeKey string) error {
ip := hostAddress.Host
port := hostAddress.Port
Expand All @@ -58,20 +60,21 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo

var (
err error
transport thrift.Transport
pf thrift.ProtocolFactory
transport thrift.TTransport
pf thrift.TProtocolFactory
)
if useHTTP2 {
if sslConfig != nil {
transport, err = thrift.NewHTTPPostClientWithOptions("https://"+newAdd, thrift.HTTPClientOptions{
Client: &http.Client{
Transport: &http2.Transport{
TLSClientConfig: sslConfig,
transport, err = thrift.NewTHttpClientWithOptions("https://"+newAdd,
thrift.THttpClientOptions{
Client: &http.Client{
Transport: &http2.Transport{
TLSClientConfig: sslConfig,
},
},
},
})
})
} else {
transport, err = thrift.NewHTTPPostClientWithOptions("http://"+newAdd, thrift.HTTPClientOptions{
transport, err = thrift.NewTHttpClientWithOptions("https://"+newAdd, thrift.THttpClientOptions{
Client: &http.Client{
Transport: &http2.Transport{
// So http2.Transport doesn't complain the URL scheme isn't 'https'
Expand All @@ -85,13 +88,15 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo
},
},
})

}
if err != nil {
return fmt.Errorf("failed to create a net.Conn-backed Transport,: %s", err.Error())
}
pf = thrift.NewBinaryProtocolFactoryDefault()

pf = thrift.NewTBinaryProtocolFactoryDefault()
if httpHeader != nil {
client, ok := transport.(*thrift.HTTPClient)
client, ok := transport.(*thrift.THttpClient)
if !ok {
return fmt.Errorf("failed to get thrift http client")
}
Expand All @@ -109,37 +114,43 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo
} else {
bufferSize := 128 << 10

var sock thrift.Transport
var sock thrift.TTransport
if sslConfig != nil {
sock, err = thrift.NewSSLSocketTimeout(newAdd, sslConfig, timeout)
//TODO replace the deprecated thrift structures later
sock, err = thrift.NewTSSLSocketTimeout(newAdd, sslConfig, timeout, timeout)
} else {
sock, err = thrift.NewSocket(thrift.SocketAddr(newAdd), thrift.SocketTimeout(timeout))
sock, err = thrift.NewTSocketTimeout(newAdd, timeout, timeout)
}
if err != nil {
return fmt.Errorf("failed to create a net.Conn-backed Transport,: %s", err.Error())
}
// Set transport
bufferedTranFactory := thrift.NewBufferedTransportFactory(bufferSize)
transport = thrift.NewHeaderTransport(bufferedTranFactory.GetTransport(sock))
pf = thrift.NewHeaderProtocolFactory()
bufferedTransFactory := thrift.NewTBufferedTransportFactory(bufferSize)
buffTransport, _ := bufferedTransFactory.GetTransport(sock)
transport = thrift.NewTHeaderTransport(buffTransport)
pf = thrift.NewTHeaderProtocolFactory()
}

cn.transport = transport
cn.graph = graph.NewGraphServiceClientFactory(transport, pf)
if err = cn.graph.Open(); err != nil {

if err = cn.transport.Open(); err != nil {
return fmt.Errorf("failed to open transport, error: %s", err.Error())
}
if !cn.graph.IsOpen() {

if !cn.transport.IsOpen() {
return fmt.Errorf("transport is off")
}
return cn.verifyClientVersion()

return cn.verifyClientVersion(ctx)
}

func (cn *connection) verifyClientVersion() error {
func (cn *connection) verifyClientVersion(ctx context.Context) error {
req := graph.NewVerifyClientVersionReq()
if cn.handshakeKey != "" {
req.SetVersion([]byte(cn.handshakeKey))
req.Version = []byte(cn.handshakeKey)
}
resp, err := cn.graph.VerifyClientVersion(req)
resp, err := cn.graph.VerifyClientVersion(ctx, req)
if err != nil {
cn.close()
return fmt.Errorf("failed to verify client handshakeKey: %s", err.Error())
Expand All @@ -154,17 +165,18 @@ func (cn *connection) verifyClientVersion() error {
// Because the code generated by Fbthrift does not handle the seqID,
// the message will be dislocated when the timeout occurs, resulting in unexpected response.
// When the timeout occurs, the connection will be reopened to avoid the impact of the message.
func (cn *connection) reopen() error {
func (cn *connection) reopen(ctx context.Context) error {
cn.close()
return cn.open(cn.severAddress, cn.timeout, cn.sslConfig, cn.useHTTP2, cn.httpHeader, cn.handshakeKey)
return cn.open(ctx, cn.severAddress, cn.timeout, cn.sslConfig, cn.useHTTP2, cn.httpHeader, cn.handshakeKey)
}

// Authenticate
func (cn *connection) authenticate(username, password string) (*graph.AuthResponse, error) {
resp, err := cn.graph.Authenticate([]byte(username), []byte(password))
func (cn *connection) authenticate(ctx context.Context, username, password string) (*graph.AuthResponse, error) {
resp, err := cn.graph.Authenticate(ctx, []byte(username), []byte(password))
if err != nil {
err = fmt.Errorf("authentication fails, %s", err.Error())
if e := cn.graph.Close(); e != nil {

if e := cn.transport.Close(); e != nil {
err = fmt.Errorf("fail to close transport, error: %s", e.Error())
}
return nil, err
Expand All @@ -173,39 +185,41 @@ func (cn *connection) authenticate(username, password string) (*graph.AuthRespon
return resp, nil
}

func (cn *connection) execute(sessionID int64, stmt string) (*graph.ExecutionResponse, error) {
return cn.executeWithParameter(sessionID, stmt, map[string]*nebula.Value{})
func (cn *connection) execute(ctx context.Context, sessionID int64, stmt string) (*graph.ExecutionResponse, error) {
return cn.executeWithParameter(ctx, sessionID, stmt, map[string]*nebula.Value{})
}

func (cn *connection) executeWithParameter(sessionID int64, stmt string,
func (cn *connection) executeWithParameter(ctx context.Context, sessionID int64, stmt string,
params map[string]*nebula.Value) (*graph.ExecutionResponse, error) {
resp, err := cn.graph.ExecuteWithParameter(sessionID, []byte(stmt), params)
resp, err := cn.graph.ExecuteWithParameter(ctx, sessionID, []byte(stmt), params)
if err != nil {
return nil, err
}

return resp, nil
}

func (cn *connection) executeWithParameterTimeout(sessionID int64, stmt string, params map[string]*nebula.Value, timeoutMs int64) (*graph.ExecutionResponse, error) {
return cn.graph.ExecuteWithTimeout(sessionID, []byte(stmt), params, timeoutMs)
func (cn *connection) executeWithParameterTimeout(ctx context.Context, sessionID int64, stmt string, params map[string]*nebula.Value, timeoutMs int64) (*graph.ExecutionResponse, error) {
//TODO handle timeout value later
return cn.graph.ExecuteWithParameter(ctx, sessionID, []byte(stmt), params)
//return cn.graph.ExecuteWithTimeout(ctx, sessionID, []byte(stmt), params, timeoutMs)
}

func (cn *connection) executeJson(sessionID int64, stmt string) ([]byte, error) {
return cn.ExecuteJsonWithParameter(sessionID, stmt, map[string]*nebula.Value{})
func (cn *connection) executeJson(ctx context.Context, sessionID int64, stmt string) ([]byte, error) {
return cn.ExecuteJsonWithParameter(ctx, sessionID, stmt, map[string]*nebula.Value{})
}

func (cn *connection) ExecuteJsonWithParameter(sessionID int64, stmt string, params map[string]*nebula.Value) ([]byte, error) {
jsonResp, err := cn.graph.ExecuteJsonWithParameter(sessionID, []byte(stmt), params)
func (cn *connection) ExecuteJsonWithParameter(ctx context.Context, sessionID int64, stmt string, params map[string]*nebula.Value) ([]byte, error) {
jsonResp, err := cn.graph.ExecuteJsonWithParameter(ctx, sessionID, []byte(stmt), params)
if err != nil {
// reopen the connection if timeout
if _, ok := err.(thrift.TransportException); ok {
if err.(thrift.TransportException).TypeID() == thrift.TIMED_OUT {
reopenErr := cn.reopen()
if _, ok := err.(thrift.TTransportException); ok {
if err.(thrift.TTransportException).TypeId() == thrift.TIMED_OUT {
reopenErr := cn.reopen(ctx)
if reopenErr != nil {
return nil, reopenErr
}
return cn.graph.ExecuteJsonWithParameter(sessionID, []byte(stmt), params)
return cn.graph.ExecuteJsonWithParameter(ctx, sessionID, []byte(stmt), params)
}
}
}
Expand All @@ -214,15 +228,15 @@ func (cn *connection) ExecuteJsonWithParameter(sessionID int64, stmt string, par
}

// Check connection to host address
func (cn *connection) ping() bool {
_, err := cn.execute(0, "YIELD 1")
func (cn *connection) ping(ctx context.Context) bool {
_, err := cn.execute(ctx, 0, "YIELD 1")
return err == nil
}

// Sign out and release session ID
func (cn *connection) signOut(sessionID int64) error {
func (cn *connection) signOut(ctx context.Context, sessionID int64) error {
// Release session ID to graphd
return cn.graph.Signout(sessionID)
return cn.graph.Signout(ctx, sessionID)
}

// Update returnedAt for cleaner
Expand All @@ -231,6 +245,11 @@ func (cn *connection) release() {
}

// Close transport
func (cn *connection) close() {
cn.graph.Close()
func (cn *connection) close() error {
if e := cn.transport.Close(); e != nil {
err := fmt.Errorf("fail to close transport, error: %s", e.Error())

return err
}
return nil
}
Loading