Skip to content

Commit

Permalink
fetch SAs from apiserver
Browse files Browse the repository at this point in the history
  • Loading branch information
modulitos committed Nov 14, 2024
1 parent feac6cc commit 34280cc
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 29 deletions.
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ func main() {
saInformer,
cmInformer,
composeRoleArnCache,
clientset.CoreV1(),
)
stop := make(chan struct{})
informerFactory.Start(stop)
Expand Down
72 changes: 52 additions & 20 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,23 @@
package cache

import (
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/aws/amazon-eks-pod-identity-webhook/pkg"
"github.com/prometheus/client_golang/prometheus"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
coreinformers "k8s.io/client-go/informers/core/v1"
"k8s.io/client-go/kubernetes"
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/client-go/tools/cache"
"k8s.io/klog/v2"
)
Expand Down Expand Up @@ -80,8 +84,7 @@ type serviceAccountCache struct {
composeRoleArn ComposeRoleArn
defaultTokenExpiration int64
webhookUsage prometheus.Gauge
notificationHandlers map[string]chan struct{}
handlerMu sync.Mutex
notifications *notifications
}

type ComposeRoleArn struct {
Expand Down Expand Up @@ -156,20 +159,13 @@ func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (u
return false, pkg.DefaultTokenExpiration
}

func (c *serviceAccountCache) getSA(req Request) (*Entry, chan struct{}) {
func (c *serviceAccountCache) getSA(req Request) (*Entry, <-chan struct{}) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, ok := c.saCache[req.CacheKey()]
if !ok && req.RequestNotification {
klog.V(5).Infof("Service Account %s not found in cache, adding notification handler", req.CacheKey())
c.handlerMu.Lock()
defer c.handlerMu.Unlock()
notifier, found := c.notificationHandlers[req.CacheKey()]
if !found {
notifier = make(chan struct{})
c.notificationHandlers[req.CacheKey()] = notifier
}
return nil, notifier
return nil, c.notifications.create(req)
}
return entry, nil
}
Expand Down Expand Up @@ -264,13 +260,7 @@ func (c *serviceAccountCache) setSA(name, namespace string, entry *Entry) {
klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, entry)
c.saCache[key] = entry

c.handlerMu.Lock()
defer c.handlerMu.Unlock()
if handler, found := c.notificationHandlers[key]; found {
klog.V(5).Infof("Notifying handlers for %q", key)
close(handler)
delete(c.notificationHandlers, key)
}
c.notifications.broadcast(key)
}

