diff --git a/examples/scratch-env/go.mod b/examples/scratch-env/go.mod index 5d921c30ce..bd7fc50656 100644 --- a/examples/scratch-env/go.mod +++ b/examples/scratch-env/go.mod @@ -14,6 +14,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/evanphx/json-patch/v5 v5.9.0 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect diff --git a/examples/scratch-env/go.sum b/examples/scratch-env/go.sum index 3c11d595c5..63a151e33f 100644 --- a/examples/scratch-env/go.sum +++ b/examples/scratch-env/go.sum @@ -13,6 +13,8 @@ github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8 github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg= github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= diff --git a/go.mod b/go.mod index 724eabfde7..ae141ccb72 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.23.0 require ( github.com/evanphx/json-patch/v5 v5.9.0 + github.com/fsnotify/fsnotify v1.7.0 github.com/go-logr/logr v1.4.2 github.com/go-logr/zapr v1.3.0 github.com/google/btree v1.1.3 @@ -42,7 +43,6 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect diff --git a/pkg/certwatcher/certwatcher.go b/pkg/certwatcher/certwatcher.go index d295b29864..c323240982 100644 --- a/pkg/certwatcher/certwatcher.go +++ b/pkg/certwatcher/certwatcher.go @@ -20,10 +20,15 @@ import ( "bytes" "context" "crypto/tls" + "fmt" "os" "sync" "time" + "github.com/fsnotify/fsnotify" + kerrors "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" "sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics" logf "sigs.k8s.io/controller-runtime/pkg/internal/log" ) @@ -40,6 +45,7 @@ type CertWatcher struct { sync.RWMutex currentCert *tls.Certificate + watcher *fsnotify.Watcher interval time.Duration certPath string @@ -53,13 +59,25 @@ type CertWatcher struct { // New returns a new CertWatcher watching the given certificate and key. func New(certPath, keyPath string) (*CertWatcher, error) { + var err error + cw := &CertWatcher{ certPath: certPath, keyPath: keyPath, interval: defaultWatchInterval, } - return cw, cw.ReadCertificate() + // Initial read of certificate and key. + if err := cw.ReadCertificate(); err != nil { + return nil, err + } + + cw.watcher, err = fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + return cw, nil } // WithWatchInterval sets the watch interval and returns the CertWatcher pointer @@ -88,14 +106,35 @@ func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, // Start starts the watch on the certificate and key files. func (cw *CertWatcher) Start(ctx context.Context) error { + files := sets.New(cw.certPath, cw.keyPath) + + { + var watchErr error + if err := wait.PollUntilContextTimeout(ctx, 1*time.Second, 10*time.Second, true, func(ctx context.Context) (done bool, err error) { + for _, f := range files.UnsortedList() { + if err := cw.watcher.Add(f); err != nil { + watchErr = err + return false, nil //nolint:nilerr // We want to keep trying. + } + // We've added the watch, remove it from the set. + files.Delete(f) + } + return true, nil + }); err != nil { + return fmt.Errorf("failed to add watches: %w", kerrors.NewAggregate([]error{err, watchErr})) + } + } + + go cw.Watch() + ticker := time.NewTicker(cw.interval) defer ticker.Stop() - log.Info("Starting certificate watcher") + log.Info("Starting certificate poll+watcher", "interval", cw.interval) for { select { case <-ctx.Done(): - return nil + return cw.watcher.Close() case <-ticker.C: if err := cw.ReadCertificate(); err != nil { log.Error(err, "failed read certificate") @@ -104,6 +143,28 @@ func (cw *CertWatcher) Start(ctx context.Context) error { } } +// Watch reads events from the watcher's channel and reacts to changes. +func (cw *CertWatcher) Watch() { + for { + select { + case event, ok := <-cw.watcher.Events: + // Channel is closed. + if !ok { + return + } + + cw.handleEvent(event) + case err, ok := <-cw.watcher.Errors: + // Channel is closed. + if !ok { + return + } + + log.Error(err, "certificate watch error") + } + } +} + // updateCachedCertificate checks if the new certificate differs from the cache, // updates it and returns the result if it was updated or not func (cw *CertWatcher) updateCachedCertificate(cert *tls.Certificate, keyPEMBlock []byte) bool { @@ -159,3 +220,23 @@ func (cw *CertWatcher) ReadCertificate() error { } return nil } + +func (cw *CertWatcher) handleEvent(event fsnotify.Event) { + // Only care about events which may modify the contents of the file. + switch { + case event.Op.Has(fsnotify.Write): + case event.Op.Has(fsnotify.Create): + case event.Op.Has(fsnotify.Chmod), event.Op.Has(fsnotify.Remove): + // If the file was removed or renamed, re-add the watch to the previous name + if err := cw.watcher.Add(event.Name); err != nil { + log.Error(err, "error re-watching file") + } + default: + return + } + + log.V(1).Info("certificate event", "event", event) + if err := cw.ReadCertificate(); err != nil { + log.Error(err, "error re-reading certificate") + } +} diff --git a/pkg/certwatcher/certwatcher_test.go b/pkg/certwatcher/certwatcher_test.go index f3388f096e..b8018dbdc0 100644 --- a/pkg/certwatcher/certwatcher_test.go +++ b/pkg/certwatcher/certwatcher_test.go @@ -76,12 +76,12 @@ var _ = Describe("CertWatcher", func() { Expect(err).ToNot(HaveOccurred()) }) - startWatcher := func() (done <-chan struct{}) { + startWatcher := func(interval time.Duration) (done <-chan struct{}) { doneCh := make(chan struct{}) go func() { defer GinkgoRecover() defer close(doneCh) - Expect(watcher.WithWatchInterval(time.Second).Start(ctx)).To(Succeed()) + Expect(watcher.WithWatchInterval(interval).Start(ctx)).To(Succeed()) }() // wait till we read first cert Eventually(func() error { @@ -92,14 +92,16 @@ var _ = Describe("CertWatcher", func() { } It("should read the initial cert/key", func() { - doneCh := startWatcher() + // This test verifies the initial read succeeded. So interval doesn't matter. + doneCh := startWatcher(10 * time.Second) ctxCancel() Eventually(doneCh, "4s").Should(BeClosed()) }) It("should reload currentCert when changed", func() { - doneCh := startWatcher() + // This test verifies fsnotify detects the cert change. So interval doesn't matter. + doneCh := startWatcher(10 * time.Second) called := atomic.Int64{} watcher.RegisterCallback(func(crt tls.Certificate) { called.Add(1) @@ -123,7 +125,8 @@ var _ = Describe("CertWatcher", func() { }) It("should reload currentCert when changed with rename", func() { - doneCh := startWatcher() + // This test verifies fsnotify detects the cert change. So interval doesn't matter. + doneCh := startWatcher(10 * time.Second) called := atomic.Int64{} watcher.RegisterCallback(func(crt tls.Certificate) { called.Add(1) @@ -153,7 +156,8 @@ var _ = Describe("CertWatcher", func() { }) It("should reload currentCert after move out", func() { - doneCh := startWatcher() + // This test verifies poll works, so we'll use 1s as interval (fsnotify doesn't detect this change). + doneCh := startWatcher(1 * time.Second) called := atomic.Int64{} watcher.RegisterCallback(func(crt tls.Certificate) { called.Add(1) @@ -189,7 +193,8 @@ var _ = Describe("CertWatcher", func() { }) It("should get updated on successful certificate read", func() { - doneCh := startWatcher() + // This test verifies fsnotify, so interval doesn't matter. + doneCh := startWatcher(10 * time.Second) Eventually(func() error { readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal) @@ -204,7 +209,8 @@ var _ = Describe("CertWatcher", func() { }) It("should get updated on read certificate errors", func() { - doneCh := startWatcher() + // This test works with fsnotify, so interval doesn't matter. + doneCh := startWatcher(10 * time.Second) Eventually(func() error { readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)