From 31b1387ae46ac252037c7e88992f52754a73639a Mon Sep 17 00:00:00 2001 From: Seng Lin Shee Date: Thu, 28 May 2020 01:12:33 +0000 Subject: [PATCH] Create default SG for Task Run and Service Create - Added SetDefaultSecurityGroupID call to check for default security groups for run task and create service. - Run task and create service creates default security group if default security group does not exist. - Similar function in vpc_operation.go is refactored into EC2 client and shared with service_create.go and task_run.go. - Added console logs in EC2 client, service_create.go, task_run.go. - Unit tests for refactored code --- cmd/lb_create_test.go | 49 +--------- cmd/service_create.go | 61 ++++++++----- cmd/task_run.go | 25 +++++- cmd/vpc_operation.go | 20 +---- cmd/vpc_operation_test.go | 44 ++------- ec2/main.go | 1 + ec2/mock/client/client.go | 15 ++++ ec2/vpc.go | 32 +++++++ ec2/vpc_test.go | 182 +++++++++++++++++++++++++++++++++++++- 9 files changed, 300 insertions(+), 129 deletions(-) diff --git a/cmd/lb_create_test.go b/cmd/lb_create_test.go index 449ed1f..5c7d07c 100644 --- a/cmd/lb_create_test.go +++ b/cmd/lb_create_test.go @@ -5,13 +5,13 @@ import ( "reflect" "testing" - "github.com/golang/mock/gomock" "github.com/awslabs/fargatecli/acm" acmclient "github.com/awslabs/fargatecli/acm/mock/client" "github.com/awslabs/fargatecli/cmd/mock" ec2client "github.com/awslabs/fargatecli/ec2/mock/client" "github.com/awslabs/fargatecli/elbv2" elbv2client "github.com/awslabs/fargatecli/elbv2/mock/client" + "github.com/golang/mock/gomock" ) var ( @@ -756,48 +756,7 @@ func TestNewLBCreateOperationDefaults(t *testing.T) { mockEC2.EXPECT().GetSubnetVPCID("subnet-1234567").Return("vpc-1234567", nil) mockEC2.EXPECT().GetDefaultSubnetIDs().Return([]string{"subnet-1234567", "subnet-abcdef"}, nil) - mockEC2.EXPECT().GetDefaultSecurityGroupID().Return("sg-abcdef", nil) - - o, errs := newLBCreateOperation( - "web", - "internet-facing", - []string{}, - []string{"80"}, - []string{}, - []string{}, - mockOutput, - mockACM, - mockEC2, - mockELBV2, - ) - - if len(errs) > 0 { - t.Fatalf("expected no error, got: %v", errs) - } - - if o.securityGroupIDs[0] != "sg-abcdef" { - t.Errorf("expected security group ID == sg-abcdef, got: %v", o.securityGroupIDs) - } - - if o.subnetIDs[0] != "subnet-1234567" || o.subnetIDs[1] != "subnet-abcdef" { - t.Errorf("expected subnet ID == subnet-1234567, subnet-abcdef, got: %v", o.subnetIDs) - } -} - -func TestNewLBCreateOperationDefaultsWithSGCreate(t *testing.T) { - mockOutput := &mock.Output{} - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockEC2 := ec2client.NewMockClient(mockCtrl) - mockACM := acmclient.NewMockClient(mockCtrl) - mockELBV2 := elbv2client.NewMockClient(mockCtrl) - - mockEC2.EXPECT().GetSubnetVPCID("subnet-1234567").Return("vpc-1234567", nil) - mockEC2.EXPECT().GetDefaultSubnetIDs().Return([]string{"subnet-1234567", "subnet-abcdef"}, nil) - mockEC2.EXPECT().GetDefaultSecurityGroupID().Return("", nil) - mockEC2.EXPECT().CreateDefaultSecurityGroup().Return("sg-abcdef", nil) - mockEC2.EXPECT().AuthorizeAllSecurityGroupIngress("sg-abcdef").Return(nil) + mockEC2.EXPECT().SetDefaultSecurityGroupID().Return("sg-abcdef", nil) //SGCreate fallback is tested in vpc_test.go o, errs := newLBCreateOperation( "web", @@ -985,7 +944,7 @@ func TestNewLBCreateOperationUseDefaultSG(t *testing.T) { defer mockCtrl.Finish() ec2 := ec2client.NewMockClient(mockCtrl) - ec2.EXPECT().GetDefaultSecurityGroupID().Return("sg-1234567", nil) + ec2.EXPECT().SetDefaultSecurityGroupID().Return("sg-1234567", nil) ec2.EXPECT().GetSubnetVPCID(gomock.Any()).Return("vpc-1234567", nil) o, err := newLBCreateOperation( @@ -1016,7 +975,7 @@ func TestNewLBCreateOperationDefaultSGError(t *testing.T) { defer mockCtrl.Finish() ec2 := ec2client.NewMockClient(mockCtrl) - ec2.EXPECT().GetDefaultSecurityGroupID().Return("", errors.New("boom")) + ec2.EXPECT().SetDefaultSecurityGroupID().Return("", errors.New("boom")) ec2.EXPECT().GetSubnetVPCID(gomock.Any()).Return("vpc-1234567", nil) _, errs := newLBCreateOperation( diff --git a/cmd/service_create.go b/cmd/service_create.go index 9b070f2..8197960 100644 --- a/cmd/service_create.go +++ b/cmd/service_create.go @@ -20,21 +20,22 @@ import ( const typeService = "service" type ServiceCreateOperation struct { - Cpu string - EnvVars []ECS.EnvVar - Image string - LoadBalancerArn string - LoadBalancerName string - Memory string - Num int64 - Port Port - Rules []ELBV2.Rule - SecurityGroupIds []string - ServiceName string - SubnetIds []string - TaskRole string - TaskCommand []string - AssignPublicIPEnabled bool + Cpu string + EnvVars []ECS.EnvVar + Image string + LoadBalancerArn string + LoadBalancerName string + Memory string + Num int64 + Port Port + Rules []ELBV2.Rule + SecurityGroupIds []string + ServiceName string + SubnetIds []string + TaskRole string + TaskCommand []string + AssignPublicIPEnabled bool + output Output } func (o *ServiceCreateOperation) SetPort(inputPort string) { @@ -149,7 +150,7 @@ var ( flagServiceCreateSecurityGroupIds []string flagServiceCreateSubnetIds []string flagServiceCreateTaskRole string - flagServiceCreateTaskCommand []string + flagServiceCreateTaskCommand []string flagServiceAssignPublicIP bool ) @@ -241,8 +242,9 @@ Services can be configured to have only private ip address via the ServiceName: args[0], SubnetIds: flagServiceCreateSubnetIds, TaskRole: flagServiceCreateTaskRole, - TaskCommand: flagServiceCreateTaskCommand, + TaskCommand: flagServiceCreateTaskCommand, AssignPublicIPEnabled: flagServiceAssignPublicIP, + output: output, } if flagServiceCreatePort != "" { @@ -262,7 +264,12 @@ Services can be configured to have only private ip address via the } operation.Validate() - createService(operation) + errs := createService(operation) + + if len(errs) > 0 { + output.Fatals(errs, "Errors found while executing [COMMAND=service Action=create]") + return + } }, } @@ -279,11 +286,11 @@ func init() { serviceCreateCmd.Flags().StringSliceVar(&flagServiceCreateSubnetIds, "subnet-id", []string{}, "ID of a subnet in which to place the service (can be specified multiple times)") serviceCreateCmd.Flags().StringVarP(&flagServiceCreateTaskRole, "task-role", "", "", "Name or ARN of an IAM role that the service's tasks can assume") serviceCreateCmd.Flags().StringSliceVar(&flagServiceCreateTaskCommand, "task-command", []string{}, "Command to run inside container instead of the one specified in the docker image") - serviceCreateCmd.Flags().BoolVarP(&flagServiceAssignPublicIP, "assign-public-ip", "", true, "Assign public ip address") + serviceCreateCmd.Flags().BoolVarP(&flagServiceAssignPublicIP, "assign-public-ip", "", true, "Assign public ip address") serviceCmd.AddCommand(serviceCreateCmd) } -func createService(operation *ServiceCreateOperation) { +func createService(operation *ServiceCreateOperation) (errors []error) { var targetGroupArn string cwl := CWL.New(sess) @@ -296,12 +303,20 @@ func createService(operation *ServiceCreateOperation) { logGroupName := cwl.CreateLogGroup(serviceLogGroupFormat, operation.ServiceName) if len(operation.SecurityGroupIds) == 0 { - defaultSecurityGroupID, _ := ec2.GetDefaultSecurityGroupID() + operation.output.Debug("Find the default security group, and creates it if it does not exist [COMMAND=task Action=run]") + defaultSecurityGroupID, err := ec2.SetDefaultSecurityGroupID() + if err != nil { + errors = append(errors, err) + } operation.SecurityGroupIds = []string{defaultSecurityGroupID} } if len(operation.SubnetIds) == 0 { - operation.SubnetIds, _ = ec2.GetDefaultSubnetIDs() + var err error + operation.SubnetIds, err = ec2.GetDefaultSubnetIDs() + if err != nil { + errors = append(errors, err) + } } if operation.Image == "" { @@ -381,4 +396,6 @@ func createService(operation *ServiceCreateOperation) { ) console.Info("Created service %s", operation.ServiceName) + + return } diff --git a/cmd/task_run.go b/cmd/task_run.go index eecbe05..8fe66b7 100644 --- a/cmd/task_run.go +++ b/cmd/task_run.go @@ -25,6 +25,7 @@ type TaskRunOperation struct { TaskName string TaskRole string TaskCommand []string + output Output } func (o *TaskRunOperation) Validate() { @@ -125,12 +126,18 @@ the requirements of the docker CMD syntax`, TaskName: args[0], TaskRole: flagTaskRunTaskRole, TaskCommand: flagTaskRunTaskCommand, + output: output, } operation.SetEnvVars(flagTaskRunEnvVars) operation.Validate() - runTask(operation) + errs := runTask(operation) + + if len(errs) > 0 { + output.Fatals(errs, "Errors found while executing [COMMAND=task Action=run]") + return + } }, } @@ -147,7 +154,7 @@ func init() { taskCmd.AddCommand(taskRunCmd) } -func runTask(operation *TaskRunOperation) { +func runTask(operation *TaskRunOperation) (errors []error) { cwl := CWL.New(sess) ec2 := EC2.New(sess) ecr := ECR.New(sess) @@ -157,12 +164,20 @@ func runTask(operation *TaskRunOperation) { logGroupName := cwl.CreateLogGroup(taskLogGroupFormat, operation.TaskName) if len(operation.SecurityGroupIds) == 0 { - defaultSecurityGroupID, _ := ec2.GetDefaultSecurityGroupID() + operation.output.Debug("Find the default security group, and creates it if it does not exist [COMMAND=task Action=run]") + defaultSecurityGroupID, err := ec2.SetDefaultSecurityGroupID() + if err != nil { + errors = append(errors, err) + } operation.SecurityGroupIds = []string{defaultSecurityGroupID} } if len(operation.SubnetIds) == 0 { - operation.SubnetIds, _ = ec2.GetDefaultSubnetIDs() + var err error + operation.SubnetIds, err = ec2.GetDefaultSubnetIDs() + if err != nil { + errors = append(errors, err) + } } if operation.Image == "" { @@ -218,4 +233,6 @@ func runTask(operation *TaskRunOperation) { ) console.Info("Running task %s", operation.TaskName) + + return } diff --git a/cmd/vpc_operation.go b/cmd/vpc_operation.go index 5285b38..0c9094f 100644 --- a/cmd/vpc_operation.go +++ b/cmd/vpc_operation.go @@ -31,29 +31,13 @@ func (o *vpcOperation) setSecurityGroupIDs(securityGroupIDs []string) { } func (o *vpcOperation) setDefaultSecurityGroupID() error { - o.output.Debug("Finding default security group [API=ec2 Action=DescribeSecurityGroups]") - defaultSecurityGroupID, err := o.ec2.GetDefaultSecurityGroupID() + // setting of default security group id is delegated to the ec2 module + defaultSecurityGroupID, err := o.ec2.SetDefaultSecurityGroupID() if err != nil { return err } - if defaultSecurityGroupID == "" { - o.output.Debug("Creating default security group [API=ec2 Action=CreateSecurityGroup]") - defaultSecurityGroupID, err = o.ec2.CreateDefaultSecurityGroup() - - if err != nil { - return err - } - - o.output.Debug("Created default security group [ID=%s]", defaultSecurityGroupID) - - o.output.Debug("Configuring default security group [API=ec2 Action=AuthorizeSecurityGroupIngress]") - if err := o.ec2.AuthorizeAllSecurityGroupIngress(defaultSecurityGroupID); err != nil { - return err - } - } - o.securityGroupIDs = []string{defaultSecurityGroupID} return nil diff --git a/cmd/vpc_operation_test.go b/cmd/vpc_operation_test.go index 9434f21..246685e 100644 --- a/cmd/vpc_operation_test.go +++ b/cmd/vpc_operation_test.go @@ -4,9 +4,9 @@ import ( "errors" "testing" - "github.com/golang/mock/gomock" "github.com/awslabs/fargatecli/cmd/mock" ec2client "github.com/awslabs/fargatecli/ec2/mock/client" + "github.com/golang/mock/gomock" ) func TestSetSubnetIDs(t *testing.T) { @@ -74,7 +74,7 @@ func TestSetDefaultSecurityGroupID(t *testing.T) { mockEC2Client := ec2client.NewMockClient(mockCtrl) mockOutput := &mock.Output{} - mockEC2Client.EXPECT().GetDefaultSecurityGroupID().Return("sg-1234567", nil) + mockEC2Client.EXPECT().SetDefaultSecurityGroupID().Return("sg-1234567", nil) //SGCreate fallback is tested in vpc_test.go operation := vpcOperation{ ec2: mockEC2Client, @@ -103,7 +103,7 @@ func TestSetDefaultSecurityGroupIDLookupError(t *testing.T) { mockEC2Client := ec2client.NewMockClient(mockCtrl) mockOutput := &mock.Output{} - mockEC2Client.EXPECT().GetDefaultSecurityGroupID().Return("", errors.New("boom")) + mockEC2Client.EXPECT().SetDefaultSecurityGroupID().Return("", errors.New("boom")) operation := vpcOperation{ ec2: mockEC2Client, @@ -121,37 +121,6 @@ func TestSetDefaultSecurityGroupIDLookupError(t *testing.T) { } } -func TestSetDefaultSecurityGroupIDWithCreate(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockEC2Client := ec2client.NewMockClient(mockCtrl) - mockOutput := &mock.Output{} - - mockEC2Client.EXPECT().GetDefaultSecurityGroupID().Return("", nil) - mockEC2Client.EXPECT().CreateDefaultSecurityGroup().Return("sg-1234567", nil) - mockEC2Client.EXPECT().AuthorizeAllSecurityGroupIngress("sg-1234567").Return(nil) - - operation := vpcOperation{ - ec2: mockEC2Client, - output: mockOutput, - } - - err := operation.setDefaultSecurityGroupID() - - if err != nil { - t.Errorf("expected no error, got: %v", err) - } - - if len(operation.securityGroupIDs) != 1 { - t.Fatalf("expected 1 security group ID, got: %d", len(operation.securityGroupIDs)) - } - - if expected := "sg-1234567"; operation.securityGroupIDs[0] != expected { - t.Errorf("expected: %s, got: %s", expected, operation.securityGroupIDs[0]) - } -} - func TestSetDefaultSecurityGroupIDWithCreateError(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -159,8 +128,7 @@ func TestSetDefaultSecurityGroupIDWithCreateError(t *testing.T) { mockEC2Client := ec2client.NewMockClient(mockCtrl) mockOutput := &mock.Output{} - mockEC2Client.EXPECT().GetDefaultSecurityGroupID().Return("", nil) - mockEC2Client.EXPECT().CreateDefaultSecurityGroup().Return("", errors.New("boom")) + mockEC2Client.EXPECT().SetDefaultSecurityGroupID().Return("", errors.New("boom")) operation := vpcOperation{ ec2: mockEC2Client, @@ -185,9 +153,7 @@ func TestSetDefaultSecurityGroupIDWithAuthorizeError(t *testing.T) { mockEC2Client := ec2client.NewMockClient(mockCtrl) mockOutput := &mock.Output{} - mockEC2Client.EXPECT().GetDefaultSecurityGroupID().Return("", nil) - mockEC2Client.EXPECT().CreateDefaultSecurityGroup().Return("sg-1234567", nil) - mockEC2Client.EXPECT().AuthorizeAllSecurityGroupIngress("sg-1234567").Return(errors.New("boom")) + mockEC2Client.EXPECT().SetDefaultSecurityGroupID().Return("sg-1234567", errors.New("boom")) operation := vpcOperation{ ec2: mockEC2Client, diff --git a/ec2/main.go b/ec2/main.go index 3a8c045..f26db7f 100644 --- a/ec2/main.go +++ b/ec2/main.go @@ -16,6 +16,7 @@ type Client interface { GetDefaultSecurityGroupID() (string, error) GetDefaultSubnetIDs() ([]string, error) GetSubnetVPCID(string) (string, error) + SetDefaultSecurityGroupID() (string, error) } // SDKClient implements access to EC2 via the AWS SDK. diff --git a/ec2/mock/client/client.go b/ec2/mock/client/client.go index e7ba83d..c3cfcb5 100644 --- a/ec2/mock/client/client.go +++ b/ec2/mock/client/client.go @@ -105,3 +105,18 @@ func (mr *MockClientMockRecorder) GetSubnetVPCID(arg0 interface{}) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubnetVPCID", reflect.TypeOf((*MockClient)(nil).GetSubnetVPCID), arg0) } + +// SetDefaultSecurityGroupID mocks base method +func (m *MockClient) SetDefaultSecurityGroupID() (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDefaultSecurityGroupID") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SetDefaultSecurityGroupID indicates an expected call of SetDefaultSecurityGroupID +func (mr *MockClientMockRecorder) SetDefaultSecurityGroupID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDefaultSecurityGroupID", reflect.TypeOf((*MockClient)(nil).SetDefaultSecurityGroupID)) +} diff --git a/ec2/vpc.go b/ec2/vpc.go index 6882380..ed60a3f 100644 --- a/ec2/vpc.go +++ b/ec2/vpc.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" awsec2 "github.com/aws/aws-sdk-go/service/ec2" + "github.com/awslabs/fargatecli/console" ) const ( @@ -15,6 +16,32 @@ const ( defaultSecurityGroupIngressProtocol = "-1" ) +// SetDefaultSecurityGroupID tries to find the default security group, and creates it if it does not exist. +func (ec2 SDKClient) SetDefaultSecurityGroupID() (string, error) { + var defaultSecurityGroupID string + defaultSecurityGroupID, err := ec2.GetDefaultSecurityGroupID() + + if err != nil { + return defaultSecurityGroupID, err + } + + if defaultSecurityGroupID == "" { + defaultSecurityGroupID, err = ec2.CreateDefaultSecurityGroup() + + if err != nil { + return defaultSecurityGroupID, err + } + + console.Debug("Created default security group [ID=%s]", defaultSecurityGroupID) + + if err := ec2.AuthorizeAllSecurityGroupIngress(defaultSecurityGroupID); err != nil { + return defaultSecurityGroupID, err + } + } + + return defaultSecurityGroupID, nil +} + // GetDefaultSubnetIDs finds and returns the subnet IDs marked as default. func (ec2 SDKClient) GetDefaultSubnetIDs() ([]string, error) { var subnetIDs []string @@ -24,6 +51,7 @@ func (ec2 SDKClient) GetDefaultSubnetIDs() ([]string, error) { Values: aws.StringSlice([]string{"true"}), } + console.Debug("Retrieving subnet information [API=ec2 Action=DescribeSubnets]") resp, err := ec2.client.DescribeSubnets( &awsec2.DescribeSubnetsInput{ Filters: []*awsec2.Filter{defaultFilter}, @@ -43,6 +71,7 @@ func (ec2 SDKClient) GetDefaultSubnetIDs() ([]string, error) { // GetDefaultSecurityGroupID returns the ID of the permissive security group created by default. func (ec2 SDKClient) GetDefaultSecurityGroupID() (string, error) { + console.Debug("Retrieving security group information [API=ec2 Action=DescribeSecurityGroups]") resp, err := ec2.client.DescribeSecurityGroups( &awsec2.DescribeSecurityGroupsInput{ GroupNames: aws.StringSlice([]string{defaultSecurityGroupName}), @@ -64,6 +93,7 @@ func (ec2 SDKClient) GetDefaultSecurityGroupID() (string, error) { // GetSubnetVPCID returns the VPC ID for a given subnet ID. func (ec2 SDKClient) GetSubnetVPCID(subnetID string) (string, error) { + console.Debug("Retrieving subnet information [API=ec2 Action=DescribeSubnets]") resp, err := ec2.client.DescribeSubnets( &awsec2.DescribeSubnetsInput{ SubnetIds: aws.StringSlice([]string{subnetID}), @@ -82,6 +112,7 @@ func (ec2 SDKClient) GetSubnetVPCID(subnetID string) (string, error) { // CreateDefaultSecurityGroup creates a new security group for use as the default. func (ec2 SDKClient) CreateDefaultSecurityGroup() (string, error) { + console.Debug("Creating security group [API=ec2 Action=CreateSecurityGroup]") resp, err := ec2.client.CreateSecurityGroup( &awsec2.CreateSecurityGroupInput{ GroupName: aws.String(defaultSecurityGroupName), @@ -98,6 +129,7 @@ func (ec2 SDKClient) CreateDefaultSecurityGroup() (string, error) { // AuthorizeAllSecurityGroupIngress configures a security group to allow all ingress traffic. func (ec2 SDKClient) AuthorizeAllSecurityGroupIngress(groupID string) error { + console.Debug("Configuring default security group [API=ec2 Action=AuthorizeSecurityGroupIngress]") _, err := ec2.client.AuthorizeSecurityGroupIngress( &awsec2.AuthorizeSecurityGroupIngressInput{ CidrIp: aws.String(defaultSecurityGroupIngressCIDR), diff --git a/ec2/vpc_test.go b/ec2/vpc_test.go index 4aa5c86..4314abc 100644 --- a/ec2/vpc_test.go +++ b/ec2/vpc_test.go @@ -7,8 +7,8 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" awsec2 "github.com/aws/aws-sdk-go/service/ec2" - "github.com/golang/mock/gomock" "github.com/awslabs/fargatecli/ec2/mock/sdk" + "github.com/golang/mock/gomock" ) func TestGetDefaultSubnetIDs(t *testing.T) { @@ -138,6 +138,186 @@ func TestGetDefaultSecurityGroupIDGroupNotFound(t *testing.T) { } } +func TestSetDefaultSecurityGroupID(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + securityGroupID := "sg-abcdef" + securityGroup := &awsec2.SecurityGroup{ + GroupId: aws.String(securityGroupID), + } + input := &awsec2.DescribeSecurityGroupsInput{ + GroupNames: aws.StringSlice([]string{"fargate-default"}), + } + output := &awsec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []*awsec2.SecurityGroup{securityGroup}, + } + + mockEC2Client := sdk.NewMockEC2API(mockCtrl) + ec2 := SDKClient{client: mockEC2Client} + + mockEC2Client.EXPECT().DescribeSecurityGroups(input).Return(output, nil) + + out, err := ec2.SetDefaultSecurityGroupID() + + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if out != securityGroupID { + t.Errorf("expected %s, got %s", securityGroupID, out) + } + +} + +func TestSetDefaultSecurityGroupIDWithErrorFromDescribeSecurityGroups(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + securityGroupID := "sg-abcdef" + securityGroup := &awsec2.SecurityGroup{ + GroupId: aws.String(securityGroupID), + } + input := &awsec2.DescribeSecurityGroupsInput{ + GroupNames: aws.StringSlice([]string{"fargate-default"}), + } + output := &awsec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []*awsec2.SecurityGroup{securityGroup}, + } + + mockEC2Client := sdk.NewMockEC2API(mockCtrl) + ec2 := SDKClient{client: mockEC2Client} + + mockEC2Client.EXPECT().DescribeSecurityGroups(input).Return(output, errors.New("boom")) + + out, err := ec2.SetDefaultSecurityGroupID() + + if out != "" { + t.Errorf("expected no result, got %v", out) + } + + if err == nil { + t.Errorf("expected error, got none") + } + +} + +func TestSetDefaultSecurityGroupIDWithSGCreate(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockEC2Client := sdk.NewMockEC2API(mockCtrl) + ec2 := SDKClient{client: mockEC2Client} + + securityGroupID := "sg-abcdef" + inputDescribeSecurityGroups := &awsec2.DescribeSecurityGroupsInput{ + GroupNames: aws.StringSlice([]string{"fargate-default"}), + } + errorDescribeSecurityGroups := awserr.New("InvalidGroup.NotFound", "Group not found", errors.New("boom")) + mockEC2Client.EXPECT().DescribeSecurityGroups(inputDescribeSecurityGroups).Return(&awsec2.DescribeSecurityGroupsOutput{}, errorDescribeSecurityGroups) + + inputCreateSecurityGroup := &awsec2.CreateSecurityGroupInput{ + GroupName: aws.String("fargate-default"), + Description: aws.String("Default Fargate CLI SG"), + } + outputCreateSecurityGroup := &awsec2.CreateSecurityGroupOutput{ + GroupId: aws.String(securityGroupID), + } + mockEC2Client.EXPECT().CreateSecurityGroup(inputCreateSecurityGroup).Return(outputCreateSecurityGroup, nil) + + inputAuthorizeSecurityGroupIngress := &awsec2.AuthorizeSecurityGroupIngressInput{ + CidrIp: aws.String("0.0.0.0/0"), + GroupId: aws.String(securityGroupID), + IpProtocol: aws.String("-1"), + } + mockEC2Client.EXPECT().AuthorizeSecurityGroupIngress(inputAuthorizeSecurityGroupIngress).Return(&awsec2.AuthorizeSecurityGroupIngressOutput{}, nil) + + out, err := ec2.SetDefaultSecurityGroupID() + + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if out != securityGroupID { + t.Errorf("expected %s, got %s", securityGroupID, out) + } +} + +func TestSetDefaultSecurityGroupIDWithSGCreateErrorFromCreateSecurityGroup(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockEC2Client := sdk.NewMockEC2API(mockCtrl) + ec2 := SDKClient{client: mockEC2Client} + + securityGroupID := "sg-abcdef" + inputDescribeSecurityGroups := &awsec2.DescribeSecurityGroupsInput{ + GroupNames: aws.StringSlice([]string{"fargate-default"}), + } + errorDescribeSecurityGroups := awserr.New("InvalidGroup.NotFound", "Group not found", errors.New("boom")) + mockEC2Client.EXPECT().DescribeSecurityGroups(inputDescribeSecurityGroups).Return(&awsec2.DescribeSecurityGroupsOutput{}, errorDescribeSecurityGroups) + + inputCreateSecurityGroup := &awsec2.CreateSecurityGroupInput{ + GroupName: aws.String("fargate-default"), + Description: aws.String("Default Fargate CLI SG"), + } + outputCreateSecurityGroup := &awsec2.CreateSecurityGroupOutput{ + GroupId: aws.String(securityGroupID), + } + mockEC2Client.EXPECT().CreateSecurityGroup(inputCreateSecurityGroup).Return(outputCreateSecurityGroup, errors.New("boom")) + + out, err := ec2.SetDefaultSecurityGroupID() + + if out != "" { + t.Errorf("expected no result, got %v", out) + } + + if err == nil { + t.Errorf("expected error, got none") + } +} + +func TestSetDefaultSecurityGroupIDWithSGCreateErrorFromAuthorizeSecurityGroupIngress(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockEC2Client := sdk.NewMockEC2API(mockCtrl) + ec2 := SDKClient{client: mockEC2Client} + + securityGroupID := "sg-abcdef" + inputDescribeSecurityGroups := &awsec2.DescribeSecurityGroupsInput{ + GroupNames: aws.StringSlice([]string{"fargate-default"}), + } + errorDescribeSecurityGroups := awserr.New("InvalidGroup.NotFound", "Group not found", errors.New("boom")) + mockEC2Client.EXPECT().DescribeSecurityGroups(inputDescribeSecurityGroups).Return(&awsec2.DescribeSecurityGroupsOutput{}, errorDescribeSecurityGroups) + + inputCreateSecurityGroup := &awsec2.CreateSecurityGroupInput{ + GroupName: aws.String("fargate-default"), + Description: aws.String("Default Fargate CLI SG"), + } + outputCreateSecurityGroup := &awsec2.CreateSecurityGroupOutput{ + GroupId: aws.String(securityGroupID), + } + mockEC2Client.EXPECT().CreateSecurityGroup(inputCreateSecurityGroup).Return(outputCreateSecurityGroup, nil) + + inputAuthorizeSecurityGroupIngress := &awsec2.AuthorizeSecurityGroupIngressInput{ + CidrIp: aws.String("0.0.0.0/0"), + GroupId: aws.String(securityGroupID), + IpProtocol: aws.String("-1"), + } + mockEC2Client.EXPECT().AuthorizeSecurityGroupIngress(inputAuthorizeSecurityGroupIngress).Return(&awsec2.AuthorizeSecurityGroupIngressOutput{}, errors.New("boom")) + + out, err := ec2.SetDefaultSecurityGroupID() + + if out != securityGroupID { + t.Errorf("expected %s, got %s", securityGroupID, out) + } + + if err == nil { + t.Errorf("expected error, got none") + } +} + func TestGetSubnetVPCID(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish()