Skip to content

Commit

Permalink
feat(broadcast): enabled cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
aleksander-vedvik committed May 8, 2024
1 parent cc75006 commit 9e09109
Show file tree
Hide file tree
Showing 19 changed files with 278 additions and 162 deletions.
5 changes: 3 additions & 2 deletions broadcast.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ func newBroadcastServer(logger *slog.Logger, withMetrics bool) *broadcastServer
if withMetrics {
//m = broadcast.NewMetric()
}
return &broadcastServer{
manager: broadcast.NewBroadcastManager(logger, m, createClient),
srv := &broadcastServer{
logger: logger,
metrics: m,
}
srv.manager = broadcast.NewBroadcastManager(logger, m, createClient, srv.canceler)
return srv
}

//func newBroadcastServer(logger *slog.Logger, withMetrics bool) *broadcastServer {
Expand Down
35 changes: 26 additions & 9 deletions broadcast/manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package broadcast

import (
"context"
"errors"
"log/slog"
"time"
Expand All @@ -10,9 +11,10 @@ import (
)

type Manager interface {
Process(Content) error
Process(Content) (error, context.Context)
Broadcast(uint64, protoreflect.ProtoMessage, string, ...BroadcastOptions)
SendToClient(uint64, protoreflect.ProtoMessage, error)
Cancel(uint64, []string)
NewBroadcastID() uint64
AddAddr(id uint32, addr string)
AddHandler(method string, handler any)
Expand All @@ -28,9 +30,9 @@ type manager struct {
logger *slog.Logger
}

func NewBroadcastManager(logger *slog.Logger, m *Metric, createClient func(addr string, dialOpts []grpc.DialOption) (*Client, error)) Manager {
func NewBroadcastManager(logger *slog.Logger, m *Metric, createClient func(addr string, dialOpts []grpc.DialOption) (*Client, error), canceler func(broadcastID uint64, srvAddrs []string)) Manager {
state := NewState(logger, m)
router := NewRouter(logger, m, createClient, state)
router := NewRouter(logger, m, createClient, state, canceler)
state.RunShards(router)
return &manager{
state: state,
Expand All @@ -40,24 +42,24 @@ func NewBroadcastManager(logger *slog.Logger, m *Metric, createClient func(addr
}
}

func (mgr *manager) Process(msg Content) error {
func (mgr *manager) Process(msg Content) (error, context.Context) {
_, shardID, _, _ := DecodeBroadcastID(msg.BroadcastID)
shardID = shardID % NumShards
shard := mgr.state.shards[shardID]

// we only need a single response
receiveChan := make(chan error, 1)
receiveChan := make(chan shardResponse, 1)
msg.ReceiveChan = receiveChan
select {
case <-shard.ctx.Done():
return errors.New("shard is down")
return errors.New("shard is down"), nil
case shard.sendChan <- msg:
}
select {
case <-shard.ctx.Done():
return errors.New("shard is down")
case err := <-receiveChan:
return err
return errors.New("shard is down"), nil
case resp := <-receiveChan:
return resp.err, resp.reqCtx
}
}

Expand Down Expand Up @@ -96,6 +98,21 @@ func (mgr *manager) SendToClient(broadcastID uint64, resp protoreflect.ProtoMess
}
}

func (mgr *manager) Cancel(broadcastID uint64, srvAddrs []string) {
_, shardID, _, _ := DecodeBroadcastID(broadcastID)
shardID = shardID % NumShards
shard := mgr.state.shards[shardID]
select {
case shard.broadcastChan <- Msg{
Cancellation: &cancellation{
srvAddrs: srvAddrs,
},
BroadcastID: broadcastID,
}:
case <-shard.ctx.Done():
}
}

func (mgr *manager) NewBroadcastID() uint64 {
return mgr.state.snowflake.NewBroadcastID()
}
Expand Down
49 changes: 32 additions & 17 deletions broadcast/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ import (
)

type BroadcastRequest struct {
broadcastChan chan Msg
sendChan chan Content
ctx context.Context
cancelFunc context.CancelFunc
started time.Time
ended time.Time
//sync.Once
broadcastChan chan Msg
sendChan chan Content
ctx context.Context
cancelFunc context.CancelFunc
started time.Time
ended time.Time
cancellationCtx context.Context
cancellationCtxCancel context.CancelFunc // should only be called by the shard
}

// func (req *BroadcastRequest) handle(router *BroadcastRouter, broadcastID uint64, msg Content, metrics *Metric) {
Expand All @@ -35,6 +36,7 @@ func (req *BroadcastRequest) handle(router Router, broadcastID uint64, msg Conte
//metrics.RemoveGoroutine(broadcastID, "req")
//}
req.cancelFunc()
req.cancellationCtxCancel()
}()
for {
select {
Expand All @@ -47,6 +49,10 @@ func (req *BroadcastRequest) handle(router Router, broadcastID uint64, msg Conte
if broadcastID != bMsg.BroadcastID {
continue
}
if bMsg.Cancellation != nil {
_ = router.Send(broadcastID, "", "", bMsg.Cancellation)
return
}
if bMsg.Broadcast {
// check if msg has already been broadcasted for this method
if alreadyBroadcasted(methods, bMsg.Method) {
Expand All @@ -67,11 +73,6 @@ func (req *BroadcastRequest) handle(router Router, broadcastID uint64, msg Conte
return
}
// QuorumCall if origin addr is empty.
//if !msg.hasReceivedClientRequest() {
//// Has not received client request
//continue
//}
// Has received client request
err := msg.send(bMsg.Reply.Response, bMsg.Reply.Err)
if err != nil {
// add response if not already done
Expand All @@ -89,7 +90,15 @@ func (req *BroadcastRequest) handle(router Router, broadcastID uint64, msg Conte
}
case new := <-req.sendChan:
if new.BroadcastID != broadcastID {
new.ReceiveChan <- BroadcastIDErr{}
new.ReceiveChan <- shardResponse{
err: BroadcastIDErr{},
}
continue
}
if msg.IsCancellation {
new.ReceiveChan <- shardResponse{
err: nil,
}
continue
}
if msg.OriginAddr == "" && new.OriginAddr != "" {
Expand All @@ -107,15 +116,21 @@ func (req *BroadcastRequest) handle(router Router, broadcastID uint64, msg Conte
if sent && !msg.isBroadcastCall() {
err := msg.send(respMsg, respErr)
if err != nil {
new.ReceiveChan <- err
// should return here?
new.ReceiveChan <- shardResponse{
err: err,
}
return
}
//new.ReceiveChan <- errors.New("req is done and should be returned immediately to client")
new.ReceiveChan <- AlreadyProcessedErr{}
new.ReceiveChan <- shardResponse{
err: AlreadyProcessedErr{},
}
return
}
new.ReceiveChan <- nil
new.ReceiveChan <- shardResponse{
err: nil,
reqCtx: req.cancellationCtx,
}
}
}
}
Expand Down
68 changes: 38 additions & 30 deletions broadcast/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,23 @@ func TestHandleBroadcastOption(t *testing.T) {
in: Content{
BroadcastID: broadcastID,
IsBroadcastClient: false,
ReceiveChan: make(chan error),
ReceiveChan: make(chan shardResponse),
},
out: nil,
},
{
in: Content{
BroadcastID: snowflake.NewBroadcastID(),
IsBroadcastClient: false,
ReceiveChan: make(chan error),
ReceiveChan: make(chan shardResponse),
},
out: BroadcastIDErr{},
},
{
in: Content{
BroadcastID: broadcastID,
IsBroadcastClient: false,
ReceiveChan: make(chan error),
ReceiveChan: make(chan shardResponse),
},
out: nil,
},
Expand All @@ -72,29 +72,33 @@ func TestHandleBroadcastOption(t *testing.T) {
msg := Content{
BroadcastID: broadcastID,
OriginMethod: "testMethod",
ReceiveChan: make(chan error),
ReceiveChan: make(chan shardResponse),
}

router := &mockRouter{
returnError: false,
}

cancelCtx, cancelCancel := context.WithTimeout(context.Background(), 1*time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
defer cancelCancel()
req := &BroadcastRequest{
ctx: ctx,
cancelFunc: cancel,
sendChan: make(chan Content),
broadcastChan: make(chan Msg, 5),
started: time.Now(),
ctx: ctx,
cancelFunc: cancel,
sendChan: make(chan Content),
broadcastChan: make(chan Msg, 5),
started: time.Now(),
cancellationCtx: cancelCtx,
cancellationCtxCancel: cancelCancel,
}
go req.handle(router, msg.BroadcastID, msg)

for _, tt := range tests {
req.sendChan <- tt.in
err := <-tt.in.ReceiveChan
if err != tt.out {
t.Fatalf("wrong error returned.\n\tgot: %v, want: %v", tt.out, err)
resp := <-tt.in.ReceiveChan
if resp.err != tt.out {
t.Fatalf("wrong error returned.\n\tgot: %v, want: %v", tt.out, resp.err)
}
}

Expand Down Expand Up @@ -122,13 +126,13 @@ func TestHandleBroadcastOption(t *testing.T) {
BroadcastID: broadcastID,
IsBroadcastClient: true,
SendFn: func(resp protoreflect.ProtoMessage, err error) {},
ReceiveChan: make(chan error),
ReceiveChan: make(chan shardResponse),
}
req.sendChan <- clientMsg
err := <-clientMsg.ReceiveChan
resp := <-clientMsg.ReceiveChan
expectedErr := AlreadyProcessedErr{}
if err != expectedErr {
t.Fatalf("wrong error returned.\n\tgot: %v, want: %v", err, expectedErr)
if resp.err != expectedErr {
t.Fatalf("wrong error returned.\n\tgot: %v, want: %v", resp.err, expectedErr)
}

select {
Expand All @@ -150,23 +154,23 @@ func TestHandleBroadcastCall(t *testing.T) {
in: Content{
BroadcastID: broadcastID,
IsBroadcastClient: false,
ReceiveChan: make(chan error, 1),
ReceiveChan: make(chan shardResponse, 1),
},
out: nil,
},
{
in: Content{
BroadcastID: snowflake.NewBroadcastID(),
IsBroadcastClient: false,
ReceiveChan: make(chan error, 1),
ReceiveChan: make(chan shardResponse, 1),
},
out: BroadcastIDErr{},
},
{
in: Content{
BroadcastID: broadcastID,
IsBroadcastClient: false,
ReceiveChan: make(chan error, 1),
ReceiveChan: make(chan shardResponse, 1),
},
out: nil,
},
Expand All @@ -177,29 +181,33 @@ func TestHandleBroadcastCall(t *testing.T) {
IsBroadcastClient: false,
OriginAddr: "127.0.0.1:8080",
OriginMethod: "testMethod",
ReceiveChan: make(chan error),
ReceiveChan: make(chan shardResponse),
}

router := &mockRouter{
returnError: false,
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
cancelCtx, cancelCancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
defer cancelCancel()
req := &BroadcastRequest{
ctx: ctx,
cancelFunc: cancel,
sendChan: make(chan Content),
broadcastChan: make(chan Msg, 5),
started: time.Now(),
ctx: ctx,
cancelFunc: cancel,
sendChan: make(chan Content),
broadcastChan: make(chan Msg, 5),
started: time.Now(),
cancellationCtx: cancelCtx,
cancellationCtxCancel: cancelCancel,
}
go req.handle(router, msg.BroadcastID, msg)

for _, tt := range tests {
req.sendChan <- tt.in
err := <-tt.in.ReceiveChan
if err != tt.out {
t.Fatalf("wrong error returned.\n\tgot: %v, want: %v", tt.out, err)
resp := <-tt.in.ReceiveChan
if resp.err != tt.out {
t.Fatalf("wrong error returned.\n\tgot: %v, want: %v", tt.out, resp.err)
}
}

Expand Down Expand Up @@ -228,7 +236,7 @@ func TestHandleBroadcastCall(t *testing.T) {
IsBroadcastClient: true,
OriginAddr: "127.0.0.1:8080",
OriginMethod: "testMethod",
ReceiveChan: make(chan error),
ReceiveChan: make(chan shardResponse),
}
select {
case <-req.ctx.Done():
Expand Down Expand Up @@ -263,7 +271,7 @@ func BenchmarkHandle(b *testing.B) {
IsBroadcastClient: true,
SendFn: sendFn,
OriginMethod: originMethod,
ReceiveChan: make(chan error, 1),
ReceiveChan: make(chan shardResponse, 1),
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
Expand Down
Loading

0 comments on commit 9e09109

Please sign in to comment.