From a8c2d4bb02e0418cab437cf9857d4870cefb1290 Mon Sep 17 00:00:00 2001 From: Dan <2939173+punmechanic@users.noreply.github.com> Date: Thu, 1 Feb 2024 12:04:21 -0800 Subject: [PATCH] Login to Okta when using Get if user passes --login (#105) if the user passes the --login flag and uses get when their session has expired, automatically log in, instead of asking them to log in --- cli/get.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/cli/get.go b/cli/get.go index 3cbd4bc8..f985559b 100644 --- a/cli/get.go +++ b/cli/get.go @@ -17,6 +17,7 @@ var ( FlagTimeRemaining = "time-remaining" FlagTimeToLive = "ttl" FlagBypassCache = "bypass-cache" + FlagLogin = "login" ) var ( @@ -42,6 +43,7 @@ func init() { getCmd.Flags().String(FlagTencentCLIPath, "~/.tencent/", "Path for directory used by the tencent-cli tool. Default is \"~/.tencent\".") getCmd.Flags().String(FlagCloudType, "aws", "Choose a cloud vendor. Default is aws. Can choose aws or tencent") getCmd.Flags().Bool(FlagBypassCache, false, "Do not check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache.") + getCmd.Flags().Bool(FlagLogin, false, "Login to Okta before running the command") } func isMemberOfSlice(slice []string, val string) bool { @@ -70,8 +72,22 @@ A role must be specified when using this command through the --role flag. You ma RunE: func(cmd *cobra.Command, args []string) error { config := ConfigFromCommand(cmd) ctx := cmd.Context() + oidcDomain, _ := cmd.Flags().GetString(FlagOIDCDomain) + clientID, _ := cmd.Flags().GetString(FlagClientID) if HasTokenExpired(config.Tokens) { - return ErrTokensExpiredOrAbsent + if ok, _ := cmd.Flags().GetBool(FlagLogin); ok { + token, err := Login(ctx, oidcDomain, clientID, LoginOutputModeBrowser{}) + if err != nil { + return err + } + if err := config.SaveOAuthToken(token); err != nil { + return err + } + + } else { + return ErrTokensExpiredOrAbsent + } + return nil } ttl, _ := cmd.Flags().GetUint(FlagTimeToLive) @@ -80,8 +96,6 @@ A role must be specified when using this command through the --role flag. You ma shellType, _ := cmd.Flags().GetString(FlagShellType) roleName, _ := cmd.Flags().GetString(FlagRoleName) cloudType, _ := cmd.Flags().GetString(FlagCloudType) - oidcDomain, _ := cmd.Flags().GetString(FlagOIDCDomain) - clientID, _ := cmd.Flags().GetString(FlagClientID) awsCliPath, _ := cmd.Flags().GetString(FlagAWSCLIPath) tencentCliPath, _ := cmd.Flags().GetString(FlagTencentCLIPath)