From 9acd7367b4f87978284faf17e0fdfdc0cca4e36b Mon Sep 17 00:00:00 2001 From: John McBride Date: Thu, 22 Sep 2022 14:17:26 -0700 Subject: [PATCH] Remove race conditions in tests that execute go routines Using a mutex, this patch removes race conditions that can happen when attempting to increment numbers from functions that are executed async in go routines. Signed-off-by: John McBride --- updater/aws_test.go | 32 ++++++++++++++++++++++++++++++++ updater/mock_test.go | 2 ++ 2 files changed, 34 insertions(+) diff --git a/updater/aws_test.go b/updater/aws_test.go index 7e8fb35..91196bc 100644 --- a/updater/aws_test.go +++ b/updater/aws_test.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strconv" + "sync" "testing" "github.com/aws/aws-sdk-go/aws" @@ -62,6 +63,10 @@ func TestFilterAvailableUpdates(t *testing.T) { "inst-id-4": `{"update_state": "Staged", "active_partition": { "image": { "version": "v1.1.1"}}}`, "inst-id-5": `{"update_state": "Available", "active_partition": { "image": { "version": "v1.0.5"}}}`, } + + // mutex needed to prevent race condition when incrementing counter in concurrent + // execution of WaitUntilCommandExecutedWithContextFn + var m sync.Mutex sendCommandCalls := 0 commandWaiterCalls := 0 getCommandInvocationCalls := 0 @@ -83,7 +88,9 @@ func TestFilterAvailableUpdates(t *testing.T) { }, nil }, WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error { + m.Lock() commandWaiterCalls++ + m.Unlock() assert.Equal(t, "command-id", aws.StringValue(input.CommandId)) return nil }, @@ -120,6 +127,9 @@ func TestPaginatedFilterAvailableUpdatesSuccess(t *testing.T) { }) } + // mutex needed to prevent race condition when incrementing counter in concurrent + // execution of WaitUntilCommandExecutedWithContextFn + var m sync.Mutex sendCommandCalls := 0 commandWaiterCalls := 0 getCommandInvocationCalls := 0 @@ -138,7 +148,9 @@ func TestPaginatedFilterAvailableUpdatesSuccess(t *testing.T) { }, nil }, WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error { + m.Lock() commandWaiterCalls++ + m.Unlock() assert.Equal(t, "command-id", aws.StringValue(input.CommandId)) return nil }, @@ -191,6 +203,9 @@ func TestPaginatedFilterAvailableUpdatesInPageFailures(t *testing.T) { }) } + // mutex needed to prevent race condition when incrementing counter in concurrent + // execution of WaitUntilCommandExecutedWithContextFn + var m sync.Mutex sendCommandCalls := 0 commandWaiterCalls := 0 getCommandInvocationCalls := 0 @@ -226,7 +241,9 @@ func TestPaginatedFilterAvailableUpdatesInPageFailures(t *testing.T) { }, WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error { assert.Equal(t, "command-id", aws.StringValue(input.CommandId)) + m.Lock() commandWaiterCalls++ + m.Unlock() return nil }, } @@ -264,6 +281,9 @@ func TestPaginatedFilterAvailableUpdatesSingleErr(t *testing.T) { pageErrors := []error{errors.New("Failed to send document"), nil} + // mutex needed to prevent race condition when incrementing counter in concurrent + // execution of WaitUntilCommandExecutedWithContextFn + var m sync.Mutex sendCommandCalls := 0 commandWaiterCalls := 0 getCommandInvocationCalls := 0 @@ -287,7 +307,9 @@ func TestPaginatedFilterAvailableUpdatesSingleErr(t *testing.T) { }, WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error { assert.Equal(t, "command-id", aws.StringValue(input.CommandId)) + m.Lock() commandWaiterCalls++ + m.Unlock() return nil }, } @@ -358,6 +380,9 @@ func TestGetCommandResult(t *testing.T) { func TestSendCommandSuccess(t *testing.T) { instances := []string{"inst-id-1", "inst-id-2"} + // mutex needed to prevent race condition when appending to instances slice in concurrent + // execution of WaitUntilCommandExecutedWithContextFn + var m sync.Mutex waitInstanceIDs := []string{} mockSSM := MockSSM{ SendCommandFn: func(input *ssm.SendCommandInput) (*ssm.SendCommandOutput, error) { @@ -368,7 +393,9 @@ func TestSendCommandSuccess(t *testing.T) { }, WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error { assert.Equal(t, "command-id", aws.StringValue(input.CommandId)) + m.Lock() waitInstanceIDs = append(waitInstanceIDs, aws.StringValue(input.InstanceId)) + m.Unlock() return nil }, } @@ -479,12 +506,17 @@ func TestSendCommandWaitSuccess(t *testing.T) { }) t.Run("wait all success", func(t *testing.T) { instances := []string{"inst-id-1", "inst-id-2"} + // mutex needed to prevent race condition when appending to instances slice in concurrent + // execution of WaitUntilCommandExecutedWithContextFn + var m sync.Mutex waitInstanceIDs := []string{} mockSSM := MockSSM{ SendCommandFn: mockSendCommand, WaitUntilCommandExecutedWithContextFn: func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error { assert.Equal(t, "command-id", aws.StringValue(input.CommandId)) + m.Lock() waitInstanceIDs = append(waitInstanceIDs, aws.StringValue(input.InstanceId)) + m.Unlock() return nil }, } diff --git a/updater/mock_test.go b/updater/mock_test.go index 83f5e76..3277710 100644 --- a/updater/mock_test.go +++ b/updater/mock_test.go @@ -20,6 +20,8 @@ type MockECS struct { var _ ECSAPI = (*MockECS)(nil) type MockSSM struct { + // WaitUntilCommandExecutedWithContextFn is executed concurrently through + // ECS code paths and tests should treat any data in a parallel safe manner WaitUntilCommandExecutedWithContextFn func(ctx aws.Context, input *ssm.GetCommandInvocationInput, opts ...request.WaiterOption) error SendCommandFn func(input *ssm.SendCommandInput) (*ssm.SendCommandOutput, error) GetCommandInvocationFn func(input *ssm.GetCommandInvocationInput) (*ssm.GetCommandInvocationOutput, error)