Skip to content

Commit

Permalink
test(broadcast): added tests for shard and broadcast request
Browse files Browse the repository at this point in the history
  • Loading branch information
aleksander-vedvik committed May 1, 2024
1 parent 5778dee commit bafd2ce
Show file tree
Hide file tree
Showing 4 changed files with 451 additions and 9 deletions.
5 changes: 3 additions & 2 deletions broadcast/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func NewBroadcastManager(logger *slog.Logger, m *Metric, createClient func(addr
router := NewRouter(logger, m, createClient)
state := NewState(logger, m)
for _, shard := range state.shards {
go shard.run(router, state.reqTTL, state.sendBuffer, m)
go shard.run(router, state.reqTTL, state.sendBuffer)
}
return &broadcastManager{
state: state,
Expand All @@ -47,7 +47,8 @@ func (mgr *broadcastManager) Process(msg Content) error {
shardID = shardID % NumShards
shard := mgr.state.shards[shardID]

receiveChan := make(chan error)
// we only need a single response
receiveChan := make(chan error, 1)
msg.ReceiveChan = receiveChan
select {
case <-shard.ctx.Done():
Expand Down
154 changes: 150 additions & 4 deletions broadcast/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
)

type mockReq struct{}
type mockResp struct{}

func (mockReq) ProtoReflect() protoreflect.Message {
func (mockResp) ProtoReflect() protoreflect.Message {
return nil
}

Expand All @@ -35,7 +35,7 @@ func (r *mockRouter) Send(broadcastID uint64, addr, method string, req any) erro
return nil
}

func TestHandle(t *testing.T) {
func TestHandleBroadcastOption(t *testing.T) {
snowflake := NewSnowflake("127.0.0.1:8080")
broadcastID := snowflake.NewBroadcastID()

Expand Down Expand Up @@ -106,7 +106,7 @@ func TestHandle(t *testing.T) {

req.broadcastChan <- Msg{
Reply: &reply{
Response: mockReq{},
Response: mockResp{},
Err: nil,
},
BroadcastID: broadcastID,
Expand Down Expand Up @@ -137,3 +137,149 @@ func TestHandle(t *testing.T) {
case <-req.ctx.Done():
}
}

func TestHandleBroadcastCall(t *testing.T) {
snowflake := NewSnowflake("127.0.0.1:8080")
broadcastID := snowflake.NewBroadcastID()

var tests = []struct {
in Content
out error
}{
{
in: Content{
BroadcastID: broadcastID,
IsBroadcastClient: false,
ReceiveChan: make(chan error, 1),
},
out: nil,
},
{
in: Content{
BroadcastID: snowflake.NewBroadcastID(),
IsBroadcastClient: false,
ReceiveChan: make(chan error, 1),
},
out: BroadcastIDErr{},
},
{
in: Content{
BroadcastID: broadcastID,
IsBroadcastClient: false,
ReceiveChan: make(chan error, 1),
},
out: nil,
},
}

msg := Content{
BroadcastID: broadcastID,
IsBroadcastClient: false,
OriginAddr: "127.0.0.1:8080",
OriginMethod: "testMethod",
ReceiveChan: make(chan error),
}

router := &mockRouter{
returnError: false,
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
req := &BroadcastRequest{
ctx: ctx,
cancelFunc: cancel,
sendChan: make(chan Content),
broadcastChan: make(chan Msg, 5),
started: time.Now(),
}
go req.handle(router, msg.BroadcastID, msg)

for _, tt := range tests {
req.sendChan <- tt.in
err := <-tt.in.ReceiveChan
if err != tt.out {
t.Fatalf("wrong error returned.\n\tgot: %v, want: %v", tt.out, err)
}
}

select {
case <-time.After(100 * time.Millisecond):
case <-req.ctx.Done():
t.Fatalf("the request is not done yet. SendToClient has not been called.")
}

req.broadcastChan <- Msg{
Reply: &reply{
Response: mockResp{},
Err: nil,
},
BroadcastID: broadcastID,
}

select {
case <-time.After(1 * time.Second):
t.Fatalf("the request is done. SendToClient has been called and this is a BroadcastCall, meaning it should respond regardless of the client request.")
case <-req.ctx.Done():
}

clientMsg := Content{
BroadcastID: broadcastID,
IsBroadcastClient: true,
OriginAddr: "127.0.0.1:8080",
OriginMethod: "testMethod",
ReceiveChan: make(chan error),
}
select {
case <-req.ctx.Done():
case req.sendChan <- clientMsg:
t.Fatalf("the request is done. SendToClient has been called so this message should be dropped.")
}
}

func BenchmarkHandle(b *testing.B) {
snowflake := NewSnowflake("127.0.0.1:8080")
originMethod := "testMethod"
router := &mockRouter{
returnError: false,
}
// not important to use unique broadcastID because we are
// not using shards in this test
broadcastID := snowflake.NewBroadcastID()
resp := Msg{
Reply: &reply{
Response: mockResp{},
Err: nil,
},
BroadcastID: broadcastID,
}
sendFn := func(resp protoreflect.ProtoMessage, err error) {}

b.ResetTimer()
b.Run("RequestHandler", func(b *testing.B) {
for i := 0; i < b.N; i++ {
msg := Content{
BroadcastID: broadcastID,
IsBroadcastClient: true,
SendFn: sendFn,
OriginMethod: originMethod,
ReceiveChan: make(chan error, 1),
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
req := &BroadcastRequest{
ctx: ctx,
cancelFunc: cancel,
sendChan: make(chan Content),
broadcastChan: make(chan Msg, 5),
started: time.Now(),
}
go req.handle(router, msg.BroadcastID, msg)

req.broadcastChan <- resp

<-req.ctx.Done()
cancel()
}
})
}
14 changes: 11 additions & 3 deletions broadcast/shard.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package broadcast
import (
"context"
"errors"
"fmt"
"time"
)

Expand Down Expand Up @@ -50,7 +49,7 @@ func createShards(ctx context.Context, shardBuffer int) []*shard {
return shards
}

func (s *shard) run(router Router, reqTTL time.Duration, sendBuffer int, metrics *Metric) {
func (s *shard) run(router Router, reqTTL time.Duration, sendBuffer int) {
for {
select {
case <-s.ctx.Done():
Expand All @@ -62,17 +61,26 @@ func (s *shard) run(router Router, reqTTL time.Duration, sendBuffer int, metrics
//metrics.AddShardDistribution(s.id)
//}
if req, ok := s.reqs[msg.BroadcastID]; ok {
// must check if the req is done first to prevent
// unecessarily running the server handler
select {
case <-req.ctx.Done():
s.metrics.droppedMsgs++
msg.ReceiveChan <- AlreadyProcessedErr{}
default:
}
if !msg.IsBroadcastClient {
// no need to send it to the broadcast request goroutine.
// the first request should contain all info needed
// except for the routing info given in the client req.
msg.ReceiveChan <- nil
continue
}
// must check if the req is done to prevent deadlock
select {
case <-req.ctx.Done():
s.metrics.droppedMsgs++
msg.ReceiveChan <- fmt.Errorf("req is done. broadcastID: %v", msg.BroadcastID)
msg.ReceiveChan <- AlreadyProcessedErr{}
case req.sendChan <- msg:
}
} else {
Expand Down
Loading

0 comments on commit bafd2ce

Please sign in to comment.