Skip to content

Commit

Permalink
Allow SOPS to use custom AWS KMS and STS Endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
anandavj committed Nov 19, 2024
1 parent db5356d commit ef6ef4c
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 30 deletions.
38 changes: 35 additions & 3 deletions cmd/sops/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,14 @@ func main() {
Name: "aws-profile",
Usage: "The AWS profile to use for requests to AWS",
},
cli.StringFlag{
Name: "aws-kms-endpoint",
Usage: "The AWS KMS Endpoint to use for requests to AWS. Ex: https://kms.ap-southeast-2.amazonaws.com",
},
cli.StringFlag{
Name: "aws-sts-endpoint",
Usage: "The AWS STS Endpoint to use for requests to AWS. Ex: https://sts.ap-southeast-2.amazonaws.com",
},
cli.StringSliceFlag{
Name: "gcp-kms",
Usage: "the GCP KMS Resource ID the new group should contain. Can be specified more than once",
Expand Down Expand Up @@ -545,7 +553,7 @@ func main() {
group = append(group, pgp.NewMasterKeyFromFingerprint(fp))
}
for _, arn := range kmsArns {
group = append(group, kms.NewMasterKeyFromArn(arn, kms.ParseKMSContext(c.String("encryption-context")), c.String("aws-profile")))
group = append(group, kms.NewMasterKeyFromArn(arn, kms.ParseKMSContext(c.String("encryption-context")), c.String("aws-profile"), c.String("aws-kms-endpoint"), c.String("aws-sts-endpoint")))
}
for _, kms := range gcpKmses {
group = append(group, gcpkms.NewMasterKeyFromResourceID(kms))
Expand Down Expand Up @@ -852,6 +860,14 @@ func main() {
Name: "aws-profile",
Usage: "The AWS profile to use for requests to AWS",
},
cli.StringFlag{
Name: "aws-kms-endpoint",
Usage: "The AWS KMS Endpoint to use for requests to AWS",
},
cli.StringFlag{
Name: "aws-sts-endpoint",
Usage: "The AWS STS Endpoint to use for requests to AWS",
},
cli.StringFlag{
Name: "gcp-kms",
Usage: "comma separated list of GCP KMS resource IDs",
Expand Down Expand Up @@ -1169,6 +1185,14 @@ func main() {
Name: "aws-profile",
Usage: "The AWS profile to use for requests to AWS",
},
cli.StringFlag{
Name: "aws-kms-endpoint",
Usage: "The AWS KMS Endpoint to use for requests to AWS",
},
cli.StringFlag{
Name: "aws-sts-endpoint",
Usage: "The AWS STS Endpoint to use for requests to AWS",
},
cli.StringFlag{
Name: "gcp-kms",
Usage: "comma separated list of GCP KMS resource IDs",
Expand Down Expand Up @@ -1529,6 +1553,14 @@ func main() {
Name: "aws-profile",
Usage: "The AWS profile to use for requests to AWS",
},
cli.StringFlag{
Name: "aws-kms-endpoint",
Usage: "The AWS KMS Endpoint to use for requests to AWS",
},
cli.StringFlag{
Name: "aws-sts-endpoint",
Usage: "The AWS STS Endpoint to use for requests to AWS",
},
cli.StringFlag{
Name: "gcp-kms",
Usage: "comma separated list of GCP KMS resource IDs",
Expand Down Expand Up @@ -2006,7 +2038,7 @@ func getEncryptConfig(c *cli.Context, fileName string) (encryptConfig, error) {

func getMasterKeys(c *cli.Context, kmsEncryptionContext map[string]*string, kmsOptionName string, pgpOptionName string, gcpKmsOptionName string, azureKvOptionName string, hcVaultTransitOptionName string, ageOptionName string) ([]keys.MasterKey, error) {
var masterKeys []keys.MasterKey
for _, k := range kms.MasterKeysFromArnString(c.String(kmsOptionName), kmsEncryptionContext, c.String("aws-profile")) {
for _, k := range kms.MasterKeysFromArnString(c.String(kmsOptionName), kmsEncryptionContext, c.String("aws-profile"), c.String("aws-kms-endpoint"), c.String("aws-sts-endpoint")) {
masterKeys = append(masterKeys, k)
}
for _, k := range pgp.MasterKeysFromFingerprintString(c.String(pgpOptionName)) {
Expand Down Expand Up @@ -2185,7 +2217,7 @@ func keyGroups(c *cli.Context, file string) ([]sops.KeyGroup, error) {
return nil, common.NewExitError("Invalid KMS encryption context format", codes.ErrorInvalidKMSEncryptionContextFormat)
}
if c.String("kms") != "" {
for _, k := range kms.MasterKeysFromArnString(c.String("kms"), kmsEncryptionContext, c.String("aws-profile")) {
for _, k := range kms.MasterKeysFromArnString(c.String("kms"), kmsEncryptionContext, c.String("aws-profile"), c.String("aws-kms-endpoint"), c.String("aws-sts-endpoint")) {
kmsKeys = append(kmsKeys, k)
}
}
Expand Down
8 changes: 6 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ type kmsKey struct {
Role string `yaml:"role,omitempty"`
Context map[string]*string `yaml:"context"`
AwsProfile string `yaml:"aws_profile"`
AwsKmsEndpoint string `yaml:"aws_kms_endpoint"`
AwsStsEndpoint string `yaml:"aws_sts_endpoint"`
}

type azureKVKey struct {
Expand All @@ -138,6 +140,8 @@ type creationRule struct {
PathRegex string `yaml:"path_regex"`
KMS string
AwsProfile string `yaml:"aws_profile"`
AwsKmsEndpoint string `yaml:"aws_kms_endpoint"`
AwsStsEndpoint string `yaml:"aws_sts_endpoint"`
Age string `yaml:"age"`
PGP string
GCPKMS string `yaml:"gcp_kms"`
Expand Down Expand Up @@ -226,7 +230,7 @@ func extractMasterKeys(group keyGroup) (sops.KeyGroup, error) {
keyGroup = append(keyGroup, pgp.NewMasterKeyFromFingerprint(k))
}
for _, k := range group.KMS {
keyGroup = append(keyGroup, kms.NewMasterKeyWithProfile(k.Arn, k.Role, k.Context, k.AwsProfile))
keyGroup = append(keyGroup, kms.NewMasterKeyWithProfile(k.Arn, k.Role, k.Context, k.AwsProfile, k.AwsKmsEndpoint, k.AwsStsEndpoint))
}
for _, k := range group.GCPKMS {
keyGroup = append(keyGroup, gcpkms.NewMasterKeyFromResourceID(k.ResourceID))
Expand Down Expand Up @@ -269,7 +273,7 @@ func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[
for _, k := range pgp.MasterKeysFromFingerprintString(cRule.PGP) {
keyGroup = append(keyGroup, k)
}
for _, k := range kms.MasterKeysFromArnString(cRule.KMS, kmsEncryptionContext, cRule.AwsProfile) {
for _, k := range kms.MasterKeysFromArnString(cRule.KMS, kmsEncryptionContext, cRule.AwsProfile, cRule.AwsKmsEndpoint, cRule.AwsStsEndpoint) {
keyGroup = append(keyGroup, k)
}
for _, k := range gcpkms.MasterKeysFromResourceIDString(cRule.GCPKMS) {
Expand Down
2 changes: 2 additions & 0 deletions keyservice/keyservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ func KeyFromMasterKey(mk keys.MasterKey) Key {
Role: mk.Role,
Context: ctx,
AwsProfile: mk.AwsProfile,
AwsKmsEndpoint: mk.AwsKmsEndpoint,
AwsStsEndpoint: mk.AwsStsEndpoint,
},
},
}
Expand Down
2 changes: 2 additions & 0 deletions keyservice/keyservice.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions keyservice/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,5 +323,7 @@ func kmsKeyToMasterKey(key *KmsKey) kms.MasterKey {
Role: key.Role,
EncryptionContext: ctx,
AwsProfile: key.AwsProfile,
AwsKmsEndpoint: key.AwsKmsEndpoint,
AwsStsEndpoint: key.AwsStsEndpoint,
}
}
37 changes: 23 additions & 14 deletions kms/keysource.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,35 +73,38 @@ type MasterKey struct {
// using CredentialsProvider.ApplyToMasterKey. If nil, the default client is used
// which utilizes runtime environmental values.
credentialsProvider aws.CredentialsProvider
// baseEndpoint can be used to override the endpoint the AWS client resolves
// to by default. This is mostly used for testing purposes as it can not be
// injected using e.g. an environment variable. The field is not publicly
// exposed, nor configurable.
baseEndpoint string
// AwsKmsEndpoint can be used to override the endpoint the AWS client resolves
// to by default. This is mostly used for custom AWS that has custom endpoint.
AwsKmsEndpoint string
// AwsStsEndpoint can be used to override the endpoint the AWS client resolves
// to by default. This is mostly used for custom AWS that has custom endpoint.
AwsStsEndpoint string
}

// NewMasterKey creates a new MasterKey from an ARN, role and context, setting
// the creation date to the current date.
func NewMasterKey(arn string, role string, context map[string]*string) *MasterKey {
func NewMasterKey(arn string, role string, context map[string]*string, awsKmsEndpoint string, awsStsEndpoint string) *MasterKey {
return &MasterKey{
Arn: arn,
Role: role,
EncryptionContext: context,
CreationDate: time.Now().UTC(),
AwsKmsEndpoint: awsKmsEndpoint,
AwsStsEndpoint: awsStsEndpoint,
}
}

// NewMasterKeyWithProfile creates a new MasterKey from an ARN, role, context
// and awsProfile, setting the creation date to the current date.
func NewMasterKeyWithProfile(arn string, role string, context map[string]*string, awsProfile string) *MasterKey {
k := NewMasterKey(arn, role, context)
func NewMasterKeyWithProfile(arn string, role string, context map[string]*string, awsProfile string, awsKmsEndpoint string, awsStsEndpoint string) *MasterKey {
k := NewMasterKey(arn, role, context, awsKmsEndpoint, awsStsEndpoint)
k.AwsProfile = awsProfile
return k
}

// NewMasterKeyFromArn takes an ARN string and returns a new MasterKey for that
// ARN.
func NewMasterKeyFromArn(arn string, context map[string]*string, awsProfile string) *MasterKey {
func NewMasterKeyFromArn(arn string, context map[string]*string, awsProfile string, awsKmsEndpoint string, awsStsEndpoint string) *MasterKey {
key := &MasterKey{}
arn = strings.Replace(arn, " ", "", -1)
key.Arn = arn
Expand All @@ -114,18 +117,20 @@ func NewMasterKeyFromArn(arn string, context map[string]*string, awsProfile stri
key.EncryptionContext = context
key.CreationDate = time.Now().UTC()
key.AwsProfile = awsProfile
key.AwsKmsEndpoint = awsKmsEndpoint
key.AwsStsEndpoint = awsStsEndpoint
return key
}

// MasterKeysFromArnString takes a comma separated list of AWS KMS ARNs, and
// returns a slice of new MasterKeys for those ARNs.
func MasterKeysFromArnString(arn string, context map[string]*string, awsProfile string) []*MasterKey {
func MasterKeysFromArnString(arn string, context map[string]*string, awsProfile string, awsKmsEndpoint string, awsStsEndpoint string) []*MasterKey {
var keys []*MasterKey
if arn == "" {
return keys
}
for _, s := range strings.Split(arn, ",") {
keys = append(keys, NewMasterKeyFromArn(s, context, awsProfile))
keys = append(keys, NewMasterKeyFromArn(s, context, awsProfile, awsKmsEndpoint, awsStsEndpoint))
}
return keys
}
Expand Down Expand Up @@ -340,8 +345,8 @@ func (key MasterKey) createKMSConfig() (*aws.Config, error) {
// createClient creates a new AWS KMS client with the provided config.
func (key MasterKey) createClient(config *aws.Config) *kms.Client {
return kms.NewFromConfig(*config, func(o *kms.Options) {
if key.baseEndpoint != "" {
o.BaseEndpoint = aws.String(key.baseEndpoint)
if key.AwsKmsEndpoint != "" {
o.BaseEndpoint = aws.String(key.AwsKmsEndpoint)
}
})
}
Expand All @@ -359,7 +364,11 @@ func (key MasterKey) createSTSConfig(config *aws.Config) (*aws.Config, error) {
RoleSessionName: &name,
}

client := sts.NewFromConfig(*config)
client := sts.NewFromConfig(*config, func(o *sts.Options) {
if key.AwsStsEndpoint != "" {
o.BaseEndpoint = aws.String(key.AwsStsEndpoint)
}
})
out, err := client.AssumeRole(context.TODO(), input)
if err != nil {
return nil, fmt.Errorf("failed to assume role '%s': %w", key.Role, err)
Expand Down
22 changes: 11 additions & 11 deletions kms/keysource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func TestNewMasterKey(t *testing.T) {
"foo": aws.String("bar"),
}
)
key := NewMasterKey(dummyARN, dummyRole, dummyEncryptionContext)
key := NewMasterKey(dummyARN, dummyRole, dummyEncryptionContext, "", "")
assert.Equal(t, dummyARN, key.Arn)
assert.Equal(t, dummyRole, key.Role)
assert.Equal(t, dummyEncryptionContext, key.EncryptionContext)
Expand All @@ -131,7 +131,7 @@ func TestNewMasterKeyWithProfile(t *testing.T) {
}
dummyProfile = "a-profile"
)
key := NewMasterKeyWithProfile(dummyARN, dummyRole, dummyEncryptionContext, dummyProfile)
key := NewMasterKeyWithProfile(dummyARN, dummyRole, dummyEncryptionContext, dummyProfile, "", "")
assert.Equal(t, dummyARN, key.Arn)
assert.Equal(t, dummyRole, key.Role)
assert.Equal(t, dummyEncryptionContext, key.EncryptionContext)
Expand All @@ -147,7 +147,7 @@ func TestNewMasterKeyFromArn(t *testing.T) {
}
dummyProfile = "a-profile"
)
key := NewMasterKeyFromArn(dummyARN, dummyEncryptionContext, dummyProfile)
key := NewMasterKeyFromArn(dummyARN, dummyEncryptionContext, dummyProfile, "", "")
assert.Equal(t, dummyARN, key.Arn)
assert.Equal(t, dummyEncryptionContext, key.EncryptionContext)
assert.Equal(t, dummyProfile, key.AwsProfile)
Expand All @@ -156,20 +156,20 @@ func TestNewMasterKeyFromArn(t *testing.T) {
})

t.Run("arn with spaces", func(t *testing.T) {
key := NewMasterKeyFromArn(" arn:aws:kms:us-west-2 :107501996527:key/612d5f 0p-p1l3-45e6-aca6-a5b00569 3a48 ", nil, "")
key := NewMasterKeyFromArn(" arn:aws:kms:us-west-2 :107501996527:key/612d5f 0p-p1l3-45e6-aca6-a5b00569 3a48 ", nil, "", "", "")
assert.Equal(t, "arn:aws:kms:us-west-2:107501996527:key/612d5f0p-p1l3-45e6-aca6-a5b005693a48", key.Arn)
})

t.Run("arn with role", func(t *testing.T) {
key := NewMasterKeyFromArn("arn:aws:kms:us-west-2:927034868273:key/fe86dd69-4132-404c-ab86-4269956b4500+arn:aws:iam::927034868273:role/sops-dev-xyz", nil, "")
key := NewMasterKeyFromArn("arn:aws:kms:us-west-2:927034868273:key/fe86dd69-4132-404c-ab86-4269956b4500+arn:aws:iam::927034868273:role/sops-dev-xyz", nil, "", "", "")
assert.Equal(t, "arn:aws:kms:us-west-2:927034868273:key/fe86dd69-4132-404c-ab86-4269956b4500", key.Arn)
assert.Equal(t, "arn:aws:iam::927034868273:role/sops-dev-xyz", key.Role)
})
}

func TestMasterKeysFromArnString(t *testing.T) {
s := "arn:aws:kms:us-east-1:656532927350:key/920aff2e-c5f1-4040-943a-047fa387b27e+arn:aws:iam::927034868273:role/sops-dev, arn:aws:kms:ap-southeast-1:656532927350:key/9006a8aa-0fa6-4c14-930e-a2dfb916de1d"
ks := MasterKeysFromArnString(s, nil, "foo")
ks := MasterKeysFromArnString(s, nil, "foo", "", "")
k1 := ks[0]
k2 := ks[1]

Expand Down Expand Up @@ -359,15 +359,15 @@ func TestMasterKey_EncryptDecrypt_RoundTrip(t *testing.T) {
}

func TestMasterKey_NeedsRotation(t *testing.T) {
key := NewMasterKeyFromArn(dummyARN, nil, "")
key := NewMasterKeyFromArn(dummyARN, nil, "", "", "")
assert.False(t, key.NeedsRotation())

key.CreationDate = key.CreationDate.Add(-(kmsTTL + time.Second))
assert.True(t, key.NeedsRotation())
}

func TestMasterKey_ToString(t *testing.T) {
key := NewMasterKeyFromArn(dummyARN, nil, "")
key := NewMasterKeyFromArn(dummyARN, nil, "", "", "")
assert.Equal(t, dummyARN, key.ToString())
}

Expand Down Expand Up @@ -518,15 +518,15 @@ func TestMasterKey_createSTSConfig(t *testing.T) {
err = fmt.Errorf("an error")
return
}
key := NewMasterKeyFromArn(dummyARN, nil, "")
key := NewMasterKeyFromArn(dummyARN, nil, "", "", "")
cfg, err := key.createSTSConfig(nil)
assert.Error(t, err)
assert.ErrorContains(t, err, "failed to construct STS session name")
assert.Nil(t, cfg)
})

t.Run("role assumption error", func(t *testing.T) {
key := NewMasterKeyFromArn(dummyARN, nil, "")
key := NewMasterKeyFromArn(dummyARN, nil, "", "", "")
key.Role = "role"
got, err := key.createSTSConfig(&aws.Config{})
assert.Error(t, err)
Expand Down Expand Up @@ -592,7 +592,7 @@ func createTestMasterKey(arn string) MasterKey {
return MasterKey{
Arn: arn,
credentialsProvider: credentials.NewStaticCredentialsProvider("id", "secret", ""),
baseEndpoint: testKMSServerURL,
AwsKmsEndpoint: testKMSServerURL,
}
}

Expand Down
6 changes: 6 additions & 0 deletions stores/stores.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ type kmskey struct {
CreatedAt string `yaml:"created_at" json:"created_at"`
EncryptedDataKey string `yaml:"enc" json:"enc"`
AwsProfile string `yaml:"aws_profile" json:"aws_profile"`
AwsKmsEndpoint string `yaml:"aws_kms_endpoint" json:"aws_kms_endpoint"`
AwsStsEndpoint string `yaml:"aws_sts_endpoint" json:"aws_sts_endpoint"`
}

type gcpkmskey struct {
Expand Down Expand Up @@ -175,6 +177,8 @@ func kmsKeysFromGroup(group sops.KeyGroup) (keys []kmskey) {
Context: key.EncryptionContext,
Role: key.Role,
AwsProfile: key.AwsProfile,
AwsKmsEndpoint: key.AwsKmsEndpoint,
AwsStsEndpoint: key.AwsStsEndpoint,
})
}
}
Expand Down Expand Up @@ -376,6 +380,8 @@ func (kmsKey *kmskey) toInternal() (*kms.MasterKey, error) {
CreationDate: creationDate,
Arn: kmsKey.Arn,
AwsProfile: kmsKey.AwsProfile,
AwsKmsEndpoint: kmsKey.AwsKmsEndpoint,
AwsStsEndpoint: kmsKey.AwsStsEndpoint,
}, nil
}

Expand Down

0 comments on commit ef6ef4c

Please sign in to comment.