Skip to content

Commit

Permalink
add fake cloud logging server and unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruide committed Mar 4, 2024
1 parent 896bcf5 commit 0651c01
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 86 deletions.
139 changes: 139 additions & 0 deletions cmd/fake_cloudlogging_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package cmd

import (
"context"
"fmt"
"log"
"net"
"regexp"
"strconv"
"strings"
"sync"
"time"

logpb "cloud.google.com/go/logging/apiv2/loggingpb"
tspb "github.com/golang/protobuf/ptypes/timestamp"
"google.golang.org/grpc"
)

// The only IDs that WriteLogEntries will accept.
const (
TestProjectID = "test-project"
)

// A fakeServer is an in-process gRPC server, listening on a system-chosen port on
// the local loopback interface. Servers are for testing only and are not
// intended to be used in production code.
type fakeServer struct {
Addr string
Port int
l net.Listener
Gsrv *grpc.Server
}

// Start causes the server to start accepting incoming connections.
// Call Start after registering handlers.
func (s *fakeServer) Start() {
go func() {
if err := s.Gsrv.Serve(s.l); err != nil {
log.Printf("fake_cloudlogging_server.fakeServer.Start: %v", err)
}
}()
}

// Close shuts down the server.
func (s *fakeServer) Close() {
s.Gsrv.Stop()
s.l.Close()
}

// newFakeServer creates a new Server. The Server will be listening for gRPC connections
// at the address named by the Addr field, without TLS.
func newFakeServer(opts ...grpc.ServerOption) (*fakeServer, error) {
return newFakeServerWithPort(0, opts...)
}

// newFakeServerWithPort creates a new Server at a specific port. The Server will be listening
// for gRPC connections at the address named by the Addr field, without TLS.
func newFakeServerWithPort(port int, opts ...grpc.ServerOption) (*fakeServer, error) {
l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
return nil, err
}
s := &fakeServer{
Addr: l.Addr().String(),
Port: parsePort(l.Addr().String()),
l: l,
Gsrv: grpc.NewServer(opts...),
}
return s, nil
}

var portParser = regexp.MustCompile(`:[0-9]+`)

func parsePort(addr string) int {
res := portParser.FindAllString(addr, -1)
if len(res) == 0 {
panic(fmt.Errorf("parsePort: found no numbers in %s", addr))
}
stringPort := res[0][1:] // strip the :
p, err := strconv.ParseInt(stringPort, 10, 32)
if err != nil {
panic(err)
}
return int(p)
}

type loggingHandler struct {
logpb.LoggingServiceV2Server

mu sync.Mutex
logs map[string][]*logpb.LogEntry // indexed by log name
}

// WriteLogEntries writes log entries to Cloud Logging. All log entries in
// Cloud Logging are written by this method.
func (h *loggingHandler) WriteLogEntries(_ context.Context, req *logpb.WriteLogEntriesRequest) (*logpb.WriteLogEntriesResponse, error) {
if !strings.HasPrefix(req.LogName, "projects/"+TestProjectID+"/") {
return nil, fmt.Errorf("bad LogName: %q", req.LogName)
}
h.mu.Lock()
defer h.mu.Unlock()
for _, e := range req.Entries {
// Assign timestamp if missing.
if e.Timestamp == nil {
e.Timestamp = &tspb.Timestamp{Seconds: time.Now().Unix(), Nanos: 0}
}
// Fill from common fields in request.
if e.LogName == "" {
e.LogName = req.LogName
}
if e.Resource == nil {
e.Resource = req.Resource
}
for k, v := range req.Labels {
if _, ok := e.Labels[k]; !ok {
e.Labels[k] = v
}
}

// Store by log name.
h.logs[e.LogName] = append(h.logs[e.LogName], e)
}
return &logpb.WriteLogEntriesResponse{}, nil
}

// newMockCloudLoggingServer creates a new in-memory fake server implementing the logging service.
// It returns the address of the server.
func newMockCloudLoggingServer() (string, error) {
srv, err := newFakeServer()
if err != nil {
return "", err
}
logpb.RegisterLoggingServiceV2Server(srv.Gsrv, &loggingHandler{
logs: make(map[string][]*logpb.LogEntry),
})

srv.Start()
return srv.Addr, nil
}
6 changes: 0 additions & 6 deletions cmd/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ var (
asAddress string
audience string
cloudLog bool
unitTest bool
)

