diff --git a/kms/awskms/signer.go b/kms/awskms/signer.go index e31535a3..feb26c01 100644 --- a/kms/awskms/signer.go +++ b/kms/awskms/signer.go @@ -15,6 +15,48 @@ import ( "go.step.sm/crypto/pemutil" ) +// AWSOptions implements the crypto.SignerOpts interface, it provides a Raw +// boolean field to indicate to the AWS KMS operation that the MessageType is +// RAW. +// +// Example: +// +// // Sign a raw message with KMS +// client := kms.NewFromConfig(cfg) +// kmsSigner, err := awskms.NewSigner(client, "my-key-id") +// if err != nil { +// // handle error ... +// } +// raw := []byte("my raw message") +// sig, err := kmsSigner.Sign(rand.Reader, raw, &awskms.AWSOptions{ +// Raw: true, +// Options: crypto.SHA256, +// }) +// if err != nil { +// // handle error ... +// } +type AWSOptions struct { + // Raw specifies to the AWS KMS operation that MessageType is RAW. + Raw bool + Options crypto.SignerOpts +} + +// HashFunc implements crypto.SignerOpts. +func (a *AWSOptions) HashFunc() crypto.Hash { + // The GoLang [crypto.SignerOpt] interfaces states that if the [HashFunc] + // returns 0, then it indicates to the [Sign] function that no hashing + // has occurred over the message. + // However, the AWS KMS Sign operation always requires that a + // SigningAlgorithm is specified. + // As such, the AWSOptions HashFunc() must return a valid (non-zero) Hash, + // such that the [getMessageTypeAndSigningAlgorithm] function can return a valid AWS KMS + // [types.SigningAlgorithmSpec] + return a.Options.HashFunc() +} + +// compile time check that AWSOptions implements crypto.SignerOpts +var _ crypto.SignerOpts = (*AWSOptions)(nil) + // Signer implements a crypto.Signer using the AWS KMS. type Signer struct { client KeyManagementClient @@ -63,7 +105,7 @@ func (s *Signer) Public() crypto.PublicKey { // Sign signs digest with the private key stored in the AWS KMS. func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - alg, err := getSigningAlgorithm(s.Public(), opts) + messageType, alg, err := getMessageTypeAndSigningAlgorithm(s.Public(), opts) if err != nil { return nil, err } @@ -72,7 +114,7 @@ func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byt KeyId: pointer(s.keyID), SigningAlgorithm: alg, Message: digest, - MessageType: types.MessageTypeDigest, + MessageType: messageType, } ctx, cancel := defaultContext() @@ -86,41 +128,49 @@ func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byt return resp.Signature, nil } -func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (types.SigningAlgorithmSpec, error) { +func getMessageTypeAndSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (types.MessageType, types.SigningAlgorithmSpec, error) { + messageType := types.MessageTypeDigest + if awsOpts, ok := opts.(*AWSOptions); ok { + if awsOpts.Raw { + messageType = types.MessageTypeRaw + } + opts = awsOpts.Options + } + switch key.(type) { case *rsa.PublicKey: _, isPSS := opts.(*rsa.PSSOptions) switch h := opts.HashFunc(); h { case crypto.SHA256: if isPSS { - return types.SigningAlgorithmSpecRsassaPssSha256, nil + return messageType, types.SigningAlgorithmSpecRsassaPssSha256, nil } - return types.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil + return messageType, types.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil case crypto.SHA384: if isPSS { - return types.SigningAlgorithmSpecRsassaPssSha384, nil + return messageType, types.SigningAlgorithmSpecRsassaPssSha384, nil } - return types.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil + return messageType, types.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil case crypto.SHA512: if isPSS { - return types.SigningAlgorithmSpecRsassaPssSha512, nil + return messageType, types.SigningAlgorithmSpecRsassaPssSha512, nil } - return types.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil + return messageType, types.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil default: - return "", errors.Errorf("unsupported hash function %v", h) + return messageType, "", errors.Errorf("unsupported hash function %v", h) } case *ecdsa.PublicKey: switch h := opts.HashFunc(); h { case crypto.SHA256: - return types.SigningAlgorithmSpecEcdsaSha256, nil + return messageType, types.SigningAlgorithmSpecEcdsaSha256, nil case crypto.SHA384: - return types.SigningAlgorithmSpecEcdsaSha384, nil + return messageType, types.SigningAlgorithmSpecEcdsaSha384, nil case crypto.SHA512: - return types.SigningAlgorithmSpecEcdsaSha512, nil + return messageType, types.SigningAlgorithmSpecEcdsaSha512, nil default: - return "", errors.Errorf("unsupported hash function %v", h) + return messageType, "", errors.Errorf("unsupported hash function %v", h) } default: - return "", errors.Errorf("unsupported key type %T", key) + return messageType, "", errors.Errorf("unsupported key type %T", key) } } diff --git a/kms/awskms/signer_test.go b/kms/awskms/signer_test.go index 6d7edea0..6b9c5199 100644 --- a/kms/awskms/signer_test.go +++ b/kms/awskms/signer_test.go @@ -125,7 +125,9 @@ func TestSigner_Sign(t *testing.T) { wantErr bool }{ {"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, signature, false}, - {"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true}, + {"(raw) ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), &AWSOptions{Raw: true, Options: crypto.SHA256}}, signature, false}, + {"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), &AWSOptions{Raw: true, Options: crypto.MD5}}, nil, true}, + {"(raw) fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true}, {"fail key", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", []byte("key")}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true}, {"fail sign", fields{&MockClient{ sign: func(ctx context.Context, input *kms.SignInput, opts ...func(*kms.Options)) (*kms.SignOutput, error) { @@ -152,39 +154,52 @@ func TestSigner_Sign(t *testing.T) { } } -func Test_getSigningAlgorithm(t *testing.T) { +func Test_getMessageTypeAndSigningAlgorithm(t *testing.T) { type args struct { key crypto.PublicKey opts crypto.SignerOpts } tests := []struct { - name string - args args - want types.SigningAlgorithmSpec - wantErr bool + name string + args args + wantMessageType types.MessageType + wantAlgo types.SigningAlgorithmSpec + wantErr bool }{ - {"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, "RSASSA_PKCS1_V1_5_SHA_256", false}, - {"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, "RSASSA_PKCS1_V1_5_SHA_384", false}, - {"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, "RSASSA_PKCS1_V1_5_SHA_512", false}, - {"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, "RSASSA_PSS_SHA_256", false}, - {"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, "RSASSA_PSS_SHA_384", false}, - {"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, "RSASSA_PSS_SHA_512", false}, - {"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, "ECDSA_SHA_256", false}, - {"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, "ECDSA_SHA_384", false}, - {"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, "ECDSA_SHA_512", false}, - {"fail type", args{[]byte("key"), crypto.SHA256}, "", true}, - {"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, "", true}, - {"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, "", true}, + {"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, types.MessageTypeDigest, "RSASSA_PKCS1_V1_5_SHA_256", false}, + {"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, types.MessageTypeDigest, "RSASSA_PKCS1_V1_5_SHA_384", false}, + {"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, types.MessageTypeDigest, "RSASSA_PKCS1_V1_5_SHA_512", false}, + {"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, types.MessageTypeDigest, "RSASSA_PSS_SHA_256", false}, + {"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, types.MessageTypeDigest, "RSASSA_PSS_SHA_384", false}, + {"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, types.MessageTypeDigest, "RSASSA_PSS_SHA_512", false}, + {"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, types.MessageTypeDigest, "ECDSA_SHA_256", false}, + {"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, types.MessageTypeDigest, "ECDSA_SHA_384", false}, + {"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, types.MessageTypeDigest, "ECDSA_SHA_512", false}, + {"(raw)rsa+sha256", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA256}}, types.MessageTypeRaw, "RSASSA_PKCS1_V1_5_SHA_256", false}, + {"(raw)rsa+sha384", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA384}}, types.MessageTypeRaw, "RSASSA_PKCS1_V1_5_SHA_384", false}, + {"(raw)rsa+sha512", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA512}}, types.MessageTypeRaw, "RSASSA_PKCS1_V1_5_SHA_512", false}, + {"(raw)pssrsa+sha256", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}}, types.MessageTypeRaw, "RSASSA_PSS_SHA_256", false}, + {"(raw)pssrsa+sha384", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}}, types.MessageTypeRaw, "RSASSA_PSS_SHA_384", false}, + {"(raw)pssrsa+sha512", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}}, types.MessageTypeRaw, "RSASSA_PSS_SHA_512", false}, + {"(raw)P256", args{&ecdsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA256}}, types.MessageTypeRaw, "ECDSA_SHA_256", false}, + {"(raw)P384", args{&ecdsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA384}}, types.MessageTypeRaw, "ECDSA_SHA_384", false}, + {"(raw)P521", args{&ecdsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA512}}, types.MessageTypeRaw, "ECDSA_SHA_512", false}, + {"fail type", args{[]byte("key"), crypto.SHA256}, types.MessageTypeDigest, "", true}, + {"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, types.MessageTypeDigest, "", true}, + {"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, types.MessageTypeDigest, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := getSigningAlgorithm(tt.args.key, tt.args.opts) + gotMessageType, gotAlgo, err := getMessageTypeAndSigningAlgorithm(tt.args.key, tt.args.opts) if (err != nil) != tt.wantErr { - t.Errorf("getSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("getMessageTypeAndSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr) return } - if got != tt.want { - t.Errorf("getSigningAlgorithm() = %v, want %v", got, tt.want) + if gotMessageType != tt.wantMessageType { + t.Errorf("getMessageTypeAndSigningAlgorithm() (message type) = %v, want %v", gotMessageType, tt.wantMessageType) + } + if gotAlgo != tt.wantAlgo { + t.Errorf("getMessageTypeAndSigningAlgorithm() (algorithm) = %v, want %v", gotAlgo, tt.wantAlgo) } }) }