diff --git a/rpc/answer.go b/rpc/answer.go index 4ce4936a..a81593de 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -183,17 +183,17 @@ func (ans *answer) setBootstrap(c capnp.Client) error { // // The caller MUST NOT hold ans.c.lk. func (ans *answer) Return(e error) { + rl := &releaseList{} + defer rl.Release() + ans.c.lk.Lock() if e != nil { - rl := &releaseList{} ans.sendException(rl, e) ans.c.lk.Unlock() - rl.Release() ans.pcalls.Wait() ans.c.tasks.Done() // added by handleCall return } - rl := &releaseList{} if err := ans.sendReturn(rl); err != nil { select { case <-ans.c.bgctx.Done(): @@ -204,13 +204,11 @@ func (ans *answer) Return(e error) { } ans.c.lk.Unlock() - rl.Release() ans.pcalls.Wait() return } } ans.c.lk.Unlock() - rl.Release() ans.pcalls.Wait() ans.c.tasks.Done() // added by handleCall } diff --git a/rpc/releaselist.go b/rpc/releaselist.go index 111f8e52..fa135006 100644 --- a/rpc/releaselist.go +++ b/rpc/releaselist.go @@ -4,11 +4,12 @@ import "capnproto.org/go/capnp/v3" type releaseList []capnp.ReleaseFunc -func (rl releaseList) Release() { - for i, r := range rl { +func (rl *releaseList) Release() { + funcs := *rl + for i, r := range funcs { if r != nil { r() - rl[i] = nil + funcs[i] = nil } } } diff --git a/rpc/rpc.go b/rpc/rpc.go index f463da4b..2c672925 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -658,6 +658,9 @@ func (c *Conn) handleBootstrap(ctx context.Context, id answerID) error { } func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capnp.ReleaseFunc) error { + rl := &releaseList{} + defer rl.Release() + id := answerID(call.QuestionId()) // TODO(3rd-party handshake): support sending results to 3rd party vat @@ -722,11 +725,9 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn c.lk.answers[id] = ans if parseErr != nil { parseErr = rpcerr.Annotate(parseErr, "incoming call") - rl := &releaseList{} ans.sendException(rl, parseErr) c.lk.Unlock() c.er.ReportError(parseErr) - rl.Release() releaseCall() return nil } @@ -778,10 +779,8 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn } if tgtAns.flags.Contains(resultsReady) { if tgtAns.err != nil { - rl := &releaseList{} ans.sendException(rl, tgtAns.err) c.lk.Unlock() - rl.Release() releaseCall() return nil } @@ -792,10 +791,8 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn content, err := tgtAns.results.Content() if err != nil { err = rpcerr.Failedf("incoming call: read results from target answer: %w", err) - rl := &releaseList{} ans.sendException(rl, err) c.lk.Unlock() - rl.Release() releaseCall() c.er.ReportError(err) return nil @@ -803,10 +800,8 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn sub, err := capnp.Transform(content, p.target.transform) if err != nil { // Not reporting, as this is the caller's fault. - rl := &releaseList{} ans.sendException(rl, err) c.lk.Unlock() - rl.Release() releaseCall() return nil } @@ -1116,14 +1111,16 @@ type parsedReturn struct { } func (c *Conn) handleFinish(ctx context.Context, id answerID, releaseResultCaps bool) error { + rl := &releaseList{} + defer rl.Release() c.lk.Lock() + defer c.lk.Unlock() + ans := c.lk.answers[id] if ans == nil { - c.lk.Unlock() return rpcerr.Failedf("incoming finish: unknown answer ID %d", id) } if ans.flags.Contains(finishReceived) { - c.lk.Unlock() return rpcerr.Failedf("incoming finish: answer ID %d already received finish", id) } ans.flags |= finishReceived @@ -1134,15 +1131,11 @@ func (c *Conn) handleFinish(ctx context.Context, id answerID, releaseResultCaps ans.cancel() } if !ans.flags.Contains(returnSent) { - c.lk.Unlock() return nil } // Return sent and finish received: time to destroy answer. - rl := &releaseList{} err := ans.destroy(rl) - c.lk.Unlock() - rl.Release() if err != nil { return rpcerr.Annotate(err, "incoming finish: release result caps") }