Skip to content

Commit

Permalink
simplify model proxy
Browse files Browse the repository at this point in the history
- Simplified the model proxy to remove the need for a mutex.
- Handle an edge case in the model proxy where if the controller routine returned we don't stop the proxy.
  • Loading branch information
kian99 committed Sep 2, 2024
1 parent bafc43f commit fb3842d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 28 deletions.
62 changes: 62 additions & 0 deletions internal/rpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,68 @@ func TestProxySockets(t *testing.T) {
<-errChan // Ensure go routines are cleaned up
}

func TestProxySocketsControllerConnectionFails(t *testing.T) {
c := qt.New(t)
ctx := context.Background()

srvController := newServer(echo)

var connController *websocket.Conn
errChan := make(chan error)
srvJIMM := newServer(func(connClient *websocket.Conn) error {
testTokenGen := testTokenGenerator{}
f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) {
var err error
connController, err = srvController.dialer.DialWebsocket(ctx, srvController.URL)
c.Check(err, qt.IsNil)
return rpc.WebsocketConnectionWithMetadata{
Conn: connController,
ModelName: "TestName",
}, nil
}
auditLogger := func(ale *dbmodel.AuditLogEntry) {}
proxyHelpers := rpc.ProxyHelpers{
ConnClient: connClient,
TokenGen: &testTokenGen,
ConnectController: f,
AuditLog: auditLogger,
LoginService: &mockLoginService{},
}
err := rpc.ProxySockets(ctx, proxyHelpers)
c.Check(err, qt.IsNil)
errChan <- err
return err
})

defer srvController.Close()
defer srvJIMM.Close()
ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL)
c.Assert(err, qt.IsNil)
defer ws.Close()

p := json.RawMessage(`{"Key":"TestVal"}`)
msg := rpc.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p}
err = ws.WriteJSON(&msg)
c.Assert(err, qt.IsNil)
resp := rpc.Message{}
receiveChan := make(chan error)
go func() {
receiveChan <- ws.ReadJSON(&resp)
}()
select {
case err := <-receiveChan:
c.Assert(err, qt.IsNil)
case <-time.After(5 * time.Second):
c.Logf("took too long to read response")
c.FailNow()
}
c.Assert(resp.Response, qt.DeepEquals, msg.Params)

// Now close the connection to the controller and ensure the model proxy is cleaned up.
connController.Close()
<-errChan // Ensure go routines are cleaned up
}

func TestCancelProxySockets(t *testing.T) {
c := qt.New(t)

Expand Down
36 changes: 8 additions & 28 deletions internal/rpc/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,18 @@ func ProxySockets(ctx context.Context, helpers ProxyHelpers) error {
}()
var err error
select {
// No cleanup is needed on error, when the client closes the connection
// all go routines will proceed to error and exit.
case err = <-errChan:
if err != nil {
zapctx.Debug(ctx, "Proxy error", zap.Error(err))
}
case <-ctx.Done():
err = errors.E(op, "Context cancelled")
zapctx.Debug(ctx, "Context cancelled")
helpers.ConnClient.Close()
clProxy.mu.Lock()
clProxy.closed = true
// TODO(Kian): Test removing close on dst below. The client connection should do it.
if clProxy.dst != nil {
clProxy.dst.conn.Close()
}
clProxy.mu.Unlock()
}
// Close the client connection to ensure everything is cleaned up.
// Normally the client would do this but we also do it here in case the
// connection to the controller fails and we want to trigger cleanup.
helpers.ConnClient.Close()
clProxy.wg.Wait()
return err
}
Expand Down Expand Up @@ -324,10 +318,6 @@ type clientProxy struct {
wg sync.WaitGroup
errChan chan error
createControllerConn func(context.Context) (WebsocketConnectionWithMetadata, error)
// mu synchronises changes to closed and modelproxy.dst, dst is only created
// at some unspecified point in the future after a client request.
mu sync.Mutex
closed bool
}

// start begins the client->controller proxier.
Expand All @@ -341,7 +331,7 @@ func (p *clientProxy) start(ctx context.Context) error {
zapctx.Debug(ctx, "Reading on client connection")
msg := new(message)
if err := p.src.readJson(&msg); err != nil {
// Error reading on the socket implies the client closed their connection, return without error.
//lint:ignore nilerr an error reading on the socket implies the client closed their connection, return without error.
return nil
}
zapctx.Debug(ctx, "Read message from client", zap.Any("message", msg))
Expand Down Expand Up @@ -389,23 +379,13 @@ func (p *clientProxy) start(ctx context.Context) error {
// proxying requests from the controller to the client.
func (p *clientProxy) makeControllerConnection(ctx context.Context) error {
const op = errors.Op("rpc.makeControllerConnection")
p.mu.Lock()
defer p.mu.Unlock()
if p.dst != nil {
return nil
}
// Checking closed ensures we don't have a race condition with a cancelled context.
if p.closed {
err := errors.E(op, "Client connection closed while starting controller connection")
return err
}

connWithMetadata, err := p.createControllerConn(ctx)
if err != nil {
return err
return errors.E(op, err)
}

p.msgs.controllerUUID = connWithMetadata.ControllerUUID

p.modelName = connWithMetadata.ModelName
p.dst = &writeLockConn{conn: connWithMetadata.Conn}
controllerToClient := controllerProxy{
Expand Down Expand Up @@ -439,7 +419,7 @@ func (p *controllerProxy) start(ctx context.Context) error {
zapctx.Debug(ctx, "Reading on controller connection")
msg := new(message)
if err := p.src.readJson(msg); err != nil {
// Error reading on the socket implies we've closed the connection to the controller, return without error.
//lint:ignore nilerr an error reading on the socket implies we've closed the connection to the controller, return without error.
return nil
}
zapctx.Debug(ctx, "Received message from controller", zap.Any("Message", msg))
Expand Down

0 comments on commit fb3842d

Please sign in to comment.