Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Create default SG for Task Run and Service Create #113

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 4 additions & 45 deletions cmd/lb_create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import (
"reflect"
"testing"

"github.com/golang/mock/gomock"
"github.com/awslabs/fargatecli/acm"
acmclient "github.com/awslabs/fargatecli/acm/mock/client"
"github.com/awslabs/fargatecli/cmd/mock"
ec2client "github.com/awslabs/fargatecli/ec2/mock/client"
"github.com/awslabs/fargatecli/elbv2"
elbv2client "github.com/awslabs/fargatecli/elbv2/mock/client"
"github.com/golang/mock/gomock"
)

var (
Expand Down Expand Up @@ -756,48 +756,7 @@ func TestNewLBCreateOperationDefaults(t *testing.T) {

mockEC2.EXPECT().GetSubnetVPCID("subnet-1234567").Return("vpc-1234567", nil)
mockEC2.EXPECT().GetDefaultSubnetIDs().Return([]string{"subnet-1234567", "subnet-abcdef"}, nil)
mockEC2.EXPECT().GetDefaultSecurityGroupID().Return("sg-abcdef", nil)

o, errs := newLBCreateOperation(
"web",
"internet-facing",
[]string{},
[]string{"80"},
[]string{},
[]string{},
mockOutput,
mockACM,
mockEC2,
mockELBV2,
)

if len(errs) > 0 {
t.Fatalf("expected no error, got: %v", errs)
}

if o.securityGroupIDs[0] != "sg-abcdef" {
t.Errorf("expected security group ID == sg-abcdef, got: %v", o.securityGroupIDs)
}

if o.subnetIDs[0] != "subnet-1234567" || o.subnetIDs[1] != "subnet-abcdef" {
t.Errorf("expected subnet ID == subnet-1234567, subnet-abcdef, got: %v", o.subnetIDs)
}
}

func TestNewLBCreateOperationDefaultsWithSGCreate(t *testing.T) {
mockOutput := &mock.Output{}
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mockEC2 := ec2client.NewMockClient(mockCtrl)
mockACM := acmclient.NewMockClient(mockCtrl)
mockELBV2 := elbv2client.NewMockClient(mockCtrl)

mockEC2.EXPECT().GetSubnetVPCID("subnet-1234567").Return("vpc-1234567", nil)
mockEC2.EXPECT().GetDefaultSubnetIDs().Return([]string{"subnet-1234567", "subnet-abcdef"}, nil)
mockEC2.EXPECT().GetDefaultSecurityGroupID().Return("", nil)
mockEC2.EXPECT().CreateDefaultSecurityGroup().Return("sg-abcdef", nil)
mockEC2.EXPECT().AuthorizeAllSecurityGroupIngress("sg-abcdef").Return(nil)
mockEC2.EXPECT().SetDefaultSecurityGroupID().Return("sg-abcdef", nil) //SGCreate fallback is tested in vpc_test.go

o, errs := newLBCreateOperation(
"web",
Expand Down Expand Up @@ -985,7 +944,7 @@ func TestNewLBCreateOperationUseDefaultSG(t *testing.T) {
defer mockCtrl.Finish()

ec2 := ec2client.NewMockClient(mockCtrl)
ec2.EXPECT().GetDefaultSecurityGroupID().Return("sg-1234567", nil)
ec2.EXPECT().SetDefaultSecurityGroupID().Return("sg-1234567", nil)
ec2.EXPECT().GetSubnetVPCID(gomock.Any()).Return("vpc-1234567", nil)

o, err := newLBCreateOperation(
Expand Down Expand Up @@ -1016,7 +975,7 @@ func TestNewLBCreateOperationDefaultSGError(t *testing.T) {
defer mockCtrl.Finish()

ec2 := ec2client.NewMockClient(mockCtrl)
ec2.EXPECT().GetDefaultSecurityGroupID().Return("", errors.New("boom"))
ec2.EXPECT().SetDefaultSecurityGroupID().Return("", errors.New("boom"))
ec2.EXPECT().GetSubnetVPCID(gomock.Any()).Return("vpc-1234567", nil)

_, errs := newLBCreateOperation(
Expand Down
61 changes: 39 additions & 22 deletions cmd/service_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@ import (
const typeService = "service"

type ServiceCreateOperation struct {
Cpu string
EnvVars []ECS.EnvVar
Image string
LoadBalancerArn string
LoadBalancerName string
Memory string
Num int64
Port Port
Rules []ELBV2.Rule
SecurityGroupIds []string
ServiceName string
SubnetIds []string
TaskRole string
TaskCommand []string
AssignPublicIPEnabled bool
Cpu string
EnvVars []ECS.EnvVar
Image string
LoadBalancerArn string
LoadBalancerName string
Memory string
Num int64
Port Port
Rules []ELBV2.Rule
SecurityGroupIds []string
ServiceName string
SubnetIds []string
TaskRole string
TaskCommand []string
AssignPublicIPEnabled bool
output Output
}

func (o *ServiceCreateOperation) SetPort(inputPort string) {
Expand Down Expand Up @@ -149,7 +150,7 @@ var (
flagServiceCreateSecurityGroupIds []string
flagServiceCreateSubnetIds []string
flagServiceCreateTaskRole string
flagServiceCreateTaskCommand []string
flagServiceCreateTaskCommand []string
flagServiceAssignPublicIP bool
)

Expand Down Expand Up @@ -241,8 +242,9 @@ Services can be configured to have only private ip address via the
ServiceName: args[0],
SubnetIds: flagServiceCreateSubnetIds,
TaskRole: flagServiceCreateTaskRole,
TaskCommand: flagServiceCreateTaskCommand,
TaskCommand: flagServiceCreateTaskCommand,
AssignPublicIPEnabled: flagServiceAssignPublicIP,
output: output,
}

if flagServiceCreatePort != "" {
Expand All @@ -262,7 +264,12 @@ Services can be configured to have only private ip address via the
}

operation.Validate()
createService(operation)
errs := createService(operation)

if len(errs) > 0 {
output.Fatals(errs, "Errors found while executing [COMMAND=service Action=create]")
return
}
},
}

Expand All @@ -279,11 +286,11 @@ func init() {
serviceCreateCmd.Flags().StringSliceVar(&flagServiceCreateSubnetIds, "subnet-id", []string{}, "ID of a subnet in which to place the service (can be specified multiple times)")
serviceCreateCmd.Flags().StringVarP(&flagServiceCreateTaskRole, "task-role", "", "", "Name or ARN of an IAM role that the service's tasks can assume")
serviceCreateCmd.Flags().StringSliceVar(&flagServiceCreateTaskCommand, "task-command", []string{}, "Command to run inside container instead of the one specified in the docker image")
serviceCreateCmd.Flags().BoolVarP(&flagServiceAssignPublicIP, "assign-public-ip", "", true, "Assign public ip address")
serviceCreateCmd.Flags().BoolVarP(&flagServiceAssignPublicIP, "assign-public-ip", "", true, "Assign public ip address")
serviceCmd.AddCommand(serviceCreateCmd)
}

func createService(operation *ServiceCreateOperation) {
func createService(operation *ServiceCreateOperation) (errors []error) {
var targetGroupArn string

cwl := CWL.New(sess)
Expand All @@ -296,12 +303,20 @@ func createService(operation *ServiceCreateOperation) {
logGroupName := cwl.CreateLogGroup(serviceLogGroupFormat, operation.ServiceName)

if len(operation.SecurityGroupIds) == 0 {
defaultSecurityGroupID, _ := ec2.GetDefaultSecurityGroupID()
operation.output.Debug("Find the default security group, and creates it if it does not exist [COMMAND=task Action=run]")
defaultSecurityGroupID, err := ec2.SetDefaultSecurityGroupID()
if err != nil {
errors = append(errors, err)
}
operation.SecurityGroupIds = []string{defaultSecurityGroupID}
}

if len(operation.SubnetIds) == 0 {
operation.SubnetIds, _ = ec2.GetDefaultSubnetIDs()
var err error
operation.SubnetIds, err = ec2.GetDefaultSubnetIDs()
if err != nil {
errors = append(errors, err)
}
}

if operation.Image == "" {
Expand Down Expand Up @@ -381,4 +396,6 @@ func createService(operation *ServiceCreateOperation) {
)

console.Info("Created service %s", operation.ServiceName)

return
}
25 changes: 21 additions & 4 deletions cmd/task_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type TaskRunOperation struct {
TaskName string
TaskRole string
TaskCommand []string
output Output
}

func (o *TaskRunOperation) Validate() {
Expand Down Expand Up @@ -125,12 +126,18 @@ the requirements of the docker CMD syntax`,
TaskName: args[0],
TaskRole: flagTaskRunTaskRole,
TaskCommand: flagTaskRunTaskCommand,
output: output,
}

operation.SetEnvVars(flagTaskRunEnvVars)
operation.Validate()

runTask(operation)
errs := runTask(operation)

if len(errs) > 0 {
output.Fatals(errs, "Errors found while executing [COMMAND=task Action=run]")
return
}
},
}

Expand All @@ -147,7 +154,7 @@ func init() {
taskCmd.AddCommand(taskRunCmd)
}

func runTask(operation *TaskRunOperation) {
func runTask(operation *TaskRunOperation) (errors []error) {
cwl := CWL.New(sess)
ec2 := EC2.New(sess)
ecr := ECR.New(sess)
Expand All @@ -157,12 +164,20 @@ func runTask(operation *TaskRunOperation) {
logGroupName := cwl.CreateLogGroup(taskLogGroupFormat, operation.TaskName)

if len(operation.SecurityGroupIds) == 0 {
defaultSecurityGroupID, _ := ec2.GetDefaultSecurityGroupID()
operation.output.Debug("Find the default security group, and creates it if it does not exist [COMMAND=task Action=run]")
defaultSecurityGroupID, err := ec2.SetDefaultSecurityGroupID()
if err != nil {
errors = append(errors, err)
}
operation.SecurityGroupIds = []string{defaultSecurityGroupID}
}

if len(operation.SubnetIds) == 0 {
operation.SubnetIds, _ = ec2.GetDefaultSubnetIDs()
var err error
operation.SubnetIds, err = ec2.GetDefaultSubnetIDs()
if err != nil {
errors = append(errors, err)
}
}

if operation.Image == "" {
Expand Down Expand Up @@ -218,4 +233,6 @@ func runTask(operation *TaskRunOperation) {
)

console.Info("Running task %s", operation.TaskName)

return
}
20 changes: 2 additions & 18 deletions cmd/vpc_operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,13 @@ func (o *vpcOperation) setSecurityGroupIDs(securityGroupIDs []string) {
}

func (o *vpcOperation) setDefaultSecurityGroupID() error {
o.output.Debug("Finding default security group [API=ec2 Action=DescribeSecurityGroups]")
defaultSecurityGroupID, err := o.ec2.GetDefaultSecurityGroupID()

// setting of default security group id is delegated to the ec2 module
defaultSecurityGroupID, err := o.ec2.SetDefaultSecurityGroupID()
if err != nil {
return err
}

if defaultSecurityGroupID == "" {
o.output.Debug("Creating default security group [API=ec2 Action=CreateSecurityGroup]")
defaultSecurityGroupID, err = o.ec2.CreateDefaultSecurityGroup()

if err != nil {
return err
}

o.output.Debug("Created default security group [ID=%s]", defaultSecurityGroupID)

o.output.Debug("Configuring default security group [API=ec2 Action=AuthorizeSecurityGroupIngress]")
if err := o.ec2.AuthorizeAllSecurityGroupIngress(defaultSecurityGroupID); err != nil {
return err
}
}

o.securityGroupIDs = []string{defaultSecurityGroupID}

return nil
Expand Down
44 changes: 5 additions & 39 deletions cmd/vpc_operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"errors"
"testing"

"github.com/golang/mock/gomock"
"github.com/awslabs/fargatecli/cmd/mock"
ec2client "github.com/awslabs/fargatecli/ec2/mock/client"
"github.com/golang/mock/gomock"
)

func TestSetSubnetIDs(t *testing.T) {
Expand Down Expand Up @@ -74,7 +74,7 @@ func TestSetDefaultSecurityGroupID(t *testing.T) {
mockEC2Client := ec2client.NewMockClient(mockCtrl)
mockOutput := &mock.Output{}

mockEC2Client.EXPECT().GetDefaultSecurityGroupID().Return("sg-1234567", nil)
mockEC2Client.EXPECT().SetDefaultSecurityGroupID().Return("sg-1234567", nil) //SGCreate fallback is tested in vpc_test.go

operation := vpcOperation{
ec2: mockEC2Client,
Expand Down Expand Up @@ -103,7 +103,7 @@ func TestSetDefaultSecurityGroupIDLookupError(t *testing.T) {
mockEC2Client := ec2client.NewMockClient(mockCtrl)
mockOutput := &mock.Output{}

mockEC2Client.EXPECT().GetDefaultSecurityGroupID().Return("", errors.New("boom"))
mockEC2Client.EXPECT().SetDefaultSecurityGroupID().Return("", errors.New("boom"))

operation := vpcOperation{
ec2: mockEC2Client,
Expand All @@ -121,46 +121,14 @@ func TestSetDefaultSecurityGroupIDLookupError(t *testing.T) {
}
}

func TestSetDefaultSecurityGroupIDWithCreate(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mockEC2Client := ec2client.NewMockClient(mockCtrl)
mockOutput := &mock.Output{}

mockEC2Client.EXPECT().GetDefaultSecurityGroupID().Return("", nil)
mockEC2Client.EXPECT().CreateDefaultSecurityGroup().Return("sg-1234567", nil)
mockEC2Client.EXPECT().AuthorizeAllSecurityGroupIngress("sg-1234567").Return(nil)

operation := vpcOperation{
ec2: mockEC2Client,
output: mockOutput,
}

err := operation.setDefaultSecurityGroupID()

if err != nil {
t.Errorf("expected no error, got: %v", err)
}

if len(operation.securityGroupIDs) != 1 {
t.Fatalf("expected 1 security group ID, got: %d", len(operation.securityGroupIDs))
}

if expected := "sg-1234567"; operation.securityGroupIDs[0] != expected {
t.Errorf("expected: %s, got: %s", expected, operation.securityGroupIDs[0])
}
}

func TestSetDefaultSecurityGroupIDWithCreateError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mockEC2Client := ec2client.NewMockClient(mockCtrl)
mockOutput := &mock.Output{}

mockEC2Client.EXPECT().GetDefaultSecurityGroupID().Return("", nil)
mockEC2Client.EXPECT().CreateDefaultSecurityGroup().Return("", errors.New("boom"))
mockEC2Client.EXPECT().SetDefaultSecurityGroupID().Return("", errors.New("boom"))

operation := vpcOperation{
ec2: mockEC2Client,
Expand All @@ -185,9 +153,7 @@ func TestSetDefaultSecurityGroupIDWithAuthorizeError(t *testing.T) {
mockEC2Client := ec2client.NewMockClient(mockCtrl)
mockOutput := &mock.Output{}

mockEC2Client.EXPECT().GetDefaultSecurityGroupID().Return("", nil)
mockEC2Client.EXPECT().CreateDefaultSecurityGroup().Return("sg-1234567", nil)
mockEC2Client.EXPECT().AuthorizeAllSecurityGroupIngress("sg-1234567").Return(errors.New("boom"))
mockEC2Client.EXPECT().SetDefaultSecurityGroupID().Return("sg-1234567", errors.New("boom"))

operation := vpcOperation{
ec2: mockEC2Client,
Expand Down
1 change: 1 addition & 0 deletions ec2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Client interface {
GetDefaultSecurityGroupID() (string, error)
GetDefaultSubnetIDs() ([]string, error)
GetSubnetVPCID(string) (string, error)
SetDefaultSecurityGroupID() (string, error)
}

// SDKClient implements access to EC2 via the AWS SDK.
Expand Down
Loading