diff --git a/capability.go b/capability.go index e7b6857f..eb0dae28 100644 --- a/capability.go +++ b/capability.go @@ -857,7 +857,8 @@ func (r Recv) AllocResults(sz ObjectSize) (Struct, error) { // Return ends the method call successfully, releasing the arguments. func (r Recv) Return() { r.ReleaseArgs() - r.Returner.Return(nil) + r.Returner.PrepareReturn(nil) + r.Returner.Return() } // Reject ends the method call with an error, releasing the arguments. @@ -866,7 +867,8 @@ func (r Recv) Reject(e error) { panic("Reject(nil)") } r.ReleaseArgs() - r.Returner.Return(e) + r.Returner.PrepareReturn(e) + r.Returner.Return() } // A Returner allocates and sends the results from a received @@ -879,13 +881,23 @@ type Returner interface { // ReleaseResults is called. AllocResults(sz ObjectSize) (Struct, error) - // Return resolves the method call successfully if e is nil, or failure - // otherwise. Return must be called once. + // PrepareReturn finalizes the return message. The method call will + // resolve successfully if e is nil, or otherwise it will return an + // exception to the caller. + // + // PrepareReturn must be called once. + // + // After PrepareReturn is invoked, no goroutine may modify the message + // containing the results. + PrepareReturn(e error) + + // Return resolves the method call, using the results finalized in + // PrepareReturn. Return must be called once. // // Return must wait for all ongoing pipelined calls to be delivered, // and after it returns, no new calls can be sent to the PipelineCaller // returned from Recv. - Return(e error) + Return() // ReleaseResults relinquishes the caller's access to the message // containing the results; once this is called the message may be diff --git a/capability_test.go b/capability_test.go index f6987a4d..dce001d4 100644 --- a/capability_test.go +++ b/capability_test.go @@ -306,11 +306,14 @@ func (dr *dummyReturner) AllocResults(sz ObjectSize) (Struct, error) { return dr.s, err } -func (dr *dummyReturner) Return(e error) { - dr.returned = true +func (dr *dummyReturner) PrepareReturn(e error) { dr.err = e } +func (dr *dummyReturner) Return() { + dr.returned = true +} + func (dr *dummyReturner) ReleaseResults() { } diff --git a/rpc/answer.go b/rpc/answer.go index d8fb8746..c2fe6f16 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -180,27 +180,33 @@ func (ans *answer) setBootstrap(c capnp.Client) error { return nil } -// Return sends the return message. -// -// The caller MUST NOT hold ans.c.lk. -func (ans *answer) Return(e error) { +// PrepareReturn implements capnp.Returner.PrepareReturn +func (ans *answer) PrepareReturn(e error) { rl := &releaseList{} defer rl.Release() - defer ans.pcalls.Wait() + ans.c.withLocked(func(c *lockedConn) { + if e == nil { + ans.prepareSendReturn(c, rl) + } else { + ans.prepareSendException(c, rl, e) + } + }) +} - if e != nil { - ans.c.withLocked(func(c *lockedConn) { - ans.sendException(c, rl, e) - }) - ans.c.tasks.Done() // added by handleCall - return - } +// Return implements capnp.Returner.Return +func (ans *answer) Return() { + rl := &releaseList{} + defer rl.Release() + defer ans.pcalls.Wait() var err error - ans.c.withLocked(func(c *lockedConn) { - err = ans.sendReturn(c, rl) + if ans.err == nil { + err = ans.completeSendReturn(c, rl) + } else { + ans.completeSendException(c, rl) + } }) ans.c.tasks.Done() // added by handleCall @@ -230,10 +236,12 @@ func (ans *answer) ReleaseResults() { // // sendReturn MUST NOT be called if sendException was previously called. func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error { - c.assertIs(ans.c) + ans.prepareSendReturn(c, rl) + return ans.completeSendReturn(c, rl) +} - ans.pcall = nil - ans.flags |= resultsReady +func (ans *answer) prepareSendReturn(c *lockedConn, rl *releaseList) { + c.assertIs(ans.c) var err error ans.exportRefs, err = c.fillPayloadCapTable(ans.results) @@ -248,8 +256,19 @@ func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error { case <-ans.c.bgctx.Done(): // We're not going to send the message after all, so don't forget to release it. ans.msgReleaser.Decr() + ans.sendMsg = nil default: - fin := ans.flags.Contains(finishReceived) + } +} + +func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error { + c.assertIs(ans.c) + + ans.pcall = nil + ans.flags |= resultsReady + + fin := ans.flags.Contains(finishReceived) + if ans.sendMsg != nil { if ans.promise != nil { if fin { // Can't use ans.result after a finish, but it's @@ -263,15 +282,13 @@ func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error { ans.promise = nil } ans.sendMsg() - if fin { - return ans.destroy(c, rl) - } } + ans.flags |= returnSent - if !ans.flags.Contains(finishReceived) { - return nil + if fin { + return ans.destroy(c, rl) } - return ans.destroy(c, rl) + return nil } // sendException sends an exception on the answer's return message. @@ -279,16 +296,13 @@ func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error { // The caller MUST be holding onto ans.c.lk. sendException MUST NOT // be called if sendReturn was previously called. func (ans *answer) sendException(c *lockedConn, rl *releaseList, ex error) { - c.assertIs(ans.c) + ans.prepareSendException(c, rl, ex) + ans.completeSendException(c, rl) +} +func (ans *answer) prepareSendException(c *lockedConn, rl *releaseList, ex error) { + c.assertIs(ans.c) ans.err = ex - ans.pcall = nil - ans.flags |= resultsReady - - if ans.promise != nil { - ans.promise.Reject(ex) - ans.promise = nil - } select { case <-ans.c.bgctx.Done(): @@ -296,15 +310,31 @@ func (ans *answer) sendException(c *lockedConn, rl *releaseList, ex error) { // Send exception. if e, err := ans.ret.NewException(); err != nil { ans.c.er.ReportError(exc.WrapError("send exception", err)) + ans.sendMsg = nil } else { e.SetType(rpccp.Exception_Type(exc.TypeOf(ex))) if err := e.SetReason(ex.Error()); err != nil { ans.c.er.ReportError(exc.WrapError("send exception", err)) - } else { - ans.sendMsg() + ans.sendMsg = nil } } } +} + +func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) { + c.assertIs(ans.c) + + ex := ans.err + ans.pcall = nil + ans.flags |= resultsReady + + if ans.promise != nil { + ans.promise.Reject(ex) + ans.promise = nil + } + if ans.sendMsg != nil { + ans.sendMsg() + } ans.flags |= returnSent if ans.flags.Contains(finishReceived) { // destroy will never return an error because sendException does diff --git a/rpc/import.go b/rpc/import.go index 37b93673..736183fc 100644 --- a/rpc/import.go +++ b/rpc/import.go @@ -193,19 +193,23 @@ func returnAnswer(ret capnp.Returner, ans *capnp.Answer, finish func()) { defer ret.ReleaseResults() result, err := ans.Struct() if err != nil { - ret.Return(err) + ret.PrepareReturn(err) + ret.Return() return } recvResult, err := ret.AllocResults(result.Size()) if err != nil { - ret.Return(err) + ret.PrepareReturn(err) + ret.Return() return } if err := recvResult.CopyFrom(result); err != nil { - ret.Return(err) + ret.PrepareReturn(err) + ret.Return() return } - ret.Return(nil) + ret.PrepareReturn(nil) + ret.Return() } func (ic *importClient) Brand() capnp.Brand { diff --git a/server/answer.go b/server/answer.go index 763df7ac..1deace10 100644 --- a/server/answer.go +++ b/server/answer.go @@ -207,14 +207,20 @@ func (sr *structReturner) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error sr.msg = s.Message() return s, nil } +func (sr *structReturner) PrepareReturn(e error) { + sr.mu.Lock() + defer sr.mu.Unlock() + sr.err = e +} -func (sr *structReturner) Return(e error) { +func (sr *structReturner) Return() { sr.mu.Lock() if sr.returned { sr.mu.Unlock() panic("structReturner.Return called twice") } sr.returned = true + e := sr.err if e == nil { sr.mu.Unlock() if sr.p != nil { @@ -222,7 +228,6 @@ func (sr *structReturner) Return(e error) { } } else { sr.result = capnp.Struct{} - sr.err = e sr.mu.Unlock() if sr.p != nil { sr.p.Reject(e) diff --git a/server/server.go b/server/server.go index 5a7b6229..0f29334e 100644 --- a/server/server.go +++ b/server/server.go @@ -216,12 +216,13 @@ func (srv *Server) handleCall(ctx context.Context, c *Call) { err := c.method.Impl(ctx, c) c.recv.ReleaseArgs() - c.recv.Returner.Return(err) + c.recv.Returner.PrepareReturn(err) if err == nil { c.aq.fulfill(c.results) } else { c.aq.reject(err) } + c.recv.Returner.Return() c.recv.Returner.ReleaseResults() }