Skip to content

Commit

Permalink
Channel version of panic-fix of surrealdb#135
Browse files Browse the repository at this point in the history
  • Loading branch information
ElecTwix committed May 15, 2024
1 parent 7c2584a commit 908a377
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 27 deletions.
8 changes: 5 additions & 3 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (

// DB is a client for the SurrealDB database that holds the connection.
type DB struct {
conn conn.Connection
conn conn.Connection
ExitSignal chan error
}

// Auth is a struct that holds surrealdb auth data for login.
Expand All @@ -25,11 +26,12 @@ type Auth struct {

// New creates a new SurrealDB client.
func New(url string, connection conn.Connection) (*DB, error) {
connection, err := connection.Connect(url)
exitSignal := make(chan error)
connection, err := connection.Connect(url, exitSignal)
if err != nil {
return nil, err
}
return &DB{connection}, nil
return &DB{connection, exitSignal}, nil
}

// --------------------------------------------------
Expand Down
12 changes: 12 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ func (s *SurrealDBTestSuite) openConnection() *surrealdb.DB {
impl := s.connImplementations[s.name]
require.NotNil(s.T(), impl)
db, err := surrealdb.New(url, impl)

go func(s *SurrealDBTestSuite) {
connErr := <-db.ExitSignal
if connErr != nil {
if connErr == conn.ErrConnectionClosed {
return // Connection closed is expected
}
fmt.Println(connErr)
require.NoError(s.T(), connErr)
}
}(s)

s.Require().NoError(err)
return db
}
Expand Down
5 changes: 2 additions & 3 deletions internal/mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ import (
"github.com/surrealdb/surrealdb.go/pkg/model"
)

type ws struct {
}
type ws struct{}

func (w *ws) Connect(url string) (conn.Connection, error) {
func (w *ws) Connect(url string, exitSignal chan error) (conn.Connection, error) {
return w, nil
}

Expand Down
16 changes: 14 additions & 2 deletions pkg/conn/conn.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
package conn

import "github.com/surrealdb/surrealdb.go/pkg/model"
import (
"errors"
"io"

"github.com/surrealdb/surrealdb.go/pkg/model"
)

type Connection interface {
Connect(url string) (Connection, error)
Connect(url string, exitChannel chan error) (Connection, error)
Send(method string, params []interface{}) (interface{}, error)
Close() error
LiveNotifications(id string) (chan model.Notification, error)
}

var (
ErrConnectionClosed = errors.New("connection closed")
ErrConnectionNotConnected = errors.New("connection not connected")
ErrTimeout = errors.New("timeout")
ErrClosedPipe = io.ErrClosedPipe
)
42 changes: 23 additions & 19 deletions pkg/conn/gorilla/gorilla.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func Create() *WebSocket {
}
}

func (ws *WebSocket) Connect(url string) (conn.Connection, error) {
func (ws *WebSocket) Connect(url string, exitSignal chan error) (conn.Connection, error) {
dialer := gorilla.DefaultDialer
dialer.EnableCompression = true

Expand All @@ -73,7 +73,8 @@ func (ws *WebSocket) Connect(url string) (conn.Connection, error) {
}
}

ws.initialize()
go ws.initialize(exitSignal)

return ws, nil
}

Expand Down Expand Up @@ -234,26 +235,29 @@ func (ws *WebSocket) write(v interface{}) error {
return ws.Conn.WriteMessage(gorilla.TextMessage, data)
}

func (ws *WebSocket) initialize() {
go func() {
for {
select {
case <-ws.close:
return
default:
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
ws.logger.Error(err.Error())
continue
func (ws *WebSocket) initialize(exitSignal chan error) {
for {
select {
case <-ws.close:
exitSignal <- conn.ErrConnectionClosed
default:
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
// this needed because gorilla not shudown gracefully
if errors.Is(err, net.ErrClosed) {
exitSignal <- conn.ErrConnectionClosed
}
go ws.handleResponse(res)
// returns error if connection drop in fly
if gorilla.IsUnexpectedCloseError(err) {
exitSignal <- conn.ErrClosedPipe
}
ws.logger.Error(err.Error())
continue
}
go ws.handleResponse(res)
}
}()
}
}

func (ws *WebSocket) handleResponse(res rpc.RPCResponse) {
Expand Down

0 comments on commit 908a377

Please sign in to comment.