Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup model proxy #1345

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions internal/rpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func TestProxySockets(t *testing.T) {
LoginService: &mockLoginService{},
}
err := rpc.ProxySockets(ctx, proxyHelpers)
c.Check(err, qt.ErrorMatches, "error reading from (client|controller).*")
c.Check(err, qt.IsNil)
errChan <- err
return err
})
Expand Down Expand Up @@ -298,6 +298,68 @@ func TestProxySockets(t *testing.T) {
<-errChan // Ensure go routines are cleaned up
}

func TestProxySocketsControllerConnectionFails(t *testing.T) {
c := qt.New(t)
ctx := context.Background()

srvController := newServer(echo)

var connController *websocket.Conn
errChan := make(chan error)
srvJIMM := newServer(func(connClient *websocket.Conn) error {
testTokenGen := testTokenGenerator{}
f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) {
var err error
connController, err = srvController.dialer.DialWebsocket(ctx, srvController.URL)
c.Check(err, qt.IsNil)
return rpc.WebsocketConnectionWithMetadata{
Conn: connController,
ModelName: "TestName",
}, nil
}
auditLogger := func(ale *dbmodel.AuditLogEntry) {}
proxyHelpers := rpc.ProxyHelpers{
ConnClient: connClient,
TokenGen: &testTokenGen,
ConnectController: f,
AuditLog: auditLogger,
LoginService: &mockLoginService{},
}
err := rpc.ProxySockets(ctx, proxyHelpers)
c.Check(err, qt.IsNil)
errChan <- err
return err
})

defer srvController.Close()
defer srvJIMM.Close()
ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL)
c.Assert(err, qt.IsNil)
defer ws.Close()

p := json.RawMessage(`{"Key":"TestVal"}`)
msg := rpc.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p}
err = ws.WriteJSON(&msg)
c.Assert(err, qt.IsNil)
resp := rpc.Message{}
receiveChan := make(chan error)
go func() {
receiveChan <- ws.ReadJSON(&resp)
}()
select {
case err := <-receiveChan:
c.Assert(err, qt.IsNil)
case <-time.After(5 * time.Second):
c.Logf("took too long to read response")
c.FailNow()
}
c.Assert(resp.Response, qt.DeepEquals, msg.Params)

// Now close the connection to the controller and ensure the model proxy is cleaned up.
connController.Close()
<-errChan // Ensure go routines are cleaned up
}

func TestCancelProxySockets(t *testing.T) {
c := qt.New(t)

Expand Down Expand Up @@ -368,7 +430,7 @@ func TestProxySocketsAuditLogs(t *testing.T) {
LoginService: &mockLoginService{},
}
err := rpc.ProxySockets(ctx, proxyHelpers)
c.Check(err, qt.ErrorMatches, `error reading from (client|controller).*`)
c.Check(err, qt.IsNil)
errChan <- err
return err
})
Expand Down
111 changes: 56 additions & 55 deletions internal/rpc/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"github.com/gorilla/websocket"
"github.com/juju/juju/rpc/params"
"github.com/juju/names/v5"
"github.com/juju/zaputil/zapctx"
Expand Down Expand Up @@ -121,22 +122,18 @@ func ProxySockets(ctx context.Context, helpers ProxyHelpers) error {
}()
var err error
select {
// No cleanup is needed on error, when the client closes the connection
// all go routines will proceed to error and exit.
case err = <-errChan:
zapctx.Debug(ctx, "Proxy error", zap.Error(err))
if err != nil {
zapctx.Debug(ctx, "Proxy error", zap.Error(err))
}
case <-ctx.Done():
err = errors.E(op, "Context cancelled")
zapctx.Debug(ctx, "Context cancelled")
helpers.ConnClient.Close()
clProxy.mu.Lock()
clProxy.closed = true
// TODO(Kian): Test removing close on dst below. The client connection should do it.
if clProxy.dst != nil {
clProxy.dst.conn.Close()
}
clProxy.mu.Unlock()
}
// Close the client connection to ensure everything is cleaned up.
// Normally the client would do this but we also do it here in case the
// connection to the controller fails and we want to trigger cleanup.
helpers.ConnClient.Close()
clProxy.wg.Wait()
return err
}
Expand Down Expand Up @@ -316,16 +313,22 @@ func (p *modelProxy) auditLogMessage(msg *message, isResponse bool) error {
return nil
}

