diff --git a/fingerproxy.go b/fingerproxy.go index ebd48e8..ed28567 100644 --- a/fingerproxy.go +++ b/fingerproxy.go @@ -14,6 +14,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/wi1dcard/fingerproxy/pkg/certwatcher" "github.com/wi1dcard/fingerproxy/pkg/debug" "github.com/wi1dcard/fingerproxy/pkg/fingerprint" "github.com/wi1dcard/fingerproxy/pkg/proxyserver" @@ -36,6 +37,7 @@ var ( PrometheusLog = log.New(os.Stderr, "[metrics] ", logFlags) ReverseProxyLog = log.New(os.Stderr, "[reverseproxy] ", logFlags) FingerprintLog = log.New(os.Stderr, "[fingerprint] ", logFlags) + CertWatcherLog = log.New(os.Stderr, "[certwatcher] ", logFlags) DefaultLog = log.New(os.Stderr, "[fingerproxy] ", logFlags) // The Prometheus metric registry used by fingerproxy @@ -90,8 +92,7 @@ func defaultReverseProxyHTTPHandler(forwardTo *url.URL, headerInjectors []revers return handler } -func defaultProxyServer(handler http.Handler, tlsConfig *tls.Config) *proxyserver.Server { - ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) +func defaultProxyServer(ctx context.Context, handler http.Handler, tlsConfig *tls.Config) *proxyserver.Server { svr := proxyserver.NewServer(ctx, handler, tlsConfig) svr.VerboseLogs = *flagVerboseLogs @@ -108,6 +109,25 @@ func defaultProxyServer(handler http.Handler, tlsConfig *tls.Config) *proxyserve return svr } +func initCertWatcher() *certwatcher.CertWatcher { + certwatcher.Logger = CertWatcherLog + certwatcher.VerboseLogs = *flagVerboseLogs + cw, err := certwatcher.New(*flagCertFilename, *flagKeyFilename) + if err != nil { + DefaultLog.Fatalf(`invalid cert filename "%s" or certkey filename "%s": %s`, *flagCertFilename, *flagKeyFilename, err) + } + return cw +} + +func defaultTLSConfig(cw *certwatcher.CertWatcher) *tls.Config { + return &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS13, + GetCertificate: cw.GetCertificate, + } +} + func initFingerprint() { fingerprint.Logger = FingerprintLog fingerprint.VerboseLogs = *flagVerboseLogs @@ -124,20 +144,25 @@ func Run() { // fingerprint package initFingerprint() + // tls cert watcher + cw := initCertWatcher() + + // signal cancels context + ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + // main TLS server server := defaultProxyServer( + ctx, defaultReverseProxyHTTPHandler( parseForwardURL(), GetHeaderInjectors(), ), - &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, - MinVersion: tls.VersionTLS12, - MaxVersion: tls.VersionTLS13, - Certificates: []tls.Certificate{parseTLSCerts()}, - }, + defaultTLSConfig(cw), ) + // start cert watcher + go cw.Start(ctx) + // metrics server PrometheusLog.Printf("server listening on %s", *flagMetricsListenAddr) go http.ListenAndServe( diff --git a/flags.go b/flags.go index 9d39cb0..00907d3 100644 --- a/flags.go +++ b/flags.go @@ -1,7 +1,6 @@ package fingerproxy import ( - "crypto/tls" "flag" "fmt" "net/url" @@ -141,14 +140,6 @@ func parseForwardURL() *url.URL { return forwardURL } -func parseTLSCerts() tls.Certificate { - tlsCert, err := tls.LoadX509KeyPair(*flagCertFilename, *flagKeyFilename) - if err != nil { - DefaultLog.Fatalf(`invalid cert filename "%s" or certkey filename "%s": %s`, *flagCertFilename, *flagKeyFilename, err) - } - return tlsCert -} - func parseDurationMetricBuckets() []float64 { bucketStrings := strings.Split(*flagDurationMetricBuckets, ",") buckets := []float64{} diff --git a/go.mod b/go.mod index b132721..a3c39f2 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/pprof v0.0.0-20231212022811-ec68065c825e // indirect github.com/klauspost/compress v1.17.4 // indirect diff --git a/go.sum b/go.sum index 20add7a..5921505 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dreadl0ck/tlsx v1.0.1-google-gopacket h1:/P3y+CGRiCQbW0nZU2jWkEwKfXLkpEgHNhbbqlnrTTM= github.com/dreadl0ck/tlsx v1.0.1-google-gopacket/go.mod h1:amAb73WEEgPHWniMfwro6UpN6St3e5ypgq2tXM89IOo= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= diff --git a/pkg/certwatcher/certwatcher.go b/pkg/certwatcher/certwatcher.go new file mode 100644 index 0000000..35f9b9e --- /dev/null +++ b/pkg/certwatcher/certwatcher.go @@ -0,0 +1,179 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package certwatcher + +import ( + "context" + "crypto/tls" + "log" + "sync" + + "github.com/fsnotify/fsnotify" +) + +var ( + VerboseLogs bool + Logger *log.Logger +) + +func logf(format string, args ...any) { + if Logger != nil { + Logger.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +func vlogf(format string, args ...any) { + if VerboseLogs { + logf(format, args...) + } +} + +// CertWatcher watches certificate and key files for changes. When either file +// changes, it reads and parses both and calls an optional callback with the new +// certificate. +type CertWatcher struct { + sync.RWMutex + + currentCert *tls.Certificate + watcher *fsnotify.Watcher + + certPath string + keyPath string +} + +// New returns a new CertWatcher watching the given certificate and key. +func New(certPath, keyPath string) (*CertWatcher, error) { + var err error + + cw := &CertWatcher{ + certPath: certPath, + keyPath: keyPath, + } + + // Initial read of certificate and key. + if err := cw.ReadCertificate(); err != nil { + return nil, err + } + + cw.watcher, err = fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + return cw, nil +} + +// GetCertificate fetches the currently loaded certificate, which may be nil. +func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + cw.RLock() + defer cw.RUnlock() + return cw.currentCert, nil +} + +// Start starts the watch on the certificate and key files. +func (cw *CertWatcher) Start(ctx context.Context) error { + files := []string{cw.certPath, cw.keyPath} + + for _, f := range files { + if err := cw.watcher.Add(f); err != nil { + logf("error watching file: %s", err) + return err + } + } + + go cw.Watch() + + // Block until the context is done. + <-ctx.Done() + + return cw.watcher.Close() +} + +// Watch reads events from the watcher's channel and reacts to changes. +func (cw *CertWatcher) Watch() { + for { + select { + case event, ok := <-cw.watcher.Events: + // Channel is closed. + if !ok { + return + } + + cw.handleEvent(event) + + case err, ok := <-cw.watcher.Errors: + // Channel is closed. + if !ok { + return + } + + logf("certificate watch error: %s", err) + } + } +} + +// ReadCertificate reads the certificate and key files from disk, parses them, +// and updates the current certificate on the watcher. If a callback is set, it +// is invoked with the new certificate. +func (cw *CertWatcher) ReadCertificate() error { + cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath) + if err != nil { + return err + } + + cw.Lock() + cw.currentCert = &cert + cw.Unlock() + + vlogf("updated current TLS certificate") + + return nil +} + +func (cw *CertWatcher) handleEvent(event fsnotify.Event) { + // Only care about events which may modify the contents of the file. + if !(isWrite(event) || isRemove(event) || isCreate(event)) { + return + } + + vlogf("certificate event: %s", event) + + // If the file was removed, re-add the watch. + if isRemove(event) { + if err := cw.watcher.Add(event.Name); err != nil { + logf("error re-watching file: %s", err) + } + } + + if err := cw.ReadCertificate(); err != nil { + logf("error re-reading certificate: %s", err) + } +} + +func isWrite(event fsnotify.Event) bool { + return event.Op&fsnotify.Write == fsnotify.Write +} + +func isCreate(event fsnotify.Event) bool { + return event.Op&fsnotify.Create == fsnotify.Create +} + +func isRemove(event fsnotify.Event) bool { + return event.Op&fsnotify.Remove == fsnotify.Remove +}