Skip to content

Commit

Permalink
[cmd] Add new command token in the CLI tool (#375)
Browse files Browse the repository at this point in the history
Command Description: Fetch an attestation report from GCE VM vTPM and send it to Google Attestation Service for an OIDC token.

This command improves usability for a GCE VM user.
  • Loading branch information
Ruide authored Feb 27, 2024
1 parent acbae2f commit b22bad0
Show file tree
Hide file tree
Showing 8 changed files with 502 additions and 9 deletions.
80 changes: 80 additions & 0 deletions cmd/fake_attestation_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package cmd

import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"time"

"github.com/golang-jwt/jwt/v4"
"golang.org/x/net/http2"
)

const fakeAsHostEnv = "GOOGLE_APPLICATION_CREDENTIALS"

// attestationServer provides fake implementation for the GCE attestation server.
type attestationServer struct {
server *httptest.Server
oldFakeAsHostEnv string
}

type fakeOidcTokenPayload struct {
Audience string `json:"aud"`
IssuedAt int64 `json:"iat"`
ExpiredAt int64 `json:"exp"`
}

func (payload *fakeOidcTokenPayload) Valid() error {
return nil
}

func newMockAttestationServer() (*attestationServer, error) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
locationPath := "/v1/projects/test-project/locations/us-central"
if r.URL.Path == locationPath {
location := "{\n \"name\": \"projects/test-project/locations/us-central-1\",\n \"locationId\": \"us-central-1\"\n}\n"
w.Write([]byte(location))
}
challengePath := locationPath + "-1/challenges"
if r.URL.Path == challengePath {
challenge := "{\n \"name\": \"projects/test-project/locations/us-central-1/challenges/947b4f7b-e6d4-4cfe-971c-39ffe00268ba\",\n \"createTime\": \"2023-09-21T01:04:48.230111757Z\",\n \"expireTime\": \"2023-09-21T02:04:48.230111757Z\",\n \"tpmNonce\": \"R29vZ0F0dGVzdFYxeGtJUGlRejFPOFRfTzg4QTRjdjRpQQ==\"\n}\n"
w.Write([]byte(challenge))
}
challengeNonce := "/947b4f7b-e6d4-4cfe-971c-39ffe00268ba"
verifyAttestationPath := challengePath + challengeNonce + ":verifyAttestation"
if r.URL.Path == verifyAttestationPath {
payload := &fakeOidcTokenPayload{
Audience: "test",
IssuedAt: time.Now().Unix(),
ExpiredAt: time.Now().Add(time.Minute).Unix(),
}
jwtTokenUnsigned := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
jwtToken, err := jwtTokenUnsigned.SignedString([]byte("kcxjxnalpraetgccnnwhpnfwocxscaih"))
if err != nil {
fmt.Print("error creating test OIDC token")
}
w.Write([]byte("{\n \"oidcClaimsToken\": \"" + jwtToken + "\"\n}\n"))
}
})
httpServer := httptest.NewUnstartedServer(handler)
if err := http2.ConfigureServer(httpServer.Config, new(http2.Server)); err != nil {
return nil, fmt.Errorf("failed to configure HTTP/2 server: %v", err)
}
httpServer.Start()

old := os.Getenv(fakeAsHostEnv)
cwd, err := os.Getwd()
if err != nil {
return nil, err
}
os.Setenv(fakeAsHostEnv, cwd+"/testdata/credentials")

return &attestationServer{oldFakeAsHostEnv: old, server: httpServer}, nil
}