func unexpectedReadError(err error) bool {
closeError := websocket.IsUnexpectedCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseNoStatusReceived,
websocket.CloseAbnormalClosure)
_, unmarshalError := err.(*json.InvalidUnmarshalError)
return closeError || unmarshalError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this logic, don't do this. Just check if we have a close error, if we do return else check a marshal error. The name is also confusing, as this is an unexpected closure versus unmarshalling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix this in a follow-up.

}

// clientProxy proxies messages from client->controller.
type clientProxy struct {
modelProxy
wg sync.WaitGroup
errChan chan error
createControllerConn func(context.Context) (WebsocketConnectionWithMetadata, error)
// mu synchronises changes to closed and modelproxy.dst, dst is is only created
// at some unspecified point in the future after a client request.
mu sync.Mutex
closed bool
connectController sync.Once
kian99 marked this conversation as resolved.
Show resolved Hide resolved
}

// start begins the client->controller proxier.
Expand All @@ -339,8 +342,11 @@ func (p *clientProxy) start(ctx context.Context) error {
zapctx.Debug(ctx, "Reading on client connection")
msg := new(message)
if err := p.src.readJson(&msg); err != nil {
// Error reading on the socket implies it is closed, simply return.
return fmt.Errorf("error reading from client: %w", err)
if unexpectedReadError(err) {
zapctx.Error(ctx, "unexpected client read error", zap.Error(err))
return err
}
return nil
}
zapctx.Debug(ctx, "Read message from client", zap.Any("message", msg))
err := p.makeControllerConnection(ctx)
Expand Down Expand Up @@ -387,43 +393,35 @@ func (p *clientProxy) start(ctx context.Context) error {
// proxying requests from the controller to the client.
func (p *clientProxy) makeControllerConnection(ctx context.Context) error {
const op = errors.Op("rpc.makeControllerConnection")
p.mu.Lock()
defer p.mu.Unlock()
if p.dst != nil {
return nil
}
// Checking closed ensures we don't have a race condition with a cancelled context.
if p.closed {
err := errors.E(op, "Client connection closed while starting controller connection")
return err
}
connWithMetadata, err := p.createControllerConn(ctx)
if err != nil {
return err
}

p.msgs.controllerUUID = connWithMetadata.ControllerUUID
var createConnErr error
// Create the controller connection once.
p.connectController.Do(func() {
connWithMetadata, err := p.createControllerConn(ctx)
if err != nil {
createConnErr = errors.E(op, err)
}

p.modelName = connWithMetadata.ModelName
p.dst = &writeLockConn{conn: connWithMetadata.Conn}
controllerToClient := controllerProxy{
modelProxy: modelProxy{
src: p.dst,
dst: p.src,
msgs: p.msgs,
auditLog: p.auditLog,
tokenGen: p.tokenGen,
modelName: p.modelName,
conversationId: p.conversationId,
},
}
p.wg.Add(1)
go func() {
defer p.wg.Done()
p.errChan <- controllerToClient.start(ctx)
}()
zapctx.Debug(ctx, "Successfully made controller connection")
return nil
p.msgs.controllerUUID = connWithMetadata.ControllerUUID
p.modelName = connWithMetadata.ModelName
p.dst = &writeLockConn{conn: connWithMetadata.Conn}
controllerToClient := controllerProxy{
modelProxy: modelProxy{
src: p.dst,
dst: p.src,
msgs: p.msgs,
auditLog: p.auditLog,
tokenGen: p.tokenGen,
modelName: p.modelName,
conversationId: p.conversationId,
},
}
p.wg.Add(1)
go func() {
defer p.wg.Done()
p.errChan <- controllerToClient.start(ctx)
}()
})
return createConnErr
}

// controllerProxy proxies messages from controller->client with the caveat that
Expand All @@ -438,8 +436,11 @@ func (p *controllerProxy) start(ctx context.Context) error {
zapctx.Debug(ctx, "Reading on controller connection")
msg := new(message)
if err := p.src.readJson(msg); err != nil {
// Error reading on the socket implies it is closed, simply return.
return fmt.Errorf("error reading from controller: %w", err)
if unexpectedReadError(err) {
zapctx.Error(ctx, "unexpected controller read error", zap.Error(err))
return err
}
return nil
}
zapctx.Debug(ctx, "Received message from controller", zap.Any("Message", msg))
permissionsRequired, err := checkPermissionsRequired(ctx, msg)
Expand Down
Loading