Skip to content

Commit

Permalink
Merge pull request #106 from jpmcb/unflake-aws-tests
Browse files Browse the repository at this point in the history
Remove race conditions in tests that execute `go` routines
  • Loading branch information
jpmcb authored Sep 23, 2022
2 parents cdaf5f7 + 9acd736 commit 20b98bf
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
32 changes: 32 additions & 0 deletions updater/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"strconv"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -62,6 +63,10 @@ func TestFilterAvailableUpdates(t *testing.T) {
"inst-id-4": `{"update_state": "Staged", "active_partition": { "image": { "version": "v1.1.1"}}}`,
"inst-id-5": `{"update_state": "Available", "active_partition": { "image": { "version": "v1.0.5"}}}`,
}

// mutex needed to prevent race condition when incrementing counter in concurrent
// execution of WaitUntilCommandExecutedWithContextFn
var m sync.Mutex
sendCommandCalls := 0
commandWaiterCalls := 0
getCommandInvocationCalls := 0
Expand All @@ -83,7 +88,9 @@ func TestFilterAvailableUpdates(t *testing.T) {
}, nil
},
WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error {
m.Lock()
commandWaiterCalls++
m.Unlock()
assert.Equal(t, "command-id", aws.StringValue(input.CommandId))
return nil
},
Expand Down Expand Up @@ -120,6 +127,9 @@ func TestPaginatedFilterAvailableUpdatesSuccess(t *testing.T) {
})
}

// mutex needed to prevent race condition when incrementing counter in concurrent
// execution of WaitUntilCommandExecutedWithContextFn
var m sync.Mutex
sendCommandCalls := 0
commandWaiterCalls := 0
getCommandInvocationCalls := 0
Expand All @@ -138,7 +148,9 @@ func TestPaginatedFilterAvailableUpdatesSuccess(t *testing.T) {
}, nil
},
WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error {
m.Lock()
commandWaiterCalls++
m.Unlock()
assert.Equal(t, "command-id", aws.StringValue(input.CommandId))
return nil
},
Expand Down Expand Up @@ -191,6 +203,9 @@ func TestPaginatedFilterAvailableUpdatesInPageFailures(t *testing.T) {
})
}

// mutex needed to prevent race condition when incrementing counter in concurrent
// execution of WaitUntilCommandExecutedWithContextFn
var m sync.Mutex
sendCommandCalls := 0
commandWaiterCalls := 0
getCommandInvocationCalls := 0
Expand Down Expand Up @@ -226,7 +241,9 @@ func TestPaginatedFilterAvailableUpdatesInPageFailures(t *testing.T) {
},
WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error {
assert.Equal(t, "command-id", aws.StringValue(input.CommandId))
m.Lock()
commandWaiterCalls++
m.Unlock()
return nil
},
}
Expand Down Expand Up @@ -264,6 +281,9 @@ func TestPaginatedFilterAvailableUpdatesSingleErr(t *testing.T) {

pageErrors := []error{errors.New("Failed to send document"), nil}

// mutex needed to prevent race condition when incrementing counter in concurrent
// execution of WaitUntilCommandExecutedWithContextFn
var m sync.Mutex
sendCommandCalls := 0
commandWaiterCalls := 0
getCommandInvocationCalls := 0
Expand All @@ -287,7 +307,9 @@ func TestPaginatedFilterAvailableUpdatesSingleErr(t *testing.T) {
},
WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error {
assert.Equal(t, "command-id", aws.StringValue(input.CommandId))
m.Lock()
commandWaiterCalls++
m.Unlock()
return nil
},
}
Expand Down Expand Up @@ -358,6 +380,9 @@ func TestGetCommandResult(t *testing.T) {

func TestSendCommandSuccess(t *testing.T) {
instances := []string{"inst-id-1", "inst-id-2"}
// mutex needed to prevent race condition when appending to instances slice in concurrent
// execution of WaitUntilCommandExecutedWithContextFn
var m sync.Mutex
waitInstanceIDs := []string{}
mockSSM := MockSSM{
SendCommandFn: func(input *ssm.SendCommandInput) (*ssm.SendCommandOutput, error) {
Expand All @@ -368,7 +393,9 @@ func TestSendCommandSuccess(t *testing.T) {
},
WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error {
assert.Equal(t, "command-id", aws.StringValue(input.CommandId))
m.Lock()
waitInstanceIDs = append(waitInstanceIDs, aws.StringValue(input.InstanceId))
m.Unlock()
return nil
},
}
Expand Down Expand Up @@ -479,12 +506,17 @@ func TestSendCommandWaitSuccess(t *testing.T) {
})
t.Run("wait all success", func(t *testing.T) {
instances := []string{"inst-id-1", "inst-id-2"}
// mutex needed to prevent race condition when appending to instances slice in concurrent
// execution of WaitUntilCommandExecutedWithContextFn
var m sync.Mutex
waitInstanceIDs := []string{}
mockSSM := MockSSM{
SendCommandFn: mockSendCommand,
WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error {
assert.Equal(t, "command-id", aws.StringValue(input.CommandId))
m.Lock()
waitInstanceIDs = append(waitInstanceIDs, aws.StringValue(input.InstanceId))
m.Unlock()
return nil
},
}
Expand Down
2 changes: 2 additions & 0 deletions updater/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type MockECS struct {
var _ ECSAPI = (*MockECS)(nil)

type MockSSM struct {
// WaitUntilCommandExecutedWithContextFn is executed concurrently through
// ECS code paths and tests should treat any data in a parallel safe manner
WaitUntilCommandExecutedWithContextFn func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error
SendCommandFn func(input *ssm.SendCommandInput) (*ssm.SendCommandOutput, error)
GetCommandInvocationFn func(input *ssm.GetCommandInvocationInput) (*ssm.GetCommandInvocationOutput, error)
Expand Down

0 comments on commit 20b98bf

Please sign in to comment.