diff --git a/pkg/runtime/backend.go b/pkg/runtime/backend.go index a086b9f..6e9c38b 100644 --- a/pkg/runtime/backend.go +++ b/pkg/runtime/backend.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "strings" - "sync/atomic" + "sync" "time" "github.com/acorn-io/baaah/pkg/backend" @@ -24,7 +24,8 @@ type Backend struct { cacheFactory SharedControllerFactory cache cache.Cache - started atomic.Bool + startedLock *sync.RWMutex + started bool } func newBackend(cacheFactory SharedControllerFactory, client *cacheClient, cache cache.Cache) *Backend { @@ -32,13 +33,16 @@ func newBackend(cacheFactory SharedControllerFactory, client *cacheClient, cache cacheClient: client, cacheFactory: cacheFactory, cache: cache, + startedLock: new(sync.RWMutex), } } func (b *Backend) Start(ctx context.Context) (err error) { + b.startedLock.Lock() + defer b.startedLock.Unlock() defer func() { if err == nil { - b.started.Store(true) + b.started = true } }() if err := b.cacheFactory.Start(ctx, 5); err != nil { @@ -47,9 +51,10 @@ func (b *Backend) Start(ctx context.Context) (err error) { if !b.cache.WaitForCacheSync(ctx) { return fmt.Errorf("failed to wait for caches to sync") } - if !b.started.Load() { + if !b.started { b.cacheClient.startPurge(ctx) } + return nil } @@ -89,7 +94,7 @@ func (b *Backend) addIndexer(ctx context.Context, gvk schema.GroupVersionKind) e indexers := map[string]kcache.IndexFunc{} for _, field := range f.FieldNames() { field := field - indexers["field:"+field] = kcache.IndexFunc(func(obj interface{}) ([]string, error) { + indexers["field:"+field] = func(obj interface{}) ([]string, error) { f, ok := obj.(fields.Fields) if !ok { return nil, nil @@ -103,7 +108,7 @@ func (b *Backend) addIndexer(ctx context.Context, gvk schema.GroupVersionKind) e vals = append(vals, keyFunc(ko.GetNamespace(), v)) } return vals, nil - }) + } } return cache.AddIndexers(indexers) } @@ -121,7 +126,7 @@ func (b *Backend) Watch(ctx context.Context, gvk schema.GroupVersionKind, name s }) c.RegisterHandler(ctx, fmt.Sprintf("%s %v", name, gvk), handler) - if b.started.Load() { + if b.hasStarted() { return c.Start(ctx, 5) } return nil @@ -146,3 +151,9 @@ func (b *Backend) GetInformerForKind(ctx context.Context, gvk schema.GroupVersio } return i.(kcache.SharedIndexInformer), nil } + +func (b *Backend) hasStarted() bool { + b.startedLock.RLock() + defer b.startedLock.RUnlock() + return b.started +}