// Stop shuts down the server.
func (s *attestationServer) Stop() {
os.Setenv(fakeAsHostEnv, s.oldFakeAsHostEnv)
s.server.Close()
}
1 change: 1 addition & 0 deletions cmd/fake_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func NewMetadataServer(data Instance) (*MetadataServer, error) {
resp["instance/id"] = data.InstanceID
resp["instance/zone"] = data.Zone
resp["instance/name"] = data.InstanceName
resp["instance/service-accounts/default/identity"] = "test_jwt_token"

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := "/computeMetadata/v1/"
Expand Down
32 changes: 32 additions & 0 deletions cmd/fake_oauth2_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package cmd

import (
"net/http"
"net/http/httptest"
)

type oauth2Server struct {
server *httptest.Server
}

func newMockOauth2Server() *oauth2Server {
mux := http.NewServeMux()
mux.HandleFunc("/o/oauth2/auth", func(_ http.ResponseWriter, _ *http.Request) {
// Unimplemented: Should return authorization code back to the user
})

mux.HandleFunc("/token", func(w http.ResponseWriter, _ *http.Request) {
// Should return acccess token back to the user
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
w.Write([]byte("access_token=mocktoken&scope=user&token_type=bearer"))
})

server := httptest.NewServer(mux)

return &oauth2Server{server: server}
}

// Stop shuts down the server.
func (s *oauth2Server) Stop() {
s.server.Close()
}
23 changes: 15 additions & 8 deletions cmd/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ import (
)

var (
output string
input string
nvIndex uint32
nonce []byte
teeNonce []byte
keyAlgo = tpm2.AlgRSA
pcrs []int
format string
output string
input string
nvIndex uint32
nonce []byte
teeNonce []byte
keyAlgo = tpm2.AlgRSA
pcrs []int
format string
asAddress string
)

type pcrsFlag struct {
Expand Down Expand Up @@ -123,6 +124,12 @@ func addInputFlag(cmd *cobra.Command) {
"input file (defaults to stdin)")
}

// Lets this command specify an Attestation Server Address.
func addAsAddressFlag(cmd *cobra.Command) {
cmd.PersistentFlags().StringVar(&asAddress, "asAddr", "https://confidentialcomputing.googleapis.com",
"Attestation Service address")
}

// Lets this command specify an NVDATA index, for use with nvIndex.
func addIndexFlag(cmd *cobra.Command) {
cmd.PersistentFlags().Uint32Var(&nvIndex, "index", 0,
Expand Down
6 changes: 6 additions & 0 deletions cmd/testdata/credentials
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"client_id": "id",
"client_secret": "testdata",
"refresh_token": "testdata",
"type": "authorized_user"
}
192 changes: 192 additions & 0 deletions cmd/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package cmd

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
"time"

"cloud.google.com/go/compute/metadata"
"github.com/containerd/containerd/namespaces"
"github.com/golang-jwt/jwt/v4"
"github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/launcher/agent"
"github.com/google/go-tpm-tools/launcher/spec"
"github.com/google/go-tpm-tools/launcher/verifier"
"github.com/google/go-tpm-tools/launcher/verifier/rest"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/spf13/cobra"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
)

var mdsClient *metadata.Client

// If hardware technology needs a variable length teenonce then please modify the flags description
var tokenCmd = &cobra.Command{
Use: "token",
Short: "Attest and fetch an OIDC token from Google Attestation Verification Service.",
Long: `Gather attestation report and send it to Google Attestation Verification Service for an OIDC token.
The OIDC token includes claims regarding the GCE VM, which is verified by Attestation Verification Service. Note that Confidential Computing API needs to be enabled for your account to access Google Attestation Verification Service https://pantheon.corp.google.com/apis/api/confidentialcomputing.googleapis.com.
--algo flag overrides the public key algorithm for the GCE TPM attestation key. If not provided then by default rsa is used.
`,
Args: cobra.NoArgs,
RunE: func(*cobra.Command, []string) error {
rwc, err := openTpm()
if err != nil {
return err
}
defer rwc.Close()

// Metadata Server (MDS). A GCP specific client.
mdsClient = metadata.NewClient(nil)

ctx := namespaces.WithNamespace(context.Background(), namespaces.Default)
// TODO: principalFetcher is copied from go-tpm-tools/launcher/container_runner.go, to be refactored
// Fetch GCP specific ID token with specific audience.
// See https://cloud.google.com/functions/docs/securing/authenticating#functions-bearer-token-example-go.
principalFetcher := func(audience string) ([][]byte, error) {
u := url.URL{
Path: "instance/service-accounts/default/identity",
RawQuery: url.Values{
"audience": {audience},
"format": {"full"},
}.Encode(),
}
idToken, err := mdsClient.Get(u.String())
if err != nil {
return nil, fmt.Errorf("failed to get principal tokens: %w", err)
}
fmt.Fprintf(debugOutput(), "GCP ID token fetched is: %s\n", idToken)
tokens := [][]byte{[]byte(idToken)}
return tokens, nil
}

fmt.Fprintf(debugOutput(), "Attestation Address is set to %s\n", asAddress)

region, err := getRegion(mdsClient)
if err != nil {
return fmt.Errorf("failed to fetch Region from MDS, the tool is probably not running in a GCE VM: %v", err)
}

projectID, err := mdsClient.ProjectID()
if err != nil {
return fmt.Errorf("failed to retrieve ProjectID from MDS: %v", err)
}

verifierClient, err := getRESTClient(ctx, asAddress, projectID, region)
if err != nil {
return fmt.Errorf("failed to create REST verifier client: %v", err)
}

// Supports GCE VM. Hard code the AK type. Set GCE AK (EK signing) cert
var gceAK *client.Key
var usedKeyAlgo string
if keyAlgo == tpm2.AlgRSA {
usedKeyAlgo = "RSA"
gceAK, err = client.GceAttestationKeyRSA(rwc)
}
if keyAlgo == tpm2.AlgECC {
usedKeyAlgo = "ECC"
gceAK, err = client.GceAttestationKeyECC(rwc)
}
if err != nil {
return err
}
if gceAK.Cert() == nil {
return errors.New("failed to find gceAKCert on this VM: try creating a new VM or verifying the VM has an EK cert using get-shielded-identity gcloud command. The used key algorithm is: " + usedKeyAlgo)
}
gceAK.Close()

key = "gceAK"
attestAgent := agent.CreateAttestationAgent(rwc, attestationKeys[key][keyAlgo], verifierClient, principalFetcher, nil, spec.LaunchSpec{}, nil)

fmt.Fprintf(debugOutput(), "Fetching attestation verifier OIDC token\n")
token, err := attestAgent.Attest(ctx, agent.AttestAgentOpts{})
if err != nil {
return fmt.Errorf("failed to retrieve attestation service token: %v", err)
}

// Get token expiration.
claims := &jwt.RegisteredClaims{}
_, _, err = jwt.NewParser().ParseUnverified(string(token), claims)
if err != nil {
return fmt.Errorf("failed to parse token: %w", err)
}

now := time.Now()
if !now.Before(claims.ExpiresAt.Time) {
return errors.New("token is expired")
}

// Print out the claims in the jwt payload
mapClaims := jwt.MapClaims{}
_, _, err = jwt.NewParser().ParseUnverified(string(token), mapClaims)
if err != nil {
return fmt.Errorf("failed to parse token: %w", err)
}
claimsString, err := json.MarshalIndent(mapClaims, "", " ")
if err != nil {
return fmt.Errorf("failed to format claims: %w", err)
}

if output == "" {
fmt.Fprintf(messageOutput(), string(token)+"\n")
} else {
out := []byte(token)
if _, err := dataOutput().Write(out); err != nil {
return fmt.Errorf("failed to write the token: %v", err)
}
}

fmt.Fprintf(debugOutput(), string(claimsString)+"\n"+"Note: these Claims are for debugging purpose and not verified"+"\n")
return nil
},
}

// TODO: getRESTClient is copied from go-tpm-tools/launcher/container_runner.go, to be refactored.
// getRESTClient returns a REST verifier.Client that points to the given address.
// It defaults to the Attestation Verifier instance at
// https://confidentialcomputing.googleapis.com.
func getRESTClient(ctx context.Context, asAddr string, ProjectID string, Region string) (verifier.Client, error) {
httpClient, err := google.DefaultClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %v", err)
}

opts := []option.ClientOption{option.WithHTTPClient(httpClient)}
if asAddr != "" {
opts = append(opts, option.WithEndpoint(asAddr))
}

restClient, err := rest.NewClient(ctx, ProjectID, Region, opts...)
if err != nil {
return nil, err
}
return restClient, nil
}

func getRegion(client *metadata.Client) (string, error) {
zone, err := client.Zone()
if err != nil {
return "", fmt.Errorf("failed to retrieve zone from MDS: %v", err)
}
lastDash := strings.LastIndex(zone, "-")
if lastDash == -1 {
return "", fmt.Errorf("got malformed zone from MDS: %v", zone)
}
return zone[:lastDash], nil
}

func init() {
RootCmd.AddCommand(tokenCmd)
addOutputFlag(tokenCmd)
addPublicKeyAlgoFlag(tokenCmd)
addAsAddressFlag(tokenCmd)
// TODO: Add TEE hardware OIDC token generation
// addTeeNonceflag(tokenCmd)
// addTeeTechnology(tokenCmd)
}
Loading

0 comments on commit b22bad0

Please sign in to comment.