From 0ea12351fd4811c2c20236a99a6a5266dce65a55 Mon Sep 17 00:00:00 2001 From: Ruide Zhang Date: Thu, 21 Sep 2023 13:56:06 -0700 Subject: [PATCH] add fake AS and oauth2, done unit test --- cmd/fake_attestation_server.go | 80 +++++++++++++++ cmd/fake_metadata.go | 1 + cmd/fake_oauth2_server.go | 32 ++++++ cmd/gen_token.go | 21 +--- cmd/gen_token_test.go | 174 +++++++++++++++++++++++++++++++++ cmd/testdata/credentials | 6 ++ 6 files changed, 296 insertions(+), 18 deletions(-) create mode 100644 cmd/fake_attestation_server.go create mode 100644 cmd/fake_oauth2_server.go create mode 100644 cmd/testdata/credentials diff --git a/cmd/fake_attestation_server.go b/cmd/fake_attestation_server.go new file mode 100644 index 00000000..df896983 --- /dev/null +++ b/cmd/fake_attestation_server.go @@ -0,0 +1,80 @@ +package cmd + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "time" + + "github.com/golang-jwt/jwt/v4" + "golang.org/x/net/http2" +) + +const fakeAsHostEnv = "GOOGLE_APPLICATION_CREDENTIALS" + +// attestationServer provides fake implementation for the GCE attestation server. +type attestationServer struct { + server *httptest.Server + oldFakeAsHostEnv string +} + +type fakeOidcTokenPayload struct { + Audience string `json:"aud"` + IssuedAt int64 `json:"iat"` + ExpiredAt int64 `json:"exp"` +} + +func (payload *fakeOidcTokenPayload) Valid() error { + return nil +} + +func newMockAttestationServer() (*attestationServer, error) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + locationPath := "/v1/projects/test-project/locations/us-central" + if r.URL.Path == locationPath { + location := "{\n \"name\": \"projects/test-project/locations/us-central-1\",\n \"locationId\": \"us-central-1\"\n}\n" + w.Write([]byte(location)) + } + challengePath := locationPath + "-1/challenges" + if r.URL.Path == challengePath { + challenge := "{\n \"name\": \"projects/test-project/locations/us-central-1/challenges/947b4f7b-e6d4-4cfe-971c-39ffe00268ba\",\n \"createTime\": \"2023-09-21T01:04:48.230111757Z\",\n \"expireTime\": \"2023-09-21T02:04:48.230111757Z\",\n \"tpmNonce\": \"R29vZ0F0dGVzdFYxeGtJUGlRejFPOFRfTzg4QTRjdjRpQQ==\"\n}\n" + w.Write([]byte(challenge)) + } + challengeNonce := "/947b4f7b-e6d4-4cfe-971c-39ffe00268ba" + verifyAttestationPath := challengePath + challengeNonce + ":verifyAttestation" + if r.URL.Path == verifyAttestationPath { + payload := &fakeOidcTokenPayload{ + Audience: "test", + IssuedAt: time.Now().Unix(), + ExpiredAt: time.Now().Add(time.Minute).Unix(), + } + jwtTokenUnsigned := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) + jwtToken, err := jwtTokenUnsigned.SignedString([]byte("kcxjxnalpraetgccnnwhpnfwocxscaih")) + if err != nil { + fmt.Print("error creating test OIDC token") + } + w.Write([]byte("{\n \"oidcClaimsToken\": \"" + jwtToken + "\"\n}\n")) + } + }) + httpServer := httptest.NewUnstartedServer(handler) + if err := http2.ConfigureServer(httpServer.Config, new(http2.Server)); err != nil { + return nil, fmt.Errorf("failed to configure HTTP/2 server: %v", err) + } + httpServer.Start() + + old := os.Getenv(fakeAsHostEnv) + cwd, err := os.Getwd() + if err != nil { + return nil, err + } + os.Setenv(fakeAsHostEnv, cwd+"/testdata/credentials") + + return &attestationServer{oldFakeAsHostEnv: old, server: httpServer}, nil +} + +// Stop shuts down the server. +func (s *attestationServer) Stop() { + os.Setenv(fakeAsHostEnv, s.oldFakeAsHostEnv) + s.server.Close() +} diff --git a/cmd/fake_metadata.go b/cmd/fake_metadata.go index 46a6a2a2..f1b378a6 100644 --- a/cmd/fake_metadata.go +++ b/cmd/fake_metadata.go @@ -37,6 +37,7 @@ func NewMetadataServer(data Instance) (*MetadataServer, error) { resp["instance/id"] = data.InstanceID resp["instance/zone"] = data.Zone resp["instance/name"] = data.InstanceName + resp["instance/service-accounts/default/identity"] = "test_jwt_token" handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { path := "/computeMetadata/v1/" diff --git a/cmd/fake_oauth2_server.go b/cmd/fake_oauth2_server.go new file mode 100644 index 00000000..9c17dbb9 --- /dev/null +++ b/cmd/fake_oauth2_server.go @@ -0,0 +1,32 @@ +package cmd + +import ( + "net/http" + "net/http/httptest" +) + +type oauth2Server struct { + server *httptest.Server +} + +func newMockOauth2Server() *oauth2Server { + mux := http.NewServeMux() + mux.HandleFunc("/o/oauth2/auth", func(w http.ResponseWriter, r *http.Request) { + // Unimplemented: Should return authorization code back to the user + }) + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + // Should return acccess token back to the user + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + w.Write([]byte("access_token=mocktoken&scope=user&token_type=bearer")) + }) + + server := httptest.NewServer(mux) + + return &oauth2Server{server: server} +} + +// Stop shuts down the server. +func (s *oauth2Server) Stop() { + s.server.Close() +} diff --git a/cmd/gen_token.go b/cmd/gen_token.go index dc34ab77..f6f98e95 100644 --- a/cmd/gen_token.go +++ b/cmd/gen_token.go @@ -65,7 +65,6 @@ The OIDC token includes claims regarding the authentication of the user by the a return tokens, nil } - // TODO: make this an optional flag for generalization if asAddress == "" { asAddress = "https://confidentialcomputing.googleapis.com" } @@ -86,22 +85,10 @@ The OIDC token includes claims regarding the authentication of the user by the a return fmt.Errorf("failed to create REST verifier client: %v", err) } - // Now only supports GCE VM. Hard code the AK type. + // supports GCE VM. Hard code the AK type. key = "gceAK" fmt.Fprintf(debugOutput(), "key is set to gceAK\n") - // Set AK (EK signing) cert - if key == "AK" { - ak, err := client.AttestationKeyECC(rwc) - if err != nil { - return err - } - if ak.Cert() == nil { - return errors.New("failed to find AKCert on this VM: try creating a new VM or contacting support") - } - ak.Close() - } - // Set GCE AK (EK signing) cert if key == "gceAK" { var gceAK *client.Key @@ -123,7 +110,7 @@ The OIDC token includes claims regarding the authentication of the user by the a attestAgent := agent.CreateAttestationAgent(rwc, attestationKeys[key][keyAlgo], verifierClient, principalFetcher) - fmt.Fprintf(messageOutput(), "Fetching attestation verifier OIDC token") + fmt.Fprintf(messageOutput(), "Fetching attestation verifier OIDC token\n") token, err := attestAgent.Attest(ctx) if err != nil { return fmt.Errorf("failed to retrieve attestation service token: %v", err) @@ -153,7 +140,7 @@ The OIDC token includes claims regarding the authentication of the user by the a } if output == "" { - fmt.Fprintf(messageOutput(), string(claimsString)) + fmt.Fprintf(messageOutput(), string(claimsString)+"\n") } if output != "" { @@ -205,8 +192,6 @@ func init() { addOutputFlag(gentokenCmd) addPublicKeyAlgoFlag(gentokenCmd) addAsAdressFlag(gentokenCmd) - // TODO: Alow AK certificate from other parties than gceAK - // addKeyFlag(gentokenCmd) // TODO: Add TEE hardware OIDC token generation // addTeeNonceflag(gentokenCmd) // addTeeTechnology(gentokenCmd) diff --git a/cmd/gen_token_test.go b/cmd/gen_token_test.go index 1d619dd0..21227684 100644 --- a/cmd/gen_token_test.go +++ b/cmd/gen_token_test.go @@ -1 +1,175 @@ package cmd + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "io" + "math/big" + "os" + "testing" + "time" + + "github.com/google/go-tpm-tools/client" + "github.com/google/go-tpm-tools/internal/test" + "github.com/google/go-tpm/legacy/tpm2" + "github.com/google/go-tpm/tpmutil" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +func TestGenTokenWithGCEAK(t *testing.T) { + rwc := test.GetTPM(t) + defer client.CheckedClose(t, rwc) + ExternalTPM = rwc + secretFile1 := makeOutputFile(t, "gentoken") + defer os.RemoveAll(secretFile1) + var template = map[string]tpm2.Public{ + "rsa": GCEAKTemplateRSA(), + "ecc": GCEAKTemplateECC(), + } + tests := []struct { + name string + algo string + }{ + {"gceAK:RSA", "rsa"}, + {"gceAK:ECC", "ecc"}, + } + for _, op := range tests { + t.Run(op.name, func(t *testing.T) { + gceAkTemplate, err := template[op.algo].Encode() + if err != nil { + t.Fatalf("failed to encode GCEAKTemplateRSA: %v", err) + } + err = setGCEAKCertTemplate(t, rwc, op.algo, gceAkTemplate) + if err != nil { + t.Error(err) + } + defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(getIndex[op.algo])) + defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(getCertIndex[op.algo])) + + var dummyMetaInstance = Instance{ProjectID: "test-project", ProjectNumber: "1922337278274", Zone: "us-central-1a", InstanceID: "12345678", InstanceName: "default"} + mockMdsServer, err := NewMetadataServer(dummyMetaInstance) + if err != nil { + t.Error(err) + } + defer mockMdsServer.Stop() + + mockOauth2Server := newMockOauth2Server() + defer mockOauth2Server.Stop() + + // Endpoint is Google's OAuth 2.0 default endpoint. Change to mock server. + google.Endpoint = oauth2.Endpoint{ + AuthURL: mockOauth2Server.server.URL + "/o/oauth2/auth", + TokenURL: mockOauth2Server.server.URL + "/token", + AuthStyle: oauth2.AuthStyleInParams, + } + + mockAttestationServer, err := newMockAttestationServer() + if err != nil { + t.Error(err) + } + defer mockAttestationServer.Stop() + + RootCmd.SetArgs([]string{"gentoken", "--algo", op.algo, "--output", secretFile1, "--asAddr", mockAttestationServer.server.URL}) + if err := RootCmd.Execute(); err != nil { + t.Error(err) + } + }) + } +} + +// Need to call tpm2.NVUndefinespace twice on the handle with authHandle tpm2.HandlePlatform. +// e.g defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(client.GceAKTemplateNVIndexRSA)) +// defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(client.GceAKCertNVIndexRSA)) +func setGCEAKCertTemplate(tb testing.TB, rwc io.ReadWriteCloser, algo string, akTemplate []byte) error { + var err error + // Write AK template to NV memory + if err := tpm2.NVDefineSpace(rwc, tpm2.HandlePlatform, tpmutil.Handle(getIndex[algo]), + "", "", nil, + tpm2.AttrPPWrite|tpm2.AttrPPRead|tpm2.AttrWriteDefine|tpm2.AttrOwnerRead|tpm2.AttrAuthRead|tpm2.AttrPlatformCreate|tpm2.AttrNoDA, + uint16(len(akTemplate))); err != nil { + tb.Fatalf("NVDefineSpace failed: %v", err) + } + err = tpm2.NVWrite(rwc, tpm2.HandlePlatform, tpmutil.Handle(getIndex[algo]), "", akTemplate, 0) + if err != nil { + tb.Fatalf("failed to write NVIndex: %v", err) + } + + // create self-signed AK cert + getAttestationKeyFunc := getAttestationKey[algo] + attestKey, err := getAttestationKeyFunc(rwc) + if err != nil { + tb.Fatalf("Unable to create key: %v", err) + } + defer attestKey.Close() + // create self-signed Root CA + ca, caKey := getTestCert(tb, nil, nil, nil) + // sign the attestation key certificate + akCert, _ := getTestCert(tb, attestKey.PublicKey(), ca, caKey) + if err = attestKey.SetCert(akCert); err != nil { + tb.Errorf("SetCert() returned error: %v", err) + } + + // write test AK cert. + // size need to be less than 1024 (MAX_NV_BUFFER_SIZE). If not, split before write. + certASN1 := akCert.Raw + // write to gceAK slot in NV memory + if err := tpm2.NVDefineSpace(rwc, tpm2.HandlePlatform, tpmutil.Handle(getCertIndex[algo]), + "", "", nil, + tpm2.AttrPPWrite|tpm2.AttrPPRead|tpm2.AttrWriteDefine|tpm2.AttrOwnerRead|tpm2.AttrAuthRead|tpm2.AttrPlatformCreate|tpm2.AttrNoDA, + uint16(len(certASN1))); err != nil { + tb.Fatalf("NVDefineSpace failed: %v", err) + } + err = tpm2.NVWrite(rwc, tpm2.HandlePlatform, tpmutil.Handle(getCertIndex[algo]), "", certASN1, 0) + if err != nil { + tb.Fatalf("failed to write NVIndex: %v", err) + } + + return nil +} + +var getCertIndex = map[string]uint32{ + "rsa": client.GceAKCertNVIndexRSA, + "ecc": client.GceAKCertNVIndexECC, +} + +var getAttestationKey = map[string]func(rw io.ReadWriter) (*client.Key, error){ + "rsa": client.GceAttestationKeyRSA, + "ecc": client.GceAttestationKeyECC, +} + +// Returns an x509 Certificate for the provided pubkey, signed with the provided parent certificate and key. +// If the provided fields are nil, will create a self-signed certificate. +func getTestCert(tb testing.TB, pubKey crypto.PublicKey, parentCert *x509.Certificate, parentKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey) { + certKey, _ := rsa.GenerateKey(rand.Reader, 2048) + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLenZero: true, + } + + if pubKey == nil && parentCert == nil && parentKey == nil { + pubKey = certKey.Public() + parentCert = template + parentKey = certKey + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, parentCert, pubKey, parentKey) + if err != nil { + tb.Fatalf("Unable to create test certificate: %v", err) + } + + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + tb.Fatalf("Unable to parse test certificate: %v", err) + } + + return cert, certKey +} diff --git a/cmd/testdata/credentials b/cmd/testdata/credentials new file mode 100644 index 00000000..229c3227 --- /dev/null +++ b/cmd/testdata/credentials @@ -0,0 +1,6 @@ +{ + "client_id": "id", + "client_secret": "testdata", + "refresh_token": "testdata", + "type": "authorized_user" +} \ No newline at end of file