Skip to content

Commit

Permalink
Merge pull request #168 from josephlr/cert_log
Browse files Browse the repository at this point in the history
client: Have client.Attest return an error on bad certs
  • Loading branch information
josephlr authored Mar 8, 2022
2 parents 580daaf + 1bfd1e9 commit 32227c4
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 47 deletions.
74 changes: 39 additions & 35 deletions client/attest.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ import (
"crypto/x509"
"fmt"
"io/ioutil"
"log"
"net/http"

pb "github.com/google/go-tpm-tools/proto/attest"
)

const (
maxIssuingCertificateURLs = 3
maxCertChainDepth = 4
maxCertChainLength = 4
)

// AttestOpts allows for customizing the functionality of Attest.
Expand All @@ -39,74 +38,75 @@ type AttestOpts struct {
FetchCertChain bool
}

// Given a certificate, iterates through its IssuingCertificateURLs and returns the certificate
// that signed it. If unable to find an intermediate certificate, it returns a nil.
func fetchIssuingCertificate(cert *x509.Certificate) *x509.Certificate {
if cert == nil {
return nil
// Given a certificate, iterates through its IssuingCertificateURLs and returns
// the certificate that signed it. If the certificate lacks an
// IssuingCertificateURL, return nil. If fetching the certificates fails or the
// cert chain is malformed, return an error.
func fetchIssuingCertificate(cert *x509.Certificate) (*x509.Certificate, error) {
// Check if we should event attempt fetching.
if cert == nil || len(cert.IssuingCertificateURL) == 0 {
return nil, nil
}
// For each URL, fetch and parse the certificate, then verify whether it signed cert.
// If successful, return the parsed certificate. If any step in this process fails, try the next url.
// If all the URLs fail, return the last error we got.
// TODO(Issue #169): Return a multi-error here
var lastErr error
for i, url := range cert.IssuingCertificateURL {
// Limit the number of attempts.
if i == maxIssuingCertificateURLs {
log.Printf("Reached the maximum number of attempts (%v)", maxIssuingCertificateURLs)
return nil
if i >= maxIssuingCertificateURLs {
break
}
resp, err := http.Get(url)
if err != nil {
log.Printf("failed to retrieve certificate at %v: %v\n", url, err)
lastErr = fmt.Errorf("failed to retrieve certificate at %v: %w", url, err)
continue
}

if resp.StatusCode != http.StatusOK {
log.Printf("certificate retrieval from %s returned non-OK status: %v\n", url, resp.StatusCode)
lastErr = fmt.Errorf("certificate retrieval from %s returned non-OK status: %v", url, resp.StatusCode)
continue
}
certBytes, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
log.Printf("failed to read response body from %s: %v\n", url, err)
lastErr = fmt.Errorf("failed to read response body from %s: %w", url, err)
continue
}

parsedCert, err := x509.ParseCertificate(certBytes)
if err != nil {
log.Printf("failed to parse response from %s into a certificate: %v\n", url, err)
lastErr = fmt.Errorf("failed to parse response from %s into a certificate: %w", url, err)
continue
}

// Check if the parsed certificate signed the current one.
if err = cert.CheckSignatureFrom(parsedCert); err == nil {
return parsedCert
if err = cert.CheckSignatureFrom(parsedCert); err != nil {
lastErr = fmt.Errorf("parent certificate from %s did not sign child: %w", url, err)
continue
}
return parsedCert, nil
}

log.Println("failed to find intermediate certificate")
return nil
return nil, lastErr
}

// Constructs the certificate chain for the key's certificate, using the provided HTTP client.
// Constructs the certificate chain for the key's certificate.
// If an error is encountered in the process, return what has been constructed so far.
func (k *Key) getCertificateChain() [][]byte {
func (k *Key) getCertificateChain() ([][]byte, error) {
var certs [][]byte
currentCert := k.cert
for i := 0; i < maxCertChainDepth; i++ {
issuingCert := fetchIssuingCertificate(currentCert)
for len(certs) <= maxCertChainLength {
issuingCert, err := fetchIssuingCertificate(currentCert)
if err != nil {
return nil, err
}
if issuingCert == nil {
break
return certs, nil
}

certs = append(certs, issuingCert.Raw)
// Stop searching if no IssuingCertificateURLs found.
if len(issuingCert.IssuingCertificateURL) == 0 {
break
}

currentCert = issuingCert
}

return certs
return nil, fmt.Errorf("max certificate chain length (%v) exceeded", maxCertChainLength)
}

// Attest generates an Attestation containing the TCG Event Log and a Quote over
Expand Down Expand Up @@ -146,9 +146,13 @@ func (k *Key) Attest(opts AttestOpts) (*pb.Attestation, error) {
attestation.CanonicalEventLog = opts.CanonicalEventLog
}

// Construct certficate chain if AK cert is present and contains intermediate cert URLs.
if opts.FetchCertChain && k.cert != nil && len(k.cert.IssuingCertificateURL) > 0 {
attestation.IntermediateCerts = k.getCertificateChain()
// Attempt to construct certificate chain. fetchIssuingCertificate checks if
// AK cert is present and contains intermediate cert URLs.
if opts.FetchCertChain {
attestation.IntermediateCerts, err = k.getCertificateChain()
if err != nil {
return nil, fmt.Errorf("fetching certificate chain: %w", err)
}
}

return &attestation, nil
Expand Down
5 changes: 4 additions & 1 deletion client/attest_network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ func TestNetworkFetchIssuingCertificate(t *testing.T) {

key := &Key{cert: akCert}

certChain := key.getCertificateChain()
certChain, err := key.getCertificateChain()
if err != nil {
t.Error(err)
}
if len(certChain) == 0 {
t.Error("Did not retrieve any certificates.")
}
Expand Down
19 changes: 11 additions & 8 deletions client/attest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ func TestFetchIssuingCertificateSucceeds(t *testing.T) {

leafCert, _ := getTestCert(t, []string{"invalid.URL", ts.URL}, testCA, caKey)

cert := fetchIssuingCertificate(leafCert)
if cert == nil {
t.Errorf("fetchIssuingCertificate() did not find valid intermediate cert")
cert, err := fetchIssuingCertificate(leafCert)
if err != nil || cert == nil {
t.Errorf("fetchIssuingCertificate() did not find valid intermediate cert: %v", err)
}
}

func TestFetchIssuingCertificateReturnsErrorIfNoValidCertificateFound(t *testing.T) {
func TestFetchIssuingCertificateReturnsErrorIfMalformedCertificateFound(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("these are some random bytes"))
Expand All @@ -78,9 +78,9 @@ func TestFetchIssuingCertificateReturnsErrorIfNoValidCertificateFound(t *testing
testCA, caKey := getTestCert(t, nil, nil, nil)
leafCert, _ := getTestCert(t, []string{ts.URL}, testCA, caKey)

cert := fetchIssuingCertificate(leafCert)
if cert != nil {
t.Error("fetchIssuingCertificate returned non-nil certificate, but expected nil.")
_, err := fetchIssuingCertificate(leafCert)
if err == nil {
t.Fatal("expected fetchIssuingCertificate to fail with malformed cert")
}
}

Expand Down Expand Up @@ -109,7 +109,10 @@ func TestGetCertificateChainSucceeds(t *testing.T) {

key := &Key{cert: leafCert}

certChain := key.getCertificateChain()
certChain, err := key.getCertificateChain()
if err != nil {
t.Fatal(err)
}
if len(certChain) != 2 {
t.Fatalf("getCertificateChain did not return the expected number of certificates: got %v, want 2", len(certChain))
}
Expand Down
5 changes: 2 additions & 3 deletions server/policy_constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
_ "embed" // Necessary to use go:embed
"errors"
"fmt"
"log"
"strconv"

"github.com/google/certificate-transparency-go/x509"
Expand Down Expand Up @@ -80,11 +79,11 @@ func init() {
var err error
GceEKRoots, err = getPool([][]byte{gceEKRootCA})
if err != nil {
log.Panicf("failed to create the root cert pool: %v", err)
panic(fmt.Sprintf("failed to create the root cert pool: %v", err))
}
GceEKIntermediates, err = getPool([][]byte{gceEKIntermediateCA2})
if err != nil {
log.Panicf("failed to create the intermediate cert pool: %v", err)
panic(fmt.Sprintf("failed to create the intermediate cert pool: %v", err))
}
}
func getPool(certs [][]byte) (*x509.CertPool, error) {
Expand Down

0 comments on commit 32227c4

Please sign in to comment.