type pcrsFlag struct {
Expand Down Expand Up @@ -138,11 +137,6 @@ func addCloudLoggingFlag(cmd *cobra.Command) {
cmd.Flags().BoolVar(&cloudLog, "cloud-log", false, "logs the attestation and token to Cloud Logging for auditing purposes. Requires the audience flag.")
}

// Lets this command enable unit test
func addUnitTestFlag(cmd *cobra.Command) {
cmd.Flags().BoolVar(&unitTest, "unit-test", false, "logs the attestation and token to local for unit test purposes.")
}

// Lets this command specify custom audience field of the attestation token
func addAudienceFlag(cmd *cobra.Command) {
cmd.PersistentFlags().StringVar(&audience, "audience", "",
Expand Down
30 changes: 21 additions & 9 deletions cmd/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net/url"
"os"
"strings"
"time"

Expand All @@ -23,9 +23,12 @@ import (
"github.com/spf13/cobra"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)

var mdsClient *metadata.Client
var mockCloudLoggingServerAddress string

const toolName = "gotpm"

Expand Down Expand Up @@ -108,16 +111,26 @@ The OIDC token includes claims regarding the GCE VM, which is verified by Attest
var cloudLogClient *logging.Client
var cloudLogger *logging.Logger
if cloudLog {
cloudLogClient, err = logging.NewClient(ctx, projectID)
if err != nil {
return fmt.Errorf("failed to create Cloud Logging client: %w", err)
if audience == "" {
return errors.New("cloud logging requires the --audience flag")
}

if unitTest {
cloudLogger = cloudLogClient.Logger(toolName, logging.RedirectAsJSON(os.Stdout))
if mockCloudLoggingServerAddress != "" {
conn, err := grpc.Dial(mockCloudLoggingServerAddress, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
log.Fatalf("dialing %q: %v", mockCloudLoggingServerAddress, err)
}
cloudLogClient, err = logging.NewClient(ctx, TestProjectID, option.WithGRPCConn(conn))
if err != nil {
return fmt.Errorf("failed to create cloud logging client for mock cloud logging server: %w", err)
}
} else {
cloudLogger = cloudLogClient.Logger(toolName)
cloudLogClient, err = logging.NewClient(ctx, projectID)
if err != nil {
return fmt.Errorf("failed to create cloud logging client: %w", err)
}
}

cloudLogger = cloudLogClient.Logger(toolName)
fmt.Fprintf(debugOutput(), "cloudLogger created for project: "+projectID+"\n")
}

Expand Down Expand Up @@ -218,7 +231,6 @@ func init() {
addAsAddressFlag(tokenCmd)
addCloudLoggingFlag(tokenCmd)
addAudienceFlag(tokenCmd)
addUnitTestFlag(tokenCmd)
// TODO: Add TEE hardware OIDC token generation
// addTeeNonceflag(tokenCmd)
// addTeeTechnology(tokenCmd)
Expand Down
75 changes: 4 additions & 71 deletions cmd/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,14 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"io"
"math/big"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/internal/test"
"github.com/google/go-tpm-tools/launcher/verifier"
pb "github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/google/go-tpm/tpmutil"
"golang.org/x/oauth2"
Expand All @@ -31,8 +25,6 @@ func TestTokenWithGCEAK(t *testing.T) {
ExternalTPM = rwc
secretFile1 := makeOutputFile(t, "token")
defer os.RemoveAll(secretFile1)
// match the semantics of instrumenting cloud logging logs one time
var instrumentOnce = new(sync.Once)
var template = map[string]tpm2.Public{
"rsa": GCEAKTemplateRSA(),
"ecc": GCEAKTemplateECC(),
Expand Down Expand Up @@ -80,78 +72,19 @@ func TestTokenWithGCEAK(t *testing.T) {
}
defer mockAttestationServer.Stop()

//redirect cloud log from http request to stdout
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w

RootCmd.SetArgs([]string{"token", "--algo", op.algo, "--output", secretFile1, "--verifier-endpoint", mockAttestationServer.server.URL, "--cloud-log", "--audience", "https://api.test.com", "--unit-test"})
if err := RootCmd.Execute(); err != nil {
t.Error(err)
}

w.Close()
out, _ := io.ReadAll(r)
os.Stdout = old
arrays := strings.Split(string(out), "\n")
instrumentOnce.Do(func() {
// remove cloud log one-time instrumentation
arrays = append(arrays[:1], arrays[2:]...)
})

// parse json redirected
var gotChallenge challengeLogEntry
var gotAttestationRequest attestationRequestLogEntry
var gotToken tokenLogEntry
var gotClaims claimsLogEntry
err = json.Unmarshal([]byte(arrays[0]), &gotChallenge)
if err != nil {
t.Error(err)
}
err = json.Unmarshal([]byte(arrays[1]), &gotAttestationRequest)
if err != nil {
t.Error(err)
}
err = json.Unmarshal([]byte(arrays[2]), &gotToken)
mockCloudLoggingServerAddress, err = newMockCloudLoggingServer()
if err != nil {
t.Error(err)
}
err = json.Unmarshal([]byte(arrays[3]), &gotClaims)
if err != nil {

RootCmd.SetArgs([]string{"token", "--algo", op.algo, "--output", secretFile1, "--verifier-endpoint", mockAttestationServer.server.URL, "--cloud-log", "--audience", "https://api.test.com"})
if err := RootCmd.Execute(); err != nil {
t.Error(err)
}
})
}
}

type challengeLogEntry struct {
Message verifier.Challenge `json:"message"`
Severity string `json:"severity"`
Timestamp string `json:"timestamp"`
}

type attestationRequestLogEntry struct {
Message *pb.Attestation `json:"message"`
Severity string `json:"severity"`
Timestamp string `json:"timestamp"`
}

type Message struct {
Token string `json:"token"`
}

type tokenLogEntry struct {
Message Message `json:"message"`
Severity string `json:"severity"`
Timestamp string `json:"timestamp"`
}

type claimsLogEntry struct {
Message jwt.RegisteredClaims `json:"message"`
Severity string `json:"severity"`
Timestamp string `json:"timestamp"`
}

// Need to call tpm2.NVUndefinespace twice on the handle with authHandle tpm2.HandlePlatform.
// e.g defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(client.GceAKTemplateNVIndexRSA))
// defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(client.GceAKCertNVIndexRSA))
Expand Down

0 comments on commit 0651c01

Please sign in to comment.