Skip to content

Commit

Permalink
Convert Switch command to be struct-based
Browse files Browse the repository at this point in the history
  • Loading branch information
punmechanic committed Nov 12, 2024
1 parent 40b1f72 commit 5abd257
Showing 1 changed file with 62 additions and 31 deletions.
93 changes: 62 additions & 31 deletions command/switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)

var (
Expand Down Expand Up @@ -42,45 +43,75 @@ This command will fail if you do not have active Cloud credentials.
Args: cobra.ExactArgs(1),
Aliases: []string{"switch-account"},
RunE: func(cmd *cobra.Command, args []string) error {
outputType, _ := cmd.Flags().GetString(FlagOutputType)
shellType, _ := cmd.Flags().GetString(FlagShellType)
awsCliPath, _ := cmd.Flags().GetString(FlagAWSCLIPath)
if !slices.Contains(permittedOutputTypes, outputType) {
return ValueError{Value: outputType, ValidValues: permittedOutputTypes}
}

if !slices.Contains(permittedShellTypes, shellType) {
return ValueError{Value: shellType, ValidValues: permittedShellTypes}
var switchCmd SwitchCommand
if err := switchCmd.Parse(cmd.Flags(), args); err != nil {
return err
}

// We could read the environment variable for the assumed role ARN, but it might be expired which isn't very useful to the user.
var err error
var creds CloudCredentials
sessionName, _ := cmd.Flags().GetString(FlagRoleSessionName)

creds, err = getAWSCredentials(args[0], sessionName)
if err != nil {
// If this failed, either there was a network error or the user is not authorized to assume into this role
// This can happen if the user is not authenticated using the Bastion instance.
if err := switchCmd.Validate(); err != nil {
return err
}

switch outputType {
case outputTypeEnvironmentVariable:
creds.WriteFormat(os.Stdout, shellType)
return nil
case outputTypeAWSCredentialsFile:
acc := Account{ID: args[0], Name: args[0]}
newCliEntry := NewCloudCliEntry(creds, &acc)
return SaveCloudCredentialInCLI(awsCliPath, newCliEntry)
default:
return fmt.Errorf("%s is an invalid output type", outputType)
}
return switchCmd.Execute(cmd.Context())
},
}

func getAWSCredentials(accountID, roleSessionName string) (creds CloudCredentials, err error) {
ctx := context.Background()
type SwitchCommand struct {
OutputType string
ShellType string
AWSCLIPath string
RoleSessionName string
AccountID string
}

func (s *SwitchCommand) Parse(flags *pflag.FlagSet, args []string) error {
s.OutputType, _ = flags.GetString(FlagOutputType)
s.ShellType, _ = flags.GetString(FlagShellType)
s.AWSCLIPath, _ = flags.GetString(FlagAWSCLIPath)
s.RoleSessionName, _ = flags.GetString(FlagRoleSessionName)
if len(args) == 0 {
return fmt.Errorf("account-id is required")
}

s.AccountID = args[0]
return nil
}

func (s SwitchCommand) Validate() error {
if !slices.Contains(permittedOutputTypes, s.OutputType) {
return ValueError{Value: s.OutputType, ValidValues: permittedOutputTypes}
}

if !slices.Contains(permittedShellTypes, s.ShellType) {
return ValueError{Value: s.ShellType, ValidValues: permittedShellTypes}
}

return nil
}

func (s SwitchCommand) Execute(ctx context.Context) error {
// We could read the environment variable for the assumed role ARN, but it might be expired which isn't very useful to the user.
creds, err := getAWSCredentials(ctx, s.AccountID, s.RoleSessionName)
if err != nil {
// If this failed, either there was a network error or the user is not authorized to assume into this role
// This can happen if the user is not authenticated using the Bastion instance.
return err
}

switch s.OutputType {
case outputTypeEnvironmentVariable:
creds.WriteFormat(os.Stdout, s.ShellType)
return nil
case outputTypeAWSCredentialsFile:
acc := Account{ID: s.AccountID, Name: s.AccountID}
newCliEntry := NewCloudCliEntry(creds, &acc)
return SaveCloudCredentialInCLI(s.AWSCLIPath, newCliEntry)
default:
return fmt.Errorf("%s is an invalid output type", s.OutputType)
}
}

func getAWSCredentials(ctx context.Context, accountID, roleSessionName string) (creds CloudCredentials, err error) {
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return
Expand Down

0 comments on commit 5abd257

Please sign in to comment.