Skip to content

Commit

Permalink
Allow using a custom service account on an ipc-registration. (#461)
Browse files Browse the repository at this point in the history
One can already configure a custom service account by using the
annotation on the configmap. The api will be added in teh next changes.
  • Loading branch information
ensonic authored Nov 25, 2024
1 parent 54f144e commit 19b76a6
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 57 deletions.
1 change: 1 addition & 0 deletions src/go/cmd/token-vendor/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ go_library(
"//src/go/cmd/token-vendor/api/v1:go_default_library",
"//src/go/cmd/token-vendor/app:go_default_library",
"//src/go/cmd/token-vendor/oauth:go_default_library",
"//src/go/cmd/token-vendor/repository:go_default_library",
"//src/go/cmd/token-vendor/repository/k8s:go_default_library",
"//src/go/cmd/token-vendor/repository/memory:go_default_library",
"//src/go/cmd/token-vendor/tokensource:go_default_library",
Expand Down
2 changes: 1 addition & 1 deletion src/go/cmd/token-vendor/api/v1/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (h *HandlerContext) verifyJWTHandler(w http.ResponseWriter, r *http.Request
// Authorization: ...
jwtString := strings.TrimPrefix(authHeader[0], "Bearer ")

if _, err := h.tv.ValidateJWT(r.Context(), jwtString); err != nil {
if _, _, err := h.tv.ValidateJWT(r.Context(), jwtString); err != nil {
slog.WarnContext(r.Context(), "JWT failed validation", ilog.Err(err))
api.ErrResponse(w, http.StatusForbidden, "JWT not valid")
return
Expand Down
1 change: 1 addition & 0 deletions src/go/cmd/token-vendor/app/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ go_library(
deps = [
"//src/go/cmd/token-vendor/oauth:go_default_library",
"//src/go/cmd/token-vendor/oauth/jwt:go_default_library",
"//src/go/cmd/token-vendor/repository:go_default_library",
"//src/go/cmd/token-vendor/tokensource:go_default_library",
"@com_github_pkg_errors//:go_default_library",
"@com_github_prometheus_client_golang//prometheus:go_default_library",
Expand Down
50 changes: 23 additions & 27 deletions src/go/cmd/token-vendor/app/tokenvendor.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,19 @@ import (

"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/oauth"
"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/oauth/jwt"
"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository"
"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/tokensource"
"github.com/pkg/errors"
)

type PubKeyRepository interface {

// LookupKey retrieves the public key of a device from the repository.
// An empty string return indicates that no key exists for the given identifier or
// that the device is blocked.
LookupKey(ctx context.Context, deviceID string) (string, error)
PublishKey(ctx context.Context, deviceID, publicKey string) error
}

type TokenVendor struct {
repo PubKeyRepository
repo repository.PubKeyRepository
v *oauth.TokenVerifier
ts *tokensource.GCPTokenSource
accAud string
}

func NewTokenVendor(ctx context.Context, repo PubKeyRepository, v *oauth.TokenVerifier, ts *tokensource.GCPTokenSource, acceptedAudience string) (*TokenVendor, error) {
func NewTokenVendor(ctx context.Context, repo repository.PubKeyRepository, v *oauth.TokenVerifier, ts *tokensource.GCPTokenSource, acceptedAudience string) (*TokenVendor, error) {
if acceptedAudience == "" {
return nil, errors.New("accepted audience must not be empty")
}
Expand All @@ -61,7 +53,11 @@ func (tv *TokenVendor) PublishPublicKey(ctx context.Context, deviceID, publicKey

func (tv *TokenVendor) ReadPublicKey(ctx context.Context, deviceID string) (string, error) {
slog.Debug("Returning public key", slog.String("DeviceID", deviceID))
return tv.repo.LookupKey(ctx, deviceID)
key, err := tv.repo.LookupKey(ctx, deviceID)
if key != nil {
return key.PublicKey, nil
}
return "", err
}

var (
Expand Down Expand Up @@ -95,48 +91,48 @@ func (tv *TokenVendor) GetOAuth2Token(ctx context.Context, jwtk string) (*tokens
return r, err
}

func (tv *TokenVendor) ValidateJWT(ctx context.Context, jwtk string) (string, error) {
func (tv *TokenVendor) ValidateJWT(ctx context.Context, jwtk string) (string, string, error) {
p, err := jwt.PayloadUnsafe(jwtk)
if err != nil {
return "", errors.Wrap(err, "failed to extract JWT payload")
return "", "", errors.Wrap(err, "failed to extract JWT payload")
}
exp := time.Unix(p.Exp, 0)
if exp.Before(time.Now()) {
return "", fmt.Errorf("JWT has expired %v, %v ago (iss: %q)",
return "", "", fmt.Errorf("JWT has expired %v, %v ago (iss: %q)",
exp, time.Since(exp), p.Iss)
}
if err := acceptedAudience(p.Aud, tv.accAud); err != nil {
return "", errors.Wrapf(err, "validation of JWT audience failed (iss: %q)", p.Iss)
return "", "", errors.Wrapf(err, "validation of JWT audience failed (iss: %q)", p.Iss)
}
if !IsValidDeviceID(p.Iss) {
return "", fmt.Errorf("missing or invalid device identifier (`iss`: %q)", p.Iss)
return "", "", fmt.Errorf("missing or invalid device identifier (`iss`: %q)", p.Iss)
}
deviceID := p.Iss
pubKey, err := tv.repo.LookupKey(ctx, deviceID)
k, err := tv.repo.LookupKey(ctx, deviceID)
if err != nil {
return "", errors.Wrapf(err, "failed to retrieve public key for device %q", deviceID)
return "", "", errors.Wrapf(err, "failed to retrieve public key for device %q", deviceID)
}
if pubKey == "" {
return "", errors.Errorf("no public key found for device %q", deviceID)
if k.PublicKey == "" {
return "", "", errors.Errorf("no public key found for device %q", deviceID)
}
err = jwt.VerifySignature(jwtk, pubKey)
err = jwt.VerifySignature(jwtk, k.PublicKey)
if err != nil {
return "", errors.Wrapf(err, "failed to verify signature for device %q", deviceID)
return "", "", errors.Wrapf(err, "failed to verify signature for device %q", deviceID)
}

return deviceID, nil
return deviceID, k.SAName, nil
}

func (tv *TokenVendor) getOAuth2Token(ctx context.Context, jwtk string) (*tokensource.TokenResponse, error) {
deviceID, err := tv.ValidateJWT(ctx, jwtk)
deviceID, sa, err := tv.ValidateJWT(ctx, jwtk)
if err != nil {
return nil, err
}
cloudToken, err := tv.ts.Token(ctx)
cloudToken, err := tv.ts.Token(ctx, sa)
if err != nil {
return nil, errors.Wrapf(err, "failed to retrieve a cloud token for device %q", deviceID)
}
slog.Info("Handing out cloud token", slog.String("DeviceID", deviceID))
slog.Info("Handing out cloud token", slog.String("DeviceID", deviceID), slog.String("ServiceAccount", sa))
return cloudToken, nil
}

Expand Down
5 changes: 3 additions & 2 deletions src/go/cmd/token-vendor/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
apiv1 "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/api/v1"
"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/app"
"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/oauth"
"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository"
"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository/k8s"
"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository/memory"
"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/tokensource"
Expand Down Expand Up @@ -87,7 +88,7 @@ var (
"", "Endpoint URL of the token vendor. Used for verification of JWTs send by robots.")
scopes = scopeFlags{}
robotName = flag.String("service_account", "robot-service",
"Name of the service account to generate cloud access tokens for.")
"Name of the service account to generate cloud access tokens for (unless specified per on-prem robot).")
)

func main() {
Expand All @@ -101,7 +102,7 @@ func main() {
slog.SetDefault(slog.New(logHandler))
// init components
ctx := context.Background()
var rep app.PubKeyRepository
var rep repository.PubKeyRepository
var err error
if *keyStore == Kubernetes {
config, err := rest.InClusterConfig()
Expand Down
8 changes: 8 additions & 0 deletions src/go/cmd/token-vendor/repository/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "go_default_library",
srcs = ["repository.go"],
importpath = "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository",
visibility = ["//visibility:public"],
)
6 changes: 5 additions & 1 deletion src/go/cmd/token-vendor/repository/k8s/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@ go_library(
"@io_k8s_apimachinery//pkg/api/errors:go_default_library",
"@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library",
"@io_k8s_client_go//kubernetes:go_default_library",
"//src/go/cmd/token-vendor/repository:go_default_library",
],
)

go_test(
name = "go_default_test",
srcs = ["k8s_test.go"],
embed = [":go_default_library"],
deps = ["@io_k8s_client_go//kubernetes/fake:go_default_library"],
deps = [
"@io_k8s_client_go//kubernetes/fake:go_default_library",
"//src/go/cmd/token-vendor/repository:go_default_library",
],
)
15 changes: 10 additions & 5 deletions src/go/cmd/token-vendor/repository/k8s/k8s.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
kerrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"

"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository"
)

// K8sRepository uses Kubernetes configmaps as public key backend for devices.
Expand All @@ -43,6 +45,8 @@ func NewK8sRepository(ctx context.Context, kcl kubernetes.Interface, ns string)

const (
pubKey = "pubKey" // Configmap key for the public key
// Configmap annotation specifies the service account to use (optional)
serviceAccountAnnotation = "cloudrobotics.com/gcp-service-account"
)

// ListAllDeviceIDs returns a slice of all device identifiers found in the namespace.
Expand All @@ -63,21 +67,22 @@ func (k *K8sRepository) ListAllDeviceIDs(ctx context.Context) ([]string, error)
// The public key is stored under a specific key in the configmap. If the configmap
// does not exist, we return an empty string. For any other error or a malformed
// configmap we return an error.
func (k *K8sRepository) LookupKey(ctx context.Context, deviceID string) (string, error) {
func (k *K8sRepository) LookupKey(ctx context.Context, deviceID string) (*repository.Key, error) {
slog.Debug("looking up public key", slog.String("Namespace", k.ns), slog.String("ConfigMap", deviceID))
cm, err := k.kcl.CoreV1().ConfigMaps(k.ns).Get(ctx, deviceID, metav1.GetOptions{})
if kerrors.IsNotFound(err) {
slog.Debug("ConfigMap not found", slog.String("Namespace", k.ns), slog.String("ConfigMap", deviceID))
return "", nil
return nil, nil
}
if err != nil {
return "", errors.Wrapf(err, "failed to retrieve configmap %q/%q", k.ns, deviceID)
return nil, errors.Wrapf(err, "failed to retrieve configmap %q/%q", k.ns, deviceID)
}
key, found := cm.Data[pubKey]
if !found {
return "", fmt.Errorf("configmap %q/%q does not contain key %q", k.ns, deviceID, pubKey)
return nil, fmt.Errorf("configmap %q/%q does not contain key %q", k.ns, deviceID, pubKey)
}
return key, nil
sa, _ := cm.ObjectMeta.Annotations[serviceAccountAnnotation]
return &repository.Key{key, sa}, nil
}

// PublishKey sets or updates a public key for a given device identifier.
Expand Down
6 changes: 3 additions & 3 deletions src/go/cmd/token-vendor/repository/k8s/k8s_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ func TestPublishKeyUpdate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if k != key2 {
t.Fatalf("LookupKey(..) = %q, want %q", k, key2)
if k.PublicKey != key2 {
t.Fatalf("LookupKey(..) = %q, want %q", k.PublicKey, key2)
}
}

Expand All @@ -80,7 +80,7 @@ func TestLookupDoesNotExist(t *testing.T) {
if err != nil {
t.Fatalf("LookupKey produced error %v, want nil", err)
}
if k != "" {
if k != nil {
t.Fatalf("LookupKey(..) = %q, want empty string", k)
}
}
3 changes: 3 additions & 0 deletions src/go/cmd/token-vendor/repository/memory/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ go_library(
srcs = ["memory.go"],
importpath = "github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository/memory",
visibility = ["//visibility:public"],
deps = [
"//src/go/cmd/token-vendor/repository:go_default_library",
]
)

go_test(
Expand Down
10 changes: 8 additions & 2 deletions src/go/cmd/token-vendor/repository/memory/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package memory
import (
"context"
"log/slog"

"github.com/googlecloudrobotics/core/src/go/cmd/token-vendor/repository"
)

// MemoryRepository uses a in-memory datastructure to store the keys.
Expand All @@ -35,8 +37,12 @@ func (m *MemoryRepository) PublishKey(ctx context.Context, deviceID, publicKey s
return nil
}

func (m *MemoryRepository) LookupKey(ctx context.Context, deviceID string) (string, error) {
func (m *MemoryRepository) LookupKey(ctx context.Context, deviceID string) (*repository.Key, error) {
slog.Debug("LookupKey", slog.String("DeviceID", deviceID))
// key not found does not need to be an error
return m.keys[deviceID], nil
k, found := m.keys[deviceID]
if !found {
return nil, nil
}
return &repository.Key{k, ""}, nil
}
17 changes: 9 additions & 8 deletions src/go/cmd/token-vendor/repository/memory/memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,18 @@ func TestMemoryBackend(t *testing.T) {
if err := m.PublishKey(context.TODO(), "b", "bkey"); err != nil {
t.Fatal(err)
}
var k string
if k, err = m.LookupKey(context.TODO(), "a"); err != nil {
k, err := m.LookupKey(context.TODO(), "a")
if err != nil {
t.Fatal(err)
}
if k != "akey" {
if k.PublicKey != "akey" {
t.Fatalf("Key for a: got %q, want %q", k, "akey")
}
if k, err = m.LookupKey(context.TODO(), "b"); err != nil {
k, err = m.LookupKey(context.TODO(), "b")
if err != nil {
t.Fatal(err)
}
if k != "bkey" {
if k.PublicKey != "bkey" {
t.Fatalf("Key for b: got %q, want %q", k, "bkey")
}
}
Expand All @@ -51,11 +52,11 @@ func TestMemoryNotFound(t *testing.T) {
if err != nil {
t.Fatal(err)
}
var k string
if k, err = m.LookupKey(context.TODO(), "a"); err != nil {
k, err := m.LookupKey(context.TODO(), "a")
if err != nil {
t.Fatal(err)
}
if k != "" {
if k != nil {
t.Fatalf("LookupKey: got %q, expected empty response", k)
}
}
37 changes: 37 additions & 0 deletions src/go/cmd/token-vendor/repository/repository.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2024 The Cloud Robotics Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package repository defines the api for the pub key stores
package repository

import (
"context"
)

// Key holds data + metadata of a public key entry
type Key struct {
// PublicKey contains the public key data
PublicKey string
// SAName is the optional GCP IAM service-account that has been associated.
SAName string
}

// PubKeyRepository defines the api for the pub key stores
type PubKeyRepository interface {
// LookupKey retrieves the public key of a device from the repository.
// An empty string return indicates that no key exists for the given identifier or
// that the device is blocked.
LookupKey(ctx context.Context, deviceID string) (*Key, error)
PublishKey(ctx context.Context, deviceID, publicKey string) error
}
Loading

0 comments on commit 19b76a6

Please sign in to comment.