diff --git a/tpm2tools/signer.go b/tpm2tools/signer.go index bcd10b728..61d35113f 100644 --- a/tpm2tools/signer.go +++ b/tpm2tools/signer.go @@ -65,18 +65,7 @@ func (signer *tpmSigner) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts if err != nil { return nil, err } - - switch sig.Alg { - case tpm2.AlgRSASSA: - return sig.RSA.Signature, nil - case tpm2.AlgRSAPSS: - return sig.RSA.Signature, nil - case tpm2.AlgECDSA: - sigStruct := struct{ R, S *big.Int }{sig.ECC.R, sig.ECC.S} - return asn1.Marshal(sigStruct) - default: - panic("unsupported signing algorithm") - } + return getSignature(sig) } // GetSigner returns a crypto.Signer wrapping the loaded TPM Key. @@ -100,6 +89,47 @@ func (k *Key) GetSigner() (crypto.Signer, error) { return &tpmSigner{k, hash}, nil } +// SignData signs a data buffer with a TPM loaded key. Unlike GetSigner, this +// method works with restricted and unrestricted keys. If this method is called +// on a restriced key, the TPM itself will hash the provided data, failing the +// signing operation if the data begins with TPM_GENERATED_VALUE. +func (k *Key) SignData(data []byte) ([]byte, error) { + hashAlg, err := getSigningHashAlg(k) + if err != nil { + return nil, err + } + + var digest []byte + var ticket *tpm2.Ticket + if k.hasAttribute(tpm2.FlagRestricted) { + // Restricted keys can only sign data hashed by the TPM. We use the + // owner hierarchy for the Ticket, but any non-Null hierarchy would do. + digest, ticket, err = tpm2.Hash(k.rw, hashAlg, data, tpm2.HandleOwner) + if err != nil { + return nil, err + } + } else { + // Unrestricted keys can sign any digest, no need for TPM hashing. + hash, err := hashAlg.Hash() + if err != nil { + return nil, err + } + hasher := hash.New() + hasher.Write(data) + digest = hasher.Sum(nil) + } + + auth, err := k.session.Auth() + if err != nil { + return nil, err + } + sig, err := tpm2.SignWithSession(k.rw, auth.Session, k.handle, "", digest, ticket, nil) + if err != nil { + return nil, err + } + return getSignature(sig) +} + func getSigningHashAlg(k *Key) (tpm2.Algorithm, error) { if !k.hasAttribute(tpm2.FlagSign) { return tpm2.AlgNull, fmt.Errorf("non-signing key used with signing operation") @@ -125,3 +155,17 @@ func getSigningHashAlg(k *Key) (tpm2.Algorithm, error) { return tpm2.AlgNull, fmt.Errorf("unsupported signing algorithm: %v", sigScheme.Alg) } } + +func getSignature(sig *tpm2.Signature) ([]byte, error) { + switch sig.Alg { + case tpm2.AlgRSASSA: + return sig.RSA.Signature, nil + case tpm2.AlgRSAPSS: + return sig.RSA.Signature, nil + case tpm2.AlgECDSA: + sigStruct := struct{ R, S *big.Int }{sig.ECC.R, sig.ECC.S} + return asn1.Marshal(sigStruct) + default: + return nil, fmt.Errorf("unsupported signing algorithm: %v", sig.Alg) + } +} diff --git a/tpm2tools/signer_test.go b/tpm2tools/signer_test.go index e6977f49c..29bbea07f 100644 --- a/tpm2tools/signer_test.go +++ b/tpm2tools/signer_test.go @@ -84,7 +84,14 @@ func TestSign(t *testing.T) { {"Auth-ECC", crypto.SHA256, templateAuthECC(), verifyECC}, } + message := []byte("authenticated message") + // Data beginning with TPM_GENERATED_VALUE (looks like a TPM-internal message) + generatedMsg := append([]byte("\xffTCG"), message...) for _, test := range tests { + hash := test.hash.New() + hash.Write(message) + digest := hash.Sum(nil) + t.Run(test.name, func(t *testing.T) { key, err := NewKey(rwc, tpm2.HandleEndorsement, test.template) if err != nil { @@ -92,15 +99,10 @@ func TestSign(t *testing.T) { } defer key.Close() - hash := test.hash.New() - hash.Write([]byte("authenticated message")) - digest := hash.Sum(nil) - signer, err := key.GetSigner() if err != nil { t.Fatal(err) } - sig, err := signer.Sign(nil, digest, test.hash) if err != nil { t.Fatal(err) @@ -109,6 +111,48 @@ func TestSign(t *testing.T) { t.Error(err) } }) + t.Run(test.name+"-SignData", func(t *testing.T) { + key, err := NewKey(rwc, tpm2.HandleEndorsement, test.template) + if err != nil { + t.Fatal(err) + } + defer key.Close() + + sig, err := key.SignData(message) + if err != nil { + t.Fatal(err) + } + if !test.verify(key.PublicKey(), test.hash, digest, sig) { + t.Error(err) + } + + // Unrestricted keys can sign data beginning with TPM_GENERATED_VALUE + if _, err = key.SignData(generatedMsg); err != nil { + t.Error(err) + } + }) + t.Run(test.name+"-SignDataRestricted", func(t *testing.T) { + restrictedTemplate := test.template + restrictedTemplate.Attributes |= tpm2.FlagRestricted + key, err := NewKey(rwc, tpm2.HandleEndorsement, restrictedTemplate) + if err != nil { + t.Fatal(err) + } + defer key.Close() + + sig, err := key.SignData(message) + if err != nil { + t.Fatal(err) + } + if !test.verify(key.PublicKey(), test.hash, digest, sig) { + t.Error(err) + } + + // Restricted keys cannot sign data beginning with TPM_GENERATED_VALUE + if _, err = key.SignData(generatedMsg); err == nil { + t.Error("Signing TPM_GENERATED_VALUE data should fail") + } + }) } }