Skip to content

Commit

Permalink
add fake AS and oauth2, done unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruide committed Feb 7, 2024
1 parent 9c99095 commit 0ea1235
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 18 deletions.
80 changes: 80 additions & 0 deletions cmd/fake_attestation_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package cmd

import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"time"

"github.com/golang-jwt/jwt/v4"
"golang.org/x/net/http2"
)

const fakeAsHostEnv = "GOOGLE_APPLICATION_CREDENTIALS"

// attestationServer provides fake implementation for the GCE attestation server.
type attestationServer struct {
server *httptest.Server
oldFakeAsHostEnv string
}

type fakeOidcTokenPayload struct {
Audience string `json:"aud"`
IssuedAt int64 `json:"iat"`
ExpiredAt int64 `json:"exp"`
}

func (payload *fakeOidcTokenPayload) Valid() error {
return nil
}

func newMockAttestationServer() (*attestationServer, error) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
locationPath := "/v1/projects/test-project/locations/us-central"
if r.URL.Path == locationPath {
location := "{\n \"name\": \"projects/test-project/locations/us-central-1\",\n \"locationId\": \"us-central-1\"\n}\n"
w.Write([]byte(location))
}
challengePath := locationPath + "-1/challenges"
if r.URL.Path == challengePath {
challenge := "{\n \"name\": \"projects/test-project/locations/us-central-1/challenges/947b4f7b-e6d4-4cfe-971c-39ffe00268ba\",\n \"createTime\": \"2023-09-21T01:04:48.230111757Z\",\n \"expireTime\": \"2023-09-21T02:04:48.230111757Z\",\n \"tpmNonce\": \"R29vZ0F0dGVzdFYxeGtJUGlRejFPOFRfTzg4QTRjdjRpQQ==\"\n}\n"
w.Write([]byte(challenge))
}
challengeNonce := "/947b4f7b-e6d4-4cfe-971c-39ffe00268ba"
verifyAttestationPath := challengePath + challengeNonce + ":verifyAttestation"
if r.URL.Path == verifyAttestationPath {
payload := &fakeOidcTokenPayload{
Audience: "test",
IssuedAt: time.Now().Unix(),
ExpiredAt: time.Now().Add(time.Minute).Unix(),
}
jwtTokenUnsigned := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
jwtToken, err := jwtTokenUnsigned.SignedString([]byte("kcxjxnalpraetgccnnwhpnfwocxscaih"))
if err != nil {
fmt.Print("error creating test OIDC token")
}
w.Write([]byte("{\n \"oidcClaimsToken\": \"" + jwtToken + "\"\n}\n"))
}
})
httpServer := httptest.NewUnstartedServer(handler)
if err := http2.ConfigureServer(httpServer.Config, new(http2.Server)); err != nil {
return nil, fmt.Errorf("failed to configure HTTP/2 server: %v", err)
}
httpServer.Start()

old := os.Getenv(fakeAsHostEnv)
cwd, err := os.Getwd()
if err != nil {
return nil, err
}
os.Setenv(fakeAsHostEnv, cwd+"/testdata/credentials")

return &attestationServer{oldFakeAsHostEnv: old, server: httpServer}, nil
}

// Stop shuts down the server.
func (s *attestationServer) Stop() {
os.Setenv(fakeAsHostEnv, s.oldFakeAsHostEnv)
s.server.Close()
}
1 change: 1 addition & 0 deletions cmd/fake_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func NewMetadataServer(data Instance) (*MetadataServer, error) {
resp["instance/id"] = data.InstanceID
resp["instance/zone"] = data.Zone
resp["instance/name"] = data.InstanceName
resp["instance/service-accounts/default/identity"] = "test_jwt_token"

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := "/computeMetadata/v1/"
Expand Down
32 changes: 32 additions & 0 deletions cmd/fake_oauth2_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package cmd

import (
"net/http"
"net/http/httptest"
)

type oauth2Server struct {
server *httptest.Server
}

func newMockOauth2Server() *oauth2Server {
mux := http.NewServeMux()
mux.HandleFunc("/o/oauth2/auth", func(w http.ResponseWriter, r *http.Request) {
// Unimplemented: Should return authorization code back to the user
})

mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
// Should return acccess token back to the user
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
w.Write([]byte("access_token=mocktoken&scope=user&token_type=bearer"))
})

server := httptest.NewServer(mux)

return &oauth2Server{server: server}
}

