diff --git a/db.go b/db.go index f73f2da..5a45108 100644 --- a/db.go +++ b/db.go @@ -11,7 +11,8 @@ import ( // DB is a client for the SurrealDB database that holds the connection. type DB struct { - conn conn.Connection + conn conn.Connection + ExitSignal chan error } // Auth is a struct that holds surrealdb auth data for login. @@ -25,11 +26,12 @@ type Auth struct { // New creates a new SurrealDB client. func New(url string, connection conn.Connection) (*DB, error) { - connection, err := connection.Connect(url) + exitSignal := make(chan error) + connection, err := connection.Connect(url, exitSignal) if err != nil { return nil, err } - return &DB{connection}, nil + return &DB{connection, exitSignal}, nil } // -------------------------------------------------- diff --git a/db_test.go b/db_test.go index 874c567..01df3fc 100644 --- a/db_test.go +++ b/db_test.go @@ -121,6 +121,18 @@ func (s *SurrealDBTestSuite) openConnection() *surrealdb.DB { impl := s.connImplementations[s.name] require.NotNil(s.T(), impl) db, err := surrealdb.New(url, impl) + + go func(s *SurrealDBTestSuite) { + connErr := <-db.ExitSignal + if connErr != nil { + if connErr == conn.ErrConnectionClosed { + return // Connection closed is expected + } + fmt.Println(connErr) + require.NoError(s.T(), connErr) + } + }(s) + s.Require().NoError(err) return db } diff --git a/internal/mock/mock.go b/internal/mock/mock.go index a3b7e3e..dc0d081 100644 --- a/internal/mock/mock.go +++ b/internal/mock/mock.go @@ -7,10 +7,9 @@ import ( "github.com/surrealdb/surrealdb.go/pkg/model" ) -type ws struct { -} +type ws struct{} -func (w *ws) Connect(url string) (conn.Connection, error) { +func (w *ws) Connect(url string, exitSignal chan error) (conn.Connection, error) { return w, nil } diff --git a/pkg/conn/conn.go b/pkg/conn/conn.go index dc619ce..7a1ba00 100644 --- a/pkg/conn/conn.go +++ b/pkg/conn/conn.go @@ -1,10 +1,22 @@ package conn -import "github.com/surrealdb/surrealdb.go/pkg/model" +import ( + "errors" + "io" + + "github.com/surrealdb/surrealdb.go/pkg/model" +) type Connection interface { - Connect(url string) (Connection, error) + Connect(url string, exitChannel chan error) (Connection, error) Send(method string, params []interface{}) (interface{}, error) Close() error LiveNotifications(id string) (chan model.Notification, error) } + +var ( + ErrConnectionClosed = errors.New("connection closed") + ErrConnectionNotConnected = errors.New("connection not connected") + ErrTimeout = errors.New("timeout") + ErrClosedPipe = io.ErrClosedPipe +) diff --git a/pkg/conn/gorilla/gorilla.go b/pkg/conn/gorilla/gorilla.go index dcfbb37..23ad851 100644 --- a/pkg/conn/gorilla/gorilla.go +++ b/pkg/conn/gorilla/gorilla.go @@ -56,7 +56,7 @@ func Create() *WebSocket { } } -func (ws *WebSocket) Connect(url string) (conn.Connection, error) { +func (ws *WebSocket) Connect(url string, exitSignal chan error) (conn.Connection, error) { dialer := gorilla.DefaultDialer dialer.EnableCompression = true @@ -73,7 +73,8 @@ func (ws *WebSocket) Connect(url string) (conn.Connection, error) { } } - ws.initialize() + go ws.initialize(exitSignal) + return ws, nil } @@ -234,26 +235,29 @@ func (ws *WebSocket) write(v interface{}) error { return ws.Conn.WriteMessage(gorilla.TextMessage, data) } -func (ws *WebSocket) initialize() { - go func() { - for { - select { - case <-ws.close: - return - default: - var res rpc.RPCResponse - err := ws.read(&res) - if err != nil { - if errors.Is(err, net.ErrClosed) { - break - } - ws.logger.Error(err.Error()) - continue +func (ws *WebSocket) initialize(exitSignal chan error) { + for { + select { + case <-ws.close: + exitSignal <- conn.ErrConnectionClosed + default: + var res rpc.RPCResponse + err := ws.read(&res) + if err != nil { + // this needed because gorilla not shudown gracefully + if errors.Is(err, net.ErrClosed) { + exitSignal <- conn.ErrConnectionClosed } - go ws.handleResponse(res) + // returns error if connection drop in fly + if gorilla.IsUnexpectedCloseError(err) { + exitSignal <- conn.ErrClosedPipe + } + ws.logger.Error(err.Error()) + continue } + go ws.handleResponse(res) } - }() + } } func (ws *WebSocket) handleResponse(res rpc.RPCResponse) {