Skip to content

Commit

Permalink
Update LoginCommand to match SwitchCommand
Browse files Browse the repository at this point in the history
  • Loading branch information
punmechanic committed Nov 12, 2024
1 parent 5abd257 commit e4f02d8
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 65 deletions.
60 changes: 30 additions & 30 deletions command/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Ac
}

type GetCommand struct {
Config *Config

Args []string
TimeToLive uint
TimeRemaining uint
Expand Down Expand Up @@ -108,37 +106,35 @@ func (g GetCommand) printUsage() error {
return g.UsageFunc()
}

func (g GetCommand) Execute(ctx context.Context) error {
if HasTokenExpired(g.Config.Tokens) {
if g.Login {
login := LoginCommand{
Config: g.Config,
OIDCDomain: g.OIDCDomain,
ClientID: g.ClientID,
MachineOutput: ShouldUseMachineOutput(g.Flags) || g.URLOnly,
NoBrowser: g.NoBrowser,
}

if err := login.Execute(ctx); err != nil {
return err
}
} else {
func (g GetCommand) Execute(ctx context.Context, config *Config) error {
if HasTokenExpired(config.Tokens) {
if !g.Login {
return ErrTokensExpiredOrAbsent
}
return nil

loginCommand := LoginCommand{
OIDCDomain: g.OIDCDomain,
ClientID: g.ClientID,
MachineOutput: ShouldUseMachineOutput(g.Flags) || g.URLOnly,
NoBrowser: g.NoBrowser,
}

if err := loginCommand.Execute(ctx, config); err != nil {
return err
}
}

var accountID string
if len(g.Args) > 0 {
accountID = g.Args[0]
} else if g.Config.LastUsedAccount != nil {
} else if config.LastUsedAccount != nil {
// No account specified. Can we use the most recent one?
accountID = *g.Config.LastUsedAccount
accountID = *config.LastUsedAccount
} else {
return g.printUsage()
}

account, ok := resolveApplicationInfo(g.Config, g.BypassCache, accountID)
account, ok := resolveApplicationInfo(config, g.BypassCache, accountID)
if !ok {
return UnknownAccountError(g.Args[0], FlagBypassCache)
}
Expand All @@ -151,13 +147,13 @@ func (g GetCommand) Execute(ctx context.Context) error {
g.RoleName = account.MostRecentRole
}

if g.Config.TimeRemaining != 0 && g.TimeRemaining == DefaultTimeRemaining {
g.TimeRemaining = g.Config.TimeRemaining
if config.TimeRemaining != 0 && g.TimeRemaining == DefaultTimeRemaining {
g.TimeRemaining = config.TimeRemaining
}

credentials := LoadAWSCredentialsFromEnvironment()
if !credentials.ValidUntil(account, time.Duration(g.TimeRemaining)*time.Minute) {
newCredentials, err := g.fetchNewCredentials(ctx, *account)
newCredentials, err := g.fetchNewCredentials(ctx, *account, config)
if err != nil {
return err
}
Expand All @@ -168,12 +164,12 @@ func (g GetCommand) Execute(ctx context.Context) error {
account.MostRecentRole = g.RoleName
}

g.Config.LastUsedAccount = &accountID
config.LastUsedAccount = &accountID
return echoCredentials(accountID, accountID, credentials, g.OutputType, g.ShellType, g.AWSCLIPath)
}

func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account) (*CloudCredentials, error) {
samlResponse, assertionStr, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(ctx, g.Config.Tokens.AccessToken, g.Config.Tokens.IDToken, g.OIDCDomain, g.ClientID, account.ID)
func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, cfg *Config) (*CloudCredentials, error) {
samlResponse, assertionStr, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(ctx, cfg.Tokens.AccessToken, cfg.Tokens.IDToken, g.OIDCDomain, g.ClientID, account.ID)
if err != nil {
return nil, err
}
Expand All @@ -183,8 +179,8 @@ func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account) (*
return nil, UnknownRoleError(g.RoleName, g.Args[0])
}

if g.TimeToLive == 1 && g.Config.TTL != 0 {
g.TimeToLive = g.Config.TTL
if g.TimeToLive == 1 && cfg.TTL != 0 {
g.TimeToLive = cfg.TTL
}

awsCfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(g.Region))
Expand Down Expand Up @@ -232,7 +228,11 @@ A role must be specified when using this command through the --role flag. You ma
return err
}

return getCmd.Execute(cmd.Context())
if err := getCmd.Validate(); err != nil {
return err
}

return getCmd.Execute(cmd.Context(), ConfigFromCommand(cmd))
},
}

Expand Down
58 changes: 29 additions & 29 deletions command/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,51 +25,51 @@ func init() {
loginCmd.Flags().BoolP(FlagNoBrowser, "b", false, "Do not open a browser window, printing the URL instead")
}

// ShouldUseMachineOutput indicates whether or not we should write to standard output as if the user is a machine.
//
// What this means is implementation specific, but this usually indicates the user is trying to use this program in a script and we should avoid user-friendly output messages associated with values a user might find useful.
func ShouldUseMachineOutput(flags *pflag.FlagSet) bool {
quiet, _ := flags.GetBool(FlagQuiet)
fi, _ := os.Stdout.Stat()
isPiped := fi.Mode()&os.ModeCharDevice == 0
return isPiped || quiet
}

var loginCmd = &cobra.Command{
Use: "login",
Short: "Authenticate with KeyConjurer.",
Long: "Login to KeyConjurer using OAuth2. You will be required to open the URL printed to the console or scan a QR code.",
RunE: func(cmd *cobra.Command, args []string) error {
config := ConfigFromCommand(cmd)
if !HasTokenExpired(config.Tokens) {
return nil
var loginCmd LoginCommand
if err := loginCmd.Parse(cmd.Flags(), args); err != nil {
return err
}

oidcDomain, _ := cmd.Flags().GetString(FlagOIDCDomain)
clientID, _ := cmd.Flags().GetString(FlagClientID)
urlOnly, _ := cmd.Flags().GetBool(FlagURLOnly)
noBrowser, _ := cmd.Flags().GetBool(FlagNoBrowser)
command := LoginCommand{
Config: config,
OIDCDomain: oidcDomain,
ClientID: clientID,
MachineOutput: ShouldUseMachineOutput(cmd.Flags()) || urlOnly,
NoBrowser: noBrowser,
}

return command.Execute(cmd.Context())
return loginCmd.Execute(cmd.Context(), ConfigFromCommand(cmd))
},
}

// ShouldUseMachineOutput indicates whether or not we should write to standard output as if the user is a machine.
//
// What this means is implementation specific, but this usually indicates the user is trying to use this program in a script and we should avoid user-friendly output messages associated with values a user might find useful.
func ShouldUseMachineOutput(flags *pflag.FlagSet) bool {
quiet, _ := flags.GetBool(FlagQuiet)
fi, _ := os.Stdout.Stat()
isPiped := fi.Mode()&os.ModeCharDevice == 0
return isPiped || quiet
}

type LoginCommand struct {
Config *Config
OIDCDomain string
ClientID string
MachineOutput bool
NoBrowser bool
}

func (c LoginCommand) Execute(ctx context.Context) error {
func (c *LoginCommand) Parse(flags *pflag.FlagSet, args []string) error {
c.OIDCDomain, _ = flags.GetString(FlagOIDCDomain)
c.ClientID, _ = flags.GetString(FlagClientID)
c.NoBrowser, _ = flags.GetBool(FlagNoBrowser)
urlOnly, _ := flags.GetBool(FlagURLOnly)
c.MachineOutput = ShouldUseMachineOutput(flags) || urlOnly
return nil
}

func (c LoginCommand) Execute(ctx context.Context, config *Config) error {
if !HasTokenExpired(config.Tokens) {
return nil
}

oauthCfg, err := oauth2.DiscoverConfig(ctx, c.OIDCDomain, c.ClientID)
if err != nil {
return err
Expand Down Expand Up @@ -111,7 +111,7 @@ func (c LoginCommand) Execute(ctx context.Context) error {
return fmt.Errorf("id_token not found in token response")
}

return c.Config.SaveOAuthToken(accessToken, idToken)
return config.SaveOAuthToken(accessToken, idToken)
}

var ErrNoPortsAvailable = errors.New("no ports available")
Expand Down
11 changes: 5 additions & 6 deletions internal/api/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ func ServeJSON[T any](w *events.ALBTargetGroupResponse, data T) {
w.Body = string(buf)
}

func ServeJSONError(w *events.ALBTargetGroupResponse, statusCode int, msg string) {
var jsonError struct {
Message string `json:"error"`
}
type JSONError struct {
Message string `json:"error"`
}

jsonError.Message = msg
func ServeJSONError(w *events.ALBTargetGroupResponse, statusCode int, msg string) {
w.StatusCode = statusCode
ServeJSON(w, jsonError)
ServeJSON(w, JSONError{Message: msg})
}

0 comments on commit e4f02d8

Please sign in to comment.