// Stop shuts down the server.
func (s *oauth2Server) Stop() {
s.server.Close()
}
21 changes: 3 additions & 18 deletions cmd/gen_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ The OIDC token includes claims regarding the authentication of the user by the a
return tokens, nil
}

// TODO: make this an optional flag for generalization
if asAddress == "" {
asAddress = "https://confidentialcomputing.googleapis.com"
}
Expand All @@ -86,22 +85,10 @@ The OIDC token includes claims regarding the authentication of the user by the a
return fmt.Errorf("failed to create REST verifier client: %v", err)
}

// Now only supports GCE VM. Hard code the AK type.
// supports GCE VM. Hard code the AK type.
key = "gceAK"
fmt.Fprintf(debugOutput(), "key is set to gceAK\n")

// Set AK (EK signing) cert
if key == "AK" {
ak, err := client.AttestationKeyECC(rwc)
if err != nil {
return err
}
if ak.Cert() == nil {
return errors.New("failed to find AKCert on this VM: try creating a new VM or contacting support")
}
ak.Close()
}

// Set GCE AK (EK signing) cert
if key == "gceAK" {
var gceAK *client.Key
Expand All @@ -123,7 +110,7 @@ The OIDC token includes claims regarding the authentication of the user by the a

attestAgent := agent.CreateAttestationAgent(rwc, attestationKeys[key][keyAlgo], verifierClient, principalFetcher)

fmt.Fprintf(messageOutput(), "Fetching attestation verifier OIDC token")
fmt.Fprintf(messageOutput(), "Fetching attestation verifier OIDC token\n")
token, err := attestAgent.Attest(ctx)
if err != nil {
return fmt.Errorf("failed to retrieve attestation service token: %v", err)
Expand Down Expand Up @@ -153,7 +140,7 @@ The OIDC token includes claims regarding the authentication of the user by the a
}

if output == "" {
fmt.Fprintf(messageOutput(), string(claimsString))
fmt.Fprintf(messageOutput(), string(claimsString)+"\n")
}

if output != "" {
Expand Down Expand Up @@ -205,8 +192,6 @@ func init() {
addOutputFlag(gentokenCmd)
addPublicKeyAlgoFlag(gentokenCmd)
addAsAdressFlag(gentokenCmd)
// TODO: Alow AK certificate from other parties than gceAK
// addKeyFlag(gentokenCmd)
// TODO: Add TEE hardware OIDC token generation
// addTeeNonceflag(gentokenCmd)
// addTeeTechnology(gentokenCmd)
Expand Down
174 changes: 174 additions & 0 deletions cmd/gen_token_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,175 @@
package cmd

import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"io"
"math/big"
"os"
"testing"
"time"

"github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/internal/test"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/google/go-tpm/tpmutil"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)

func TestGenTokenWithGCEAK(t *testing.T) {
rwc := test.GetTPM(t)
defer client.CheckedClose(t, rwc)
ExternalTPM = rwc
secretFile1 := makeOutputFile(t, "gentoken")
defer os.RemoveAll(secretFile1)
var template = map[string]tpm2.Public{
"rsa": GCEAKTemplateRSA(),
"ecc": GCEAKTemplateECC(),
}
tests := []struct {
name string
algo string
}{
{"gceAK:RSA", "rsa"},
{"gceAK:ECC", "ecc"},
}
for _, op := range tests {
t.Run(op.name, func(t *testing.T) {
gceAkTemplate, err := template[op.algo].Encode()
if err != nil {
t.Fatalf("failed to encode GCEAKTemplateRSA: %v", err)
}
err = setGCEAKCertTemplate(t, rwc, op.algo, gceAkTemplate)
if err != nil {
t.Error(err)
}
defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(getIndex[op.algo]))
defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(getCertIndex[op.algo]))

var dummyMetaInstance = Instance{ProjectID: "test-project", ProjectNumber: "1922337278274", Zone: "us-central-1a", InstanceID: "12345678", InstanceName: "default"}
mockMdsServer, err := NewMetadataServer(dummyMetaInstance)
if err != nil {
t.Error(err)
}
defer mockMdsServer.Stop()

mockOauth2Server := newMockOauth2Server()
defer mockOauth2Server.Stop()

// Endpoint is Google's OAuth 2.0 default endpoint. Change to mock server.
google.Endpoint = oauth2.Endpoint{
AuthURL: mockOauth2Server.server.URL + "/o/oauth2/auth",
TokenURL: mockOauth2Server.server.URL + "/token",
AuthStyle: oauth2.AuthStyleInParams,
}

mockAttestationServer, err := newMockAttestationServer()
if err != nil {
t.Error(err)
}
defer mockAttestationServer.Stop()

RootCmd.SetArgs([]string{"gentoken", "--algo", op.algo, "--output", secretFile1, "--asAddr", mockAttestationServer.server.URL})
if err := RootCmd.Execute(); err != nil {
t.Error(err)
}
})
}
}

// Need to call tpm2.NVUndefinespace twice on the handle with authHandle tpm2.HandlePlatform.
// e.g defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(client.GceAKTemplateNVIndexRSA))
// defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(client.GceAKCertNVIndexRSA))
func setGCEAKCertTemplate(tb testing.TB, rwc io.ReadWriteCloser, algo string, akTemplate []byte) error {
var err error
// Write AK template to NV memory
if err := tpm2.NVDefineSpace(rwc, tpm2.HandlePlatform, tpmutil.Handle(getIndex[algo]),
"", "", nil,
tpm2.AttrPPWrite|tpm2.AttrPPRead|tpm2.AttrWriteDefine|tpm2.AttrOwnerRead|tpm2.AttrAuthRead|tpm2.AttrPlatformCreate|tpm2.AttrNoDA,
uint16(len(akTemplate))); err != nil {
tb.Fatalf("NVDefineSpace failed: %v", err)
}
err = tpm2.NVWrite(rwc, tpm2.HandlePlatform, tpmutil.Handle(getIndex[algo]), "", akTemplate, 0)
if err != nil {
tb.Fatalf("failed to write NVIndex: %v", err)
}

// create self-signed AK cert
getAttestationKeyFunc := getAttestationKey[algo]
attestKey, err := getAttestationKeyFunc(rwc)
if err != nil {
tb.Fatalf("Unable to create key: %v", err)
}
defer attestKey.Close()
// create self-signed Root CA
ca, caKey := getTestCert(tb, nil, nil, nil)
// sign the attestation key certificate
akCert, _ := getTestCert(tb, attestKey.PublicKey(), ca, caKey)
if err = attestKey.SetCert(akCert); err != nil {
tb.Errorf("SetCert() returned error: %v", err)
}

// write test AK cert.
// size need to be less than 1024 (MAX_NV_BUFFER_SIZE). If not, split before write.
certASN1 := akCert.Raw
// write to gceAK slot in NV memory
if err := tpm2.NVDefineSpace(rwc, tpm2.HandlePlatform, tpmutil.Handle(getCertIndex[algo]),
"", "", nil,
tpm2.AttrPPWrite|tpm2.AttrPPRead|tpm2.AttrWriteDefine|tpm2.AttrOwnerRead|tpm2.AttrAuthRead|tpm2.AttrPlatformCreate|tpm2.AttrNoDA,
uint16(len(certASN1))); err != nil {
tb.Fatalf("NVDefineSpace failed: %v", err)
}
err = tpm2.NVWrite(rwc, tpm2.HandlePlatform, tpmutil.Handle(getCertIndex[algo]), "", certASN1, 0)
if err != nil {
tb.Fatalf("failed to write NVIndex: %v", err)
}

return nil
}

var getCertIndex = map[string]uint32{
"rsa": client.GceAKCertNVIndexRSA,
"ecc": client.GceAKCertNVIndexECC,
}

var getAttestationKey = map[string]func(rw io.ReadWriter) (*client.Key, error){
"rsa": client.GceAttestationKeyRSA,
"ecc": client.GceAttestationKeyECC,
}

// Returns an x509 Certificate for the provided pubkey, signed with the provided parent certificate and key.
// If the provided fields are nil, will create a self-signed certificate.
func getTestCert(tb testing.TB, pubKey crypto.PublicKey, parentCert *x509.Certificate, parentKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey) {
certKey, _ := rsa.GenerateKey(rand.Reader, 2048)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
MaxPathLenZero: true,
}

if pubKey == nil && parentCert == nil && parentKey == nil {
pubKey = certKey.Public()
parentCert = template
parentKey = certKey
}

certBytes, err := x509.CreateCertificate(rand.Reader, template, parentCert, pubKey, parentKey)
if err != nil {
tb.Fatalf("Unable to create test certificate: %v", err)
}

cert, err := x509.ParseCertificate(certBytes)
if err != nil {
tb.Fatalf("Unable to parse test certificate: %v", err)
}

return cert, certKey
}
6 changes: 6 additions & 0 deletions cmd/testdata/credentials
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"client_id": "id",
"client_secret": "testdata",
"refresh_token": "testdata",
"type": "authorized_user"
}

0 comments on commit 0ea1235

Please sign in to comment.