From e4f02d8bcc24dd515a51117d74f00e5c6b73eaba Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 08:18:53 -0800 Subject: [PATCH] Update LoginCommand to match SwitchCommand --- command/get.go | 60 ++++++++++++++++++++++---------------------- command/login.go | 58 +++++++++++++++++++++--------------------- internal/api/json.go | 11 ++++---- 3 files changed, 64 insertions(+), 65 deletions(-) diff --git a/command/get.go b/command/get.go index 09a92dbe..30bdc398 100644 --- a/command/get.go +++ b/command/get.go @@ -56,8 +56,6 @@ func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Ac } type GetCommand struct { - Config *Config - Args []string TimeToLive uint TimeRemaining uint @@ -108,37 +106,35 @@ func (g GetCommand) printUsage() error { return g.UsageFunc() } -func (g GetCommand) Execute(ctx context.Context) error { - if HasTokenExpired(g.Config.Tokens) { - if g.Login { - login := LoginCommand{ - Config: g.Config, - OIDCDomain: g.OIDCDomain, - ClientID: g.ClientID, - MachineOutput: ShouldUseMachineOutput(g.Flags) || g.URLOnly, - NoBrowser: g.NoBrowser, - } - - if err := login.Execute(ctx); err != nil { - return err - } - } else { +func (g GetCommand) Execute(ctx context.Context, config *Config) error { + if HasTokenExpired(config.Tokens) { + if !g.Login { return ErrTokensExpiredOrAbsent } - return nil + + loginCommand := LoginCommand{ + OIDCDomain: g.OIDCDomain, + ClientID: g.ClientID, + MachineOutput: ShouldUseMachineOutput(g.Flags) || g.URLOnly, + NoBrowser: g.NoBrowser, + } + + if err := loginCommand.Execute(ctx, config); err != nil { + return err + } } var accountID string if len(g.Args) > 0 { accountID = g.Args[0] - } else if g.Config.LastUsedAccount != nil { + } else if config.LastUsedAccount != nil { // No account specified. Can we use the most recent one? - accountID = *g.Config.LastUsedAccount + accountID = *config.LastUsedAccount } else { return g.printUsage() } - account, ok := resolveApplicationInfo(g.Config, g.BypassCache, accountID) + account, ok := resolveApplicationInfo(config, g.BypassCache, accountID) if !ok { return UnknownAccountError(g.Args[0], FlagBypassCache) } @@ -151,13 +147,13 @@ func (g GetCommand) Execute(ctx context.Context) error { g.RoleName = account.MostRecentRole } - if g.Config.TimeRemaining != 0 && g.TimeRemaining == DefaultTimeRemaining { - g.TimeRemaining = g.Config.TimeRemaining + if config.TimeRemaining != 0 && g.TimeRemaining == DefaultTimeRemaining { + g.TimeRemaining = config.TimeRemaining } credentials := LoadAWSCredentialsFromEnvironment() if !credentials.ValidUntil(account, time.Duration(g.TimeRemaining)*time.Minute) { - newCredentials, err := g.fetchNewCredentials(ctx, *account) + newCredentials, err := g.fetchNewCredentials(ctx, *account, config) if err != nil { return err } @@ -168,12 +164,12 @@ func (g GetCommand) Execute(ctx context.Context) error { account.MostRecentRole = g.RoleName } - g.Config.LastUsedAccount = &accountID + config.LastUsedAccount = &accountID return echoCredentials(accountID, accountID, credentials, g.OutputType, g.ShellType, g.AWSCLIPath) } -func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account) (*CloudCredentials, error) { - samlResponse, assertionStr, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(ctx, g.Config.Tokens.AccessToken, g.Config.Tokens.IDToken, g.OIDCDomain, g.ClientID, account.ID) +func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, cfg *Config) (*CloudCredentials, error) { + samlResponse, assertionStr, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(ctx, cfg.Tokens.AccessToken, cfg.Tokens.IDToken, g.OIDCDomain, g.ClientID, account.ID) if err != nil { return nil, err } @@ -183,8 +179,8 @@ func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account) (* return nil, UnknownRoleError(g.RoleName, g.Args[0]) } - if g.TimeToLive == 1 && g.Config.TTL != 0 { - g.TimeToLive = g.Config.TTL + if g.TimeToLive == 1 && cfg.TTL != 0 { + g.TimeToLive = cfg.TTL } awsCfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(g.Region)) @@ -232,7 +228,11 @@ A role must be specified when using this command through the --role flag. You ma return err } - return getCmd.Execute(cmd.Context()) + if err := getCmd.Validate(); err != nil { + return err + } + + return getCmd.Execute(cmd.Context(), ConfigFromCommand(cmd)) }, } diff --git a/command/login.go b/command/login.go index 27d8162a..b6c69233 100644 --- a/command/login.go +++ b/command/login.go @@ -25,51 +25,51 @@ func init() { loginCmd.Flags().BoolP(FlagNoBrowser, "b", false, "Do not open a browser window, printing the URL instead") } -// ShouldUseMachineOutput indicates whether or not we should write to standard output as if the user is a machine. -// -// What this means is implementation specific, but this usually indicates the user is trying to use this program in a script and we should avoid user-friendly output messages associated with values a user might find useful. -func ShouldUseMachineOutput(flags *pflag.FlagSet) bool { - quiet, _ := flags.GetBool(FlagQuiet) - fi, _ := os.Stdout.Stat() - isPiped := fi.Mode()&os.ModeCharDevice == 0 - return isPiped || quiet -} - var loginCmd = &cobra.Command{ Use: "login", Short: "Authenticate with KeyConjurer.", Long: "Login to KeyConjurer using OAuth2. You will be required to open the URL printed to the console or scan a QR code.", RunE: func(cmd *cobra.Command, args []string) error { - config := ConfigFromCommand(cmd) - if !HasTokenExpired(config.Tokens) { - return nil + var loginCmd LoginCommand + if err := loginCmd.Parse(cmd.Flags(), args); err != nil { + return err } - oidcDomain, _ := cmd.Flags().GetString(FlagOIDCDomain) - clientID, _ := cmd.Flags().GetString(FlagClientID) - urlOnly, _ := cmd.Flags().GetBool(FlagURLOnly) - noBrowser, _ := cmd.Flags().GetBool(FlagNoBrowser) - command := LoginCommand{ - Config: config, - OIDCDomain: oidcDomain, - ClientID: clientID, - MachineOutput: ShouldUseMachineOutput(cmd.Flags()) || urlOnly, - NoBrowser: noBrowser, - } - - return command.Execute(cmd.Context()) + return loginCmd.Execute(cmd.Context(), ConfigFromCommand(cmd)) }, } +// ShouldUseMachineOutput indicates whether or not we should write to standard output as if the user is a machine. +// +// What this means is implementation specific, but this usually indicates the user is trying to use this program in a script and we should avoid user-friendly output messages associated with values a user might find useful. +func ShouldUseMachineOutput(flags *pflag.FlagSet) bool { + quiet, _ := flags.GetBool(FlagQuiet) + fi, _ := os.Stdout.Stat() + isPiped := fi.Mode()&os.ModeCharDevice == 0 + return isPiped || quiet +} + type LoginCommand struct { - Config *Config OIDCDomain string ClientID string MachineOutput bool NoBrowser bool } -func (c LoginCommand) Execute(ctx context.Context) error { +func (c *LoginCommand) Parse(flags *pflag.FlagSet, args []string) error { + c.OIDCDomain, _ = flags.GetString(FlagOIDCDomain) + c.ClientID, _ = flags.GetString(FlagClientID) + c.NoBrowser, _ = flags.GetBool(FlagNoBrowser) + urlOnly, _ := flags.GetBool(FlagURLOnly) + c.MachineOutput = ShouldUseMachineOutput(flags) || urlOnly + return nil +} + +func (c LoginCommand) Execute(ctx context.Context, config *Config) error { + if !HasTokenExpired(config.Tokens) { + return nil + } + oauthCfg, err := oauth2.DiscoverConfig(ctx, c.OIDCDomain, c.ClientID) if err != nil { return err @@ -111,7 +111,7 @@ func (c LoginCommand) Execute(ctx context.Context) error { return fmt.Errorf("id_token not found in token response") } - return c.Config.SaveOAuthToken(accessToken, idToken) + return config.SaveOAuthToken(accessToken, idToken) } var ErrNoPortsAvailable = errors.New("no ports available") diff --git a/internal/api/json.go b/internal/api/json.go index 0b28abf9..8a00cee6 100644 --- a/internal/api/json.go +++ b/internal/api/json.go @@ -20,12 +20,11 @@ func ServeJSON[T any](w *events.ALBTargetGroupResponse, data T) { w.Body = string(buf) } -func ServeJSONError(w *events.ALBTargetGroupResponse, statusCode int, msg string) { - var jsonError struct { - Message string `json:"error"` - } +type JSONError struct { + Message string `json:"error"` +} - jsonError.Message = msg +func ServeJSONError(w *events.ALBTargetGroupResponse, statusCode int, msg string) { w.StatusCode = statusCode - ServeJSON(w, jsonError) + ServeJSON(w, JSONError{Message: msg}) }