diff --git a/cmd/fake_cloudlogging_server.go b/cmd/fake_cloudlogging_server.go new file mode 100644 index 000000000..2f58a96c2 --- /dev/null +++ b/cmd/fake_cloudlogging_server.go @@ -0,0 +1,139 @@ +package cmd + +import ( + "context" + "fmt" + "log" + "net" + "regexp" + "strconv" + "strings" + "sync" + "time" + + logpb "cloud.google.com/go/logging/apiv2/loggingpb" + tspb "github.com/golang/protobuf/ptypes/timestamp" + "google.golang.org/grpc" +) + +// The only IDs that WriteLogEntries will accept. +const ( + TestProjectID = "test-project" +) + +// A fakeServer is an in-process gRPC server, listening on a system-chosen port on +// the local loopback interface. Servers are for testing only and are not +// intended to be used in production code. +type fakeServer struct { + Addr string + Port int + l net.Listener + Gsrv *grpc.Server +} + +// Start causes the server to start accepting incoming connections. +// Call Start after registering handlers. +func (s *fakeServer) Start() { + go func() { + if err := s.Gsrv.Serve(s.l); err != nil { + log.Printf("fake_cloudlogging_server.fakeServer.Start: %v", err) + } + }() +} + +// Close shuts down the server. +func (s *fakeServer) Close() { + s.Gsrv.Stop() + s.l.Close() +} + +// newFakeServer creates a new Server. The Server will be listening for gRPC connections +// at the address named by the Addr field, without TLS. +func newFakeServer(opts ...grpc.ServerOption) (*fakeServer, error) { + return newFakeServerWithPort(0, opts...) +} + +// newFakeServerWithPort creates a new Server at a specific port. The Server will be listening +// for gRPC connections at the address named by the Addr field, without TLS. +func newFakeServerWithPort(port int, opts ...grpc.ServerOption) (*fakeServer, error) { + l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + return nil, err + } + s := &fakeServer{ + Addr: l.Addr().String(), + Port: parsePort(l.Addr().String()), + l: l, + Gsrv: grpc.NewServer(opts...), + } + return s, nil +} + +var portParser = regexp.MustCompile(`:[0-9]+`) + +func parsePort(addr string) int { + res := portParser.FindAllString(addr, -1) + if len(res) == 0 { + panic(fmt.Errorf("parsePort: found no numbers in %s", addr)) + } + stringPort := res[0][1:] // strip the : + p, err := strconv.ParseInt(stringPort, 10, 32) + if err != nil { + panic(err) + } + return int(p) +} + +type loggingHandler struct { + logpb.LoggingServiceV2Server + + mu sync.Mutex + logs map[string][]*logpb.LogEntry // indexed by log name +} + +// WriteLogEntries writes log entries to Cloud Logging. All log entries in +// Cloud Logging are written by this method. +func (h *loggingHandler) WriteLogEntries(_ context.Context, req *logpb.WriteLogEntriesRequest) (*logpb.WriteLogEntriesResponse, error) { + if !strings.HasPrefix(req.LogName, "projects/"+TestProjectID+"/") { + return nil, fmt.Errorf("bad LogName: %q", req.LogName) + } + h.mu.Lock() + defer h.mu.Unlock() + for _, e := range req.Entries { + // Assign timestamp if missing. + if e.Timestamp == nil { + e.Timestamp = &tspb.Timestamp{Seconds: time.Now().Unix(), Nanos: 0} + } + // Fill from common fields in request. + if e.LogName == "" { + e.LogName = req.LogName + } + if e.Resource == nil { + e.Resource = req.Resource + } + for k, v := range req.Labels { + if _, ok := e.Labels[k]; !ok { + e.Labels[k] = v + } + } + + // Store by log name. + h.logs[e.LogName] = append(h.logs[e.LogName], e) + } + return &logpb.WriteLogEntriesResponse{}, nil +} + +// newMockCloudLoggingServer creates a new in-memory fake server implementing the logging service. +// It returns the address of the server. +func newMockCloudLoggingServer() (string, error) { + srv, err := newFakeServer() + if err != nil { + return "", err + } + logpb.RegisterLoggingServiceV2Server(srv.Gsrv, &loggingHandler{ + logs: make(map[string][]*logpb.LogEntry), + }) + + srv.Start() + return srv.Addr, nil +} diff --git a/cmd/flags.go b/cmd/flags.go index dc966510a..54f12dd4e 100644 --- a/cmd/flags.go +++ b/cmd/flags.go @@ -25,7 +25,6 @@ var ( asAddress string audience string cloudLog bool - unitTest bool ) type pcrsFlag struct { @@ -138,11 +137,6 @@ func addCloudLoggingFlag(cmd *cobra.Command) { cmd.Flags().BoolVar(&cloudLog, "cloud-log", false, "logs the attestation and token to Cloud Logging for auditing purposes. Requires the audience flag.") } -// Lets this command enable unit test -func addUnitTestFlag(cmd *cobra.Command) { - cmd.Flags().BoolVar(&unitTest, "unit-test", false, "logs the attestation and token to local for unit test purposes.") -} - // Lets this command specify custom audience field of the attestation token func addAudienceFlag(cmd *cobra.Command) { cmd.PersistentFlags().StringVar(&audience, "audience", "", diff --git a/cmd/token.go b/cmd/token.go index 15aa187fd..95e88bfb8 100644 --- a/cmd/token.go +++ b/cmd/token.go @@ -5,8 +5,8 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/url" - "os" "strings" "time" @@ -23,9 +23,12 @@ import ( "github.com/spf13/cobra" "golang.org/x/oauth2/google" "google.golang.org/api/option" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" ) var mdsClient *metadata.Client +var mockCloudLoggingServerAddress string const toolName = "gotpm" @@ -108,16 +111,26 @@ The OIDC token includes claims regarding the GCE VM, which is verified by Attest var cloudLogClient *logging.Client var cloudLogger *logging.Logger if cloudLog { - cloudLogClient, err = logging.NewClient(ctx, projectID) - if err != nil { - return fmt.Errorf("failed to create Cloud Logging client: %w", err) + if audience == "" { + return errors.New("cloud logging requires the --audience flag") } - - if unitTest { - cloudLogger = cloudLogClient.Logger(toolName, logging.RedirectAsJSON(os.Stdout)) + if mockCloudLoggingServerAddress != "" { + conn, err := grpc.Dial(mockCloudLoggingServerAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + log.Fatalf("dialing %q: %v", mockCloudLoggingServerAddress, err) + } + cloudLogClient, err = logging.NewClient(ctx, TestProjectID, option.WithGRPCConn(conn)) + if err != nil { + return fmt.Errorf("failed to create cloud logging client for mock cloud logging server: %w", err) + } } else { - cloudLogger = cloudLogClient.Logger(toolName) + cloudLogClient, err = logging.NewClient(ctx, projectID) + if err != nil { + return fmt.Errorf("failed to create cloud logging client: %w", err) + } } + + cloudLogger = cloudLogClient.Logger(toolName) fmt.Fprintf(debugOutput(), "cloudLogger created for project: "+projectID+"\n") } @@ -218,7 +231,6 @@ func init() { addAsAddressFlag(tokenCmd) addCloudLoggingFlag(tokenCmd) addAudienceFlag(tokenCmd) - addUnitTestFlag(tokenCmd) // TODO: Add TEE hardware OIDC token generation // addTeeNonceflag(tokenCmd) // addTeeTechnology(tokenCmd) diff --git a/cmd/token_test.go b/cmd/token_test.go index 8f471bfff..31e171f94 100644 --- a/cmd/token_test.go +++ b/cmd/token_test.go @@ -5,20 +5,14 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" - "encoding/json" "io" "math/big" "os" - "strings" - "sync" "testing" "time" - "github.com/golang-jwt/jwt/v4" "github.com/google/go-tpm-tools/client" "github.com/google/go-tpm-tools/internal/test" - "github.com/google/go-tpm-tools/launcher/verifier" - pb "github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm/legacy/tpm2" "github.com/google/go-tpm/tpmutil" "golang.org/x/oauth2" @@ -31,8 +25,6 @@ func TestTokenWithGCEAK(t *testing.T) { ExternalTPM = rwc secretFile1 := makeOutputFile(t, "token") defer os.RemoveAll(secretFile1) - // match the semantics of instrumenting cloud logging logs one time - var instrumentOnce = new(sync.Once) var template = map[string]tpm2.Public{ "rsa": GCEAKTemplateRSA(), "ecc": GCEAKTemplateECC(), @@ -80,78 +72,19 @@ func TestTokenWithGCEAK(t *testing.T) { } defer mockAttestationServer.Stop() - //redirect cloud log from http request to stdout - old := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - RootCmd.SetArgs([]string{"token", "--algo", op.algo, "--output", secretFile1, "--verifier-endpoint", mockAttestationServer.server.URL, "--cloud-log", "--audience", "https://api.test.com", "--unit-test"}) - if err := RootCmd.Execute(); err != nil { - t.Error(err) - } - - w.Close() - out, _ := io.ReadAll(r) - os.Stdout = old - arrays := strings.Split(string(out), "\n") - instrumentOnce.Do(func() { - // remove cloud log one-time instrumentation - arrays = append(arrays[:1], arrays[2:]...) - }) - - // parse json redirected - var gotChallenge challengeLogEntry - var gotAttestationRequest attestationRequestLogEntry - var gotToken tokenLogEntry - var gotClaims claimsLogEntry - err = json.Unmarshal([]byte(arrays[0]), &gotChallenge) - if err != nil { - t.Error(err) - } - err = json.Unmarshal([]byte(arrays[1]), &gotAttestationRequest) - if err != nil { - t.Error(err) - } - err = json.Unmarshal([]byte(arrays[2]), &gotToken) + mockCloudLoggingServerAddress, err = newMockCloudLoggingServer() if err != nil { t.Error(err) } - err = json.Unmarshal([]byte(arrays[3]), &gotClaims) - if err != nil { + + RootCmd.SetArgs([]string{"token", "--algo", op.algo, "--output", secretFile1, "--verifier-endpoint", mockAttestationServer.server.URL, "--cloud-log", "--audience", "https://api.test.com"}) + if err := RootCmd.Execute(); err != nil { t.Error(err) } }) } } -type challengeLogEntry struct { - Message verifier.Challenge `json:"message"` - Severity string `json:"severity"` - Timestamp string `json:"timestamp"` -} - -type attestationRequestLogEntry struct { - Message *pb.Attestation `json:"message"` - Severity string `json:"severity"` - Timestamp string `json:"timestamp"` -} - -type Message struct { - Token string `json:"token"` -} - -type tokenLogEntry struct { - Message Message `json:"message"` - Severity string `json:"severity"` - Timestamp string `json:"timestamp"` -} - -type claimsLogEntry struct { - Message jwt.RegisteredClaims `json:"message"` - Severity string `json:"severity"` - Timestamp string `json:"timestamp"` -} - // Need to call tpm2.NVUndefinespace twice on the handle with authHandle tpm2.HandlePlatform. // e.g defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(client.GceAKTemplateNVIndexRSA)) // defer tpm2.NVUndefineSpace(rwc, "", tpm2.HandlePlatform, tpmutil.Handle(client.GceAKCertNVIndexRSA))