diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 40d229f6..27196b2b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,8 +10,6 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - with: - ref: ${{ github.head_ref }} - uses: hashicorp/setup-terraform@v2 with: terraform_version: "1.3.7" @@ -30,33 +28,15 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - with: - ref: ${{ github.head_ref }} - uses: actions/setup-node@v3 with: node-version: "16.17.0" - # Tests just need the file to exist with the appropriate exports - the values do not matter. - run: cd frontend && npm install && npm test - api-test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - ref: ${{ github.head_ref }} - - uses: actions/setup-go@v3 - with: - go-version: "1.19" - - run: cd api && go test ./... - cli-test: + go-test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - with: - ref: ${{ github.head_ref }} - uses: actions/setup-go@v3 with: go-version: "1.19" - # CLI test requires a dummy ~/.aws/credentials and config file exists. - - run: mkdir -p ~/.aws - - run: touch ~/.aws/{credentials,config} - - run: cd cli && go test ./... + - run: go test ./... diff --git a/Makefile b/Makefile index cc32d054..89f9e5f6 100644 --- a/Makefile +++ b/Makefile @@ -62,6 +62,7 @@ cli/keyconjurer: cd cli && \ go build \ -ldflags "\ + -s -w \ -X main.Version=$(shell git rev-parse --short HEAD)-$(RELEASE) \ -X main.ClientID=$(CLIENT_ID) \ -X main.OIDCDomain=$(OIDC_DOMAIN) \ diff --git a/cli/accounts.go b/cli/accounts.go index 6b861ca7..39445186 100644 --- a/cli/accounts.go +++ b/cli/accounts.go @@ -43,7 +43,7 @@ var accountsCmd = &cobra.Command{ } serverAddr, _ := cmd.Flags().GetString(FlagServerAddress) - serverAddrUri, err := url.Parse(serverAddr) + serverAddrURI, err := url.Parse(serverAddr) if err != nil { cmd.PrintErrf("--%s had an invalid value: %s\n", FlagServerAddress, err) return nil @@ -62,7 +62,7 @@ var accountsCmd = &cobra.Command{ TokenType: config.Tokens.TokenType, } - accounts, err := refreshAccounts(cmd.Context(), serverAddrUri, &tok) + accounts, err := refreshAccounts(cmd.Context(), serverAddrURI, &tok) if err != nil { cmd.PrintErrf("Error refreshing accounts: %s\n", err) cmd.PrintErrln("If you don't need to refresh your accounts, consider adding the --no-refresh flag") diff --git a/cli/awsconfig.go b/cli/awsconfig.go index b803134f..a26bec4c 100644 --- a/cli/awsconfig.go +++ b/cli/awsconfig.go @@ -12,7 +12,7 @@ import ( // Intentionally missing the `ini` notation sections,keys, and values are being handled by the ini library type CloudCliEntry struct { profileName string - keyId string + keyID string key string token string } @@ -25,7 +25,7 @@ func NewCloudCliEntry(c CloudCredentials, a *Account) CloudCliEntry { return CloudCliEntry{ profileName: name, - keyId: c.AccessKeyID, + keyID: c.AccessKeyID, key: c.SecretAccessKey, token: c.SessionToken, } @@ -56,11 +56,11 @@ func ResolveAWSCredentialsPath(rootPath string) string { func saveCredentialEntry(file *ini.File, entry CloudCliEntry, cloud string) error { section := file.Section(entry.profileName) if cloud == cloudAws { - section.Key("aws_access_key_id").SetValue(entry.keyId) + section.Key("aws_access_key_id").SetValue(entry.keyID) section.Key("aws_secret_access_key").SetValue(entry.key) section.Key("aws_session_token").SetValue(entry.token) } else if cloud == cloudTencent { - section.Key("tencent_access_key_id").SetValue(entry.keyId) + section.Key("tencent_access_key_id").SetValue(entry.keyID) section.Key("tencent_secret_access_key").SetValue(entry.key) section.Key("tencent_session_token").SetValue(entry.token) } diff --git a/cli/awsconfig_test.go b/cli/awsconfig_test.go index d4f6c26c..2e5ac322 100644 --- a/cli/awsconfig_test.go +++ b/cli/awsconfig_test.go @@ -15,7 +15,7 @@ func TestAddAWSCliEntry(t *testing.T) { entry := CloudCliEntry{ profileName: "test-profile", - keyId: "notanid", + keyID: "notanid", key: "notakey", token: "notatoken", } diff --git a/cli/config.go b/cli/config.go index d126fde2..bc1c8392 100644 --- a/cli/config.go +++ b/cli/config.go @@ -150,13 +150,13 @@ func (a *accountSet) ReplaceWith(other []Account) { m := map[string]struct{}{} for _, acc := range other { - copy := acc + clone := acc // Preserve the alias if the account ID is the same and it already exists if entry, ok := a.accounts[acc.ID]; ok { // The name is the only thing that might change. entry.Name = acc.Name } else { - a.accounts[acc.ID] = © + a.accounts[acc.ID] = &clone } m[acc.ID] = struct{}{} @@ -169,10 +169,10 @@ func (a *accountSet) ReplaceWith(other []Account) { } } -func (s accountSet) WriteTable(w io.Writer) { +func (a accountSet) WriteTable(w io.Writer) { tbl := csv.NewWriter(w) tbl.Write([]string{"id,name,alias"}) - s.ForEach(func(id string, acc Account, alias string) { + a.ForEach(func(id string, acc Account, alias string) { tbl.Write([]string{id, acc.Name, alias}) }) tbl.Flush() diff --git a/cli/consts.go b/cli/consts.go index 22943ee6..7dfd8e32 100644 --- a/cli/consts.go +++ b/cli/consts.go @@ -5,9 +5,9 @@ var ( ClientID string OIDCDomain string ServerAddress string - Version string = "TBD" - BuildTimestamp string = "BuildTimestamp is not set" - DownloadURL string = "URL not set yet" + Version = "TBD" + BuildTimestamp = "BuildTimestamp is not set" + DownloadURL = "URL not set yet" ) const ( diff --git a/cli/credentials.go b/cli/credentials.go index c00137f7..420bda0a 100644 --- a/cli/credentials.go +++ b/cli/credentials.go @@ -68,7 +68,7 @@ func LoadAWSCredentialsFromEnvironment() CloudCredentials { } } -func (c *CloudCredentials) ValidUntil(account *Account, cloudFlag string, dur time.Duration) bool { +func (c *CloudCredentials) ValidUntil(account *Account, dur time.Duration) bool { if account == nil || c == nil { return false } @@ -86,7 +86,7 @@ func (c *CloudCredentials) ValidUntil(account *Account, cloudFlag string, dur ti } const ( - aws_shellTypePowershell = `$Env:AWS_ACCESS_KEY_ID = "%v" + awsShellTypePowershell = `$Env:AWS_ACCESS_KEY_ID = "%v" $Env:AWS_SECRET_ACCESS_KEY = "%v" $Env:AWS_SESSION_TOKEN = "%v" $Env:AWS_SECURITY_TOKEN = "%v" @@ -96,7 +96,7 @@ $Env:TF_VAR_token = $Env:AWS_SESSION_TOKEN $Env:AWSKEY_EXPIRATION = "%v" $Env:AWSKEY_ACCOUNT = "%v" ` - tencent_shellTypePowershell = `$Env:TENCENTCLOUD_SECRET_ID = "%v" + tencentShellTypePowershell = `$Env:TENCENTCLOUD_SECRET_ID = "%v" $Env:TENCENTCLOUD_SECRET_KEY = "%v" $Env:TENCENTCLOUD_TOKEN = "%v" $Env:TENCENTCLOUD_SECURITY_TOKEN = "%v" @@ -106,7 +106,7 @@ $Env:TF_VAR_token = $Env:TENCENTCLOUD_TOKEN $Env:TENCENT_KEY_EXPIRATION = "%v" $Env:TENCENT_KEY_ACCOUNT = "%v" ` - aws_shellTypeBasic = `SET AWS_ACCESS_KEY_ID=%v + awsShellTypeBasic = `SET AWS_ACCESS_KEY_ID=%v SET AWS_SECRET_ACCESS_KEY=%v SET AWS_SESSION_TOKEN=%v SET AWS_SECURITY_TOKEN=%v @@ -116,7 +116,7 @@ SET TF_VAR_token=%%AWS_SESSION_TOKEN%% SET AWSKEY_EXPIRATION=%v SET AWSKEY_ACCOUNT=%v ` - tencent_shellTypeBasic = `SET TENCENTCLOUD_SECRET_ID=%v + tencentShellTypeBasic = `SET TENCENTCLOUD_SECRET_ID=%v SET TENCENTCLOUD_SECRET_KEY=%v SET TENCENTCLOUD_TOKEN=%v SET TENCENTCLOUD_SECURITY_TOKEN=%v @@ -125,7 +125,7 @@ SET TF_VAR_secret_key=%%TENCENTCLOUD_SECRET_KEY%% SET TF_VAR_token=%%TENCENTCLOUD_TOKEN%% SET TENCENTKEY_EXPIRATION=%v SET TENCENTKEY_ACCOUNT=%v` - aws_shellTypeBash = `export AWS_ACCESS_KEY_ID=%v + awsShellTypeBash = `export AWS_ACCESS_KEY_ID=%v export AWS_SECRET_ACCESS_KEY=%v export AWS_SESSION_TOKEN=%v export AWS_SECURITY_TOKEN=%v @@ -135,7 +135,7 @@ export TF_VAR_token=$AWS_SESSION_TOKEN export AWSKEY_EXPIRATION=%v export AWSKEY_ACCOUNT=%v ` - tencent_shellTypeBash = `export TENCENTCLOUD_SECRET_ID=%v + tencentShellTypeBash = `export TENCENTCLOUD_SECRET_ID=%v export TENCENTCLOUD_SECRET_KEY=%v export TENCENTCLOUD_TOKEN=%v export TENCENT_SECURITY_TOKEN=%v @@ -155,19 +155,19 @@ func (c CloudCredentials) WriteFormat(w io.Writer, format ShellType) (int, error switch format { case shellTypePowershell: - str = aws_shellTypePowershell + str = awsShellTypePowershell if c.credentialsType == cloudTencent { - str = tencent_shellTypePowershell + str = tencentShellTypePowershell } case shellTypeBasic: - str = aws_shellTypeBasic + str = awsShellTypeBasic if c.credentialsType == cloudTencent { - str = tencent_shellTypeBasic + str = tencentShellTypeBasic } case shellTypeBash: - str = aws_shellTypeBash + str = awsShellTypeBash if c.credentialsType == cloudTencent { - str = tencent_shellTypeBash + str = tencentShellTypeBash } } diff --git a/cli/credentials_test.go b/cli/credentials_test.go index bbcd1186..e7c84585 100644 --- a/cli/credentials_test.go +++ b/cli/credentials_test.go @@ -7,23 +7,6 @@ import ( "github.com/stretchr/testify/assert" ) -/* -interesting thread on using ENV in unit testing -https://www.reddit.com/r/golang/comments/ar5z3i/how_to_set_env_variables_while_unit_testing/ -*/ - -var envsToUse []string - -func init() { - envsToUse = []string{ - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "AWS_SESSION_TOKEN", - "AWSKEY_EXPIRATION", - "AWSKEY_ACCOUNT", - } -} - func setEnv(t *testing.T, valid bool) *Account { t.Setenv("AWS_ACCESS_KEY_ID", "1234") t.Setenv("AWS_SECRET_ACCESS_KEY", "accesskey") @@ -47,7 +30,7 @@ func setEnv(t *testing.T, valid bool) *Account { func TestGetValidEnvCreds(t *testing.T) { account := setEnv(t, true) creds := LoadAWSCredentialsFromEnvironment() - assert.True(t, creds.ValidUntil(account, "aws", 0), "credentials should be valid") + assert.True(t, creds.ValidUntil(account, 0), "credentials should be valid") } func TestGetInvalidEnvCreds(t *testing.T) { @@ -56,18 +39,18 @@ func TestGetInvalidEnvCreds(t *testing.T) { // test incorrect time first t.Log("testing expired timestamp for key") creds := LoadAWSCredentialsFromEnvironment() - assert.False(t, creds.ValidUntil(account, "aws", 0), "credentials should be invalid due to timestamp") + assert.False(t, creds.ValidUntil(account, 0), "credentials should be invalid due to timestamp") account = setEnv(t, true) account.ID = "" creds = LoadAWSCredentialsFromEnvironment() - assert.False(t, creds.ValidUntil(account, "aws", 0), "credentials should be invalid due to non-matching id") + assert.False(t, creds.ValidUntil(account, 0), "credentials should be invalid due to non-matching id") account = setEnv(t, true) t.Setenv("AWSKEY_EXPIRATION", "definitely not a timestamp") creds = LoadAWSCredentialsFromEnvironment() - assert.False(t, creds.ValidUntil(account, "aws", 0), "credentials should be invalid due to non-parsable timestamp") + assert.False(t, creds.ValidUntil(account, 0), "credentials should be invalid due to non-parsable timestamp") } func TestTimeWindowEnvCreds(t *testing.T) { @@ -75,12 +58,12 @@ func TestTimeWindowEnvCreds(t *testing.T) { t.Log("testing minutes window still within 1hr period for test creds") creds := LoadAWSCredentialsFromEnvironment() - assert.True(t, creds.ValidUntil(account, "aws", 0), "credentials should be valid") - assert.True(t, creds.ValidUntil(account, "aws", 5), "credentials should be valid") - assert.True(t, creds.ValidUntil(account, "aws", 30), "credentials should be valid") - assert.True(t, creds.ValidUntil(account, "aws", 58), "credentials should be valid") + assert.True(t, creds.ValidUntil(account, 0), "credentials should be valid") + assert.True(t, creds.ValidUntil(account, 5), "credentials should be valid") + assert.True(t, creds.ValidUntil(account, 30), "credentials should be valid") + assert.True(t, creds.ValidUntil(account, 58), "credentials should be valid") t.Log("testing minutes window is outside 1hr period for test creds") - assert.False(t, creds.ValidUntil(account, "aws", 60*time.Minute), "credentials should be valid") - assert.False(t, creds.ValidUntil(account, "aws", 61*time.Minute), "credentials should be valid") + assert.False(t, creds.ValidUntil(account, 60*time.Minute), "credentials should be valid") + assert.False(t, creds.ValidUntil(account, 61*time.Minute), "credentials should be valid") } diff --git a/cli/get.go b/cli/get.go index 6d1dc880..b5824b6d 100644 --- a/cli/get.go +++ b/cli/get.go @@ -55,12 +55,11 @@ func isMemberOfSlice(slice []string, val string) bool { return false } -func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrId string) (*Account, bool) { +func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Account, bool) { if bypassCache { - return &Account{ID: nameOrId, Name: nameOrId}, true - } else { - return cfg.FindAccount(nameOrId) + return &Account{ID: nameOrID, Name: nameOrID}, true } + return cfg.FindAccount(nameOrID) } var getCmd = &cobra.Command{ @@ -128,11 +127,11 @@ A role must be specified when using this command through the --role flag. You ma credentials = LoadTencentCredentialsFromEnvironment() } - if credentials.ValidUntil(account, cloudType, time.Duration(timeRemaining)*time.Minute) { - return echoCredentials(args[0], args[0], credentials, outputType, shellType, awsCliPath, tencentCliPath, cloudType) + if credentials.ValidUntil(account, time.Duration(timeRemaining)*time.Minute) { + return echoCredentials(args[0], args[0], credentials, outputType, shellType, awsCliPath, tencentCliPath) } - oauthCfg, _, err := DiscoverOAuth2Config(cmd.Context(), oidcDomain, clientID) + oauthCfg, err := DiscoverOAuth2Config(cmd.Context(), oidcDomain, clientID) if err != nil { cmd.PrintErrf("could not discover oauth2 config: %s\n", err) return nil @@ -157,7 +156,7 @@ A role must be specified when using this command through the --role flag. You ma return nil } - pair, _, ok := FindRoleInSAML(roleName, samlResponse) + pair, ok := FindRoleInSAML(roleName, samlResponse) if !ok { cmd.PrintErrf("you do not have access to the role %s on application %s\n", roleName, args[0]) return nil @@ -199,10 +198,10 @@ A role must be specified when using this command through the --role flag. You ma account.MostRecentRole = roleName } - return echoCredentials(args[0], args[0], credentials, outputType, shellType, awsCliPath, tencentCliPath, cloudType) + return echoCredentials(args[0], args[0], credentials, outputType, shellType, awsCliPath, tencentCliPath) }} -func echoCredentials(id, name string, credentials CloudCredentials, outputType, shellType, awsCliPath, tencentCliPath, cloudFlag string) error { +func echoCredentials(id, name string, credentials CloudCredentials, outputType, shellType, awsCliPath, tencentCliPath string) error { switch outputType { case outputTypeEnvironmentVariable: credentials.WriteFormat(os.Stdout, shellType) diff --git a/cli/log.go b/cli/log.go index 1ef2fe15..094c5604 100644 --- a/cli/log.go +++ b/cli/log.go @@ -6,11 +6,11 @@ import ( "golang.org/x/exp/slog" ) -type logRoundTripper struct { - rt http.RoundTripper +type LogRoundTripper struct { + RoundTripper http.RoundTripper } -func findOktaHeaders(r *http.Response) []slog.Attr { +func FindOktaHeaders(r *http.Response) []slog.Attr { var attrs []slog.Attr if hdr := r.Header.Get("X-Okta-Request-Id"); hdr != "" { attrs = append(attrs, slog.String("okta_request_id", hdr)) @@ -18,9 +18,9 @@ func findOktaHeaders(r *http.Response) []slog.Attr { return attrs } -func (t logRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { +func (t LogRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { slog.Debug("HTTP Request", slog.String("url", r.URL.String())) - resp, err := t.rt.RoundTrip(r) + resp, err := t.RoundTripper.RoundTrip(r) if err != nil { return nil, err } @@ -32,14 +32,10 @@ func (t logRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { slog.Bool("ok", resp.StatusCode == http.StatusOK), } - for _, attr := range findOktaHeaders(resp) { + for _, attr := range FindOktaHeaders(resp) { attrs = append(attrs, any(attr)) } slog.Debug("HTTP Response", attrs...) return resp, nil } - -func LogRoundTripper(rt http.RoundTripper) logRoundTripper { - return logRoundTripper{rt} -} diff --git a/cli/login.go b/cli/login.go index 614ce0d0..26651263 100644 --- a/cli/login.go +++ b/cli/login.go @@ -2,7 +2,6 @@ package main import ( "context" - "net/http" "os" "github.com/spf13/cobra" @@ -40,7 +39,7 @@ var loginCmd = &cobra.Command{ clientID, _ := cmd.Flags().GetString(FlagClientID) urlOnly, _ := cmd.Flags().GetBool(FlagURLOnly) isMachineOutput := ShouldUseMachineOutput(cmd.Flags()) || urlOnly - token, err := Login(cmd.Context(), NewHTTPClient(), oidcDomain, clientID, isMachineOutput) + token, err := Login(cmd.Context(), oidcDomain, clientID, isMachineOutput) if err != nil { return err } @@ -49,8 +48,8 @@ var loginCmd = &cobra.Command{ }, } -func Login(ctx context.Context, client *http.Client, domain, clientID string, machineOutput bool) (*oauth2.Token, error) { - oauthCfg, _, err := DiscoverOAuth2Config(ctx, domain, clientID) +func Login(ctx context.Context, domain, clientID string, machineOutput bool) (*oauth2.Token, error) { + oauthCfg, err := DiscoverOAuth2Config(ctx, domain, clientID) if err != nil { return nil, err } diff --git a/cli/oauth2.go b/cli/oauth2.go index 2c7a3c1c..7307033c 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -30,13 +30,13 @@ func NewHTTPClient() *http.Client { } } - return &http.Client{Transport: LogRoundTripper(tr)} + return &http.Client{Transport: LogRoundTripper{tr}} } -func DiscoverOAuth2Config(ctx context.Context, domain, clientID string) (*oauth2.Config, *oidc.Provider, error) { +func DiscoverOAuth2Config(ctx context.Context, domain, clientID string) (*oauth2.Config, error) { provider, err := oidc.NewProvider(ctx, domain) if err != nil { - return nil, nil, fmt.Errorf("couldn't discover OIDC configuration for %s: %w", domain, err) + return nil, fmt.Errorf("couldn't discover OIDC configuration for %s: %w", domain, err) } cfg := oauth2.Config{ @@ -45,7 +45,7 @@ func DiscoverOAuth2Config(ctx context.Context, domain, clientID string) (*oauth2 Scopes: []string{"openid", "profile", "okta.apps.read", "okta.apps.sso"}, } - return &cfg, provider, nil + return &cfg, nil } type OAuth2CallbackInfo struct { @@ -173,7 +173,7 @@ func RedirectionFlow(ctx context.Context, oauthCfg *oauth2.Config, state, codeCh return oauthCfg.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) } -func ExchangeAccessTokenForWebSSOToken(ctx context.Context, client *http.Client, oauthCfg *oauth2.Config, token *TokenSet, applicationId string) (*oauth2.Token, error) { +func ExchangeAccessTokenForWebSSOToken(ctx context.Context, client *http.Client, oauthCfg *oauth2.Config, token *TokenSet, applicationID string) (*oauth2.Token, error) { if client == nil { client = http.DefaultClient } @@ -187,7 +187,7 @@ func ExchangeAccessTokenForWebSSOToken(ctx context.Context, client *http.Client, "grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"}, // https://www.linkedin.com/pulse/oktas-aws-cli-app-mysterious-case-powerful-okta-apis-chaim-sanders/ "requested_token_type": {"urn:okta:oauth:token-type:web_sso_token"}, - "audience": {fmt.Sprintf("urn:okta:apps:%s", applicationId)}, + "audience": {fmt.Sprintf("urn:okta:apps:%s", applicationID)}, } body := strings.NewReader(data.Encode()) req, err := http.NewRequestWithContext(ctx, http.MethodPost, oauthCfg.Endpoint.TokenURL, body) diff --git a/cli/roles.go b/cli/roles.go index 6b1e58ad..9f291ec8 100644 --- a/cli/roles.go +++ b/cli/roles.go @@ -26,7 +26,7 @@ var rolesCmd = cobra.Command{ applicationID = account.ID } - oauthCfg, _, err := DiscoverOAuth2Config(cmd.Context(), oidcDomain, clientID) + oauthCfg, err := DiscoverOAuth2Config(cmd.Context(), oidcDomain, clientID) if err != nil { cmd.PrintErrf("could not discover oauth2 config: %s\n", err) return nil diff --git a/cli/saml.go b/cli/saml.go index 09ab67eb..0b973a44 100644 --- a/cli/saml.go +++ b/cli/saml.go @@ -12,10 +12,8 @@ type RoleProviderPair struct { } const ( - awsRoleUrl = "https://aws.amazon.com/SAML/Attributes/Role" - tencentRoleUrl = "https://cloud.tencent.com/SAML/Attributes/Role" - awsFlag = 0 - tencentFlag = 1 + awsFlag = 0 + tencentFlag = 1 ) func ListSAMLRoles(response *saml.Response) []string { @@ -23,15 +21,15 @@ func ListSAMLRoles(response *saml.Response) []string { return nil } - roleUrl := awsRoleUrl + roleURL := "https://aws.amazon.com/SAML/Attributes/Role" roleSubstr := "role/" - if response.GetAttribute(roleUrl) == "" { - roleUrl = tencentRoleUrl + if response.GetAttribute(roleURL) == "" { + roleURL = "https://cloud.tencent.com/SAML/Attributes/Role" roleSubstr = "roleName/" } var names []string - for _, v := range response.GetAttributeValues(roleUrl) { + for _, v := range response.GetAttributeValues(roleURL) { p := getARN(v) idx := strings.Index(p.RoleARN, roleSubstr) parts := strings.Split(p.RoleARN[idx:], "/") @@ -41,38 +39,31 @@ func ListSAMLRoles(response *saml.Response) []string { return names } -func FindRoleInSAML(roleName string, response *saml.Response) (RoleProviderPair, int, bool) { +func FindRoleInSAML(roleName string, response *saml.Response) (RoleProviderPair, bool) { if response == nil { - return RoleProviderPair{}, 0, false + return RoleProviderPair{}, false } - cloud := awsFlag - roleUrl := awsRoleUrl + roleURL := "https://aws.amazon.com/SAML/Attributes/Role" roleSubstr := "role/" - if response.GetAttribute(roleUrl) == "" { - cloud = tencentFlag - roleUrl = tencentRoleUrl + attrs := response.GetAttributeValues(roleURL) + if len(attrs) == 0 { + attrs = response.GetAttributeValues("https://cloud.tencent.com/SAML/Attributes/Role") roleSubstr = "roleName/" } - if roleName == "" && cloud == awsFlag { - // This is for legacy support. - // Legacy clients would always retrieve the first two ARNs in the list, which would be - // AWS: - // arn:cloud:iam::[account-id]:role/[onelogin_role] - // arn:cloud:iam::[account-id]:saml-provider/[saml-provider] - // If we get weird breakages with Key Conjurer when it's deployed alongside legacy clients, this is almost certainly a culprit! - pair := getARN(response.GetAttribute(roleUrl)) - return pair, cloud, false + if len(attrs) == 0 { + // The SAML assertoin contains no known roles for AWS or Tencent. + return RoleProviderPair{}, false } var pairs []RoleProviderPair - for _, v := range response.GetAttributeValues(roleUrl) { + for _, v := range response.GetAttributeValues(roleURL) { pairs = append(pairs, getARN(v)) } if len(pairs) == 0 { - return RoleProviderPair{}, cloud, false + return RoleProviderPair{}, false } var pair RoleProviderPair @@ -85,10 +76,10 @@ func FindRoleInSAML(roleName string, response *saml.Response) (RoleProviderPair, } if pair.RoleARN == "" { - return RoleProviderPair{}, cloud, false + return RoleProviderPair{}, false } - return pair, cloud, true + return pair, true } func getARN(value string) RoleProviderPair { diff --git a/cli/saml_test.go b/cli/saml_test.go index b8546c73..3797a1e7 100644 --- a/cli/saml_test.go +++ b/cli/saml_test.go @@ -11,11 +11,11 @@ func TestAwsFindRoleDoesntBreakIfYouHaveMultipleRoles(t *testing.T) { resp := saml.Response{} resp.AddAttribute("https://aws.amazon.com/SAML/Attributes/Role", "arn:cloud:iam::1234:saml-provider/Okta,arn:cloud:iam::1234:role/Admin") resp.AddAttribute("https://aws.amazon.com/SAML/Attributes/Role", "arn:cloud:iam::1234:saml-provider/Okta,arn:cloud:iam::1234:role/Power") - pair, _, err := FindRoleInSAML("Power", &resp) + pair, err := FindRoleInSAML("Power", &resp) require.True(t, err) require.Equal(t, "arn:cloud:iam::1234:saml-provider/Okta", pair.ProviderARN) require.Equal(t, "arn:cloud:iam::1234:role/Power", pair.RoleARN) - pair, _, err = FindRoleInSAML("Admin", &resp) + pair, err = FindRoleInSAML("Admin", &resp) require.True(t, err) require.Equal(t, "arn:cloud:iam::1234:saml-provider/Okta", pair.ProviderARN) require.Equal(t, "arn:cloud:iam::1234:role/Admin", pair.RoleARN) diff --git a/cli/switch.go b/cli/switch.go index 2b42a6ff..7e261111 100644 --- a/cli/switch.go +++ b/cli/switch.go @@ -89,7 +89,7 @@ This command will fail if you do not have active Cloud credentials. }, } -func getTencentCredentials(accountId, roleSessionName string) (creds CloudCredentials, err error) { +func getTencentCredentials(accountID, roleSessionName string) (creds CloudCredentials, err error) { region := os.Getenv("TENCENT_REGION") stsClient, err := tencent.NewSTSClient(region) if err != nil { @@ -102,15 +102,15 @@ func getTencentCredentials(accountId, roleSessionName string) (creds CloudCreden } arn := response.Response.Arn - roleId := "" + roleID := "" if (*arn) != "" { arns := strings.Split(*arn, ":") if len(arns) >= 5 && len(strings.Split(arns[4], "/")) >= 2 { - roleId = strings.Split(arns[4], "/")[1] + roleID = strings.Split(arns[4], "/")[1] } } - if roleId == "" { - err = fmt.Errorf("roleId is null") + if roleID == "" { + err = fmt.Errorf("roleID is null") return } @@ -118,17 +118,17 @@ func getTencentCredentials(accountId, roleSessionName string) (creds CloudCreden if err != nil { return } - roleName, err := camClient.GetRoleName(roleId) + roleName, err := camClient.GetRoleName(roleID) if err != nil { return } - resp, err := stsClient.AssumeRole(fmt.Sprintf("qcs::cam::uin/%s:roleName/%s", accountId, roleName), roleSessionName) + resp, err := stsClient.AssumeRole(fmt.Sprintf("qcs::cam::uin/%s:roleName/%s", accountID, roleName), roleSessionName) if err != nil { return } creds = CloudCredentials{ - AccountID: accountId, + AccountID: accountID, AccessKeyID: *resp.Response.Credentials.TmpSecretId, SecretAccessKey: *resp.Response.Credentials.TmpSecretKey, SessionToken: *resp.Response.Credentials.Token, @@ -139,7 +139,7 @@ func getTencentCredentials(accountId, roleSessionName string) (creds CloudCreden return creds, nil } -func getAWSCredentials(accountId, roleSessionName string) (creds CloudCredentials, err error) { +func getAWSCredentials(accountID, roleSessionName string) (creds CloudCredentials, err error) { ctx := context.Background() sess, err := session.NewSession(aws.NewConfig()) if err != nil { @@ -160,7 +160,7 @@ func getAWSCredentials(accountId, roleSessionName string) (creds CloudCredential parts := strings.Split(id.Resource, "/") arn := arn.ARN{ - AccountID: accountId, + AccountID: accountID, Partition: "aws", Service: "iam", Resource: fmt.Sprintf("role/%s", parts[1]), @@ -178,7 +178,7 @@ func getAWSCredentials(accountId, roleSessionName string) (creds CloudCredential } creds = CloudCredentials{ - AccountID: accountId, + AccountID: accountID, AccessKeyID: *resp.Credentials.AccessKeyId, SecretAccessKey: *resp.Credentials.SecretAccessKey, SessionToken: *resp.Credentials.SessionToken, diff --git a/internal/api/okta.go b/internal/api/okta.go index 71ad2e53..886c5a47 100644 --- a/internal/api/okta.go +++ b/internal/api/okta.go @@ -12,24 +12,24 @@ import ( "github.com/okta/okta-sdk-golang/v2/okta" ) -type oktaService struct { +type Okta struct { Domain *url.URL Token string client *http.Client oktaClient *okta.Client } -func NewOktaService(domain *url.URL, token string) oktaService { +func NewOktaService(domain *url.URL, token string) Okta { _, oktaClient, _ := okta.NewClient( context.Background(), okta.WithToken(token), okta.WithOrgUrl(domain.String()), ) - return oktaService{domain, token, http.DefaultClient, oktaClient} + return Okta{domain, token, http.DefaultClient, oktaClient} } -func (o oktaService) ListApplicationsForUser(ctx context.Context, user string) ([]*okta.AppLink, error) { +func (o Okta) ListApplicationsForUser(ctx context.Context, user string) ([]*okta.AppLink, error) { links, resp, err := o.oktaClient.User.ListAppLinks(ctx, user) if err != nil { return nil, err @@ -59,7 +59,7 @@ type OktaUserInfo struct { } // GetUserInfo returns user information about the given token -func (o oktaService) GetUserInfo(ctx context.Context, token string) (info OktaUserInfo, err error) { +func (o Okta) GetUserInfo(ctx context.Context, token string) (info OktaUserInfo, err error) { if o.client == nil { o.client = http.DefaultClient } diff --git a/internal/api/settings.go b/internal/api/settings.go index 95689e77..3275d43a 100644 --- a/internal/api/settings.go +++ b/internal/api/settings.go @@ -50,7 +50,7 @@ func NewSettings(ctx context.Context) (*Settings, error) { return entry.FetchSettings(ctx) } -func RetrieveSettingsFromEnv(ctx context.Context) (*Settings, error) { +func RetrieveSettingsFromEnv(_ context.Context) (*Settings, error) { s := Settings{ OktaHost: os.Getenv("OKTA_HOST"), OktaToken: os.Getenv("OKTA_TOKEN"), @@ -66,7 +66,7 @@ type VaultRetriever struct { SecretPath string } -func (v VaultRetriever) FetchSettings(ctx context.Context) (*Settings, error) { +func (v VaultRetriever) FetchSettings(_ context.Context) (*Settings, error) { var settings Settings client, err := vault.NewClient(vault.DefaultConfig()) if err != nil { diff --git a/internal/lambdaify.go b/internal/lambdaify.go index 6360a0dd..b16baa49 100644 --- a/internal/lambdaify.go +++ b/internal/lambdaify.go @@ -37,7 +37,7 @@ type lambda2HttpHandler struct { next http.Handler } -func (h lambda2HttpHandler) Invoke(ctx context.Context, b []byte) ([]byte, error) { +func (h lambda2HttpHandler) Invoke(_ context.Context, b []byte) ([]byte, error) { var inboundReq events.ALBTargetGroupRequest if err := json.Unmarshal(b, &inboundReq); err != nil { return nil, err diff --git a/internal/tencent/provider.go b/internal/tencent/provider.go index 4011399c..719adf1c 100644 --- a/internal/tencent/provider.go +++ b/internal/tencent/provider.go @@ -100,9 +100,9 @@ func NewCAMClient(region string) (*CAMClient, error) { } // APIļ¼š GetRoleName -func (c *CAMClient) GetRoleName(roleId string) (roleName string, err error) { +func (c *CAMClient) GetRoleName(roleID string) (roleName string, err error) { req := cam.NewGetRoleRequest() - req.RoleId = &roleId + req.RoleId = &roleID roleRsp, err := c.client.GetRole(req) fmt.Println(roleRsp.ToJsonString()) if err != nil { @@ -121,35 +121,35 @@ func ChainedCredsToCli() (common.CredentialIface, error) { // for tools login to STS auth type EnvProvider struct { - secretIdENV string - secretKeyENV string - tokenENV string + secretID string + secretKey string + token string } // DefaultEnvProvider return a default provider // The default environment variable name are TENCENTCLOUD_SECRET_ID and TENCENTCLOUD_SECRET_KEY and TOKEN func DefaultEnvProvider() *EnvProvider { return &EnvProvider{ - secretIdENV: "TENCENTCLOUD_SECRET_ID", - secretKeyENV: "TENCENTCLOUD_SECRET_KEY", - tokenENV: "TENCENTCLOUD_TOKEN", + secretID: "TENCENTCLOUD_SECRET_ID", + secretKey: "TENCENTCLOUD_SECRET_KEY", + token: "TENCENTCLOUD_TOKEN", } } // GetCredential func (p *EnvProvider) GetCredential() (common.CredentialIface, error) { - secretId, ok1 := os.LookupEnv(p.secretIdENV) - secretKey, ok2 := os.LookupEnv(p.secretKeyENV) - token, ok3 := os.LookupEnv(p.tokenENV) + secretID, ok1 := os.LookupEnv(p.secretID) + secretKey, ok2 := os.LookupEnv(p.secretKey) + token, ok3 := os.LookupEnv(p.token) if !ok1 || !ok2 || !ok3 { return nil, envNotSet } - if secretId == "" || secretKey == "" || token == "" { + if secretID == "" || secretKey == "" || token == "" { return nil, tcerr.NewTencentCloudSDKError(creErr, - "Environmental variable ("+p.secretIdENV+" or "+ - p.secretKeyENV+" or "+p.secretKeyENV+") is empty", "") + "Environmental variable ("+p.secretID+" or "+ + p.secretKey+" or "+p.secretKey+") is empty", "") } - return common.NewTokenCredential(secretId, secretKey, token), nil + return common.NewTokenCredential(secretID, secretKey, token), nil } var creErr = "ClientError.CredentialError"