diff --git a/lib/tbot/service_spiffe_workload_api.go b/lib/tbot/service_spiffe_workload_api.go index 748e2e3cbd13f..b8a5675673657 100644 --- a/lib/tbot/service_spiffe_workload_api.go +++ b/lib/tbot/service_spiffe_workload_api.go @@ -52,6 +52,7 @@ import ( "github.com/gravitational/teleport" machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/observability/metrics" @@ -227,13 +228,27 @@ func (s *SPIFFEWorkloadAPIService) Run(ctx context.Context) error { ) workloadpb.RegisterSpiffeWorkloadAPIServer(srv, s) sdsHandler := &spiffeSDSHandler{ - log: s.log, - cfg: s.cfg, - botCfg: s.botCfg, - - trustBundleCache: s.trustBundleCache, - clientAuthenticator: s.authenticateClient, - svidFetcher: s.fetchX509SVIDs, + log: s.log, + botCfg: s.botCfg, + trustBundleCache: s.trustBundleCache, + clientAuthenticator: func(ctx context.Context) (*slog.Logger, svidFetcher, error) { + log, attrs, err := s.authenticateClient(ctx) + if err != nil { + return log, nil, trace.Wrap(err, "authenticating client") + } + fetchSVIDs := func( + ctx context.Context, + localBundle *spiffebundle.Bundle, + ) ([]*workloadpb.X509SVID, error) { + return s.fetchX509SVIDs( + ctx, + log, + localBundle, + filterSVIDRequests(ctx, log, s.cfg.SVIDs, attrs), + ) + } + return log, fetchSVIDs, nil + }, } secretv3pb.RegisterSecretDiscoveryServiceServer(srv, sdsHandler) @@ -373,7 +388,7 @@ func filterSVIDRequests( ctx context.Context, log *slog.Logger, svidRequests []config.SVIDRequestWithRules, - att workloadattest.Attestation, + att *workloadidentityv1pb.WorkloadAttrs, ) []config.SVIDRequest { var filtered []config.SVIDRequest for _, req := range svidRequests { @@ -413,67 +428,67 @@ func filterSVIDRequests( "Evaluating rule against workload attestation", ) if rule.Unix.UID != nil { - if !att.Unix.Attested { + if !att.GetUnix().GetAttested() { logNotAttested("unix") continue } - if *rule.Unix.UID != att.Unix.UID { - logMismatch("unix.uid", *rule.Unix.UID, att.Unix.UID) + if *rule.Unix.UID != int(att.GetUnix().GetUid()) { + logMismatch("unix.uid", *rule.Unix.UID, att.GetUnix().GetUid()) continue } // Rule field matched! } if rule.Unix.PID != nil { - if !att.Unix.Attested { + if !att.GetUnix().GetAttested() { logNotAttested("unix") continue } - if *rule.Unix.PID != att.Unix.PID { - logMismatch("unix.pid", *rule.Unix.PID, att.Unix.PID) + if *rule.Unix.PID != int(att.GetUnix().GetPid()) { + logMismatch("unix.pid", *rule.Unix.PID, att.GetUnix().GetPid()) continue } // Rule field matched! } if rule.Unix.GID != nil { - if !att.Unix.Attested { + if !att.GetUnix().GetAttested() { logNotAttested("unix") continue } - if *rule.Unix.GID != att.Unix.GID { - logMismatch("unix.gid", *rule.Unix.GID, att.Unix.GID) + if *rule.Unix.GID != int(att.GetUnix().GetGid()) { + logMismatch("unix.gid", *rule.Unix.GID, att.GetUnix().GetGid()) continue } // Rule field matched! } if rule.Kubernetes.Namespace != "" { - if !att.Kubernetes.Attested { + if !att.GetKubernetes().GetAttested() { logNotAttested("kubernetes") continue } - if rule.Kubernetes.Namespace != att.Kubernetes.Namespace { - logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.Kubernetes.Namespace) + if rule.Kubernetes.Namespace != att.GetKubernetes().GetNamespace() { + logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.GetKubernetes().GetNamespace()) continue } // Rule field matched! } if rule.Kubernetes.PodName != "" { - if !att.Kubernetes.Attested { + if !att.GetKubernetes().GetAttested() { logNotAttested("kubernetes") continue } - if rule.Kubernetes.PodName != att.Kubernetes.PodName { - logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.Kubernetes.PodName) + if rule.Kubernetes.PodName != att.GetKubernetes().GetPodName() { + logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.GetKubernetes().GetPodName()) continue } // Rule field matched! } if rule.Kubernetes.ServiceAccount != "" { - if !att.Kubernetes.Attested { + if !att.GetKubernetes().GetAttested() { logNotAttested("kubernetes") continue } - if rule.Kubernetes.ServiceAccount != att.Kubernetes.ServiceAccount { - logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.Kubernetes.ServiceAccount) + if rule.Kubernetes.ServiceAccount != att.GetKubernetes().GetServiceAccount() { + logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.GetKubernetes().GetServiceAccount()) continue } // Rule field matched! @@ -499,10 +514,10 @@ func filterSVIDRequests( func (s *SPIFFEWorkloadAPIService) authenticateClient( ctx context.Context, -) (*slog.Logger, workloadattest.Attestation, error) { +) (*slog.Logger, *workloadidentityv1pb.WorkloadAttrs, error) { p, ok := peer.FromContext(ctx) if !ok { - return nil, workloadattest.Attestation{}, trace.BadParameter("peer not found in context") + return nil, nil, trace.BadParameter("peer not found in context") } log := s.log @@ -516,7 +531,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient( // We expect Creds to be nil/unset if the client is connecting via TCP and // therefore there is no workload attestation that can be completed. if !ok || authInfo.Creds == nil { - return log, workloadattest.Attestation{}, nil + return log, nil, nil } // For a UDS, sometimes we are unable to determine the PID of the calling @@ -528,7 +543,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient( if authInfo.Creds.PID == 0 { log.DebugContext( ctx, "Failed to determine the PID of the calling workload. TBot may be running in a different process namespace to the workload. Workload attestation will not be completed.") - return log, workloadattest.Attestation{}, nil + return log, nil, nil } att, err := s.attestor.Attest(ctx, authInfo.Creds.PID) @@ -541,10 +556,10 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient( "error", err, "pid", authInfo.Creds.PID, ) - return log, workloadattest.Attestation{}, nil + return log, nil, nil } log = log.With( - "workload", slog.LogValuer(att), + "workload", att, ) return log, att, nil diff --git a/lib/tbot/service_spiffe_workload_api_sds.go b/lib/tbot/service_spiffe_workload_api_sds.go index a74379e52383c..23bd84ad512d5 100644 --- a/lib/tbot/service_spiffe_workload_api_sds.go +++ b/lib/tbot/service_spiffe_workload_api_sds.go @@ -40,7 +40,6 @@ import ( "github.com/gravitational/teleport/lib/tbot/config" "github.com/gravitational/teleport/lib/tbot/workloadidentity" - "github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest" "github.com/gravitational/teleport/lib/utils" ) @@ -63,23 +62,18 @@ type bundleSetGetter interface { GetBundleSet(ctx context.Context) (*workloadidentity.BundleSet, error) } +type svidFetcher func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error) + // spiffeSDSHandler implements an Envoy SDS API. // // This effectively replaces the Workload API for Envoy, but functions in a // very similar way. type spiffeSDSHandler struct { log *slog.Logger - cfg *config.SPIFFEWorkloadAPIService botCfg *config.BotConfig trustBundleCache bundleSetGetter - clientAuthenticator func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error) - svidFetcher func( - ctx context.Context, - log *slog.Logger, - localBundle *spiffebundle.Bundle, - svidRequests []config.SVIDRequest, - ) ([]*workloadpb.X509SVID, error) + clientAuthenticator func(ctx context.Context) (*slog.Logger, svidFetcher, error) } // FetchSecrets implements @@ -97,7 +91,7 @@ func (s *spiffeSDSHandler) FetchSecrets( return nil, trace.Wrap(err) } - log, creds, err := s.clientAuthenticator(ctx) + log, fetchSVIDs, err := s.clientAuthenticator(ctx) if err != nil { return nil, trace.Wrap(err, "authenticating client") } @@ -114,11 +108,7 @@ func (s *spiffeSDSHandler) FetchSecrets( return nil, trace.Wrap(err, "getting trust bundle set") } - // Filter SVIDs down to those accessible to this workload - svids, err := s.svidFetcher( - ctx, - log, - bundleSet.Local, filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds)) + svids, err := fetchSVIDs(ctx, bundleSet.Local) if err != nil { return nil, trace.Wrap(err, "fetching X509 SVIDs") } @@ -174,7 +164,7 @@ func (s *spiffeSDSHandler) StreamSecrets( srv secretv3pb.SecretDiscoveryService_StreamSecretsServer, ) error { ctx := srv.Context() - log, creds, err := s.clientAuthenticator(ctx) + log, fetchSVIDs, err := s.clientAuthenticator(ctx) if err != nil { return trace.Wrap(err, "authenticating client") } @@ -216,9 +206,6 @@ func (s *spiffeSDSHandler) StreamSecrets( renewalTimer.Stop() defer renewalTimer.Stop() - // Filter SVIDs down to those accessible to this workload - availableSVIDs := filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds) - // Track the last response and last request to allow us to handle ACK/NACK // and versioning. var ( @@ -311,7 +298,7 @@ func (s *spiffeSDSHandler) StreamSecrets( // Fetch the SVIDs if necessary if svids == nil { - svids, err = s.svidFetcher(ctx, log, bundleSet.Local, availableSVIDs) + svids, err = fetchSVIDs(ctx, bundleSet.Local) if err != nil { return trace.Wrap(err, "fetching X509 SVIDs") } diff --git a/lib/tbot/service_spiffe_workload_api_sds_test.go b/lib/tbot/service_spiffe_workload_api_sds_test.go index 2428c89b8a3d4..da52f683140b9 100644 --- a/lib/tbot/service_spiffe_workload_api_sds_test.go +++ b/lib/tbot/service_spiffe_workload_api_sds_test.go @@ -35,7 +35,6 @@ import ( discoveryv3pb "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" secretv3pb "github.com/envoyproxy/go-control-plane/envoy/service/secret/v3" "github.com/google/go-cmp/cmp" - "github.com/gravitational/trace" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" workloadpb "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload" "github.com/spiffe/go-spiffe/v2/spiffeid" @@ -51,7 +50,6 @@ import ( "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/tbot/config" "github.com/gravitational/teleport/lib/tbot/workloadidentity" - "github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/testutils/golden" "github.com/gravitational/teleport/tool/teleport/testenv" @@ -80,14 +78,22 @@ func TestSDS_FetchSecrets(t *testing.T) { ca, err := x509.ParseCertificate(b.Bytes) require.NoError(t, err) - uid := 100 - notUID := 200 - clientAuthenticator := func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error) { - return log, workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ - Attested: true, - UID: uid, - }, + clientAuthenticator := func(ctx context.Context) (*slog.Logger, svidFetcher, error) { + return log, func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error) { + return []*workloadpb.X509SVID{ + { + SpiffeId: "spiffe://example.com/default", + X509Svid: []byte("CERT-spiffe://example.com/default"), + X509SvidKey: []byte("KEY-spiffe://example.com/default"), + Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()), + }, + { + SpiffeId: "spiffe://example.com/second", + X509Svid: []byte("CERT-spiffe://example.com/second"), + X509SvidKey: []byte("KEY-spiffe://example.com/second"), + Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()), + }, + }, nil }, nil } @@ -105,72 +111,9 @@ func TestSDS_FetchSecrets(t *testing.T) { }, }, } - svidFetcher := func( - ctx context.Context, - log *slog.Logger, - localBundle *spiffebundle.Bundle, - svidRequests []config.SVIDRequest) ([]*workloadpb.X509SVID, error) { - if len(svidRequests) != 2 { - return nil, trace.BadParameter("expected 2 svids requested") - } - return []*workloadpb.X509SVID{ - { - SpiffeId: "spiffe://example.com/default", - X509Svid: []byte("CERT-spiffe://example.com/default"), - X509SvidKey: []byte("KEY-spiffe://example.com/default"), - Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()), - }, - { - SpiffeId: "spiffe://example.com/second", - X509Svid: []byte("CERT-spiffe://example.com/second"), - X509SvidKey: []byte("KEY-spiffe://example.com/second"), - Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()), - }, - }, nil - } botConfig := &config.BotConfig{ RenewalInterval: time.Minute, } - cfg := &config.SPIFFEWorkloadAPIService{ - SVIDs: []config.SVIDRequestWithRules{ - { - SVIDRequest: config.SVIDRequest{ - Path: "/default", - }, - Rules: []config.SVIDRequestRule{ - { - Unix: config.SVIDRequestRuleUnix{ - UID: &uid, - }, - }, - }, - }, - { - SVIDRequest: config.SVIDRequest{ - Path: "/second", - }, - Rules: []config.SVIDRequestRule{ - { - Unix: config.SVIDRequestRuleUnix{ - UID: &uid, - }, - }, - }, - }, - { - SVIDRequest: config.SVIDRequest{ - Path: "/not-matching", - }, - Rules: []config.SVIDRequestRule{ - { - Unix: config.SVIDRequestRuleUnix{ - UID: ¬UID, - }, - }, - }, - }, - }, - } tests := []struct { name string @@ -231,12 +174,10 @@ func TestSDS_FetchSecrets(t *testing.T) { t.Run(tt.name, func(t *testing.T) { sds := &spiffeSDSHandler{ log: log, - cfg: cfg, botCfg: botConfig, trustBundleCache: mockBundleCache, clientAuthenticator: clientAuthenticator, - svidFetcher: svidFetcher, } req := &discoveryv3pb.DiscoveryRequest{ diff --git a/lib/tbot/service_spiffe_workload_api_test.go b/lib/tbot/service_spiffe_workload_api_test.go index 3c4c10927b994..1a2b4227c9572 100644 --- a/lib/tbot/service_spiffe_workload_api_test.go +++ b/lib/tbot/service_spiffe_workload_api_test.go @@ -34,9 +34,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/tbot/config" - "github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/tool/teleport/testenv" ) @@ -52,7 +52,7 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) { log := utils.NewSlogLoggerForTests() tests := []struct { name string - att workloadattest.Attestation + att *workloadidentityv1pb.WorkloadAttrs in []config.SVIDRequestWithRules want []config.SVIDRequest }{ @@ -81,12 +81,12 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) { }, { name: "no rules with attestation", - att: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + att: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - UID: 1000, - GID: 1001, - PID: 1002, + Uid: 1000, + Gid: 1001, + Pid: 1002, }, }, in: []config.SVIDRequestWithRules{ @@ -112,15 +112,15 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) { }, { name: "no rules with attestation", - att: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + att: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ // We don't expect that workloadattest will ever return // Attested: false and include UID/PID/GID but we want to // ensure we handle this by failing regardless. Attested: false, - UID: 1000, - GID: 1001, - PID: 1002, + Uid: 1000, + Gid: 1001, + Pid: 1002, }, }, in: []config.SVIDRequestWithRules{ @@ -141,12 +141,12 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) { }, { name: "no matching rules with attestation", - att: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + att: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - UID: 1000, - GID: 1001, - PID: 1002, + Uid: 1000, + Gid: 1001, + Pid: 1002, }, }, in: []config.SVIDRequestWithRules{ @@ -220,12 +220,12 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) { }, { name: "some matching rules with uds", - att: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + att: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - UID: 1000, - GID: 1001, - PID: 1002, + Uid: 1000, + Gid: 1001, + Pid: 1002, }, }, in: []config.SVIDRequestWithRules{ @@ -290,8 +290,8 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { log := utils.NewSlogLoggerForTests() tests := []struct { field string - matching workloadattest.Attestation - nonMatching workloadattest.Attestation + matching *workloadidentityv1pb.WorkloadAttrs + nonMatching *workloadidentityv1pb.WorkloadAttrs rule config.SVIDRequestRule }{ { @@ -301,16 +301,16 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { PID: ptr(1000), }, }, - matching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - PID: 1000, + Pid: 1000, }, }, - nonMatching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - PID: 200, + Pid: 200, }, }, }, @@ -321,16 +321,16 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { UID: ptr(1000), }, }, - matching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - UID: 1000, + Uid: 1000, }, }, - nonMatching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - UID: 200, + Uid: 200, }, }, }, @@ -341,16 +341,16 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { GID: ptr(1000), }, }, - matching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - GID: 1000, + Gid: 1000, }, }, - nonMatching: workloadattest.Attestation{ - Unix: workloadattest.UnixAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - GID: 200, + Gid: 200, }, }, }, @@ -361,14 +361,14 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { Namespace: "foo", }, }, - matching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, Namespace: "foo", }, }, - nonMatching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, Namespace: "bar", }, @@ -381,14 +381,14 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { ServiceAccount: "foo", }, }, - matching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, ServiceAccount: "foo", }, }, - nonMatching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, ServiceAccount: "bar", }, @@ -401,14 +401,14 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) { PodName: "foo", }, }, - matching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + matching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, PodName: "foo", }, }, - nonMatching: workloadattest.Attestation{ - Kubernetes: workloadattest.KubernetesAttestation{ + nonMatching: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, PodName: "bar", }, diff --git a/lib/tbot/workloadidentity/workloadattest/attest.go b/lib/tbot/workloadidentity/workloadattest/attest.go index a50721b847ca1..de587dbad1c3c 100644 --- a/lib/tbot/workloadidentity/workloadattest/attest.go +++ b/lib/tbot/workloadidentity/workloadattest/attest.go @@ -23,32 +23,9 @@ import ( "log/slog" "github.com/gravitational/trace" -) - -// Attestation holds the results of the attestation process carried out on a -// PID by the attestor. -// -// The zero value of this type indicates that no attestation was performed or -// was successful. -type Attestation struct { - Unix UnixAttestation - Kubernetes KubernetesAttestation -} -// LogValue implements slog.LogValue to provide a nicely formatted set of -// log keys for a given attestation. -func (a Attestation) LogValue() slog.Value { - return slog.GroupValue( - slog.Attr{ - Key: "unix", - Value: a.Unix.LogValue(), - }, - slog.Attr{ - Key: "kubernetes", - Value: a.Kubernetes.LogValue(), - }, - ) -} + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" +) type attestor[T any] interface { Attest(ctx context.Context, pid int) (T, error) @@ -58,8 +35,8 @@ type attestor[T any] interface { // key information about the process. type Attestor struct { log *slog.Logger - kubernetes attestor[KubernetesAttestation] - unix attestor[UnixAttestation] + kubernetes attestor[*workloadidentityv1pb.WorkloadAttrsKubernetes] + unix attestor[*workloadidentityv1pb.WorkloadAttrsUnix] } // Config is the configuration for Attestor @@ -83,30 +60,27 @@ func NewAttestor(log *slog.Logger, cfg Config) (*Attestor, error) { return att, nil } -func (a *Attestor) Attest(ctx context.Context, pid int) (Attestation, error) { +func (a *Attestor) Attest(ctx context.Context, pid int) (*workloadidentityv1pb.WorkloadAttrs, error) { a.log.DebugContext(ctx, "Starting workload attestation", "pid", pid) defer a.log.DebugContext(ctx, "Finished workload attestation", "pid", pid) - var ( - att Attestation - err error - ) - + var err error + attrs := &workloadidentityv1pb.WorkloadAttrs{} // We always perform the unix attestation first - att.Unix, err = a.unix.Attest(ctx, pid) + attrs.Unix, err = a.unix.Attest(ctx, pid) if err != nil { - return att, err + return attrs, err } // Then we can perform the optionally configured attestations // For these, failure is soft. If it fails, we log, but still return the // successfully attested data. if a.kubernetes != nil { - att.Kubernetes, err = a.kubernetes.Attest(ctx, pid) + attrs.Kubernetes, err = a.kubernetes.Attest(ctx, pid) if err != nil { a.log.WarnContext(ctx, "Failed to perform Kubernetes workload attestation", "error", err) } } - return att, nil + return attrs, nil } diff --git a/lib/tbot/workloadidentity/workloadattest/kubernetes.go b/lib/tbot/workloadidentity/workloadattest/kubernetes.go index afadbab5c45e4..eb1a8dce77fe2 100644 --- a/lib/tbot/workloadidentity/workloadattest/kubernetes.go +++ b/lib/tbot/workloadidentity/workloadattest/kubernetes.go @@ -19,54 +19,9 @@ package workloadattest import ( - "log/slog" - "github.com/gravitational/trace" ) -// KubernetesAttestation holds the Kubernetes pod information retrieved from -// the workload attestation process. -type KubernetesAttestation struct { - // Attested is true if the PID was successfully attested to a Kubernetes - // pod. This indicates the validity of the rest of the fields. - Attested bool - // Namespace is the namespace of the pod. - Namespace string - // ServiceAccount is the service account of the pod. - ServiceAccount string - // PodName is the name of the pod. - PodName string - // PodUID is the UID of the pod. - PodUID string - // Labels is a map of labels on the pod. - Labels map[string]string -} - -// LogValue implements slog.LogValue to provide a nicely formatted set of -// log keys for a given attestation. -func (a KubernetesAttestation) LogValue() slog.Value { - values := []slog.Attr{ - slog.Bool("attested", a.Attested), - } - if a.Attested { - labels := []slog.Attr{} - for k, v := range a.Labels { - labels = append(labels, slog.String(k, v)) - } - values = append(values, - slog.String("namespace", a.Namespace), - slog.String("service_account", a.ServiceAccount), - slog.String("pod_name", a.PodName), - slog.String("pod_uid", a.PodUID), - slog.Attr{ - Key: "labels", - Value: slog.GroupValue(labels...), - }, - ) - } - return slog.GroupValue(values...) -} - // KubernetesAttestorConfig holds the configuration for the KubernetesAttestor. type KubernetesAttestorConfig struct { // Enabled is true if the KubernetesAttestor is enabled. If false, diff --git a/lib/tbot/workloadidentity/workloadattest/kubernetes_unix.go b/lib/tbot/workloadidentity/workloadattest/kubernetes_unix.go index 567b33d337d00..55e4b64b5314e 100644 --- a/lib/tbot/workloadidentity/workloadattest/kubernetes_unix.go +++ b/lib/tbot/workloadidentity/workloadattest/kubernetes_unix.go @@ -41,6 +41,8 @@ import ( "github.com/gravitational/trace" v1 "k8s.io/api/core/v1" "k8s.io/utils/mount" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // KubernetesAttestor attests a workload to a Kubernetes pod. @@ -75,27 +77,27 @@ func NewKubernetesAttestor(cfg KubernetesAttestorConfig, log *slog.Logger) *Kube // Attest resolves the Kubernetes pod information from the // PID of the workload. -func (a *KubernetesAttestor) Attest(ctx context.Context, pid int) (KubernetesAttestation, error) { +func (a *KubernetesAttestor) Attest(ctx context.Context, pid int) (*workloadidentityv1pb.WorkloadAttrsKubernetes, error) { a.log.DebugContext(ctx, "Starting Kubernetes workload attestation", "pid", pid) podID, containerID, err := a.getContainerAndPodID(pid) if err != nil { - return KubernetesAttestation{}, trace.Wrap(err, "determining pod and container ID") + return nil, trace.Wrap(err, "determining pod and container ID") } a.log.DebugContext(ctx, "Found pod and container ID", "pod_id", podID, "container_id", containerID) pod, err := a.getPodForID(ctx, podID) if err != nil { - return KubernetesAttestation{}, trace.Wrap(err, "finding pod by ID") + return nil, trace.Wrap(err, "finding pod by ID") } a.log.DebugContext(ctx, "Found pod", "pod_name", pod.Name) - att := KubernetesAttestation{ + att := &workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, Namespace: pod.Namespace, ServiceAccount: pod.Spec.ServiceAccountName, PodName: pod.Name, - PodUID: string(pod.UID), + PodUid: string(pod.UID), Labels: pod.Labels, } a.log.DebugContext(ctx, "Finished Kubernetes workload attestation", "attestation", att) diff --git a/lib/tbot/workloadidentity/workloadattest/kubernetes_unix_test.go b/lib/tbot/workloadidentity/workloadattest/kubernetes_unix_test.go index 79704cb775cf8..4e9a1831bc975 100644 --- a/lib/tbot/workloadidentity/workloadattest/kubernetes_unix_test.go +++ b/lib/tbot/workloadidentity/workloadattest/kubernetes_unix_test.go @@ -31,12 +31,15 @@ import ( "strconv" "testing" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/lib/utils" ) @@ -165,14 +168,14 @@ func TestKubernetesAttestor_Attest(t *testing.T) { att, err := attestor.Attest(ctx, mockPID) assert.NoError(t, err) - assert.Equal(t, KubernetesAttestation{ + assert.Empty(t, cmp.Diff(&workloadidentityv1pb.WorkloadAttrsKubernetes{ Attested: true, ServiceAccount: "my-service-account", Namespace: "default", PodName: "my-pod", - PodUID: mockPodID, + PodUid: mockPodID, Labels: map[string]string{ "my-label": "my-label-value", }, - }, att) + }, att, protocmp.Transform())) } diff --git a/lib/tbot/workloadidentity/workloadattest/kubernetes_windows.go b/lib/tbot/workloadidentity/workloadattest/kubernetes_windows.go index 27b11b13227ca..a51e8f99417da 100644 --- a/lib/tbot/workloadidentity/workloadattest/kubernetes_windows.go +++ b/lib/tbot/workloadidentity/workloadattest/kubernetes_windows.go @@ -25,14 +25,16 @@ import ( "log/slog" "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // WindowsKubernetesAttestor is the windows stub for KubernetesAttestor. type WindowsKubernetesAttestor struct { } -func (a WindowsKubernetesAttestor) Attest(_ context.Context, _ int) (KubernetesAttestation, error) { - return KubernetesAttestation{}, trace.NotImplemented("kubernetes attestation is not supported on windows") +func (a WindowsKubernetesAttestor) Attest(_ context.Context, _ int) (*workloadidentityv1pb.WorkloadAttrsKubernetes, error) { + return nil, trace.NotImplemented("kubernetes attestation is not supported on windows") } // NewKubernetesAttestor creates a new KubernetesAttestor. diff --git a/lib/tbot/workloadidentity/workloadattest/unix.go b/lib/tbot/workloadidentity/workloadattest/unix.go index 2f67fd7f6bad2..a0fc277ca4393 100644 --- a/lib/tbot/workloadidentity/workloadattest/unix.go +++ b/lib/tbot/workloadidentity/workloadattest/unix.go @@ -20,41 +20,12 @@ package workloadattest import ( "context" - "log/slog" "github.com/gravitational/trace" "github.com/shirou/gopsutil/v4/process" -) - -// UnixAttestation holds the Unix process information retrieved from the -// workload attestation process. -type UnixAttestation struct { - // Attested is true if the PID was successfully attested to a Unix - // process. This indicates the validity of the rest of the fields. - Attested bool - // PID is the process ID of the attested process. - PID int - // UID is the primary user ID of the attested process. - UID int - // GID is the primary group ID of the attested process. - GID int -} -// LogValue implements slog.LogValue to provide a nicely formatted set of -// log keys for a given attestation. -func (a UnixAttestation) LogValue() slog.Value { - values := []slog.Attr{ - slog.Bool("attested", a.Attested), - } - if a.Attested { - values = append(values, - slog.Int("uid", a.UID), - slog.Int("pid", a.PID), - slog.Int("gid", a.GID), - ) - } - return slog.GroupValue(values...) -} + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" +) // UnixAttestor attests a process id to a Unix process. type UnixAttestor struct { @@ -66,15 +37,15 @@ func NewUnixAttestor() *UnixAttestor { } // Attest attests a process id to a Unix process. -func (a *UnixAttestor) Attest(ctx context.Context, pid int) (UnixAttestation, error) { +func (a *UnixAttestor) Attest(ctx context.Context, pid int) (*workloadidentityv1pb.WorkloadAttrsUnix, error) { p, err := process.NewProcessWithContext(ctx, int32(pid)) if err != nil { - return UnixAttestation{}, trace.Wrap(err, "getting process") + return nil, trace.Wrap(err, "getting process") } - att := UnixAttestation{ + att := &workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - PID: pid, + Pid: int32(pid), } // On Linux: // Real, effective, saved, and file system GIDs @@ -82,19 +53,19 @@ func (a *UnixAttestor) Attest(ctx context.Context, pid int) (UnixAttestation, er // Effective, effective, saved GIDs gids, err := p.Gids() if err != nil { - return UnixAttestation{}, trace.Wrap(err, "getting gids") + return nil, trace.Wrap(err, "getting gids") } // We generally want to select the effective GID. switch len(gids) { case 0: // error as none returned - return UnixAttestation{}, trace.BadParameter("no gids returned") + return nil, trace.BadParameter("no gids returned") case 1: // Only one GID - this is unusual but let's take it. - att.GID = int(gids[0]) + att.Gid = gids[0] default: // Take the index 1 entry as this is effective - att.GID = int(gids[1]) + att.Gid = gids[1] } // On Linux: @@ -103,19 +74,19 @@ func (a *UnixAttestor) Attest(ctx context.Context, pid int) (UnixAttestation, er // Effective uids, err := p.Uids() if err != nil { - return UnixAttestation{}, trace.Wrap(err, "getting uids") + return nil, trace.Wrap(err, "getting uids") } // We generally want to select the effective GID. switch len(uids) { case 0: // error as none returned - return UnixAttestation{}, trace.BadParameter("no uids returned") + return nil, trace.BadParameter("no uids returned") case 1: // Only one UID, we expect this on Darwin to be the Effective UID - att.UID = int(uids[0]) + att.Uid = uids[0] default: // Take the index 1 entry as this is Effective UID on Linux - att.UID = int(uids[1]) + att.Uid = uids[1] } return att, nil diff --git a/lib/tbot/workloadidentity/workloadattest/unix_test.go b/lib/tbot/workloadidentity/workloadattest/unix_test.go index 667fdcffb2634..a14be0c1a8619 100644 --- a/lib/tbot/workloadidentity/workloadattest/unix_test.go +++ b/lib/tbot/workloadidentity/workloadattest/unix_test.go @@ -23,7 +23,11 @@ import ( "os" "testing" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) func TestUnixAttestor_Attest(t *testing.T) { @@ -37,10 +41,10 @@ func TestUnixAttestor_Attest(t *testing.T) { attestor := NewUnixAttestor() att, err := attestor.Attest(ctx, pid) require.NoError(t, err) - require.Equal(t, UnixAttestation{ + require.Empty(t, cmp.Diff(&workloadidentityv1pb.WorkloadAttrsUnix{ Attested: true, - PID: pid, - UID: uid, - GID: gid, - }, att) + Pid: int32(pid), + Uid: uint32(uid), + Gid: uint32(gid), + }, att, protocmp.Transform())) }