Skip to content

Commit

Permalink
Split Returner.Return in two.
Browse files Browse the repository at this point in the history
See the comments. No functional change yet, since all call sites
currently call PrepareReturn and Return in immediate succession, but
presently this separation will be used to fix a race condition.
  • Loading branch information
zenhack committed Jan 13, 2023
1 parent b6cb9d1 commit d93f6b2
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 28 deletions.
20 changes: 15 additions & 5 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,8 @@ func (r Recv) AllocResults(sz ObjectSize) (Struct, error) {
// Return ends the method call successfully, releasing the arguments.
func (r Recv) Return() {
r.ReleaseArgs()
r.Returner.Return(nil)
r.Returner.PrepareReturn(nil)
r.Returner.Return()
}

// Reject ends the method call with an error, releasing the arguments.
Expand All @@ -866,7 +867,8 @@ func (r Recv) Reject(e error) {
panic("Reject(nil)")
}
r.ReleaseArgs()
r.Returner.Return(e)
r.Returner.PrepareReturn(e)
r.Returner.Return()
}

// A Returner allocates and sends the results from a received
Expand All @@ -879,13 +881,21 @@ type Returner interface {
// ReleaseResults is called.
AllocResults(sz ObjectSize) (Struct, error)

// Return resolves the method call successfully if e is nil, or failure
// otherwise. Return must be called once.
// PrepareReturn finalizes the return message. The method call will
// resolve successfully if e is nil, or otherwise it will fail.
// PrepareReturn must be called once.
//
// After PrepareReturn is invoked, no goroutine may modify the message
// containing the results.
PrepareReturn(e error)

// Return resolves the method call, using the results finalized in
// PrepareReturn. Return must be called once.
//
// Return must wait for all ongoing pipelined calls to be delivered,
// and after it returns, no new calls can be sent to the PipelineCaller
// returned from Recv.
Return(e error)
Return()

// ReleaseResults relinquishes the caller's access to the message
// containing the results; once this is called the message may be
Expand Down
7 changes: 5 additions & 2 deletions capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,14 @@ func (dr *dummyReturner) AllocResults(sz ObjectSize) (Struct, error) {
return dr.s, err
}

func (dr *dummyReturner) Return(e error) {
dr.returned = true
func (dr *dummyReturner) PrepareReturn(e error) {
dr.err = e
}

func (dr *dummyReturner) Return() {
dr.returned = true
}

func (dr *dummyReturner) ReleaseResults() {
}

Expand Down
34 changes: 20 additions & 14 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,27 +180,33 @@ func (ans *answer) setBootstrap(c capnp.Client) error {
return nil
}

// Return sends the return message.
//
// The caller MUST NOT hold ans.c.lk.
func (ans *answer) Return(e error) {
// PrepareReturn implements capnp.Returner.PrepareReturn
func (ans *answer) PrepareReturn(e error) {
rl := &releaseList{}
defer rl.Release()

defer ans.pcalls.Wait()
ans.c.withLocked(func(c *lockedConn) {
if e == nil {
ans.prepareSendReturn(c, rl)
} else {
ans.prepareSendException(c, rl, e)
}
})
}

if e != nil {
ans.c.withLocked(func(c *lockedConn) {
ans.sendException(c, rl, e)
})
ans.c.tasks.Done() // added by handleCall
return
}
// Return implements capnp.Returner.Return
func (ans *answer) Return() {
rl := &releaseList{}
defer rl.Release()
defer ans.pcalls.Wait()

var err error

ans.c.withLocked(func(c *lockedConn) {
err = ans.sendReturn(c, rl)
if ans.err == nil {
err = ans.completeSendReturn(c, rl)
} else {
ans.completeSendException(c, rl)
}
})
ans.c.tasks.Done() // added by handleCall

Expand Down
12 changes: 8 additions & 4 deletions rpc/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,23 @@ func returnAnswer(ret capnp.Returner, ans *capnp.Answer, finish func()) {
defer ret.ReleaseResults()
result, err := ans.Struct()
if err != nil {
ret.Return(err)
ret.PrepareReturn(err)
ret.Return()
return
}
recvResult, err := ret.AllocResults(result.Size())
if err != nil {
ret.Return(err)
ret.PrepareReturn(err)
ret.Return()
return
}
if err := recvResult.CopyFrom(result); err != nil {
ret.Return(err)
ret.PrepareReturn(err)
ret.Return()
return
}
ret.Return(nil)
ret.PrepareReturn(nil)
ret.Return()
}

func (ic *importClient) Brand() capnp.Brand {
Expand Down
9 changes: 7 additions & 2 deletions server/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,27 @@ func (sr *structReturner) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error
sr.msg = s.Message()
return s, nil
}
func (sr *structReturner) PrepareReturn(e error) {
sr.mu.Lock()
defer sr.mu.Unlock()
sr.err = e
}

func (sr *structReturner) Return(e error) {
func (sr *structReturner) Return() {
sr.mu.Lock()
if sr.returned {
sr.mu.Unlock()
panic("structReturner.Return called twice")
}
sr.returned = true
e := sr.err
if e == nil {
sr.mu.Unlock()
if sr.p != nil {
sr.p.Fulfill(sr.result.ToPtr())
}
} else {
sr.result = capnp.Struct{}
sr.err = e
sr.mu.Unlock()
if sr.p != nil {
sr.p.Reject(e)
Expand Down
3 changes: 2 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ func (srv *Server) handleCall(ctx context.Context, c *Call) {
err := c.method.Impl(ctx, c)

c.recv.ReleaseArgs()
c.recv.Returner.Return(err)
c.recv.Returner.PrepareReturn(err)
c.recv.Returner.Return()
if err == nil {
c.aq.fulfill(c.results)
} else {
Expand Down

0 comments on commit d93f6b2

Please sign in to comment.