Skip to content

Commit

Permalink
Merge pull request #424 from zenhack/420
Browse files Browse the repository at this point in the history
Fix #420
  • Loading branch information
zenhack authored Jan 13, 2023
2 parents 9ea6bb7 + 2877e72 commit 5022e46
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 48 deletions.
22 changes: 17 additions & 5 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,8 @@ func (r Recv) AllocResults(sz ObjectSize) (Struct, error) {
// Return ends the method call successfully, releasing the arguments.
func (r Recv) Return() {
r.ReleaseArgs()
r.Returner.Return(nil)
r.Returner.PrepareReturn(nil)
r.Returner.Return()
}

// Reject ends the method call with an error, releasing the arguments.
Expand All @@ -866,7 +867,8 @@ func (r Recv) Reject(e error) {
panic("Reject(nil)")
}
r.ReleaseArgs()
r.Returner.Return(e)
r.Returner.PrepareReturn(e)
r.Returner.Return()
}

// A Returner allocates and sends the results from a received
Expand All @@ -879,13 +881,23 @@ type Returner interface {
// ReleaseResults is called.
AllocResults(sz ObjectSize) (Struct, error)

// Return resolves the method call successfully if e is nil, or failure
// otherwise. Return must be called once.
// PrepareReturn finalizes the return message. The method call will
// resolve successfully if e is nil, or otherwise it will return an
// exception to the caller.
//
// PrepareReturn must be called once.
//
// After PrepareReturn is invoked, no goroutine may modify the message
// containing the results.
PrepareReturn(e error)

// Return resolves the method call, using the results finalized in
// PrepareReturn. Return must be called once.
//
// Return must wait for all ongoing pipelined calls to be delivered,
// and after it returns, no new calls can be sent to the PipelineCaller
// returned from Recv.
Return(e error)
Return()

// ReleaseResults relinquishes the caller's access to the message
// containing the results; once this is called the message may be
Expand Down
7 changes: 5 additions & 2 deletions capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,14 @@ func (dr *dummyReturner) AllocResults(sz ObjectSize) (Struct, error) {
return dr.s, err
}

func (dr *dummyReturner) Return(e error) {
dr.returned = true
func (dr *dummyReturner) PrepareReturn(e error) {
dr.err = e
}

func (dr *dummyReturner) Return() {
dr.returned = true
}

func (dr *dummyReturner) ReleaseResults() {
}

Expand Down
98 changes: 64 additions & 34 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,27 +180,33 @@ func (ans *answer) setBootstrap(c capnp.Client) error {
return nil
}

// Return sends the return message.
//
// The caller MUST NOT hold ans.c.lk.
func (ans *answer) Return(e error) {
// PrepareReturn implements capnp.Returner.PrepareReturn
func (ans *answer) PrepareReturn(e error) {
rl := &releaseList{}
defer rl.Release()

defer ans.pcalls.Wait()
ans.c.withLocked(func(c *lockedConn) {
if e == nil {
ans.prepareSendReturn(c, rl)
} else {
ans.prepareSendException(c, rl, e)
}
})
}

if e != nil {
ans.c.withLocked(func(c *lockedConn) {
ans.sendException(c, rl, e)
})
ans.c.tasks.Done() // added by handleCall
return
}
// Return implements capnp.Returner.Return
func (ans *answer) Return() {
rl := &releaseList{}
defer rl.Release()
defer ans.pcalls.Wait()

var err error

ans.c.withLocked(func(c *lockedConn) {
err = ans.sendReturn(c, rl)
if ans.err == nil {
err = ans.completeSendReturn(c, rl)
} else {
ans.completeSendException(c, rl)
}
})
ans.c.tasks.Done() // added by handleCall

Expand Down Expand Up @@ -230,10 +236,12 @@ func (ans *answer) ReleaseResults() {
//
// sendReturn MUST NOT be called if sendException was previously called.
func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error {
c.assertIs(ans.c)
ans.prepareSendReturn(c, rl)
return ans.completeSendReturn(c, rl)
}

ans.pcall = nil
ans.flags |= resultsReady
func (ans *answer) prepareSendReturn(c *lockedConn, rl *releaseList) {
c.assertIs(ans.c)

var err error
ans.exportRefs, err = c.fillPayloadCapTable(ans.results)
Expand All @@ -248,8 +256,19 @@ func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error {
case <-ans.c.bgctx.Done():
// We're not going to send the message after all, so don't forget to release it.
ans.msgReleaser.Decr()
ans.sendMsg = nil
default:
fin := ans.flags.Contains(finishReceived)
}
}

func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error {
c.assertIs(ans.c)

ans.pcall = nil
ans.flags |= resultsReady

fin := ans.flags.Contains(finishReceived)
if ans.sendMsg != nil {
if ans.promise != nil {
if fin {
// Can't use ans.result after a finish, but it's
Expand All @@ -263,48 +282,59 @@ func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error {
ans.promise = nil
}
ans.sendMsg()
if fin {
return ans.destroy(c, rl)
}
}

ans.flags |= returnSent
if !ans.flags.Contains(finishReceived) {
return nil
if fin {
return ans.destroy(c, rl)
}
return ans.destroy(c, rl)
return nil
}

// sendException sends an exception on the answer's return message.
//
// The caller MUST be holding onto ans.c.lk. sendException MUST NOT
// be called if sendReturn was previously called.
func (ans *answer) sendException(c *lockedConn, rl *releaseList, ex error) {
c.assertIs(ans.c)
ans.prepareSendException(c, rl, ex)
ans.completeSendException(c, rl)
}

func (ans *answer) prepareSendException(c *lockedConn, rl *releaseList, ex error) {
c.assertIs(ans.c)
ans.err = ex
ans.pcall = nil
ans.flags |= resultsReady

if ans.promise != nil {
ans.promise.Reject(ex)
ans.promise = nil
}

select {
case <-ans.c.bgctx.Done():
default:
// Send exception.
if e, err := ans.ret.NewException(); err != nil {
ans.c.er.ReportError(exc.WrapError("send exception", err))
ans.sendMsg = nil
} else {
e.SetType(rpccp.Exception_Type(exc.TypeOf(ex)))
if err := e.SetReason(ex.Error()); err != nil {
ans.c.er.ReportError(exc.WrapError("send exception", err))
} else {
ans.sendMsg()
ans.sendMsg = nil
}
}
}
}

func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) {
c.assertIs(ans.c)

ex := ans.err
ans.pcall = nil
ans.flags |= resultsReady

if ans.promise != nil {
ans.promise.Reject(ex)
ans.promise = nil
}
if ans.sendMsg != nil {
ans.sendMsg()
}
ans.flags |= returnSent
if ans.flags.Contains(finishReceived) {
// destroy will never return an error because sendException does
Expand Down
12 changes: 8 additions & 4 deletions rpc/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,23 @@ func returnAnswer(ret capnp.Returner, ans *capnp.Answer, finish func()) {
defer ret.ReleaseResults()
result, err := ans.Struct()
if err != nil {
ret.Return(err)
ret.PrepareReturn(err)
ret.Return()
return
}
recvResult, err := ret.AllocResults(result.Size())
if err != nil {
ret.Return(err)
ret.PrepareReturn(err)
ret.Return()
return
}
if err := recvResult.CopyFrom(result); err != nil {
ret.Return(err)
ret.PrepareReturn(err)
ret.Return()
return
}
ret.Return(nil)
ret.PrepareReturn(nil)
ret.Return()
}

func (ic *importClient) Brand() capnp.Brand {
Expand Down
9 changes: 7 additions & 2 deletions server/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,27 @@ func (sr *structReturner) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error
sr.msg = s.Message()
return s, nil
}
func (sr *structReturner) PrepareReturn(e error) {
sr.mu.Lock()
defer sr.mu.Unlock()
sr.err = e
}

func (sr *structReturner) Return(e error) {
func (sr *structReturner) Return() {
sr.mu.Lock()
if sr.returned {
sr.mu.Unlock()
panic("structReturner.Return called twice")
}
sr.returned = true
e := sr.err
if e == nil {
sr.mu.Unlock()
if sr.p != nil {
sr.p.Fulfill(sr.result.ToPtr())
}
} else {
sr.result = capnp.Struct{}
sr.err = e
sr.mu.Unlock()
if sr.p != nil {
sr.p.Reject(e)
Expand Down
3 changes: 2 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,13 @@ func (srv *Server) handleCall(ctx context.Context, c *Call) {
err := c.method.Impl(ctx, c)

c.recv.ReleaseArgs()
c.recv.Returner.Return(err)
c.recv.Returner.PrepareReturn(err)
if err == nil {
c.aq.fulfill(c.results)
} else {
c.aq.reject(err)
}
c.recv.Returner.Return()
c.recv.Returner.ReleaseResults()
}

Expand Down

0 comments on commit 5022e46

Please sign in to comment.