func (c *serviceAccountCache) setCM(name, namespace string, entry *Entry) {
Expand All @@ -280,7 +270,15 @@ func (c *serviceAccountCache) setCM(name, namespace string, entry *Entry) {
c.cmCache[namespace+"/"+name] = entry
}

func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenExpiration int64, saInformer coreinformers.ServiceAccountInformer, cmInformer coreinformers.ConfigMapInformer, composeRoleArn ComposeRoleArn) ServiceAccountCache {
func New(defaultAudience,
prefix string,
defaultRegionalSTS bool,
defaultTokenExpiration int64,
saInformer coreinformers.ServiceAccountInformer,
cmInformer coreinformers.ConfigMapInformer,
composeRoleArn ComposeRoleArn,
SAGetter corev1.ServiceAccountsGetter,
) ServiceAccountCache {
hasSynced := func() bool {
if cmInformer != nil {
return saInformer.Informer().HasSynced() && cmInformer.Informer().HasSynced()
Expand All @@ -289,6 +287,8 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
}
}

// Rate limit to 10 concurrent requests against the API server.
saFetchRequests := make(chan *Request, 10)
c := &serviceAccountCache{
saCache: map[string]*Entry{},
cmCache: map[string]*Entry{},
Expand All @@ -299,9 +299,20 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
defaultTokenExpiration: defaultTokenExpiration,
hasSynced: hasSynced,
webhookUsage: webhookUsage,
notificationHandlers: map[string]chan struct{}{},
notifications: newNotifications(saFetchRequests),
}

go func() {
for req := range saFetchRequests {
sa, err := fetchFromAPI(SAGetter, req)
if err != nil {
klog.Errorf("fetching SA: %s, but got error from API: %v", req.CacheKey(), err)
continue
}
c.addSA(sa)
}
}()

saInformer.Informer().AddEventHandler(
cache.ResourceEventHandlerFuncs{
AddFunc: func(obj interface{}) {
Expand Down Expand Up @@ -351,6 +362,27 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
return c
}

func fetchFromAPI(getter corev1.ServiceAccountsGetter, req *Request) (*v1.ServiceAccount, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
saList, err := getter.ServiceAccounts(req.Namespace).List(
ctx,
metav1.ListOptions{},
)
if err != nil {
return nil, err
}

// Find the ServiceAccount
for _, sa := range saList.Items {
if sa.Name == req.Name {
return &sa, nil

}
}
return nil, fmt.Errorf("no SA found in namespace: %s", req.CacheKey())
}

func (c *serviceAccountCache) populateCacheFromCM(oldCM, newCM *v1.ConfigMap) error {
if newCM.Name != "pod-identity-webhook" {
return nil
Expand Down
97 changes: 88 additions & 9 deletions pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func TestSaCache(t *testing.T) {
defaultAudience: "sts.amazonaws.com",
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

resp := cache.Get(Request{Name: "default", Namespace: "default"})
Expand Down Expand Up @@ -69,9 +70,9 @@ func TestNotification(t *testing.T) {

t.Run("with one notification handler", func(t *testing.T) {
cache := &serviceAccountCache{
saCache: map[string]*Entry{},
notificationHandlers: map[string]chan struct{}{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
saCache: map[string]*Entry{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

// test that the requested SA is not in the cache
Expand Down Expand Up @@ -106,9 +107,9 @@ func TestNotification(t *testing.T) {

t.Run("with 10 notification handlers", func(t *testing.T) {
cache := &serviceAccountCache{
saCache: map[string]*Entry{},
notificationHandlers: map[string]chan struct{}{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
saCache: map[string]*Entry{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 5)),
}

// test that the requested SA is not in the cache
Expand Down Expand Up @@ -153,6 +154,63 @@ func TestNotification(t *testing.T) {
})
}

func TestFetchFromAPIServer(t *testing.T) {
testSA := &v1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{
Name: "default",
Namespace: "default",
Annotations: map[string]string{
"eks.amazonaws.com/role-arn": "arn:aws:iam::111122223333:role/s3-reader",
"eks.amazonaws.com/token-expiration": "3600",
},
},
}
fakeSAClient := fake.NewSimpleClientset(testSA)

// use an empty informer to simulate the need to fetch SA from api server:
fakeEmptyClient := fake.NewSimpleClientset()
emptyInformerFactory := informers.NewSharedInformerFactory(fakeEmptyClient, 0)
emptyInformer := emptyInformerFactory.Core().V1().ServiceAccounts()

cache := New(
"sts.amazonaws.com",
"eks.amazonaws.com",
true,
86400,
emptyInformer,
nil,
ComposeRoleArn{},
fakeSAClient.CoreV1(),
)

stop := make(chan struct{})
emptyInformerFactory.Start(stop)
emptyInformerFactory.WaitForCacheSync(stop)
cache.Start(stop)
defer close(stop)

err := wait.ExponentialBackoff(wait.Backoff{Duration: 10 * time.Millisecond, Factor: 1.0, Steps: 3}, func() (bool, error) {
return len(fakeEmptyClient.Actions()) != 0, nil
})
if err != nil {
t.Fatalf("informer never called client: %v", err)
}

resp := cache.Get(Request{Name: "default", Namespace: "default", RequestNotification: true})
assert.False(t, resp.FoundInCache, "Expected cache entry to not be found")

// wait for the notification while we fetch the SA from the API server:
select {
case <-resp.Notifier:
// expected
// test that the requested SA is now in the cache
resp := cache.Get(Request{Name: "default", Namespace: "default", RequestNotification: false})
assert.True(t, resp.FoundInCache, "Expected cache entry to be found in cache")
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for notification")
}
}

func TestNonRegionalSTS(t *testing.T) {
trueStr := "true"
falseStr := "false"
Expand Down Expand Up @@ -237,7 +295,16 @@ func TestNonRegionalSTS(t *testing.T) {

testComposeRoleArn := ComposeRoleArn{}

cache := New(audience, "eks.amazonaws.com", tc.defaultRegionalSTS, 86400, informer, nil, testComposeRoleArn)
cache := New(
audience,
"eks.amazonaws.com",
tc.defaultRegionalSTS,
86400,
informer,
nil,
testComposeRoleArn,
fakeClient.CoreV1(),
)
stop := make(chan struct{})
informerFactory.Start(stop)
informerFactory.WaitForCacheSync(stop)
Expand Down Expand Up @@ -295,7 +362,8 @@ func TestPopulateCacheFromCM(t *testing.T) {
}

c := serviceAccountCache{
cmCache: make(map[string]*Entry),
cmCache: make(map[string]*Entry),
notifications: newNotifications(make(chan *Request, 10)),
}

{
Expand Down Expand Up @@ -353,6 +421,7 @@ func TestSAAnnotationRemoval(t *testing.T) {
saCache: make(map[string]*Entry),
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

c.addSA(oldSA)
Expand Down Expand Up @@ -416,6 +485,7 @@ func TestCachePrecedence(t *testing.T) {
defaultTokenExpiration: pkg.DefaultTokenExpiration,
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

{
Expand Down Expand Up @@ -514,7 +584,15 @@ func TestRoleArnComposition(t *testing.T) {
informerFactory := informers.NewSharedInformerFactory(fakeClient, 0)
informer := informerFactory.Core().V1().ServiceAccounts()

cache := New(audience, "eks.amazonaws.com", true, 86400, informer, nil, testComposeRoleArn)
cache := New(audience,
"eks.amazonaws.com",
true,
86400,
informer,
nil,
testComposeRoleArn,
fakeClient.CoreV1(),
)
stop := make(chan struct{})
informerFactory.Start(stop)
informerFactory.WaitForCacheSync(stop)
Expand Down Expand Up @@ -613,6 +691,7 @@ func TestGetCommonConfigurations(t *testing.T) {
defaultAudience: "sts.amazonaws.com",
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

if tc.serviceAccount != nil {
Expand Down
60 changes: 60 additions & 0 deletions pkg/cache/notifications.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package cache

import (
"sync"

"k8s.io/klog/v2"

"github.com/prometheus/client_golang/prometheus"
)

var notificationUsage = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "pod_identity_cache_notifications",
Help: "Counter of SA notifications",
},
[]string{"method"},
)

func init() {
prometheus.MustRegister(notificationUsage)
}

type notifications struct {
handlers map[string]chan struct{}
mu sync.Mutex
fetchRequests chan<- *Request
}

func newNotifications(saFetchRequests chan<- *Request) *notifications {
return &notifications{
handlers: map[string]chan struct{}{},
fetchRequests: saFetchRequests,
}
}

func (n *notifications) create(req Request) <-chan struct{} {
n.mu.Lock()
defer n.mu.Unlock()

notificationUsage.WithLabelValues("used").Inc()
notifier, found := n.handlers[req.CacheKey()]
if !found {
notifier = make(chan struct{})
n.handlers[req.CacheKey()] = notifier
notificationUsage.WithLabelValues("created").Inc()
n.fetchRequests <- &req
}
return notifier
}

func (n *notifications) broadcast(key string) {
n.mu.Lock()
defer n.mu.Unlock()
if handler, found := n.handlers[key]; found {
klog.V(5).Infof("Notifying handlers for %q", key)
notificationUsage.WithLabelValues("broadcast").Inc()
close(handler)
delete(n.handlers, key)
}
}

0 comments on commit 34280cc

Please sign in to comment.