diff --git a/aws_account.go b/aws_account.go deleted file mode 100644 index ff28c3abf..000000000 --- a/aws_account.go +++ /dev/null @@ -1,84 +0,0 @@ -package saml2aws - -import ( - "bytes" - "fmt" - "io" - "net/http" - "net/url" - - "github.com/PuerkitoBio/goquery" - "github.com/pkg/errors" -) - -// AWSAccount holds the AWS account name and roles -type AWSAccount struct { - Name string - Roles []*AWSRole -} - -// ParseAWSAccounts extract the aws accounts from the saml assertion -func ParseAWSAccounts(audience string, samlAssertion string) ([]*AWSAccount, error) { - res, err := http.PostForm(audience, url.Values{"SAMLResponse": {samlAssertion}}) - if err != nil { - return nil, errors.Wrap(err, "error retrieving AWS login form") - } - - data, err := io.ReadAll(res.Body) - if err != nil { - return nil, errors.Wrap(err, "error retrieving AWS login body") - } - - return ExtractAWSAccounts(data) -} - -// ExtractAWSAccounts extract the accounts from the AWS html page -func ExtractAWSAccounts(data []byte) ([]*AWSAccount, error) { - accounts := []*AWSAccount{} - - doc, err := goquery.NewDocumentFromReader(bytes.NewBuffer(data)) - if err != nil { - return nil, errors.Wrap(err, "failed to build document from response") - } - - doc.Find("fieldset > div.saml-account").Each(func(i int, s *goquery.Selection) { - account := new(AWSAccount) - account.Name = s.Find("div.saml-account-name").Text() - s.Find("label").Each(func(i int, s *goquery.Selection) { - role := new(AWSRole) - role.Name = s.Text() - role.RoleARN, _ = s.Attr("for") - account.Roles = append(account.Roles, role) - }) - accounts = append(accounts, account) - }) - - return accounts, nil -} - -// AssignPrincipals assign principal from roles -func AssignPrincipals(awsRoles []*AWSRole, awsAccounts []*AWSAccount) { - - awsPrincipalARNs := make(map[string]string) - for _, awsRole := range awsRoles { - awsPrincipalARNs[awsRole.RoleARN] = awsRole.PrincipalARN - } - - for _, awsAccount := range awsAccounts { - for _, awsRole := range awsAccount.Roles { - awsRole.PrincipalARN = awsPrincipalARNs[awsRole.RoleARN] - } - } - -} - -// LocateRole locate role by name -func LocateRole(awsRoles []*AWSRole, roleName string) (*AWSRole, error) { - for _, awsRole := range awsRoles { - if awsRole.RoleARN == roleName { - return awsRole, nil - } - } - - return nil, fmt.Errorf("Supplied RoleArn not found in saml assertion: %s", roleName) -} diff --git a/aws_role.go b/aws_role.go deleted file mode 100644 index 60ff2246b..000000000 --- a/aws_role.go +++ /dev/null @@ -1,60 +0,0 @@ -package saml2aws - -import ( - "fmt" - "regexp" - "strings" -) - -// AWSRole aws role attributes -type AWSRole struct { - RoleARN string - PrincipalARN string - Name string -} - -// ParseAWSRoles parses and splits the roles while also validating the contents -func ParseAWSRoles(roles []string) ([]*AWSRole, error) { - awsRoles := make([]*AWSRole, len(roles)) - - for i, role := range roles { - awsRole, err := parseRole(role) - if err != nil { - return nil, err - } - - awsRoles[i] = awsRole - } - - return awsRoles, nil -} - -func parseRole(role string) (*AWSRole, error) { - r, _ := regexp.Compile("arn:([^:\n]*):([^:\n]*):([^:\n]*):([^:\n]*):(([^:/\n]*)[:/])?([^:,\n]*)") - tokens := r.FindAllString(role, -1) - - if len(tokens) != 2 { - return nil, fmt.Errorf("Invalid role string only %d tokens", len(tokens)) - } - - awsRole := &AWSRole{} - - for _, token := range tokens { - if strings.Contains(token, ":saml-provider") { - awsRole.PrincipalARN = strings.TrimSpace(token) - } - if strings.Contains(token, ":role") { - awsRole.RoleARN = strings.TrimSpace(token) - } - } - - if awsRole.PrincipalARN == "" { - return nil, fmt.Errorf("Unable to locate PrincipalARN in: %s", role) - } - - if awsRole.RoleARN == "" { - return nil, fmt.Errorf("Unable to locate RoleARN in: %s", role) - } - - return awsRole, nil -} diff --git a/aws_role_test.go b/aws_role_test.go deleted file mode 100644 index ea9eb665d..000000000 --- a/aws_role_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package saml2aws - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestParseRoles(t *testing.T) { - - roles := []string{ - "arn:aws:iam::456456456456:saml-provider/example-idp,arn:aws:iam::456456456456:role/admin", - "arn:aws:iam::456456456456:role/admin,arn:aws:iam::456456456456:saml-provider/example-idp", - } - - awsRoles, err := ParseAWSRoles(roles) - - assert.Nil(t, err) - assert.Len(t, awsRoles, 2) - - for _, awsRole := range awsRoles { - assert.Equal(t, "arn:aws:iam::456456456456:saml-provider/example-idp", awsRole.PrincipalARN) - assert.Equal(t, "arn:aws:iam::456456456456:role/admin", awsRole.RoleARN) - } - - roles = []string{""} - awsRoles, err = ParseAWSRoles(roles) - - assert.NotNil(t, err) - assert.Nil(t, awsRoles) - -} diff --git a/cloud_account.go b/cloud_account.go new file mode 100644 index 000000000..0a5321982 --- /dev/null +++ b/cloud_account.go @@ -0,0 +1,136 @@ +package saml2aws + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/PuerkitoBio/goquery" + "github.com/pkg/errors" + "github.com/versent/saml2aws/v2/pkg/cloud" +) + +// CloudAccount holds the AWS account name and roles +type CloudAccount struct { + Name string + Roles []*CloudRole +} + +func AssignAWSAccounts(awsRoles []*CloudRole, samlXml []byte, samlAssertion string) ([]*CloudRole, error) { + roleArnMap := make(map[string]*CloudRole) + for _, role := range awsRoles { + roleArnMap[role.RoleARN] = role + } + + aud, err := ExtractDestinationURL(samlXml) + if err != nil { + return nil, errors.Wrap(err, "Error parsing destination URL.") + } + + res, err := http.PostForm(aud, url.Values{"SAMLResponse": {samlAssertion}}) + if err != nil { + return nil, errors.Wrap(err, "Error retrieving cloud SAML login form.") + } + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Unexpected status code: %d", res.StatusCode) + } + + data, err := io.ReadAll(res.Body) + if err != nil { + return nil, errors.Wrap(err, "Error retrieving AWS login body.") + } + + doc, err := goquery.NewDocumentFromReader(bytes.NewBuffer(data)) + if err != nil { + return nil, errors.Wrap(err, "failed to build document from response") + } + + doc.Find("fieldset > div.saml-account").Each(func(i int, s *goquery.Selection) { + name := s.Find("div.saml-account-name").Text() + s.Find("label").Each(func(i int, s *goquery.Selection) { + arn, _ := s.Attr("for") + roleArnMap[arn].Account = name + // log.Println("Marked role", arn, "as belonging to account", name) + }) + }) + + return awsRoles, nil +} + +// ParseCloudAccounts extract the aws accounts from the saml assertion +func ParseCloudAccounts(provider cloud.Provider, audience string, samlAssertion string) ([]*CloudAccount, error) { + + switch provider { + case cloud.AWS: + res, err := http.PostForm(audience, url.Values{"SAMLResponse": {samlAssertion}}) + if err != nil { + return nil, errors.Wrap(err, "error retrieving cloud SAML login form") + } + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + data, err := io.ReadAll(res.Body) + if err != nil { + return nil, errors.Wrap(err, "error retrieving AWS login body") + } + return ExtractAWSAccounts(data) + case cloud.TencentCloud: + return nil, nil + default: + return nil, fmt.Errorf("unsupported cloud provider: %s", provider) + } +} + +// ExtractAWSAccounts extract the accounts from the AWS html page +func ExtractAWSAccounts(data []byte) ([]*CloudAccount, error) { + accounts := make([]*CloudAccount, 0) + + doc, err := goquery.NewDocumentFromReader(bytes.NewBuffer(data)) + if err != nil { + return nil, errors.Wrap(err, "failed to build document from response") + } + + doc.Find("fieldset > div.saml-account").Each(func(i int, s *goquery.Selection) { + account := new(CloudAccount) + account.Name = s.Find("div.saml-account-name").Text() + s.Find("label").Each(func(i int, s *goquery.Selection) { + role := new(CloudRole) + role.Name = s.Text() + role.RoleARN, _ = s.Attr("for") + account.Roles = append(account.Roles, role) + }) + accounts = append(accounts, account) + }) + + return accounts, nil +} + +// AssignPrincipals assign principal from roles +func AssignPrincipals(awsRoles []*CloudRole, cloudAccounts []*CloudAccount) { + + awsPrincipalARNs := make(map[string]string) + for _, awsRole := range awsRoles { + awsPrincipalARNs[awsRole.RoleARN] = awsRole.PrincipalARN + } + + for _, awsAccount := range cloudAccounts { + for _, awsRole := range awsAccount.Roles { + awsRole.PrincipalARN = awsPrincipalARNs[awsRole.RoleARN] + } + } + +} + +// LocateRole locate role by name +func LocateRole(awsRoles []*CloudRole, roleName string) (*CloudRole, error) { + for _, awsRole := range awsRoles { + if awsRole.RoleARN == roleName { + return awsRole, nil + } + } + + return nil, fmt.Errorf("Supplied RoleArn not found in saml assertion: %s", roleName) +} diff --git a/aws_account_test.go b/cloud_account_test.go similarity index 94% rename from aws_account_test.go rename to cloud_account_test.go index b5c86c43b..60ba0b250 100644 --- a/aws_account_test.go +++ b/cloud_account_test.go @@ -36,16 +36,16 @@ func TestExtractAWSAccounts(t *testing.T) { } func TestAssignPrincipals(t *testing.T) { - awsRoles := []*AWSRole{ + awsRoles := []*CloudRole{ { PrincipalARN: "arn:aws:iam::000000000001:saml-provider/test-idp", RoleARN: "arn:aws:iam::000000000001:role/Development", }, } - awsAccounts := []*AWSAccount{ + awsAccounts := []*CloudAccount{ { - Roles: []*AWSRole{ + Roles: []*CloudRole{ { RoleARN: "arn:aws:iam::000000000001:role/Development", }, @@ -59,7 +59,7 @@ func TestAssignPrincipals(t *testing.T) { } func TestLocateRole(t *testing.T) { - awsRoles := []*AWSRole{ + awsRoles := []*CloudRole{ { PrincipalARN: "arn:aws:iam::000000000001:saml-provider/test-idp", RoleARN: "arn:aws:iam::000000000001:role/Development", diff --git a/cloud_role.go b/cloud_role.go new file mode 100644 index 000000000..b1587441b --- /dev/null +++ b/cloud_role.go @@ -0,0 +1,81 @@ +package saml2aws + +import ( + "fmt" + "regexp" + "strings" + + "github.com/versent/saml2aws/v2/pkg/cloud" +) + +// CloudRole aws role attributes +type CloudRole struct { + Provider cloud.Provider + RoleARN string + PrincipalARN string + Name string + Account string +} + +// ParseCloudRoles parses and splits the roles while also validating the contents +func ParseCloudRoles(roles []string, cp cloud.Provider) ([]*CloudRole, error) { + awsRoles := make([]*CloudRole, len(roles)) + + for i, role := range roles { + awsRole, err := parseRole(role, cp) + if err != nil { + return nil, err + } + + awsRoles[i] = awsRole + } + + return awsRoles, nil +} + +func parseRole(role string, cp cloud.Provider) (*CloudRole, error) { + var r *regexp.Regexp + switch cp { + case cloud.AWS: + r, _ = regexp.Compile("arn:([^:\n]*):([^:\n]*):([^:\n]*):([^:\n]*):(([^:/\n]*)[:/])?([^:,\n]*)") + case cloud.TencentCloud: + r, _ = regexp.Compile("qcs::([^:]*):([^:]*):([^:]*):([^:/]*)(/[^,]*)?") + + default: + return nil, fmt.Errorf("Invalid provider:") + } + + // log.Println("Parsing role: ", role) + + tokens := r.FindAllString(role, -1) + if len(tokens) != 2 { + return nil, fmt.Errorf("Invalid role string only %d tokens", len(tokens)) + } + + providerRole := &CloudRole{} + for _, token := range tokens { + if strings.Contains(token, ":saml-provider") { + providerRole.PrincipalARN = strings.TrimSpace(token) + } + if strings.Contains(token, ":role") { + providerRole.RoleARN = strings.TrimSpace(token) + if cp == cloud.AWS { + providerRole.Name = strings.Split(token, "/")[1] + } else if cp == cloud.TencentCloud { + providerRole.Name = strings.Split(token, "/")[2] + providerRole.Account = strings.Split(strings.Split(token, "/")[1], ":")[0] + } + } + } + providerRole.Provider = cp + + if providerRole.PrincipalARN == "" { + return nil, fmt.Errorf("Unable to locate PrincipalARN in: %s", role) + } + + if providerRole.RoleARN == "" { + return nil, fmt.Errorf("Unable to locate RoleARN in: %s", role) + } + + return providerRole, nil +} diff --git a/cloud_role_test.go b/cloud_role_test.go new file mode 100644 index 000000000..8cad9faef --- /dev/null +++ b/cloud_role_test.go @@ -0,0 +1,54 @@ +package saml2aws + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseRoles(t *testing.T) { + + roles := []string{ + "arn:aws:iam::456456456456:saml-provider/example-idp,arn:aws:iam::456456456456:role/admin", + "arn:aws:iam::456456456456:role/admin,arn:aws:iam::456456456456:saml-provider/example-idp", + } + + awsRoles, err := ParseCloudRoles(roles, "AWS") + + assert.Nil(t, err) + assert.Len(t, awsRoles, 2) + + for _, awsRole := range awsRoles { + assert.Equal(t, "arn:aws:iam::456456456456:saml-provider/example-idp", awsRole.PrincipalARN) + assert.Equal(t, "arn:aws:iam::456456456456:role/admin", awsRole.RoleARN) + } + + roles = []string{""} + awsRoles, err = ParseCloudRoles(roles, "AWS") + + assert.NotNil(t, err) + assert.Nil(t, awsRoles) + + // TencentCloud Roles + roles = []string{ + "qcs::cam::uin/888888888888:roleName/dage,qcs::cam::uin/888888888888:saml-provider/example-provider-idp", + "qcs::cam::uin/888888888888:saml-provider/example-provider-idp,qcs::cam::uin/888888888888:roleName/dage", + } + + tencentcloudRoles, err := ParseCloudRoles(roles, "TencentCloud") + + assert.Nil(t, err) + assert.Len(t, tencentcloudRoles, 2) + + for _, tencentcloudRole := range tencentcloudRoles { + assert.Equal(t, "qcs::cam::uin/888888888888:saml-provider/example-provider-idp", tencentcloudRole.PrincipalARN) + assert.Equal(t, "qcs::cam::uin/888888888888:roleName/dage", tencentcloudRole.RoleARN) + } + + roles = []string{""} + tencentcloudRoles, err = ParseCloudRoles(roles, "TencentCloud") + + assert.NotNil(t, err) + assert.Nil(t, tencentcloudRoles) + +} diff --git a/cmd/saml2aws/commands/list_roles.go b/cmd/saml2aws/commands/list_roles.go index 75c522606..f62ee22c6 100644 --- a/cmd/saml2aws/commands/list_roles.go +++ b/cmd/saml2aws/commands/list_roles.go @@ -2,6 +2,7 @@ package commands import ( b64 "encoding/base64" + "encoding/json" "fmt" "log" "os" @@ -10,6 +11,7 @@ import ( "github.com/sirupsen/logrus" "github.com/versent/saml2aws/v2" "github.com/versent/saml2aws/v2/helper/credentials" + "github.com/versent/saml2aws/v2/pkg/cloud" "github.com/versent/saml2aws/v2/pkg/flags" "github.com/versent/saml2aws/v2/pkg/samlcache" ) @@ -89,62 +91,46 @@ func ListRoles(loginFlags *flags.LoginExecFlags) error { } } - data, err := b64.StdEncoding.DecodeString(samlAssertion) - if err != nil { - return errors.Wrap(err, "error decoding saml assertion") - } - - roles, err := saml2aws.ExtractAwsRoles(data) - if err != nil { - return errors.Wrap(err, "error parsing aws roles") - } - - if len(roles) == 0 { - log.Println("No roles to assume") - os.Exit(1) - } - - awsRoles, err := saml2aws.ParseAWSRoles(roles) - if err != nil { - return errors.Wrap(err, "error parsing aws roles") - } - - if err := listRoles(awsRoles, samlAssertion, loginFlags); err != nil { - return errors.Wrap(err, "Failed to list roles") - } - - return nil -} - -func listRoles(awsRoles []*saml2aws.AWSRole, samlAssertion string, loginFlags *flags.LoginExecFlags) error { - if len(awsRoles) == 0 { - return errors.New("no roles available") - } - - samlAssertionData, err := b64.StdEncoding.DecodeString(samlAssertion) - if err != nil { - return errors.Wrap(err, "error decoding saml assertion") + samlAssertions := make(map[cloud.Provider]string) + if loginDetails.TencentCloudURL != "" { + // If TencentCloud is configured, unmarshal the SAML assertion for both AWS and TencentCloud + if err = json.Unmarshal([]byte(samlAssertion), &samlAssertions); err != nil { + return errors.Wrap(err, "error unmarshalling saml assertion. (Devsisters custom implementation)") + } + } else { + // Only AWS is configured, proceed with normal saml2aws flow + samlAssertions[cloud.AWS] = samlAssertion } - aud, err := saml2aws.ExtractDestinationURL(samlAssertionData) - if err != nil { - return errors.Wrap(err, "error parsing destination url") - } + cloudRoles := make([]*saml2aws.CloudRole, 0) + for cloud, assertion := range samlAssertions { + data, err := b64.StdEncoding.DecodeString(assertion) + if err != nil { + return errors.Wrap(err, "error decoding SAML assertion") + } - awsAccounts, err := saml2aws.ParseAWSAccounts(aud, samlAssertion) - if err != nil { - return errors.Wrap(err, "error parsing aws role accounts") - } + roleArns, err := saml2aws.ExtractCloudRoles(data) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("error extracting %v role arns", cloud)) + } + if len(roleArns) == 0 { + // log.Println("No", cloud, "roles to assyyume") + continue + } - saml2aws.AssignPrincipals(awsRoles, awsAccounts) + roles, err := saml2aws.ParseCloudRoles(roleArns, cloud) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("error parsing %s roles", cloud)) + } + cloudRoles = append(cloudRoles, roles...) - log.Println("") - for _, account := range awsAccounts { - fmt.Println(account.Name) - for _, role := range account.Roles { - fmt.Println(role.RoleARN) + for _, role := range roles { + log.Println(fmt.Sprintf("%v (%v)", role.RoleARN, cloud)) } - fmt.Println("") + } + if len(cloudRoles) == 0 { + log.Println("No cloud provider roles to assume") + os.Exit(1) } return nil diff --git a/cmd/saml2aws/commands/login.go b/cmd/saml2aws/commands/login.go index e1cc12166..a276d3517 100644 --- a/cmd/saml2aws/commands/login.go +++ b/cmd/saml2aws/commands/login.go @@ -7,20 +7,24 @@ import ( "log" "os" "strings" - "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/tencentcloud/tencentcloud-sdk-go-intl-en/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go-intl-en/tencentcloud/common/profile" + tcsts "github.com/tencentcloud/tencentcloud-sdk-go-intl-en/tencentcloud/sts/v20180813" "github.com/versent/saml2aws/v2" "github.com/versent/saml2aws/v2/helper/credentials" "github.com/versent/saml2aws/v2/pkg/awsconfig" "github.com/versent/saml2aws/v2/pkg/cfg" + "github.com/versent/saml2aws/v2/pkg/cloud" "github.com/versent/saml2aws/v2/pkg/creds" "github.com/versent/saml2aws/v2/pkg/flags" "github.com/versent/saml2aws/v2/pkg/samlcache" + "github.com/versent/saml2aws/v2/pkg/tcconfig" ) // Login login to ADFS @@ -33,49 +37,18 @@ func Login(loginFlags *flags.LoginExecFlags) error { return errors.Wrap(err, "Error building login details.") } - sharedCreds := awsconfig.NewSharedCredentials(account.Profile, account.CredentialsFile) // creates a cacheProvider, only used when --cache is set cacheProvider := &samlcache.SAMLCacheProvider{ Account: account.Name, Filename: account.SAMLCacheFile, } - logger.Debug("Check if creds exist.") - - // this checks if the credentials file has been created yet - exist, err := sharedCreds.CredsExists() - if err != nil { - return errors.Wrap(err, "Error loading credentials.") - } - if !exist { - log.Println("Unable to load credentials. Login required to create them.") - return nil - } - - if !sharedCreds.Expired() && !loginFlags.Force { - logger.Debug("Credentials are not expired. Skipping.") - previousCreds, err := sharedCreds.Load() - if err != nil { - log.Println("Unable to load cached credentials.") - } else { - logger.Debug("Credentials will expire at ", previousCreds.Expires) - } - if loginFlags.CredentialProcess { - err = PrintCredentialProcess(previousCreds) - if err != nil { - return err - } - } - return nil - } - loginDetails, err := resolveLoginDetails(account, loginFlags) if err != nil { - log.Printf("%+v", err) - os.Exit(1) + return err } - logger.WithField("idpAccount", account).Debug("building provider") + logger.WithField("idpAccount", account).Debug("building samlProvider") provider, err := saml2aws.NewSAMLClient(account) if err != nil { @@ -120,7 +93,7 @@ func Login(loginFlags *flags.LoginExecFlags) error { log.Println("Response did not contain a valid SAML assertion.") log.Println("Please check that your username and password is correct.") log.Println("To see the output follow the instructions in https://github.com/versent/saml2aws#debugging-issues-with-idps") - os.Exit(1) + return errors.New("Response did not contain a valid SAML assertion.") } if !loginFlags.CommonFlags.DisableKeychain { @@ -130,37 +103,62 @@ func Login(loginFlags *flags.LoginExecFlags) error { } } - role, err := selectAwsRole(samlAssertion, account) - if err != nil { - return errors.Wrap(err, "Failed to assume role. Please check whether you are permitted to assume the given role for the AWS service.") - } + // log.Println("SAML assertion:", samlAssertion) - log.Println("Selected role:", role.RoleARN) + samlAssertions := make(map[cloud.Provider]string) + if loginDetails.TencentCloudURL != "" { + // If TencentCloud is configured, unmarshal the SAML assertion for both AWS and TencentCloud + if err = json.Unmarshal([]byte(samlAssertion), &samlAssertions); err != nil { + return errors.Wrap(err, "error unmarshalling saml assertion. (Devsisters custom implementation)") + } + } else { + // Only AWS is configured, proceed with normal saml2aws flow + samlAssertions[cloud.AWS] = samlAssertion + } - awsCreds, err := loginToStsUsingRole(account, role, samlAssertion) + role, err := selectCloudRole(samlAssertions, account) if err != nil { - return errors.Wrap(err, "Error logging into AWS role using SAML assertion.") + return errors.Wrap(err, "Error resolving role.") } + log.Println("Selected role:", role.RoleARN) - // print credential process if needed - if loginFlags.CredentialProcess { - err = PrintCredentialProcess(awsCreds) + switch role.Provider { + case cloud.AWS: + creds, err := assumeAwsRoleWithSAML(account, role, samlAssertions[role.Provider]) if err != nil { + return errors.Wrap(err, "Error logging into AWS role using SAML assertion.") + } + cp := awsconfig.NewSharedCredentials(account.Profile, account.CredentialsFile) + if err := cp.Save(creds); err != nil { return err } - } else { - err = saveCredentials(awsCreds, sharedCreds) + + log.Println("Logged in as:", creds.PrincipalARN) + log.Println("") + log.Println("Your new access key pair has been stored in the AWS configuration.") + log.Printf("Note that it will expire at %v", creds.Expires) + if account.Profile != "default" { + log.Println("To use this credential, call the AWS CLI with the --profile option (e.g. aws --profile", account.Profile, "ec2 describe-instances).") + } + case cloud.TencentCloud: + creds, err := assumeTencentRoleWithSAML(account, role, samlAssertions[role.Provider]) if err != nil { + return errors.Wrap(err, "Error logging into TencentCloud role using SAML assertion.") + } + cp := tcconfig.NewSharedCredentials(account.Profile, account.CredentialsFile) + if err := cp.Save(creds); err != nil { return err } - log.Println("Logged in as:", awsCreds.PrincipalARN) + log.Println("Logged in as:", creds.PrincipalARN) log.Println("") - log.Println("Your new access key pair has been stored in the AWS configuration.") - log.Printf("Note that it will expire at %v", awsCreds.Expires) - if sharedCreds.Profile != "default" { - log.Println("To use this credential, call the AWS CLI with the --profile option (e.g. aws --profile", sharedCreds.Profile, "ec2 describe-instances).") + log.Println("Your new secret key pair has been stored in the TencentCloud configuration.") + log.Printf("Note that it will expire at %v", creds.Expires) + if account.Profile != "default" { + log.Println("To use this credential, call the TC CLI with the --profile option (e.g. tccli --profile", account.Profile, "cvm DescribeInstances).") } + default: + return errors.Wrap(err, "Error resolving role (unknown provider).") } return nil @@ -180,8 +178,7 @@ func buildIdpAccount(loginFlags *flags.LoginExecFlags) (*cfg.IDPAccount, error) // update username and hostname if supplied flags.ApplyFlagOverrides(loginFlags.CommonFlags, account) - err = account.Validate() - if err != nil { + if err := account.Validate(); err != nil { return nil, errors.Wrap(err, "Failed to validate account.") } @@ -192,7 +189,7 @@ func resolveLoginDetails(account *cfg.IDPAccount, loginFlags *flags.LoginExecFla // log.Printf("loginFlags %+v", loginFlags) - loginDetails := &creds.LoginDetails{URL: account.URL, Username: account.Username, MFAToken: loginFlags.CommonFlags.MFAToken, DuoMFAOption: loginFlags.DuoMFAOption} + loginDetails := &creds.LoginDetails{URL: account.URL, TencentCloudURL: account.TencentCloudURL, Username: account.Username, MFAToken: loginFlags.CommonFlags.MFAToken, DuoMFAOption: loginFlags.DuoMFAOption} log.Printf("Using IdP Account %s to access %s %s", loginFlags.CommonFlags.IdpAccount, account.Provider, account.URL) @@ -263,69 +260,58 @@ func resolveLoginDetails(account *cfg.IDPAccount, loginFlags *flags.LoginExecFla return loginDetails, nil } -func selectAwsRole(samlAssertion string, account *cfg.IDPAccount) (*saml2aws.AWSRole, error) { - data, err := b64.StdEncoding.DecodeString(samlAssertion) - if err != nil { - return nil, errors.Wrap(err, "Error decoding SAML assertion.") - } +func selectCloudRole(samlAssertions map[cloud.Provider]string, account *cfg.IDPAccount) (*saml2aws.CloudRole, error) { + cloudRoles := make([]*saml2aws.CloudRole, 0) + for cloudProvider, assertion := range samlAssertions { + data, err := b64.StdEncoding.DecodeString(assertion) + if err != nil { + return nil, errors.Wrap(err, "Error decoding SAML assertion.") + } - roles, err := saml2aws.ExtractAwsRoles(data) - if err != nil { - return nil, errors.Wrap(err, "Error parsing AWS roles.") - } + roleArns, err := saml2aws.ExtractCloudRoles(data) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("Error extracting %v roles arns.", cloudProvider)) + } + if len(roleArns) == 0 { + log.Println("No", cloudProvider, "roles to assume.") + continue + } - if len(roles) == 0 { - log.Println("No roles to assume.") - log.Println("Please check you are permitted to assume roles for the AWS service.") - os.Exit(1) + roles, err := saml2aws.ParseCloudRoles(roleArns, cloudProvider) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("Error parsing %v roles.", cloudProvider)) + } + + if cloudProvider == cloud.AWS { + roles, err = saml2aws.AssignAWSAccounts(roles, data, assertion) + if err != nil { + return nil, errors.Wrap(err, "Error assigning AWS accounts to roles.") + } + } + + cloudRoles = append(cloudRoles, roles...) } - awsRoles, err := saml2aws.ParseAWSRoles(roles) - if err != nil { - return nil, errors.Wrap(err, "Error parsing AWS roles.") + if len(cloudRoles) == 0 { + log.Println("Please check you are permitted to assume roles for the AWS or TencentCloud service.") + os.Exit(1) } - return resolveRole(awsRoles, samlAssertion, account) + return resolveRole(cloudRoles, account) } -func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string, account *cfg.IDPAccount) (*saml2aws.AWSRole, error) { - var role = new(saml2aws.AWSRole) - - if len(awsRoles) == 1 { +func resolveRole(cloudRoles []*saml2aws.CloudRole, account *cfg.IDPAccount) (role *saml2aws.CloudRole, err error) { + if len(cloudRoles) == 1 { if account.RoleARN != "" { - return saml2aws.LocateRole(awsRoles, account.RoleARN) + return saml2aws.LocateRole(cloudRoles, account.RoleARN) } - return awsRoles[0], nil - } else if len(awsRoles) == 0 { + return cloudRoles[0], nil + } else if len(cloudRoles) == 0 { return nil, errors.New("No roles available.") } - samlAssertionData, err := b64.StdEncoding.DecodeString(samlAssertion) - if err != nil { - return nil, errors.Wrap(err, "Error decoding SAML assertion.") - } - - aud, err := saml2aws.ExtractDestinationURL(samlAssertionData) - if err != nil { - return nil, errors.Wrap(err, "Error parsing destination URL.") - } - - awsAccounts, err := saml2aws.ParseAWSAccounts(aud, samlAssertion) - if err != nil { - return nil, errors.Wrap(err, "Error parsing AWS role accounts.") - } - if len(awsAccounts) == 0 { - return nil, errors.New("No accounts available.") - } - - saml2aws.AssignPrincipals(awsRoles, awsAccounts) - - if account.RoleARN != "" { - return saml2aws.LocateRole(awsRoles, account.RoleARN) - } - for { - role, err = saml2aws.PromptForAWSRoleSelection(awsAccounts) + role, err = saml2aws.PromptForCloudRoleSelection(cloudRoles) if err == nil { break } @@ -335,7 +321,7 @@ func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string, account *cf return role, nil } -func loginToStsUsingRole(account *cfg.IDPAccount, role *saml2aws.AWSRole, samlAssertion string) (*awsconfig.AWSCredentials, error) { +func assumeAwsRoleWithSAML(account *cfg.IDPAccount, role *saml2aws.CloudRole, samlAssertion string) (*awsconfig.AWSCredentials, error) { sess, err := session.NewSession(&aws.Config{ Region: &account.Region, @@ -371,50 +357,109 @@ func loginToStsUsingRole(account *cfg.IDPAccount, role *saml2aws.AWSRole, samlAs }, nil } -func saveCredentials(awsCreds *awsconfig.AWSCredentials, sharedCreds *awsconfig.CredentialsProvider) error { - err := sharedCreds.Save(awsCreds) - if err != nil { - return errors.Wrap(err, "Error saving credentials.") - } +func assumeTencentRoleWithSAML(account *cfg.IDPAccount, role *saml2aws.CloudRole, samlAssertion string) (*tcconfig.TCCredentials, error) { - return nil -} + credential := common.NewCredential("", "") -// CredentialsToCredentialProcess -// Returns a Json output that is compatible with the AWS credential_process -// https://github.com/awslabs/awsprocesscreds -func CredentialsToCredentialProcess(awsCreds *awsconfig.AWSCredentials) (string, error) { + clientProfile := profile.NewClientProfile() - type AWSCredentialProcess struct { - Version int - AccessKeyId string - SecretAccessKey string - SessionToken string - Expiration string + client, err := tcsts.NewClient(credential, "", clientProfile) + if err != nil { + log.Fatalf("Failed to create sts client: %v", err) } - - cred_process := AWSCredentialProcess{ - Version: 1, - AccessKeyId: awsCreds.AWSAccessKey, - SecretAccessKey: awsCreds.AWSSecretKey, - SessionToken: awsCreds.AWSSessionToken, - Expiration: awsCreds.Expires.Format(time.RFC3339), + region, ok := convertAWSRegionToTencentCloud(account.Region) + if !ok { + log.Println("Selected region %v is unknown or not available in TencentCloud. Selecting %v in best effort.", account.Region, region) } + client.Init(region) + + log.Println("Requesting TencentCloud credentials using SAML assertion.") - p, err := json.Marshal(cred_process) + samlRequest := tcsts.NewAssumeRoleWithSAMLRequest() + sessionDuration := uint64(account.SessionDuration) + samlRequest.SAMLAssertion = &samlAssertion + samlRequest.PrincipalArn = &role.PrincipalARN + samlRequest.RoleArn = &role.RoleARN + samlRequest.DurationSeconds = &sessionDuration + samlRequest.RoleSessionName = &account.Username + + // log.Println(fmt.Sprintf("tccli sts AssumeRoleWithSAML --PrincipalArn %v --RoleArn %v --SAMLAssertion %v --DurationSeconds %v --RoleSessionName %v", role.PrincipalARN, role.RoleARN, samlAssertion, sessionDuration, account.Username)) + + resp, err := client.AssumeRoleWithSAML(samlRequest) if err != nil { - return "", errors.Wrap(err, "Error while marshalling the credential process.") + return nil, errors.Wrap(err, "Error retrieving STS credentials using SAML.") } - return string(p), nil + return &tcconfig.TCCredentials{ + SecretID: aws.StringValue(resp.Response.Credentials.TmpSecretId), + SecretKey: aws.StringValue(resp.Response.Credentials.TmpSecretKey), + Token: aws.StringValue(resp.Response.Credentials.Token), + Region: account.Region, + Expires: aws.StringValue(resp.Response.Expiration), + PrincipalARN: role.PrincipalARN, + }, nil } -// PrintCredentialProcess Prints a Json output that is compatible with the AWS credential_process -// https://github.com/awslabs/awsprocesscreds -func PrintCredentialProcess(awsCreds *awsconfig.AWSCredentials) error { - jsonData, err := CredentialsToCredentialProcess(awsCreds) - if err == nil { - fmt.Println(jsonData) +// convertAWSRegionToTencentCloud converts AWS regions to TencentCloud regions. Returns the TencentCloud region and a boolean indicating if the region is directly supported in TencentCloud. +func convertAWSRegionToTencentCloud(region string) (string, bool) { + switch region { + case "us-east-1": + return "na-ashburn", true + case "us-east-2": + return "na-toronto", true + case "us-west-1": + return "na-siliconvalley", true + case "us-west-2": + return "na-siliconvalley", false + case "af-south-1": + return "ap-mumbai", false + case "ap-east-1": + return "ap-hongkong", true + case "ap-south-1": + return "ap-mumbai", true + case "ap-south-2": + return "ap-mumbai", false + case "ap-southeast-1": + return "ap-singapore", true + case "ap-southeast-2": + return "ap-jakarta", false + case "ap-southeast-3": + return "ap-jakarta", true + case "ap-southeast-4": + return "ap-jakarta", false + case "ap-northeast-1": + return "ap-tokyo", true + case "ap-northeast-2": + return "ap-seoul", true + case "ap-northeast-3": + return "ap-tokyo", false + case "ca-central-1": + return "na-toronto", false + case "ca-west-1": + return "na-siliconvalley", false + case "eu-central-1": + return "eu-frankfurt", true + case "eu-central-2": + return "eu-frankfurt", false + case "eu-west-1": + return "eu-frankfurt", false + case "eu-west-2": + return "eu-frankfurt", false + case "eu-south-1": + return "eu-frankfurt", false + case "eu-south-2": + return "eu-frankfurt", false + case "eu-north-1": + return "eu-frankfurt", false + case "il-central-1": + return "ap-mumbai", false + case "me-south-1": + return "ap-mumbai", false + case "me-central-1": + return "ap-mumbai", false + case "sa-east-1": + return "sa-saopaulo", true + default: + return "ap-tokyo", false } - return err } diff --git a/cmd/saml2aws/commands/login_test.go b/cmd/saml2aws/commands/login_test.go index bca0442cf..6d217b4bb 100644 --- a/cmd/saml2aws/commands/login_test.go +++ b/cmd/saml2aws/commands/login_test.go @@ -65,13 +65,13 @@ func TestOktaResolveLoginDetailsWithFlags(t *testing.T) { func TestResolveRoleSingleEntry(t *testing.T) { - adminRole := &saml2aws.AWSRole{ + adminRole := &saml2aws.CloudRole{ Name: "admin", RoleARN: "arn:aws:iam::456456456456:saml-provider/example-idp,arn:aws:iam::456456456456:role/admin", PrincipalARN: "arn:aws:iam::456456456456:role/admin,arn:aws:iam::456456456456:saml-provider/example-idp", } - awsRoles := []*saml2aws.AWSRole{ + awsRoles := []*saml2aws.CloudRole{ adminRole, } diff --git a/cmd/saml2aws/main.go b/cmd/saml2aws/main.go index bc0619f5d..f3ab669aa 100644 --- a/cmd/saml2aws/main.go +++ b/cmd/saml2aws/main.go @@ -77,7 +77,8 @@ func main() { app.Flag("browser-autofill", "Configures browser to autofill the username and password. (env: SAML2AWS_BROWSER_AUTOFILL)").Envar("SAML2AWS_BROWSER_AUTOFILL").BoolVar(&commonFlags.BrowserAutoFill) app.Flag("mfa", "The name of the mfa. (env: SAML2AWS_MFA)").Envar("SAML2AWS_MFA").StringVar(&commonFlags.MFA) app.Flag("skip-verify", "Skip verification of server certificate. (env: SAML2AWS_SKIP_VERIFY)").Envar("SAML2AWS_SKIP_VERIFY").Short('s').BoolVar(&commonFlags.SkipVerify) - app.Flag("url", "The URL of the SAML IDP server used to login. (env: SAML2AWS_URL)").Envar("SAML2AWS_URL").StringVar(&commonFlags.URL) + app.Flag("url", "The URL of the AWS SAML IDP server used to login. (env: SAML2AWS_URL)").Envar("SAML2AWS_URL").StringVar(&commonFlags.URL) + app.Flag("tencentcloud-url", "The URL of the TencentCloud SAML IDP server used to login. (env: SAML2AWS_TENCENTCLOUD_URL)").Envar("SAML2AWS_TENCENTCLOUD_URL").StringVar(&commonFlags.TencentCloudURL) app.Flag("username", "The username used to login. (env: SAML2AWS_USERNAME)").Envar("SAML2AWS_USERNAME").StringVar(&commonFlags.Username) app.Flag("password", "The password used to login. (env: SAML2AWS_PASSWORD)").Envar("SAML2AWS_PASSWORD").StringVar(&commonFlags.Password) app.Flag("mfa-token", "The current MFA token (supported in Keycloak, ADFS, GoogleApps). (env: SAML2AWS_MFA_TOKEN)").Envar("SAML2AWS_MFA_TOKEN").StringVar(&commonFlags.MFAToken) diff --git a/go.mod b/go.mod index 7d738e995..490fd5553 100644 --- a/go.mod +++ b/go.mod @@ -50,12 +50,13 @@ require ( github.com/mtibben/percent v0.2.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/tencentcloud/tencentcloud-sdk-go-intl-en v3.0.961+incompatible // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.22.0 // indirect golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc // indirect - golang.org/x/sys v0.19.0 // indirect + golang.org/x/sys v0.20.0 // indirect golang.org/x/term v0.19.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index e40786d01..b87680ab6 100644 --- a/go.sum +++ b/go.sum @@ -179,6 +179,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tencentcloud/tencentcloud-sdk-go-intl-en v3.0.961+incompatible h1:F0j3fWCiiFZY4/zEzCCPZ0P9xwMub6LH7blNfoC5EWw= +github.com/tencentcloud/tencentcloud-sdk-go-intl-en v3.0.961+incompatible/go.mod h1:72Wo6Gt6F8d8V+njrAmduVoT9QjPwCyXktpqCWr7PUc= github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -247,8 +249,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= -golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= diff --git a/input.go b/input.go index f4c3dc00d..ba268af14 100644 --- a/input.go +++ b/input.go @@ -3,7 +3,6 @@ package saml2aws import ( "fmt" "log" - "sort" "github.com/pkg/errors" "github.com/versent/saml2aws/v2/pkg/cfg" @@ -40,6 +39,7 @@ func PromptForConfigurationDetails(idpAccount *cfg.IDPAccount) error { idpAccount.Profile = prompter.String("AWS Profile", idpAccount.Profile) idpAccount.URL = prompter.String("URL", idpAccount.URL) + idpAccount.TencentCloudURL = prompter.String("TencentCloud URL (Optional)", idpAccount.TencentCloudURL) idpAccount.Username = prompter.String("Username", idpAccount.Username) switch idpAccount.Provider { @@ -87,26 +87,21 @@ func PromptForLoginDetails(loginDetails *creds.LoginDetails, provider string) er return nil } -// PromptForAWSRoleSelection present a list of roles to the user for selection -func PromptForAWSRoleSelection(accounts []*AWSAccount) (*AWSRole, error) { +// PromptForCloudRoleSelections present a list of roles to the user for selection +func PromptForCloudRoleSelection(roles []*CloudRole) (*CloudRole, error) { - roles := map[string]*AWSRole{} - var roleOptions []string - - for _, account := range accounts { - for _, role := range account.Roles { - name := fmt.Sprintf("%s / %s", account.Name, role.Name) - roles[name] = role - roleOptions = append(roleOptions, name) - } + roleMap := make(map[string]*CloudRole) + roleOptions := make([]string, len(roles)) + for i, role := range roles { + name := fmt.Sprintf("%s %s / %s", role.Provider, role.Account, role.Name) + roleOptions[i] = name + roleMap[name] = role } - sort.Strings(roleOptions) - selectedRole, err := prompter.ChooseWithDefault("Please choose the role", roleOptions[0], roleOptions) if err != nil { return nil, errors.Wrap(err, "Role selection failed") } - return roles[selectedRole], nil + return roleMap[selectedRole], nil } diff --git a/pkg/cfg/cfg.go b/pkg/cfg/cfg.go index 01644303c..28694c422 100644 --- a/pkg/cfg/cfg.go +++ b/pkg/cfg/cfg.go @@ -37,6 +37,7 @@ type IDPAccount struct { Name string `ini:"name"` AppID string `ini:"app_id"` // used by OneLogin and AzureAD URL string `ini:"url"` + TencentCloudURL string `ini:"tencentcloud_url"` Username string `ini:"username"` Provider string `ini:"provider"` BrowserType string `ini:"browser_type,omitempty"` // used by 'Browser' Provider @@ -89,6 +90,7 @@ func (ia IDPAccount) String() string { return fmt.Sprintf(`account {%s%s%s URL: %s + TencentCloudURL: %s Username: %s Provider: %s MFA: %s @@ -98,7 +100,7 @@ func (ia IDPAccount) String() string { Profile: %s RoleARN: %s Region: %s -}`, appID, policyID, oktaCfg, ia.URL, ia.Username, ia.Provider, ia.MFA, ia.SkipVerify, ia.AmazonWebservicesURN, ia.SessionDuration, ia.Profile, ia.RoleARN, ia.Region) +}`, appID, policyID, oktaCfg, ia.URL, ia.TencentCloudURL, ia.Username, ia.Provider, ia.MFA, ia.SkipVerify, ia.AmazonWebservicesURN, ia.SessionDuration, ia.Profile, ia.RoleARN, ia.Region) } // Validate validate the required / expected fields are set diff --git a/pkg/cloud/types.go b/pkg/cloud/types.go new file mode 100644 index 000000000..d19b79d3d --- /dev/null +++ b/pkg/cloud/types.go @@ -0,0 +1,12 @@ +package cloud + +type Provider int + +const ( + AWS Provider = iota + TencentCloud +) + +func (p Provider) String() string { + return [...]string{"AWS", "TencentCloud"}[p] +} diff --git a/pkg/creds/creds.go b/pkg/creds/creds.go index 5aa697103..4936487fc 100644 --- a/pkg/creds/creds.go +++ b/pkg/creds/creds.go @@ -11,6 +11,7 @@ type LoginDetails struct { MFAToken string DuoMFAOption string URL string + TencentCloudURL string StateToken string // used by Okta OktaSessionCookie string // used by Okta } diff --git a/pkg/flags/flags.go b/pkg/flags/flags.go index 0e7607ef7..aa7365b5f 100644 --- a/pkg/flags/flags.go +++ b/pkg/flags/flags.go @@ -19,6 +19,7 @@ type CommonFlags struct { MFAIPAddress string MFAToken string URL string + TencentCloudURL string Username string Password string RoleArn string @@ -64,6 +65,10 @@ func ApplyFlagOverrides(commonFlags *CommonFlags, account *cfg.IDPAccount) { account.URL = commonFlags.URL } + if commonFlags.TencentCloudURL != "" { + account.TencentCloudURL = commonFlags.TencentCloudURL + } + if commonFlags.Username != "" { account.Username = commonFlags.Username } diff --git a/pkg/provider/keycloak/keycloak.go b/pkg/provider/keycloak/keycloak.go index c0395c3a6..0e694463a 100644 --- a/pkg/provider/keycloak/keycloak.go +++ b/pkg/provider/keycloak/keycloak.go @@ -3,6 +3,7 @@ package keycloak import ( "bytes" "encoding/base64" + "encoding/json" "fmt" "io" "log" @@ -16,6 +17,7 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/versent/saml2aws/v2/pkg/cfg" + "github.com/versent/saml2aws/v2/pkg/cloud" "github.com/versent/saml2aws/v2/pkg/creds" "github.com/versent/saml2aws/v2/pkg/prompter" "github.com/versent/saml2aws/v2/pkg/provider" @@ -57,57 +59,93 @@ func (kc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) return kc.doAuthenticate(&authContext{loginDetails.MFAToken, 0, true}, loginDetails) } +// NOTE(devsisters): Retrieve "multiple" SAML responses from all configured providers func (kc *Client) doAuthenticate(authCtx *authContext, loginDetails *creds.LoginDetails) (string, error) { - authSubmitURL, authForm, err := kc.getLoginForm(loginDetails) + awsAuthSubmitURL, awsAuthSubmitForm, err := kc.getLoginForm(loginDetails) if err != nil { - return "", errors.Wrap(err, "error retrieving login form from idp") + return "", errors.Wrap(err, "error retrieving aws login form from idp") } + if awsAuthSubmitURL == "" { + return "", errors.Wrap(err, "error retrieving aws login url from idp") + } + + // log.Println(awsAuthSubmitURL) - data, err := kc.postLoginForm(authSubmitURL, authForm) + awsdata, err := kc.postLoginForm(awsAuthSubmitURL, awsAuthSubmitForm) if err != nil { - return "", fmt.Errorf("error submitting login form") - } - if authSubmitURL == "" { - return "", fmt.Errorf("error submitting login form") + return "", errors.Wrap(err, "error submitting aws login form") } - doc, err := goquery.NewDocumentFromReader(bytes.NewBuffer(data)) + awsdoc, err := goquery.NewDocumentFromReader(bytes.NewBuffer(awsdata)) if err != nil { return "", errors.Wrap(err, "error parsing document") } - if containsTotpForm(doc) { - totpSubmitURL, err := extractSubmitURL(doc) + if containsTotpForm(awsdoc) { + totpSubmitURL, err := extractSubmitURL(awsdoc) if err != nil { return "", errors.Wrap(err, "unable to locate IDP totp form submit URL") } - doc, err = kc.postTotpForm(authCtx, totpSubmitURL, doc) + awsdoc, err = kc.postTotpForm(authCtx, totpSubmitURL, awsdoc) if err != nil { return "", errors.Wrap(err, "error posting totp form") } - } else if containsWebauthnForm(doc) { - credentialIDs, challenge, rpId, err := extractWebauthnParameters(doc) + } else if containsWebauthnForm(awsdoc) { + credentialIDs, challenge, rpId, err := extractWebauthnParameters(awsdoc) if err != nil { return "", errors.Wrap(err, "could not extract Webauthn parameters") } - webauthnSubmitURL, err := extractSubmitURL(doc) + webauthnSubmitURL, err := extractSubmitURL(awsdoc) if err != nil { return "", errors.Wrap(err, "unable to locate IDP Webauthn form submit URL") } - doc, err = kc.postWebauthnForm(webauthnSubmitURL, credentialIDs, challenge, rpId) + awsdoc, err = kc.postWebauthnForm(webauthnSubmitURL, credentialIDs, challenge, rpId) if err != nil { return "", errors.Wrap(err, "error posting Webauthn form") } } - samlResponse, err := extractSamlResponse(doc) - if err != nil && authCtx.authenticatorIndexValid && passwordValid(doc) { + awsSamlResponse, err := extractSamlResponse(awsdoc) + if err != nil && authCtx.authenticatorIndexValid && passwordValid(awsdoc) { return kc.doAuthenticate(authCtx, loginDetails) } - return samlResponse, err + // log.Println("SAML response successfully retrieved for AWS") + // log.Println(awsSamlResponse) + + // If configured, retrieve TencentCloud SAML Response with the same authCtx + if loginDetails.TencentCloudURL != "" { + tcdoc, err := kc.getAdditionalLoginForm(loginDetails.TencentCloudURL) + if err != nil { + return "", errors.Wrap(err, "error retrieving tencentcloud login form from idp") + } + + tcSamlResponse, err := extractSamlResponse(tcdoc) + if err != nil && authCtx.authenticatorIndexValid && passwordValid(tcdoc) { + return kc.doAuthenticate(authCtx, loginDetails) + } + // log.Println("SAML response successfully retrieved for TencentCloud") + // log.Println(tcSamlResponse) + + if awsSamlResponse == "" && tcSamlResponse == "" { + return "", errors.Wrap(err, "no SAML response retrieved from keycloak") + } + + // Return both AWS and TencentCloud SAML responses + samlResponses := make(map[cloud.Provider]string) + samlResponses[cloud.AWS] = awsSamlResponse + samlResponses[cloud.TencentCloud] = tcSamlResponse + jsonSamlResponses, err := json.Marshal(samlResponses) + if err != nil { + return "", errors.Wrap(err, "error marshalling SAML responses from keycloak") + } + return string(jsonSamlResponses), err + } + + // Return normally + return awsSamlResponse, err } func extractWebauthnParameters(doc *goquery.Document) (credentialIDs []string, challenge string, rpID string, err error) { @@ -182,6 +220,21 @@ func (kc *Client) getLoginForm(loginDetails *creds.LoginDetails) (string, url.Va return authSubmitURL, authForm, nil } +func (kc *Client) getAdditionalLoginForm(idpUrl string) (*goquery.Document, error) { + + res, err := kc.client.Get(idpUrl) + if err != nil { + return nil, errors.Wrap(err, "error retrieving second form") + } + + doc, err := goquery.NewDocumentFromReader(res.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to build document from response") + } + + return doc, nil +} + func (kc *Client) postLoginForm(authSubmitURL string, authForm url.Values) ([]byte, error) { req, err := http.NewRequest("POST", authSubmitURL, strings.NewReader(authForm.Encode())) diff --git a/pkg/tcconfig/tcconfig.go b/pkg/tcconfig/tcconfig.go new file mode 100644 index 000000000..db1c6458c --- /dev/null +++ b/pkg/tcconfig/tcconfig.go @@ -0,0 +1,118 @@ +package tcconfig + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + + "github.com/mitchellh/go-homedir" + "github.com/pkg/errors" +) + +var ( + ErrCredentialsHomeNotFound = errors.New("user home directory not found") + ErrCredentialsNotFound = errors.New("tc credentials not found") +) + +type TCCredentials struct { + SecretID string `json:"secretId,omitempty"` + SecretKey string `json:"secretKey,omitempty"` + Token string `json:"token,omitempty"` + Region string `json:"region,omitempty"` + Expires string `json:"x_security_token_expires,omitempty"` + PrincipalARN string `json:"-"` +} + +type CredentialsProvider struct { + Filename string + Profile string +} + +func NewSharedCredentials(profile string, filename string) *CredentialsProvider { + return &CredentialsProvider{ + Filename: filename, + Profile: profile, + } +} + +func (p *CredentialsProvider) Save(creds *TCCredentials) error { + filename, err := p.resolveFilename() + if err != nil { + return err + } + + if _, err = os.Stat(filename); err != nil { + if os.IsNotExist(err) { + dir := filepath.Dir(filename) + if err = os.MkdirAll(dir, os.ModePerm); err != nil { + return err + } + } + } + + bytes, err := json.MarshalIndent(creds, "", " ") + if err != nil { + return errors.Wrap(err, "unable to marshal credentials") + } + bytes = append(bytes, '\n') + + file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return errors.Wrap(err, "unable to load file") + } + if _, err := file.Write(bytes); err != nil { + return errors.Wrap(err, "unable to write to file") + } + defer file.Close() + + return nil +} + +func (p *CredentialsProvider) resolveFilename() (string, error) { + if p.Filename == "" { + filename, err := p.locateConfigFile() + if err != nil { + return "", err + } + p.Filename = filename + } + + return p.Filename, nil +} + +func (p *CredentialsProvider) locateConfigFile() (string, error) { + filename := os.Getenv("TENCENTCLOUD_CREDENTIALS_FILE") + if filename != "" { + return filename, nil + } + + // Default location for credentials file is ~/.tccli/{profile}.credentials + var name string + var err error + if runtime.GOOS == "windows" { + panic("error locating credentials file on windows: not implemented") + } else { + if name, err = homedir.Expand("~/.tccli/" + p.Profile + ".credential"); err != nil { + return "", ErrCredentialsHomeNotFound + } + // log.Println("config file:", name) + } + + if name, err = resolveSymlink(name); err != nil { + return "", errors.Wrap(err, "unable to resolve symlink") + } + + return name, nil +} + +func resolveSymlink(filename string) (string, error) { + sympath, err := filepath.EvalSymlinks(filename) + if os.IsNotExist(err) { + return filename, nil + } + if err != nil { + return "", err + } + return sympath, nil +} diff --git a/saml.go b/saml.go index 282aedd86..58f050392 100644 --- a/saml.go +++ b/saml.go @@ -52,7 +52,7 @@ func ExtractSessionDuration(data []byte) (int64, error) { // log.Printf("tag: %s", assertionElement.Tag) - //Get the actual assertion attributes + // Get the actual assertion attributes attributeStatement := assertionElement.FindElement(childPath(assertionElement.Space, attributeStatementTag)) if attributeStatement == nil { return 0, ErrMissingElement{Tag: attributeStatementTag} @@ -133,14 +133,13 @@ func ExtractMFATokenExpiryTime(data []byte) (time.Time, error) { return time.Parse(time.RFC3339, ValidUntilString) } -// ExtractAwsRoles given an assertion document extract the aws roles -func ExtractAwsRoles(data []byte) ([]string, error) { - - awsroles := []string{} +// ExtractCloudRoles given an assertion document extract the aws roles +func ExtractCloudRoles(data []byte) ([]string, error) { + cloudroles := []string{} doc := etree.NewDocument() if err := doc.ReadFromBytes(data); err != nil { - return awsroles, err + return cloudroles, err } // log.Printf("root tag: %s", doc.Root().Tag) @@ -152,7 +151,7 @@ func ExtractAwsRoles(data []byte) ([]string, error) { // log.Printf("tag: %s", assertionElement.Tag) - //Get the actual assertion attributes + // Get the actual assertion attributes attributeStatement := assertionElement.FindElement(childPath(assertionElement.Space, attributeStatementTag)) if attributeStatement == nil { return nil, ErrMissingElement{Tag: attributeStatementTag} @@ -162,22 +161,29 @@ func ExtractAwsRoles(data []byte) ([]string, error) { attributes := attributeStatement.FindElements(childPath(assertionElement.Space, attributeTag)) for _, attribute := range attributes { - if attribute.SelectAttrValue("Name", "") != "https://aws.amazon.com/SAML/Attributes/Role" { + if attribute.SelectAttrValue("Name", "") != "https://aws.amazon.com/SAML/Attributes/Role" && attribute.SelectAttrValue("Name", "") != "https://cloud.tencent.com/SAML/Attributes/Role" { continue } + + // if attribute.SelectAttrValue("Name", "") == "https://aws.amazon.com/SAML/Attributes/Role" { + // log.Printf("found aws assertion") + // } else if attribute.SelectAttrValue("Name", "") == "https://cloud.tencent.com/SAML/Attributes/Role" { + // log.Printf("found tencent assertion") + // } + atributeValues := attribute.FindElements(childPath(assertionElement.Space, attributeValueTag)) for _, attrValue := range atributeValues { - awsroles = append(awsroles, attrValue.Text()) + cloudroles = append(cloudroles, attrValue.Text()) } } - return awsroles, nil + return cloudroles, nil } func childPath(space, tag string) string { if space == "" { return "./" + tag } - //log.Printf("query = %s", "./"+space+":"+tag) + // log.Printf("query = %s", "./"+space+":"+tag) return "./" + space + ":" + tag } diff --git a/saml_test.go b/saml_test.go index 38338dcfb..1ed45e75a 100644 --- a/saml_test.go +++ b/saml_test.go @@ -12,7 +12,7 @@ func TestExtractAwsRoles(t *testing.T) { data, err := os.ReadFile("testdata/assertion.xml") assert.Nil(t, err) - roles, err := ExtractAwsRoles(data) + roles, err := ExtractCloudRoles(data) assert.Nil(t, err) assert.Len(t, roles, 2) } @@ -21,7 +21,7 @@ func TestExtractAwsRolesFail(t *testing.T) { data, err := os.ReadFile("testdata/notxml.xml") assert.Nil(t, err) - _, err = ExtractAwsRoles(data) + _, err = ExtractCloudRoles(data) assert.Error(t, err) }