From 1503c9585628322addb7899fe206b13753eebd0f Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Thu, 15 Dec 2022 22:49:18 -0500 Subject: [PATCH] Fix #382 This reworks the locking such that newPipelineCallMessage & newImportCallMessage expect the caller to already be holding the connection lock, and then removes the call to syncutil.Without in order to comply with sendMessage's contract. There is a big downside here: it means that it is no longer OK for PlaceArgs to call into the RPC subsystem, as it could cause a deadlock. We're planning on reworking this interface to get rid of PlaceArgs anyway (see #64), and this kind of code is not that common and generally easy to avoid, so in the interest of finishing out one batch of refactoring before starting the next, I am of the opinion that this is probably the right thing to do. --- rpc/import.go | 48 ++++++++++++++++++++++------------------------- rpc/question.go | 50 +++++++++++++++++++++++-------------------------- 2 files changed, 45 insertions(+), 53 deletions(-) diff --git a/rpc/import.go b/rpc/import.go index 8b6fbef0..d8c381f3 100644 --- a/rpc/import.go +++ b/rpc/import.go @@ -98,27 +98,25 @@ func (ic *importClient) Send(ctx context.Context, s capnp.Send) (*capnp.Answer, q := ic.c.newQuestion(s.Method) // Send call message. - syncutil.Without(&ic.c.lk, func() { - ic.c.sendMessage(ctx, func(m rpccp.Message) error { - return ic.c.newImportCallMessage(m, ic.id, q.id, s) - }, func(err error) { - if err != nil { - syncutil.With(&ic.c.lk, func() { - ic.c.lk.questions[q.id] = nil - }) - q.p.Reject(rpcerr.Failedf("send message: %w", err)) - syncutil.With(&ic.c.lk, func() { - ic.c.lk.questionID.remove(uint32(q.id)) - }) - return - } - - q.c.tasks.Add(1) - go func() { - defer q.c.tasks.Done() - q.handleCancel(ctx) - }() - }) + ic.c.sendMessage(ctx, func(m rpccp.Message) error { + return ic.c.newImportCallMessage(m, ic.id, q.id, s) + }, func(err error) { + if err != nil { + syncutil.With(&ic.c.lk, func() { + ic.c.lk.questions[q.id] = nil + }) + q.p.Reject(rpcerr.Failedf("send message: %w", err)) + syncutil.With(&ic.c.lk, func() { + ic.c.lk.questionID.remove(uint32(q.id)) + }) + return + } + + q.c.tasks.Add(1) + go func() { + defer q.c.tasks.Done() + q.handleCancel(ctx) + }() }) ans := q.p.Answer() @@ -131,7 +129,7 @@ func (ic *importClient) Send(ctx context.Context, s capnp.Send) (*capnp.Answer, // newImportCallMessage builds a Call message targeted to an import. // -// The caller MUST NOT hold c.mu. +// The caller MUST hold c.mu. func (c *Conn) newImportCallMessage(msg rpccp.Message, imp importID, qid questionID, s capnp.Send) error { call, err := msg.NewCall() if err != nil { @@ -163,10 +161,8 @@ func (c *Conn) newImportCallMessage(msg rpccp.Message, imp importID, qid questio if err := s.PlaceArgs(args); err != nil { return rpcerr.Failedf("place arguments: %w", err) } - syncutil.With(&c.lk, func() { - // TODO(soon): save param refs - _, err = c.fillPayloadCapTable(payload) - }) + // TODO(soon): save param refs + _, err = c.fillPayloadCapTable(payload) if err != nil { return rpcerr.Annotatef(err, "build call message") } diff --git a/rpc/question.go b/rpc/question.go index b8a1015d..db955a1e 100644 --- a/rpc/question.go +++ b/rpc/question.go @@ -151,28 +151,26 @@ func (q *question) PipelineSend(ctx context.Context, transform []capnp.PipelineO q.mark(transform) q2 := q.c.newQuestion(s.Method) - syncutil.Without(&q.c.lk, func() { - // Send call message. - q.c.sendMessage(ctx, func(m rpccp.Message) error { - return q.c.newPipelineCallMessage(m, q.id, transform, q2.id, s) - }, func(err error) { - if err != nil { - syncutil.With(&q.c.lk, func() { - q.c.lk.questions[q2.id] = nil - }) - q2.p.Reject(rpcerr.Failedf("send message: %w", err)) - syncutil.With(&q.c.lk, func() { - q.c.lk.questionID.remove(uint32(q2.id)) - }) - return - } - - q2.c.tasks.Add(1) - go func() { - defer q2.c.tasks.Done() - q2.handleCancel(ctx) - }() - }) + // Send call message. + q.c.sendMessage(ctx, func(m rpccp.Message) error { + return q.c.newPipelineCallMessage(m, q.id, transform, q2.id, s) + }, func(err error) { + if err != nil { + syncutil.With(&q.c.lk, func() { + q.c.lk.questions[q2.id] = nil + }) + q2.p.Reject(rpcerr.Failedf("send message: %w", err)) + syncutil.With(&q.c.lk, func() { + q.c.lk.questionID.remove(uint32(q2.id)) + }) + return + } + + q2.c.tasks.Add(1) + go func() { + defer q2.c.tasks.Done() + q2.handleCancel(ctx) + }() }) ans := q2.p.Answer() @@ -185,7 +183,7 @@ func (q *question) PipelineSend(ctx context.Context, transform []capnp.PipelineO // newPipelineCallMessage builds a Call message targeted to a promised answer.. // -// The caller MUST NOT hold c.mu. +// The caller MUST hold c.mu. func (c *Conn) newPipelineCallMessage(msg rpccp.Message, tgt questionID, transform []capnp.PipelineOp, qid questionID, s capnp.Send) error { call, err := msg.NewCall() if err != nil { @@ -230,10 +228,8 @@ func (c *Conn) newPipelineCallMessage(msg rpccp.Message, tgt questionID, transfo if err := s.PlaceArgs(args); err != nil { return rpcerr.Failedf("place arguments: %w", err) } - syncutil.With(&c.lk, func() { - // TODO(soon): save param refs - _, err = c.fillPayloadCapTable(payload) - }) + // TODO(soon): save param refs + _, err = c.fillPayloadCapTable(payload) if err != nil { return rpcerr.Annotatef(err, "build call message")