diff --git a/command/switch.go b/command/switch.go index 8deec003..35c0e08b 100644 --- a/command/switch.go +++ b/command/switch.go @@ -13,6 +13,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/spf13/cobra" + "github.com/spf13/pflag" ) var ( @@ -42,45 +43,75 @@ This command will fail if you do not have active Cloud credentials. Args: cobra.ExactArgs(1), Aliases: []string{"switch-account"}, RunE: func(cmd *cobra.Command, args []string) error { - outputType, _ := cmd.Flags().GetString(FlagOutputType) - shellType, _ := cmd.Flags().GetString(FlagShellType) - awsCliPath, _ := cmd.Flags().GetString(FlagAWSCLIPath) - if !slices.Contains(permittedOutputTypes, outputType) { - return ValueError{Value: outputType, ValidValues: permittedOutputTypes} - } - - if !slices.Contains(permittedShellTypes, shellType) { - return ValueError{Value: shellType, ValidValues: permittedShellTypes} + var switchCmd SwitchCommand + if err := switchCmd.Parse(cmd.Flags(), args); err != nil { + return err } - // We could read the environment variable for the assumed role ARN, but it might be expired which isn't very useful to the user. - var err error - var creds CloudCredentials - sessionName, _ := cmd.Flags().GetString(FlagRoleSessionName) - - creds, err = getAWSCredentials(args[0], sessionName) - if err != nil { - // If this failed, either there was a network error or the user is not authorized to assume into this role - // This can happen if the user is not authenticated using the Bastion instance. + if err := switchCmd.Validate(); err != nil { return err } - switch outputType { - case outputTypeEnvironmentVariable: - creds.WriteFormat(os.Stdout, shellType) - return nil - case outputTypeAWSCredentialsFile: - acc := Account{ID: args[0], Name: args[0]} - newCliEntry := NewCloudCliEntry(creds, &acc) - return SaveCloudCredentialInCLI(awsCliPath, newCliEntry) - default: - return fmt.Errorf("%s is an invalid output type", outputType) - } + return switchCmd.Execute(cmd.Context()) }, } -func getAWSCredentials(accountID, roleSessionName string) (creds CloudCredentials, err error) { - ctx := context.Background() +type SwitchCommand struct { + OutputType string + ShellType string + AWSCLIPath string + RoleSessionName string + AccountID string +} + +func (s *SwitchCommand) Parse(flags *pflag.FlagSet, args []string) error { + s.OutputType, _ = flags.GetString(FlagOutputType) + s.ShellType, _ = flags.GetString(FlagShellType) + s.AWSCLIPath, _ = flags.GetString(FlagAWSCLIPath) + s.RoleSessionName, _ = flags.GetString(FlagRoleSessionName) + if len(args) == 0 { + return fmt.Errorf("account-id is required") + } + + s.AccountID = args[0] + return nil +} + +func (s SwitchCommand) Validate() error { + if !slices.Contains(permittedOutputTypes, s.OutputType) { + return ValueError{Value: s.OutputType, ValidValues: permittedOutputTypes} + } + + if !slices.Contains(permittedShellTypes, s.ShellType) { + return ValueError{Value: s.ShellType, ValidValues: permittedShellTypes} + } + + return nil +} + +func (s SwitchCommand) Execute(ctx context.Context) error { + // We could read the environment variable for the assumed role ARN, but it might be expired which isn't very useful to the user. + creds, err := getAWSCredentials(ctx, s.AccountID, s.RoleSessionName) + if err != nil { + // If this failed, either there was a network error or the user is not authorized to assume into this role + // This can happen if the user is not authenticated using the Bastion instance. + return err + } + + switch s.OutputType { + case outputTypeEnvironmentVariable: + creds.WriteFormat(os.Stdout, s.ShellType) + return nil + case outputTypeAWSCredentialsFile: + acc := Account{ID: s.AccountID, Name: s.AccountID} + newCliEntry := NewCloudCliEntry(creds, &acc) + return SaveCloudCredentialInCLI(s.AWSCLIPath, newCliEntry) + default: + return fmt.Errorf("%s is an invalid output type", s.OutputType) + } +} + +func getAWSCredentials(ctx context.Context, accountID, roleSessionName string) (creds CloudCredentials, err error) { cfg, err := config.LoadDefaultConfig(ctx) if err != nil { return