diff --git a/src/go/cmd/token-vendor/BUILD.bazel b/src/go/cmd/token-vendor/BUILD.bazel index 1a8b948c2..7a24b3d3a 100644 --- a/src/go/cmd/token-vendor/BUILD.bazel +++ b/src/go/cmd/token-vendor/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "//src/go/cmd/token-vendor/api/v1:go_default_library", "//src/go/cmd/token-vendor/app:go_default_library", "//src/go/cmd/token-vendor/oauth:go_default_library", + "//src/go/cmd/token-vendor/repository:go_default_library", "//src/go/cmd/token-vendor/repository/k8s:go_default_library", "//src/go/cmd/token-vendor/repository/memory:go_default_library", "//src/go/cmd/token-vendor/tokensource:go_default_library", diff --git a/src/go/cmd/token-vendor/api/v1/v1.go b/src/go/cmd/token-vendor/api/v1/v1.go index 8f475b82b..49b0ad589 100644 --- a/src/go/cmd/token-vendor/api/v1/v1.go +++ b/src/go/cmd/token-vendor/api/v1/v1.go @@ -293,7 +293,7 @@ func (h *HandlerContext) verifyJWTHandler(w http.ResponseWriter, r *http.Request // Authorization: ... jwtString := strings.TrimPrefix(authHeader[0], "Bearer ") - if _, err := h.tv.ValidateJWT(r.Context(), jwtString); err != nil { + if _, _, err := h.tv.ValidateJWT(r.Context(), jwtString); err != nil { slog.WarnContext(r.Context(), "JWT failed validation", ilog.Err(err)) api.ErrResponse(w, http.StatusForbidden, "JWT not valid") return diff --git a/src/go/cmd/token-vendor/app/BUILD.bazel b/src/go/cmd/token-vendor/app/BUILD.bazel index 1e8e3d97a..ce9d97032 100644 --- a/src/go/cmd/token-vendor/app/BUILD.bazel +++ b/src/go/cmd/token-vendor/app/BUILD.bazel @@ -8,6 +8,7 @@ go_library( deps = [ "//src/go/cmd/token-vendor/oauth:go_default_library", "//src/go/cmd/token-vendor/oauth/jwt:go_default_library", + "//src/go/cmd/token-vendor/repository:go_default_library", "//src/go/cmd/token-vendor/tokensource:go_default_library", "@com_github_pkg_errors//:go_default_library", "@com_github_prometheus_client_golang//prometheus:go_default_library", diff --git a/src/go/cmd/token-vendor/app/tokenvendor.go b/src/go/cmd/token-vendor/app/tokenvendor.go index 358ef793e..4eb162baf 100644 --- a/src/go/cmd/token-vendor/app/tokenvendor.go +++ b/src/go/cmd/token-vendor/app/tokenvendor.go @@ -27,27 +27,19 @@ import ( "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/oauth" "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/oauth/jwt" + "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository" "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/tokensource" "github.com/pkg/errors" ) -type PubKeyRepository interface { - - // LookupKey retrieves the public key of a device from the repository. - // An empty string return indicates that no key exists for the given identifier or - // that the device is blocked. - LookupKey(ctx context.Context, deviceID string) (string, error) - PublishKey(ctx context.Context, deviceID, publicKey string) error -} - type TokenVendor struct { - repo PubKeyRepository + repo repository.PubKeyRepository v *oauth.TokenVerifier ts *tokensource.GCPTokenSource accAud string } -func NewTokenVendor(ctx context.Context, repo PubKeyRepository, v *oauth.TokenVerifier, ts *tokensource.GCPTokenSource, acceptedAudience string) (*TokenVendor, error) { +func NewTokenVendor(ctx context.Context, repo repository.PubKeyRepository, v *oauth.TokenVerifier, ts *tokensource.GCPTokenSource, acceptedAudience string) (*TokenVendor, error) { if acceptedAudience == "" { return nil, errors.New("accepted audience must not be empty") } @@ -61,7 +53,11 @@ func (tv *TokenVendor) PublishPublicKey(ctx context.Context, deviceID, publicKey func (tv *TokenVendor) ReadPublicKey(ctx context.Context, deviceID string) (string, error) { slog.Debug("Returning public key", slog.String("DeviceID", deviceID)) - return tv.repo.LookupKey(ctx, deviceID) + key, err := tv.repo.LookupKey(ctx, deviceID) + if key != nil { + return key.PublicKey, nil + } + return "", err } var ( @@ -95,48 +91,48 @@ func (tv *TokenVendor) GetOAuth2Token(ctx context.Context, jwtk string) (*tokens return r, err } -func (tv *TokenVendor) ValidateJWT(ctx context.Context, jwtk string) (string, error) { +func (tv *TokenVendor) ValidateJWT(ctx context.Context, jwtk string) (string, string, error) { p, err := jwt.PayloadUnsafe(jwtk) if err != nil { - return "", errors.Wrap(err, "failed to extract JWT payload") + return "", "", errors.Wrap(err, "failed to extract JWT payload") } exp := time.Unix(p.Exp, 0) if exp.Before(time.Now()) { - return "", fmt.Errorf("JWT has expired %v, %v ago (iss: %q)", + return "", "", fmt.Errorf("JWT has expired %v, %v ago (iss: %q)", exp, time.Since(exp), p.Iss) } if err := acceptedAudience(p.Aud, tv.accAud); err != nil { - return "", errors.Wrapf(err, "validation of JWT audience failed (iss: %q)", p.Iss) + return "", "", errors.Wrapf(err, "validation of JWT audience failed (iss: %q)", p.Iss) } if !IsValidDeviceID(p.Iss) { - return "", fmt.Errorf("missing or invalid device identifier (`iss`: %q)", p.Iss) + return "", "", fmt.Errorf("missing or invalid device identifier (`iss`: %q)", p.Iss) } deviceID := p.Iss - pubKey, err := tv.repo.LookupKey(ctx, deviceID) + k, err := tv.repo.LookupKey(ctx, deviceID) if err != nil { - return "", errors.Wrapf(err, "failed to retrieve public key for device %q", deviceID) + return "", "", errors.Wrapf(err, "failed to retrieve public key for device %q", deviceID) } - if pubKey == "" { - return "", errors.Errorf("no public key found for device %q", deviceID) + if k.PublicKey == "" { + return "", "", errors.Errorf("no public key found for device %q", deviceID) } - err = jwt.VerifySignature(jwtk, pubKey) + err = jwt.VerifySignature(jwtk, k.PublicKey) if err != nil { - return "", errors.Wrapf(err, "failed to verify signature for device %q", deviceID) + return "", "", errors.Wrapf(err, "failed to verify signature for device %q", deviceID) } - return deviceID, nil + return deviceID, k.SAName, nil } func (tv *TokenVendor) getOAuth2Token(ctx context.Context, jwtk string) (*tokensource.TokenResponse, error) { - deviceID, err := tv.ValidateJWT(ctx, jwtk) + deviceID, sa, err := tv.ValidateJWT(ctx, jwtk) if err != nil { return nil, err } - cloudToken, err := tv.ts.Token(ctx) + cloudToken, err := tv.ts.Token(ctx, sa) if err != nil { return nil, errors.Wrapf(err, "failed to retrieve a cloud token for device %q", deviceID) } - slog.Info("Handing out cloud token", slog.String("DeviceID", deviceID)) + slog.Info("Handing out cloud token", slog.String("DeviceID", deviceID), slog.String("ServiceAccount", sa)) return cloudToken, nil } diff --git a/src/go/cmd/token-vendor/main.go b/src/go/cmd/token-vendor/main.go index 6bc2d8bdc..c8ae4f46a 100644 --- a/src/go/cmd/token-vendor/main.go +++ b/src/go/cmd/token-vendor/main.go @@ -32,6 +32,7 @@ import ( apiv1 "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/api/v1" "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/app" "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/oauth" + "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository" "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository/k8s" "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository/memory" "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/tokensource" @@ -87,7 +88,7 @@ var ( "", "Endpoint URL of the token vendor. Used for verification of JWTs send by robots.") scopes = scopeFlags{} robotName = flag.String("service_account", "robot-service", - "Name of the service account to generate cloud access tokens for.") + "Name of the service account to generate cloud access tokens for (unless specified per on-prem robot).") ) func main() { @@ -101,7 +102,7 @@ func main() { slog.SetDefault(slog.New(logHandler)) // init components ctx := context.Background() - var rep app.PubKeyRepository + var rep repository.PubKeyRepository var err error if *keyStore == Kubernetes { config, err := rest.InClusterConfig() diff --git a/src/go/cmd/token-vendor/repository/BUILD.bazel b/src/go/cmd/token-vendor/repository/BUILD.bazel index e69de29bb..1461d88aa 100644 --- a/src/go/cmd/token-vendor/repository/BUILD.bazel +++ b/src/go/cmd/token-vendor/repository/BUILD.bazel @@ -0,0 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = ["repository.go"], + importpath = "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository", + visibility = ["//visibility:public"], +) diff --git a/src/go/cmd/token-vendor/repository/k8s/BUILD.bazel b/src/go/cmd/token-vendor/repository/k8s/BUILD.bazel index 4de42dbe9..88d752531 100644 --- a/src/go/cmd/token-vendor/repository/k8s/BUILD.bazel +++ b/src/go/cmd/token-vendor/repository/k8s/BUILD.bazel @@ -11,6 +11,7 @@ go_library( "@io_k8s_apimachinery//pkg/api/errors:go_default_library", "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library", "@io_k8s_client_go//kubernetes:go_default_library", + "//src/go/cmd/token-vendor/repository:go_default_library", ], ) @@ -18,5 +19,8 @@ go_test( name = "go_default_test", srcs = ["k8s_test.go"], embed = [":go_default_library"], - deps = ["@io_k8s_client_go//kubernetes/fake:go_default_library"], + deps = [ + "@io_k8s_client_go//kubernetes/fake:go_default_library", + "//src/go/cmd/token-vendor/repository:go_default_library", + ], ) diff --git a/src/go/cmd/token-vendor/repository/k8s/k8s.go b/src/go/cmd/token-vendor/repository/k8s/k8s.go index fb06f706b..3062ca3bc 100644 --- a/src/go/cmd/token-vendor/repository/k8s/k8s.go +++ b/src/go/cmd/token-vendor/repository/k8s/k8s.go @@ -24,6 +24,8 @@ import ( kerrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" + + "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository" ) // K8sRepository uses Kubernetes configmaps as public key backend for devices. @@ -43,6 +45,8 @@ func NewK8sRepository(ctx context.Context, kcl kubernetes.Interface, ns string) const ( pubKey = "pubKey" // Configmap key for the public key + // Configmap annotation specifies the service account to use (optional) + serviceAccountAnnotation = "cloudrobotics.com/gcp-service-account" ) // ListAllDeviceIDs returns a slice of all device identifiers found in the namespace. @@ -63,21 +67,22 @@ func (k *K8sRepository) ListAllDeviceIDs(ctx context.Context) ([]string, error) // The public key is stored under a specific key in the configmap. If the configmap // does not exist, we return an empty string. For any other error or a malformed // configmap we return an error. -func (k *K8sRepository) LookupKey(ctx context.Context, deviceID string) (string, error) { +func (k *K8sRepository) LookupKey(ctx context.Context, deviceID string) (*repository.Key, error) { slog.Debug("looking up public key", slog.String("Namespace", k.ns), slog.String("ConfigMap", deviceID)) cm, err := k.kcl.CoreV1().ConfigMaps(k.ns).Get(ctx, deviceID, metav1.GetOptions{}) if kerrors.IsNotFound(err) { slog.Debug("ConfigMap not found", slog.String("Namespace", k.ns), slog.String("ConfigMap", deviceID)) - return "", nil + return nil, nil } if err != nil { - return "", errors.Wrapf(err, "failed to retrieve configmap %q/%q", k.ns, deviceID) + return nil, errors.Wrapf(err, "failed to retrieve configmap %q/%q", k.ns, deviceID) } key, found := cm.Data[pubKey] if !found { - return "", fmt.Errorf("configmap %q/%q does not contain key %q", k.ns, deviceID, pubKey) + return nil, fmt.Errorf("configmap %q/%q does not contain key %q", k.ns, deviceID, pubKey) } - return key, nil + sa, _ := cm.ObjectMeta.Annotations[serviceAccountAnnotation] + return &repository.Key{key, sa}, nil } // PublishKey sets or updates a public key for a given device identifier. diff --git a/src/go/cmd/token-vendor/repository/k8s/k8s_test.go b/src/go/cmd/token-vendor/repository/k8s/k8s_test.go index c697ff8ca..e5706f2d6 100644 --- a/src/go/cmd/token-vendor/repository/k8s/k8s_test.go +++ b/src/go/cmd/token-vendor/repository/k8s/k8s_test.go @@ -64,8 +64,8 @@ func TestPublishKeyUpdate(t *testing.T) { if err != nil { t.Fatal(err) } - if k != key2 { - t.Fatalf("LookupKey(..) = %q, want %q", k, key2) + if k.PublicKey != key2 { + t.Fatalf("LookupKey(..) = %q, want %q", k.PublicKey, key2) } } @@ -80,7 +80,7 @@ func TestLookupDoesNotExist(t *testing.T) { if err != nil { t.Fatalf("LookupKey produced error %v, want nil", err) } - if k != "" { + if k != nil { t.Fatalf("LookupKey(..) = %q, want empty string", k) } } diff --git a/src/go/cmd/token-vendor/repository/memory/BUILD.bazel b/src/go/cmd/token-vendor/repository/memory/BUILD.bazel index c5cb99a79..778b76302 100644 --- a/src/go/cmd/token-vendor/repository/memory/BUILD.bazel +++ b/src/go/cmd/token-vendor/repository/memory/BUILD.bazel @@ -5,6 +5,9 @@ go_library( srcs = ["memory.go"], importpath = "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository/memory", visibility = ["//visibility:public"], + deps = [ + "//src/go/cmd/token-vendor/repository:go_default_library", + ] ) go_test( diff --git a/src/go/cmd/token-vendor/repository/memory/memory.go b/src/go/cmd/token-vendor/repository/memory/memory.go index 92a6ce163..bcdb4d088 100644 --- a/src/go/cmd/token-vendor/repository/memory/memory.go +++ b/src/go/cmd/token-vendor/repository/memory/memory.go @@ -17,6 +17,8 @@ package memory import ( "context" "log/slog" + + "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository" ) // MemoryRepository uses a in-memory datastructure to store the keys. @@ -35,8 +37,12 @@ func (m *MemoryRepository) PublishKey(ctx context.Context, deviceID, publicKey s return nil } -func (m *MemoryRepository) LookupKey(ctx context.Context, deviceID string) (string, error) { +func (m *MemoryRepository) LookupKey(ctx context.Context, deviceID string) (*repository.Key, error) { slog.Debug("LookupKey", slog.String("DeviceID", deviceID)) // key not found does not need to be an error - return m.keys[deviceID], nil + k, found := m.keys[deviceID] + if !found { + return nil, nil + } + return &repository.Key{k, ""}, nil } diff --git a/src/go/cmd/token-vendor/repository/memory/memory_test.go b/src/go/cmd/token-vendor/repository/memory/memory_test.go index 56438f85e..5a8246b90 100644 --- a/src/go/cmd/token-vendor/repository/memory/memory_test.go +++ b/src/go/cmd/token-vendor/repository/memory/memory_test.go @@ -31,17 +31,18 @@ func TestMemoryBackend(t *testing.T) { if err := m.PublishKey(context.TODO(), "b", "bkey"); err != nil { t.Fatal(err) } - var k string - if k, err = m.LookupKey(context.TODO(), "a"); err != nil { + k, err := m.LookupKey(context.TODO(), "a") + if err != nil { t.Fatal(err) } - if k != "akey" { + if k.PublicKey != "akey" { t.Fatalf("Key for a: got %q, want %q", k, "akey") } - if k, err = m.LookupKey(context.TODO(), "b"); err != nil { + k, err = m.LookupKey(context.TODO(), "b") + if err != nil { t.Fatal(err) } - if k != "bkey" { + if k.PublicKey != "bkey" { t.Fatalf("Key for b: got %q, want %q", k, "bkey") } } @@ -51,11 +52,11 @@ func TestMemoryNotFound(t *testing.T) { if err != nil { t.Fatal(err) } - var k string - if k, err = m.LookupKey(context.TODO(), "a"); err != nil { + k, err := m.LookupKey(context.TODO(), "a") + if err != nil { t.Fatal(err) } - if k != "" { + if k != nil { t.Fatalf("LookupKey: got %q, expected empty response", k) } } diff --git a/src/go/cmd/token-vendor/repository/repository.go b/src/go/cmd/token-vendor/repository/repository.go new file mode 100644 index 000000000..a580d1685 --- /dev/null +++ b/src/go/cmd/token-vendor/repository/repository.go @@ -0,0 +1,37 @@ +// Copyright 2024 The Cloud Robotics Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package repository defines the api for the pub key stores +package repository + +import ( + "context" +) + +// Key holds data + metadata of a public key entry +type Key struct { + // PublicKey contains the public key data + PublicKey string + // SAName is the optional GCP IAM service-account that has been associated. + SAName string +} + +// PubKeyRepository defines the api for the pub key stores +type PubKeyRepository interface { + // LookupKey retrieves the public key of a device from the repository. + // An empty string return indicates that no key exists for the given identifier or + // that the device is blocked. + LookupKey(ctx context.Context, deviceID string) (*Key, error) + PublishKey(ctx context.Context, deviceID, publicKey string) error +} diff --git a/src/go/cmd/token-vendor/tokensource/gcp.go b/src/go/cmd/token-vendor/tokensource/gcp.go index 660624a8d..d6caf99e3 100644 --- a/src/go/cmd/token-vendor/tokensource/gcp.go +++ b/src/go/cmd/token-vendor/tokensource/gcp.go @@ -16,8 +16,8 @@ import ( type GCPTokenSource struct { service *iam.Service // the FQN of the service account - resource string - scopes []string + defaultResource string + scopes []string } type TokenResponse struct { @@ -32,27 +32,33 @@ type TokenResponse struct { // `client` parameter is optional. If you supply your own client, you have to make // sure you set the correct authentication headers yourself. If no client is given, // authentication information is looked up from the environment. -func NewGCPTokenSource(ctx context.Context, client *http.Client, project, sa string, scopes []string) (*GCPTokenSource, error) { +// `defaultSAName` specifies the GCP IAM service accoutn name to use if no +// dedicated service account is configurred on the key. +func NewGCPTokenSource(ctx context.Context, client *http.Client, project, defaultSAName string, scopes []string) (*GCPTokenSource, error) { service, err := iam.NewService(ctx, option.WithHTTPClient(client)) if err != nil { return nil, errors.Wrap(err, "failed to create IAM service client") } - resource := fmt.Sprintf("projects/-/serviceAccounts/%s@%s.iam.gserviceaccount.com", sa, project) - return &GCPTokenSource{service: service, resource: resource, scopes: scopes}, nil + resource := fmt.Sprintf("projects/-/serviceAccounts/%s@%s.iam.gserviceaccount.com", defaultSAName, project) + return &GCPTokenSource{service: service, defaultResource: resource, scopes: scopes}, nil } // Token returns an access token for the configured service account and scopes. // // API: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/generateAccessToken -func (g *GCPTokenSource) Token(ctx context.Context) (*TokenResponse, error) { +func (g *GCPTokenSource) Token(ctx context.Context, saName string) (*TokenResponse, error) { req := iam.GenerateAccessTokenRequest{Scope: g.scopes} + resource := g.defaultResource + if saName != "" { + resource = "projects/-/serviceAccounts/" + saName + } // We don't set a 'lifetime' on the request, so we get the default value (3600 sec = 1h). // This needs to be in sync with the min(cookie-expire,cookie-refresh) duration // configured on oauth2-proxy. resp, err := g.service.Projects.ServiceAccounts. - GenerateAccessToken(g.resource, &req).Context(ctx).Do() + GenerateAccessToken(resource, &req).Context(ctx).Do() if err != nil { - return nil, errors.Wrapf(err, "GenerateAccessToken(..) for %q failed", g.resource) + return nil, errors.Wrapf(err, "GenerateAccessToken(..) for %q failed", resource) } tok, err := tokenResponse(resp, g.scopes, time.Now()) if err != nil {