From a8b6193b339071c0b09aadeafba27013a1dae2dc Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 30 Dec 2023 09:26:56 -0600 Subject: [PATCH 1/6] Migrate the InitServer routine to utilize the bitmask Avoid any concept of a "primary" or "current" server -- all should be configured by `InitServer`. --- cmd/cache.go | 2 +- cmd/director.go | 2 +- cmd/origin.go | 2 +- cmd/origin_reset_password_test.go | 2 +- cmd/registry.go | 2 +- config/config.go | 38 +++++++++++++++++++--------- config/config_test.go | 13 +++++++--- registry/client_commands_test.go | 2 +- server_ui/register_namespace_test.go | 2 +- web_ui/ui_test.go | 12 ++++----- xrootd/authorization_test.go | 6 ++--- xrootd/origin_test.go | 2 +- 12 files changed, 52 insertions(+), 33 deletions(-) diff --git a/cmd/cache.go b/cmd/cache.go index 976ccf280..3eae3eee2 100644 --- a/cmd/cache.go +++ b/cmd/cache.go @@ -43,7 +43,7 @@ var ( ) func initCache() error { - err := config.InitServer([]config.ServerType{config.CacheType}, config.CacheType) + err := config.InitServer(config.CacheType) cobra.CheckErr(err) metrics.SetComponentHealthStatus(metrics.OriginCache_XRootD, metrics.StatusCritical, "xrootd has not been started") metrics.SetComponentHealthStatus(metrics.OriginCache_CMSD, metrics.StatusCritical, "cmsd has not been started") diff --git a/cmd/director.go b/cmd/director.go index 3ec7420f4..d8d1c052b 100644 --- a/cmd/director.go +++ b/cmd/director.go @@ -71,7 +71,7 @@ func getDirectorEndpoint() (string, error) { } func initDirector() error { - err := config.InitServer([]config.ServerType{config.DirectorType}, config.DirectorType) + err := config.InitServer(config.DirectorType) cobra.CheckErr(err) return err diff --git a/cmd/origin.go b/cmd/origin.go index 75bd1458d..331ba94fc 100644 --- a/cmd/origin.go +++ b/cmd/origin.go @@ -99,7 +99,7 @@ func configOrigin( /*cmd*/ *cobra.Command /*args*/, []string) { } func initOrigin() error { - err := config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err := config.InitServer(config.OriginType) cobra.CheckErr(err) metrics.SetComponentHealthStatus(metrics.OriginCache_XRootD, metrics.StatusCritical, "xrootd has not been started") metrics.SetComponentHealthStatus(metrics.OriginCache_CMSD, metrics.StatusCritical, "cmsd has not been started") diff --git a/cmd/origin_reset_password_test.go b/cmd/origin_reset_password_test.go index bcde6a3e2..ccecb16f8 100644 --- a/cmd/origin_reset_password_test.go +++ b/cmd/origin_reset_password_test.go @@ -36,7 +36,7 @@ func TestResetPassword(t *testing.T) { dirName := t.TempDir() viper.Reset() viper.Set("ConfigDir", dirName) - err := config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err := config.InitServer(config.OriginType) require.NoError(t, err) rootCmd.SetArgs([]string{"origin", "web-ui", "reset-password", "--stdin"}) diff --git a/cmd/registry.go b/cmd/registry.go index 18e3e0d18..6a30a31f3 100644 --- a/cmd/registry.go +++ b/cmd/registry.go @@ -56,7 +56,7 @@ var ( ) func initRegistry() error { - err := config.InitServer([]config.ServerType{config.RegistryType}, config.RegistryType) + err := config.InitServer(config.RegistryType) cobra.CheckErr(err) return err diff --git a/config/config.go b/config/config.go index a5ab56de1..81aeaf3cc 100644 --- a/config/config.go +++ b/config/config.go @@ -123,23 +123,34 @@ var ( ) // Set sets a list of newServers to ServerType instance -func (sType *ServerType) Set(newServers []ServerType) { +func (sType *ServerType) SetList(newServers []ServerType) { for _, server := range newServers { *sType |= server } } +// Enable a single server type in the bitmask +func (sType *ServerType) Set(server ServerType) ServerType { + *sType |= server + return *sType +} + // IsEnabled checks if a testServer is in the ServerType instance func (sType ServerType) IsEnabled(testServer ServerType) bool { return sType&testServer == testServer } +// Clear all values in a server type +func (sType *ServerType) Clear() { + *sType = ServerType(0) +} + // setEnabledServer sets the global variable config.EnabledServers to include newServers. // Since this function should only be called in config package, we mark it "private" to avoid // reset value in other pacakge // // This will only be called once in a single process -func setEnabledServer(newServers []ServerType) { +func setEnabledServer(newServers ServerType) { setServerOnce.Do(func() { // For each process, we only want to set enabled servers once enabledServers.Set(newServers) @@ -376,7 +387,7 @@ func parseServerIssuerURL(sType ServerType) error { return errors.New("If Server.IssuerHostname is configured, you must provide a valid port") } - if sType == OriginType { + if sType.IsEnabled(OriginType) { // If Origin.Mode is set to anything that isn't "posix" or "", assume we're running a plugin and // that the origin's issuer URL actually uses the same port as OriginUI instead of XRootD. This is // because under that condition, keys are being served by the Pelican process instead of by XRootD @@ -505,20 +516,23 @@ func initConfigDir() error { return nil } -// Initialize Pelican server instance. Pass a list of "enabledServers" if you want to enable multiple servers, -// and pass your "current" server to instantiate through "currentServer" so that the functions -// knows which server it's being evoked for -func InitServer(enabledServers []ServerType, currentServer ServerType) error { +// Initialize Pelican server instance. Pass a list of `enabledServices` if you want to enable multiple services. +// Note not all configurations are supported: currently, if you enable both cache and origin then an error +// is thrown +func InitServer(enabledServices ServerType) error { if err := initConfigDir(); err != nil { return errors.Wrap(err, "Failed to initialize the server configuration") } + if enabledServices.IsEnabled(OriginType) && enabledServices.IsEnabled(CacheType) { + return errors.New("A cache and origin cannot both be enabled in the same instance") + } - setEnabledServer(enabledServers) + setEnabledServer(enabledServices) xrootdPrefix := "" - if currentServer == OriginType { + if enabledServices.IsEnabled(OriginType) { xrootdPrefix = "origin" - } else if currentServer == CacheType { + } else if enabledServices.IsEnabled(CacheType) { xrootdPrefix = "cache" } configDir := viper.GetString("ConfigDir") @@ -590,7 +604,7 @@ func InitServer(enabledServers []ServerType, currentServer ServerType) error { // they have overridden the defaults. hostname = viper.GetString("Server.Hostname") - if currentServer == CacheType { + if enabledServices.IsEnabled(CacheType) { viper.Set("Xrootd.Port", param.Cache_Port.GetInt()) } xrootdPort := param.Xrootd_Port.GetInt() @@ -650,7 +664,7 @@ func InitServer(enabledServers []ServerType, currentServer ServerType) error { // Set up the server's issuer URL so we can access that data wherever we need to find keys and whatnot // This populates Server.IssuerUrl, and can be safely fetched using server_utils.GetServerIssuerURL() - err = parseServerIssuerURL(currentServer) + err = parseServerIssuerURL(enabledServices) if err != nil { return err } diff --git a/config/config_test.go b/config/config_test.go index 2d07660c9..4753c06ff 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -202,25 +202,30 @@ func TestEnabledServers(t *testing.T) { for _, server := range allServerTypes { enabledServers = 0 // We didn't call setEnabledServer as it will only set once per process - enabledServers.Set([]ServerType{server}) + enabledServers.SetList([]ServerType{server}) assert.True(t, IsServerEnabled(server)) } }) t.Run("enable-multiple-servers", func(t *testing.T) { enabledServers = 0 - enabledServers.Set([]ServerType{OriginType, CacheType}) + enabledServers.SetList([]ServerType{OriginType, CacheType}) assert.True(t, IsServerEnabled(OriginType)) assert.True(t, IsServerEnabled(CacheType)) }) t.Run("setEnabledServer-only-set-once", func(t *testing.T) { enabledServers = 0 - setEnabledServer([]ServerType{OriginType, CacheType}) + sType := OriginType + sType.Set(CacheType) + setEnabledServer(sType) assert.True(t, IsServerEnabled(OriginType)) assert.True(t, IsServerEnabled(CacheType)) - setEnabledServer([]ServerType{DirectorType, RegistryType}) + sType.Clear() + sType.Set(DirectorType) + sType.Set(RegistryType) + setEnabledServer(sType) assert.True(t, IsServerEnabled(OriginType)) assert.True(t, IsServerEnabled(CacheType)) assert.False(t, IsServerEnabled(DirectorType)) diff --git a/registry/client_commands_test.go b/registry/client_commands_test.go index 785ac3b40..f87662752 100644 --- a/registry/client_commands_test.go +++ b/registry/client_commands_test.go @@ -38,7 +38,7 @@ func registryMockup(t *testing.T, testName string) *httptest.Server { ikey := filepath.Join(issuerTempDir, "issuer.jwk") viper.Set("IssuerKey", ikey) viper.Set("Registry.DbLocation", filepath.Join(issuerTempDir, "test.sql")) - err := config.InitServer([]config.ServerType{config.RegistryType}, config.RegistryType) + err := config.InitServer(config.RegistryType) require.NoError(t, err) err = InitializeDB() diff --git a/server_ui/register_namespace_test.go b/server_ui/register_namespace_test.go index ecd1f84cb..ba23f1ae6 100644 --- a/server_ui/register_namespace_test.go +++ b/server_ui/register_namespace_test.go @@ -55,7 +55,7 @@ func TestRegistration(t *testing.T) { config.InitConfig() viper.Set("Registry.DbLocation", filepath.Join(tempConfigDir, "test.sql")) - err := config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err := config.InitServer(config.OriginType) require.NoError(t, err) err = registry.InitializeDB() diff --git a/web_ui/ui_test.go b/web_ui/ui_test.go index 38f542f47..80d92e315 100644 --- a/web_ui/ui_test.go +++ b/web_ui/ui_test.go @@ -71,7 +71,7 @@ func TestMain(m *testing.M) { // Ensure we load up the default configs. config.InitConfig() - if err := config.InitServer([]config.ServerType{config.OriginType}, config.OriginType); err != nil { + if err := config.InitServer(config.OriginType); err != nil { fmt.Println("Failed to configure the test module") os.Exit(1) } @@ -104,7 +104,7 @@ func TestWaitUntilLogin(t *testing.T) { viper.Reset() viper.Set("ConfigDir", dirName) config.InitConfig() - err := config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err := config.InitServer(config.OriginType) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -151,7 +151,7 @@ func TestCodeBasedLogin(t *testing.T) { viper.Reset() viper.Set("ConfigDir", dirName) config.InitConfig() - err := config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err := config.InitServer(config.OriginType) require.NoError(t, err) err = config.GeneratePrivateKey(param.IssuerKey.GetString(), elliptic.P256()) require.NoError(t, err) @@ -203,7 +203,7 @@ func TestPasswordResetAPI(t *testing.T) { viper.Reset() viper.Set("ConfigDir", dirName) viper.Set("Server.UIPasswordFile", tempPasswdFile.Name()) - err := config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err := config.InitServer(config.OriginType) require.NoError(t, err) err = config.GeneratePrivateKey(param.IssuerKey.GetString(), elliptic.P256()) require.NoError(t, err) @@ -306,7 +306,7 @@ func TestPasswordBasedLoginAPI(t *testing.T) { viper.Reset() config.InitConfig() viper.Set("Server.UIPasswordFile", tempPasswdFile.Name()) - err := config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err := config.InitServer(config.OriginType) require.NoError(t, err) ///////////////////////////SETUP/////////////////////////////////// @@ -420,7 +420,7 @@ func TestWhoamiAPI(t *testing.T) { config.InitConfig() viper.Set("ConfigDir", dirName) viper.Set("Server.UIPasswordFile", tempPasswdFile.Name()) - err := config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err := config.InitServer(config.OriginType) require.NoError(t, err) err = config.GeneratePrivateKey(param.IssuerKey.GetString(), elliptic.P256()) require.NoError(t, err) diff --git a/xrootd/authorization_test.go b/xrootd/authorization_test.go index f19349ad6..43f44608d 100644 --- a/xrootd/authorization_test.go +++ b/xrootd/authorization_test.go @@ -131,7 +131,7 @@ func TestGenerateConfig(t *testing.T) { assert.Equal(t, issuer.Name, "") viper.Set("Origin.SelfTest", true) - err = config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err = config.InitServer(config.OriginType) require.NoError(t, err) issuer, err = GenerateMonitoringIssuer() require.NoError(t, err) @@ -145,7 +145,7 @@ func TestGenerateConfig(t *testing.T) { viper.Set("Origin.SelfTest", false) viper.Set("Origin.ScitokensDefaultUser", "user1") viper.Set("Origin.ScitokensMapSubject", true) - err = config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err = config.InitServer(config.OriginType) require.NoError(t, err) issuer, err = GenerateOriginIssuer([]string{"/foo/bar/baz", "/another/exported/path"}) require.NoError(t, err) @@ -237,7 +237,7 @@ func TestWriteOriginScitokensConfig(t *testing.T) { viper.Set("Xrootd.RunLocation", dirname) viper.Set("Xrootd.Port", 8443) viper.Set("Server.Hostname", "origin.example.com") - err := config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err := config.InitServer(config.OriginType) require.Nil(t, err) scitokensCfg := param.Xrootd_ScitokensConfig.GetString() diff --git a/xrootd/origin_test.go b/xrootd/origin_test.go index ac1791b9b..a313899f1 100644 --- a/xrootd/origin_test.go +++ b/xrootd/origin_test.go @@ -70,7 +70,7 @@ func originMockup(t *testing.T) context.CancelFunc { // Increase the log level; otherwise, its difficult to debug failures viper.Set("Logging.Level", "Debug") config.InitConfig() - err = config.InitServer([]config.ServerType{config.OriginType}, config.OriginType) + err = config.InitServer(config.OriginType) require.NoError(t, err) err = config.GeneratePrivateKey(param.Server_TLSKey.GetString(), elliptic.P256()) From fd73edfb00d0a270f22eaee236ceea9136bf68f4 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 30 Dec 2023 10:42:59 -0600 Subject: [PATCH 2/6] Add a periodic maintenance task for XRootD, copying over TLS certs This will periodically - or, when filesystem changes are noticed -- copy over the TLS certificates for use by XRootD. Note this combines both the cert chain and the key into a single file; otherwise, there is no way to update both atomically. --- cmd/cache_serve.go | 3 + cmd/origin_serve.go | 3 + go.mod | 2 +- xrootd/resources/xrootd-cache.cfg | 2 +- xrootd/resources/xrootd-origin.cfg | 2 +- xrootd/xrootd_config.go | 156 +++++++++++++++++++++++++++++ xrootd/xrootd_config_test.go | 94 +++++++++++++++++ 7 files changed, 259 insertions(+), 3 deletions(-) diff --git a/cmd/cache_serve.go b/cmd/cache_serve.go index fc3aa1a67..5338ac8a0 100644 --- a/cmd/cache_serve.go +++ b/cmd/cache_serve.go @@ -25,6 +25,7 @@ import ( "encoding/json" "net/url" "sync" + "time" "github.com/pelicanplatform/pelican/cache_ui" "github.com/pelicanplatform/pelican/config" @@ -133,6 +134,8 @@ func serveCache( /*cmd*/ *cobra.Command /*args*/, []string) error { return err } + xrootd.LaunchXrootdMaintenance(shutdownCtx, 2*time.Minute) + log.Info("Launching cache") launchers, err := xrootd.ConfigureLaunchers(false, configPath, false) if err != nil { diff --git a/cmd/origin_serve.go b/cmd/origin_serve.go index b60d13978..5c5d8800b 100644 --- a/cmd/origin_serve.go +++ b/cmd/origin_serve.go @@ -24,6 +24,7 @@ import ( "context" _ "embed" "sync" + "time" "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/daemon" @@ -116,6 +117,8 @@ func serveOrigin( /*cmd*/ *cobra.Command /*args*/, []string) error { go origin_ui.PeriodicSelfTest() } + xrootd.LaunchXrootdMaintenance(shutdownCtx, 2*time.Minute) + privileged := param.Origin_Multiuser.GetBool() launchers, err := xrootd.ConfigureLaunchers(privileged, configPath, param.Origin_EnableCmsd.GetBool()) if err != nil { diff --git a/go.mod b/go.mod index d1e849096..446561d83 100644 --- a/go.mod +++ b/go.mod @@ -73,7 +73,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/edsrzf/mmap-go v1.1.0 // indirect github.com/felixge/httpsnoop v1.0.3 // indirect - github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/fsnotify/fsnotify v1.6.0 github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sse v0.1.0 // indirect diff --git a/xrootd/resources/xrootd-cache.cfg b/xrootd/resources/xrootd-cache.cfg index c9e1b7592..9917bb6f0 100644 --- a/xrootd/resources/xrootd-cache.cfg +++ b/xrootd/resources/xrootd-cache.cfg @@ -20,7 +20,7 @@ if exec xrootd fi ofs.osslib libXrdPss.so pss.cachelib libXrdPfc.so -xrd.tls {{.Server.TLSCertificate}} {{.Server.TLSKey}} +xrd.tls {{.Xrootd.RunLocation}}/copied-tls-creds.crt {{.Xrootd.RunLocation}}/copied-tls-creds.crt {{if .Server.TLSCACertificateDirectory}} xrd.tlsca certdir {{.Server.TLSCACertificateDirectory}} {{else}} diff --git a/xrootd/resources/xrootd-origin.cfg b/xrootd/resources/xrootd-origin.cfg index e7d257f38..4d4487f9a 100644 --- a/xrootd/resources/xrootd-origin.cfg +++ b/xrootd/resources/xrootd-origin.cfg @@ -22,7 +22,7 @@ if exec xrootd xrd.port {{.Xrootd.Port}} xrd.protocol http:{{.Xrootd.Port}} libXrdHttp.so fi -xrd.tls {{.Server.TLSCertificate}} {{.Server.TLSKey}} +xrd.tls {{.Xrootd.RunLocation}}/copied-tls-creds.crt {{.Xrootd.RunLocation}}/copied-tls-creds.crt {{if .Server.TLSCACertificateDirectory}} xrd.tlsca certdir {{.Server.TLSCACertificateDirectory}} {{else}} diff --git a/xrootd/xrootd_config.go b/xrootd/xrootd_config.go index 99fc10451..101a372da 100644 --- a/xrootd/xrootd_config.go +++ b/xrootd/xrootd_config.go @@ -1,21 +1,45 @@ +/*************************************************************** + * + * Copyright (C) 2023, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + package xrootd import ( "bytes" "context" "crypto/rand" + "crypto/tls" _ "embed" "encoding/base64" + builtin_errors "errors" "fmt" + "io" + "io/fs" "net/url" "os" "path" "path/filepath" + "reflect" "strings" "sync" "text/template" "time" + "github.com/fsnotify/fsnotify" "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/director" "github.com/pelicanplatform/pelican/metrics" @@ -35,6 +59,8 @@ var ( xrootdCacheCfg string //go:embed resources/robots.txt robotsTxt string + + errBadKeyPair error = errors.New("Bad X509 keypair") ) type ( @@ -313,6 +339,10 @@ func CheckXrootdEnv(server server_utils.XRootDServer) error { } } + if err = CopyXrootdCertificates(); err != nil { + return err + } + if server.GetServerType().IsEnabled(config.OriginType) { exportPath, err = CheckOriginXrootdEnv(exportPath, uid, gid, groupname) } else { @@ -371,6 +401,132 @@ func CheckXrootdEnv(server server_utils.XRootDServer) error { return nil } +// Copies the server certificate/key files into the XRootD runtime +// directory. Combines the two files into a single one so the new +// certificate shows up atomically from XRootD's perspective. +// Adjusts the ownership and mode to match that expected +// by the XRootD framework. +func CopyXrootdCertificates() error { + user, err := config.GetDaemonUserInfo() + if err != nil { + return errors.Wrap(err, "Unable to copy certificates to xrootd runtime directory; failed xrootd user lookup") + } + + certFile := param.Server_TLSCertificate.GetString() + certKey := param.Server_TLSKey.GetString() + if _, err = tls.LoadX509KeyPair(certFile, certKey); err != nil { + return builtin_errors.Join(err, errBadKeyPair) + } + + destination := filepath.Join(param.Xrootd_RunLocation.GetString(), "copied-tls-creds.crt") + tmpName := destination + ".tmp" + destFile, err := os.OpenFile(tmpName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fs.FileMode(0400)) + if err != nil { + return errors.Wrap(err, "Failure when opening temporary certificate key pair file for xrootd") + } + defer destFile.Close() + + if err = os.Chown(tmpName, user.Uid, user.Gid); err != nil { + return errors.Wrap(err, "Failure when chown'ing certificate key pair file for xrootd") + } + + srcFile, err := os.Open(param.Server_TLSCertificate.GetString()) + if err != nil { + return errors.Wrap(err, "Failure when opening source certificate for xrootd") + } + defer srcFile.Close() + + if _, err = io.Copy(destFile, srcFile); err != nil { + return errors.Wrapf(err, "Failure when copying source certificate for xrootd") + } + + if _, err = destFile.Write([]byte{'\n', '\n'}); err != nil { + return errors.Wrap(err, "Failure when writing into copied key pair for xrootd") + } + + srcKeyFile, err := os.Open(param.Server_TLSKey.GetString()) + if err != nil { + return errors.Wrap(err, "Failure when opening source key for xrootd") + } + defer srcKeyFile.Close() + + if _, err = io.Copy(destFile, srcKeyFile); err != nil { + return errors.Wrapf(err, "Failure when copying source key for xrootd") + } + + if err = os.Rename(tmpName, destination); err != nil { + return errors.Wrapf(err, "Failure when moving key pair for xrootd") + } + + return nil +} + +// Launch a separate goroutine that performs the XRootD maintenance tasks. +// For maintenance that is periodic, `sleepTime` is the maintenance period. +func LaunchXrootdMaintenance(ctx context.Context, sleepTime time.Duration) { + select_count := 4 + watcher, err := fsnotify.NewWatcher() + if err != nil { + select_count -= 2 + } else if err = watcher.Add(filepath.Dir(param.Server_TLSCertificate.GetString())); err != nil { + select_count -= 2 + } + cases := make([]reflect.SelectCase, select_count) + ticker := time.NewTicker(sleepTime) + cases[0].Dir = reflect.SelectRecv + cases[0].Chan = reflect.ValueOf(ticker.C) + cases[1].Dir = reflect.SelectRecv + cases[1].Chan = reflect.ValueOf(ctx.Done()) + if err == nil { + cases[2].Dir = reflect.SelectRecv + cases[2].Chan = reflect.ValueOf(watcher.Events) + cases[3].Dir = reflect.SelectRecv + cases[3].Chan = reflect.ValueOf(watcher.Errors) + } + go func() { + defer watcher.Close() + for { + chosen, recv, ok := reflect.Select(cases) + if chosen == 0 { + if !ok { + log.Panicln("Ticker failed in the xrootd maintenance routine; exiting") + } + err := CopyXrootdCertificates() + if err != nil { + log.Warningln("Failed to update xrootd certificates during maintenance:", err) + } + } else if chosen == 1 { + log.Infoln("XRootD maintenance thread has been cancelled. Shutting down") + return + } else if chosen == 2 { // watcher.Events + if !ok { + log.Panicln("Watcher events failed in xrootd maintenance routine; exiting") + } + if event, ok := recv.Interface().(fsnotify.Event); ok { + log.Debugf("Got filesystem event (%v); will update the xrootd certificates", event) + if err = CopyXrootdCertificates(); errors.Is(err, errBadKeyPair) { + log.Debugln("Bad keypair encountered when doing xrootd certificate maintenance:", err) + } else if err != nil { + log.Warningf("Failed to update xrootd certificates based on file event %v: %v", event, err) + } + } else { + log.Panicln("Watcher returned an unknown event") + } + } else if chosen == 3 { // watcher.Errors + if !ok { + log.Panicln("Watcher error channel closed in xrootd maintenance routine; exiting") + } + if err, ok := recv.Interface().(error); ok { + log.Errorf("Watcher failure in the xrootd maintenance routine: %v", err) + } else { + log.Panicln("Watcher error channel has internal error; exiting") + } + time.Sleep(time.Second) + } + } + }() +} + func ConfigXrootd(origin bool) (string, error) { gid, err := config.GetDaemonGID() if err != nil { diff --git a/xrootd/xrootd_config_test.go b/xrootd/xrootd_config_test.go index e2fcdacd3..2bc6c6436 100644 --- a/xrootd/xrootd_config_test.go +++ b/xrootd/xrootd_config_test.go @@ -21,8 +21,16 @@ package xrootd import ( + "bytes" + "context" + "os" + "path/filepath" "testing" + "time" + "github.com/pelicanplatform/pelican/config" + "github.com/pelicanplatform/pelican/param" + "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -45,3 +53,89 @@ func TestXrootDCacheConfig(t *testing.T) { require.NoError(t, err) assert.NotNil(t, configPath) } + +func TestCopyCertificates(t *testing.T) { + runDirname := t.TempDir() + configDirname := t.TempDir() + viper.Reset() + viper.Set("Logging.Level", "Debug") + viper.Set("Xrootd.RunLocation", runDirname) + viper.Set("ConfigDir", configDirname) + config.InitConfig() + + // First, invoke CopyXrootdCertificates directly, ensure it works. + err := CopyXrootdCertificates() + assert.ErrorIs(t, err, errBadKeyPair) + + err = config.InitServer(config.OriginType) + require.NoError(t, err) + err = CopyXrootdCertificates() + require.NoError(t, err) + destKeyPairName := filepath.Join(param.Xrootd_RunLocation.GetString(), "copied-tls-creds.crt") + assert.FileExists(t, destKeyPairName) + + keyPairContents, err := os.ReadFile(destKeyPairName) + require.NoError(t, err) + certName := param.Server_TLSCertificate.GetString() + firstCertContents, err := os.ReadFile(certName) + require.NoError(t, err) + keyName := param.Server_TLSKey.GetString() + firstKeyContents, err := os.ReadFile(keyName) + require.NoError(t, err) + firstKeyPairContents := append(firstCertContents, '\n', '\n') + firstKeyPairContents = append(firstKeyPairContents, firstKeyContents...) + assert.True(t, bytes.Equal(firstKeyPairContents, keyPairContents)) + + err = os.Rename(certName, certName+".orig") + require.NoError(t, err) + + err = CopyXrootdCertificates() + assert.ErrorIs(t, err, errBadKeyPair) + + err = os.Rename(keyName, keyName+".orig") + require.NoError(t, err) + + err = config.InitServer(config.OriginType) + require.NoError(t, err) + + err = CopyXrootdCertificates() + require.NoError(t, err) + + secondKeyPairContents, err := os.ReadFile(destKeyPairName) + require.NoError(t, err) + assert.False(t, bytes.Equal(firstKeyPairContents, secondKeyPairContents)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + LaunchXrootdMaintenance(ctx, 2*time.Hour) + + // Helper function to wait for a copy of the first cert to show up + // in the destination + waitForCopy := func() bool { + for idx := 0; idx < 10; idx++ { + time.Sleep(50 * time.Millisecond) + logrus.Debug("Re-reading destination cert") + destContents, err := os.ReadFile(destKeyPairName) + require.NoError(t, err) + if bytes.Equal(destContents, firstKeyPairContents) { + return true + } + } + return false + } + + // The maintenance thread should only copy if there's a valid keypair + // Thus, if we only copy one, we shouldn't see any changes + err = os.Rename(certName+".orig", certName) + require.NoError(t, err) + logrus.Debug("Will wait to see if the new certs are not copied") + assert.False(t, waitForCopy()) + + // Now, if we overwrite the key, the maintenance thread should notice + // and overwrite the destination + err = os.Rename(keyName+".orig", keyName) + require.NoError(t, err) + logrus.Debug("Will wait to see if the new certs are copied") + assert.True(t, waitForCopy()) + +} From fbbd6269b8b6ae8439b032c424f2c28d477cc99c Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 30 Dec 2023 10:45:52 -0600 Subject: [PATCH 3/6] Add missing for-loop to the CA bundle update thread --- utils/ca_utils.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/utils/ca_utils.go b/utils/ca_utils.go index 58fd5aae2..d6921d6e7 100644 --- a/utils/ca_utils.go +++ b/utils/ca_utils.go @@ -89,10 +89,12 @@ func PeriodicWriteCABundle(filename string, sleepTime time.Duration) (count int, } go func() { - time.Sleep(sleepTime) - _, err := WriteCABundle(filename) - if err != nil { - log.Warningln("Failure during periodic CA bundle update:", err) + for { + time.Sleep(sleepTime) + _, err := WriteCABundle(filename) + if err != nil { + log.Warningln("Failure during periodic CA bundle update:", err) + } } }() From 293048c87cc1dec7106ad158336dc7f59aef0822 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 30 Dec 2023 11:33:14 -0600 Subject: [PATCH 4/6] Refactor the watcher routine to be separate from xrootd Will allow us to utilize the same approach to maintain the golang TLS certificate. --- server_utils/server_utils.go | 78 ++++++++++++++++++++++++++++++++++ xrootd/xrootd_config.go | 81 +++++++----------------------------- 2 files changed, 94 insertions(+), 65 deletions(-) diff --git a/server_utils/server_utils.go b/server_utils/server_utils.go index 2cc3ccfc8..2e2bf67b1 100644 --- a/server_utils/server_utils.go +++ b/server_utils/server_utils.go @@ -19,10 +19,15 @@ package server_utils import ( + "context" "net/url" + "reflect" + "time" + "github.com/fsnotify/fsnotify" "github.com/pelicanplatform/pelican/param" "github.com/pkg/errors" + log "github.com/sirupsen/logrus" ) // For calling from within the server. Returns the server's issuer URL/port @@ -38,3 +43,76 @@ func GetServerIssuerURL() (*url.URL, error) { return issuerUrl, nil } + +// Launch a maintenance goroutine. +// The maintenance routine will watch the directory `dirPath`, invoking `maintenanceFunc` whenever +// an event occurs in the directory. Note the behavior of directory watching differs across platforms; +// for example, an atomic rename might be one or two events for the destination file depending on Mac OS X or Linux. +// +// Even if the filesystem watcher fails, this will invoke `maintenanceFunc` every `sleepTime` duration. +// The maintenance function will be called with `true` if invoked due to a directory change, false otherwise +// When generating error messages, `description` will be used to describe the task. +func LaunchWatcherMaintenance(ctx context.Context, dirPath string, description string, sleepTime time.Duration, maintenanceFunc func(notifyEvent bool) error) { + select_count := 4 + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Warningf("%s routine failed to create new watcher", description) + select_count -= 2 + } else if err = watcher.Add(dirPath); err != nil { + log.Warningf("%s routine failed to add directory %s to watch: %v", description, dirPath, err) + select_count -= 2 + } + cases := make([]reflect.SelectCase, select_count) + ticker := time.NewTicker(sleepTime) + cases[0].Dir = reflect.SelectRecv + cases[0].Chan = reflect.ValueOf(ticker.C) + cases[1].Dir = reflect.SelectRecv + cases[1].Chan = reflect.ValueOf(ctx.Done()) + if err == nil { + cases[2].Dir = reflect.SelectRecv + cases[2].Chan = reflect.ValueOf(watcher.Events) + cases[3].Dir = reflect.SelectRecv + cases[3].Chan = reflect.ValueOf(watcher.Errors) + } + go func() { + defer watcher.Close() + for { + chosen, recv, ok := reflect.Select(cases) + if chosen == 0 { + if !ok { + log.Panicf("Ticker failed in the %s routine; exiting", description) + } + err := maintenanceFunc(false) + if err != nil { + log.Warningf("Failure during %s routine: %v", description, err) + } + } else if chosen == 1 { + log.Infof("%s routine has been cancelled. Shutting down", description) + return + } else if chosen == 2 { // watcher.Events + if !ok { + log.Panicf("Watcher events failed in %s routine; exiting", description) + } + if event, ok := recv.Interface().(fsnotify.Event); ok { + log.Debugf("Got filesystem event (%v); will run %s", event, description) + err := maintenanceFunc(true) + if err != nil { + log.Warningf("Failure during %s routine: %v", description, err) + } + } else { + log.Panicln("Watcher returned an unknown event") + } + } else if chosen == 3 { // watcher.Errors + if !ok { + log.Panicf("Watcher error channel closed in %s routine; exiting", description) + } + if err, ok := recv.Interface().(error); ok { + log.Errorf("Watcher failure in the %s routine: %v", description, err) + } else { + log.Panicln("Watcher error channel has internal error; exiting") + } + time.Sleep(time.Second) + } + } + }() +} diff --git a/xrootd/xrootd_config.go b/xrootd/xrootd_config.go index 101a372da..9ceedf763 100644 --- a/xrootd/xrootd_config.go +++ b/xrootd/xrootd_config.go @@ -33,13 +33,11 @@ import ( "os" "path" "path/filepath" - "reflect" "strings" "sync" "text/template" "time" - "github.com/fsnotify/fsnotify" "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/director" "github.com/pelicanplatform/pelican/metrics" @@ -141,10 +139,10 @@ func CheckOriginXrootdEnv(exportPath string, uid int, gid int, groupname string) } volumeMountDst = filepath.Clean(volumeMountDst) if volumeMountDst == "" { - return exportPath, fmt.Errorf("Export volume %v has empty destination path", volumeMount) + return exportPath, fmt.Errorf("export volume %v has empty destination path", volumeMount) } if volumeMountDst[0:1] != "/" { - return "", fmt.Errorf("Export volume %v has a relative destination path", + return "", fmt.Errorf("export volume %v has a relative destination path", volumeMountDst) } destPath := path.Clean(filepath.Join(exportPath, volumeMountDst[1:])) @@ -176,7 +174,7 @@ func CheckOriginXrootdEnv(exportPath string, uid int, gid int, groupname string) mountPath = filepath.Clean(mountPath) namespacePrefix = filepath.Clean(namespacePrefix) if namespacePrefix[0:1] != "/" { - return exportPath, fmt.Errorf("Namespace prefix %v must have an absolute path", + return exportPath, fmt.Errorf("namespace prefix %v must have an absolute path", namespacePrefix) } destPath := path.Clean(filepath.Join(exportPath, namespacePrefix[1:])) @@ -464,67 +462,20 @@ func CopyXrootdCertificates() error { // Launch a separate goroutine that performs the XRootD maintenance tasks. // For maintenance that is periodic, `sleepTime` is the maintenance period. func LaunchXrootdMaintenance(ctx context.Context, sleepTime time.Duration) { - select_count := 4 - watcher, err := fsnotify.NewWatcher() - if err != nil { - select_count -= 2 - } else if err = watcher.Add(filepath.Dir(param.Server_TLSCertificate.GetString())); err != nil { - select_count -= 2 - } - cases := make([]reflect.SelectCase, select_count) - ticker := time.NewTicker(sleepTime) - cases[0].Dir = reflect.SelectRecv - cases[0].Chan = reflect.ValueOf(ticker.C) - cases[1].Dir = reflect.SelectRecv - cases[1].Chan = reflect.ValueOf(ctx.Done()) - if err == nil { - cases[2].Dir = reflect.SelectRecv - cases[2].Chan = reflect.ValueOf(watcher.Events) - cases[3].Dir = reflect.SelectRecv - cases[3].Chan = reflect.ValueOf(watcher.Errors) - } - go func() { - defer watcher.Close() - for { - chosen, recv, ok := reflect.Select(cases) - if chosen == 0 { - if !ok { - log.Panicln("Ticker failed in the xrootd maintenance routine; exiting") - } - err := CopyXrootdCertificates() - if err != nil { - log.Warningln("Failed to update xrootd certificates during maintenance:", err) - } - } else if chosen == 1 { - log.Infoln("XRootD maintenance thread has been cancelled. Shutting down") - return - } else if chosen == 2 { // watcher.Events - if !ok { - log.Panicln("Watcher events failed in xrootd maintenance routine; exiting") - } - if event, ok := recv.Interface().(fsnotify.Event); ok { - log.Debugf("Got filesystem event (%v); will update the xrootd certificates", event) - if err = CopyXrootdCertificates(); errors.Is(err, errBadKeyPair) { - log.Debugln("Bad keypair encountered when doing xrootd certificate maintenance:", err) - } else if err != nil { - log.Warningf("Failed to update xrootd certificates based on file event %v: %v", event, err) - } - } else { - log.Panicln("Watcher returned an unknown event") - } - } else if chosen == 3 { // watcher.Errors - if !ok { - log.Panicln("Watcher error channel closed in xrootd maintenance routine; exiting") - } - if err, ok := recv.Interface().(error); ok { - log.Errorf("Watcher failure in the xrootd maintenance routine: %v", err) - } else { - log.Panicln("Watcher error channel has internal error; exiting") - } - time.Sleep(time.Second) + server_utils.LaunchWatcherMaintenance( + ctx, + filepath.Dir(param.Server_TLSCertificate.GetString()), + "xrootd maintenance", + sleepTime, + func(notifyEvent bool) error { + err := CopyXrootdCertificates() + if notifyEvent && errors.Is(err, errBadKeyPair) { + log.Debugln("Bad keypair encountered when doing xrootd certificate maintenance:", err) + return nil } - } - }() + return err + }, + ) } func ConfigXrootd(origin bool) (string, error) { From 8e5fa31dc0af7008226272a7890ab56cfe3fa1e3 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 30 Dec 2023 13:32:10 -0600 Subject: [PATCH 5/6] Implement reloads for certificate updates This adds the ability for the web interface to pick up TLS changes. There's a refactor of the `RunEngine` to allow shutdowns and arbitrary listeners, simplifying the unit test of the reload. --- cmd/cache_serve.go | 7 +++- cmd/director_serve.go | 9 ++++- cmd/origin_serve.go | 7 +++- cmd/registry_serve.go | 11 +++++- web_ui/ui.go | 86 +++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 111 insertions(+), 9 deletions(-) diff --git a/cmd/cache_serve.go b/cmd/cache_serve.go index 5338ac8a0..b2532a552 100644 --- a/cmd/cache_serve.go +++ b/cmd/cache_serve.go @@ -126,7 +126,12 @@ func serveCache( /*cmd*/ *cobra.Command /*args*/, []string) error { return err } - go web_ui.RunEngine(engine) + go func() { + if err := web_ui.RunEngine(shutdownCtx, engine); err != nil { + log.Panicln("Failure when running the web engine:", err) + } + shutdownCancel() + }() go web_ui.InitServerWebLogin() configPath, err := xrootd.ConfigXrootd(false) diff --git a/cmd/director_serve.go b/cmd/director_serve.go index ac65200ff..ed4cfce3a 100644 --- a/cmd/director_serve.go +++ b/cmd/director_serve.go @@ -82,7 +82,7 @@ func serveDirector( /*cmd*/ *cobra.Command /*args*/, []string) error { // or to an origin defaultResponse := param.Director_DefaultResponse.GetString() if !(defaultResponse == "cache" || defaultResponse == "origin") { - return fmt.Errorf("The director's default response must either be set to 'cache' or 'origin',"+ + return fmt.Errorf("the director's default response must either be set to 'cache' or 'origin',"+ " but you provided %q. Was there a typo?", defaultResponse) } log.Debugf("The director will redirect to %ss by default", defaultResponse) @@ -93,7 +93,12 @@ func serveDirector( /*cmd*/ *cobra.Command /*args*/, []string) error { director.RegisterDirector(rootGroup) log.Info("Starting web engine...") - go web_ui.RunEngine(engine) + go func() { + if err := web_ui.RunEngine(shutdownCtx, engine); err != nil { + log.Panicln("Failure when running the web engine:", err) + } + shutdownCancel() + }() go web_ui.InitServerWebLogin() diff --git a/cmd/origin_serve.go b/cmd/origin_serve.go index 5c5d8800b..201a84e5e 100644 --- a/cmd/origin_serve.go +++ b/cmd/origin_serve.go @@ -102,7 +102,12 @@ func serveOrigin( /*cmd*/ *cobra.Command /*args*/, []string) error { } } - go web_ui.RunEngine(engine) + go func() { + if err := web_ui.RunEngine(shutdownCtx, engine); err != nil { + log.Panicln("Failure when running the web engine:", err) + } + shutdownCancel() + }() if param.Origin_EnableUI.GetBool() { go web_ui.InitServerWebLogin() diff --git a/cmd/registry_serve.go b/cmd/registry_serve.go index 0e4eee5bf..2da0642d0 100644 --- a/cmd/registry_serve.go +++ b/cmd/registry_serve.go @@ -19,6 +19,7 @@ package main import ( + "context" "os" "os/signal" "syscall" @@ -35,6 +36,8 @@ import ( func serveRegistry( /*cmd*/ *cobra.Command /*args*/, []string) error { log.Info("Initializing the namespace registry's database...") + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + defer shutdownCancel() // Initialize the registry's sqlite database err := registry.InitializeDB() @@ -76,7 +79,12 @@ func serveRegistry( /*cmd*/ *cobra.Command /*args*/, []string) error { // more complicated routing scenarios where we can't just use // a wildcard. It removes duplicate / from the resource. //engine.RemoveExtraSlash = true - go web_ui.RunEngine(engine) + go func() { + if err := web_ui.RunEngine(shutdownCtx, engine); err != nil { + log.Panicln("Failure when running the web engine:", err) + } + shutdownCancel() + }() go web_ui.InitServerWebLogin() @@ -84,6 +92,7 @@ func serveRegistry( /*cmd*/ *cobra.Command /*args*/, []string) error { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) sig := <-sigs _ = sig + shutdownCancel() return nil } diff --git a/web_ui/ui.go b/web_ui/ui.go index db542383b..57f954440 100644 --- a/web_ui/ui.go +++ b/web_ui/ui.go @@ -20,13 +20,16 @@ package web_ui import ( "context" + "crypto/tls" "embed" "fmt" "math/rand" "mime" + "net" "net/http" "os" "os/signal" + "path/filepath" "strings" "syscall" "time" @@ -34,9 +37,11 @@ import ( "github.com/gin-gonic/gin" "github.com/pelicanplatform/pelican/metrics" "github.com/pelicanplatform/pelican/param" + "github.com/pelicanplatform/pelican/server_utils" "github.com/pkg/errors" log "github.com/sirupsen/logrus" ginprometheus "github.com/zsais/go-gin-prometheus" + "go.uber.org/atomic" "golang.org/x/term" ) @@ -279,15 +284,88 @@ func GetEngine() (*gin.Engine, error) { return engine, nil } -func RunEngine(engine *gin.Engine) { +// Run the gin engine. +// +// Will use a background golang routine to periodically reload the certificate +// utilized by the UI. +func RunEngine(ctx context.Context, engine *gin.Engine) error { + addr := fmt.Sprintf("%v:%v", param.Server_WebHost.GetString(), param.Server_WebPort.GetInt()) + + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + defer ln.Close() + + return runEngineWithListener(ctx, ln, engine) +} + +// Run the engine with a given listener. +// This was split out from RunEngine to allow unit tests to provide a Unix domain socket' +// as a listener. +func runEngineWithListener(ctx context.Context, ln net.Listener, engine *gin.Engine) error { certFile := param.Server_TLSCertificate.GetString() keyFile := param.Server_TLSKey.GetString() - addr := fmt.Sprintf("%v:%v", param.Server_WebHost.GetString(), param.Server_WebPort.GetInt()) + port := param.Server_WebPort.GetInt() + addr := fmt.Sprintf("%v:%v", param.Server_WebHost.GetString(), port) - log.Debugln("Starting web engine at address", addr) - err := engine.RunTLS(addr, certFile, keyFile) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { panic(err) } + + var certPtr atomic.Pointer[tls.Certificate] + certPtr.Store(&cert) + + server_utils.LaunchWatcherMaintenance( + ctx, + filepath.Dir(param.Server_TLSCertificate.GetString()), + "server TLS maintenance", + 2*time.Minute, + func(notifyEvent bool) error { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err == nil { + log.Debugln("Loaded new X509 key pair") + certPtr.Store(&cert) + } else if notifyEvent { + log.Debugln("Failed to load new X509 key pair after filesystem event (may succeed eventually):", err) + return nil + } + return err + }, + ) + + getCert := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + return certPtr.Load(), nil + } + + config := &tls.Config{ + GetCertificate: getCert, + } + server := &http.Server{ + Addr: addr, + Handler: engine.Handler(), + TLSConfig: config, + } + log.Debugln("Starting web engine at address", addr) + + // Once the context has been canceled, shutdown the HTTPS server. Give it + // 10 seconds to shutdown existing requests. + go func() { + <-ctx.Done() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + err = server.Shutdown(ctx) + if err != nil { + log.Panicln("Failed to shutdown server:", err) + } + }() + + if err := server.ServeTLS(ln, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + + return nil } From 49072a535f3ce05da5138226fdd3af728c09d7f8 Mon Sep 17 00:00:00 2001 From: Brian Bockelman Date: Sat, 30 Dec 2023 15:59:23 -0600 Subject: [PATCH 6/6] Add unit test for the web engine --- web_ui/engine_test.go | 181 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 web_ui/engine_test.go diff --git a/web_ui/engine_test.go b/web_ui/engine_test.go new file mode 100644 index 000000000..376f3f8ce --- /dev/null +++ b/web_ui/engine_test.go @@ -0,0 +1,181 @@ +//go:build !windows + +/*************************************************************** + * + * Copyright (C) 2023, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package web_ui + +import ( + "context" + "crypto/sha256" + "crypto/tls" + "io" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/pelicanplatform/pelican/config" + "github.com/pelicanplatform/pelican/param" + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Setup a gin engine that will serve up a /ping endpoint on a Unix domain socket. +func setupPingEngine(t *testing.T) (chan bool, context.CancelFunc, string) { + dirname := t.TempDir() + viper.Reset() + viper.Set("Logging.Level", "Debug") + viper.Set("ConfigDir", dirname) + viper.Set("Server.WebPort", 0) + config.InitConfig() + err := config.InitServer(config.OriginType) + require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + + engine, err := GetEngine() + require.NoError(t, err) + + engine.GET("/ping", func(ctx *gin.Context) { + ctx.Data(http.StatusOK, "text/plain; charset=utf-8", []byte("pong")) + }) + + // Setup a domain socket instead of listening on TCP + socketLocation := filepath.Join(dirname, "engine.sock") + ln, err := net.Listen("unix", socketLocation) + require.NoError(t, err) + + doneChan := make(chan bool) + go func() { + err = runEngineWithListener(ctx, ln, engine) + require.NoError(t, err) + doneChan <- true + }() + + transport := *config.GetTransport() + transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socketLocation) + } + httpc := http.Client{ + Transport: &transport, + } + + engineReady := false + for idx := 0; idx < 20; idx++ { + time.Sleep(10 * time.Millisecond) + log.Debug("Checking for engine ready") + + var resp *http.Response + resp, err = httpc.Get("https://" + param.Server_Hostname.GetString() + "/ping") + if err != nil { + continue + } + assert.Equal(t, "200 OK", resp.Status) + var body []byte + body, err = io.ReadAll(resp.Body) + assert.Equal(t, string(body), "pong") + } + if !engineReady { + require.NoError(t, err) + } + + return doneChan, cancel, socketLocation +} + +// Test the engine startup, serving a single request using +// TLS validation, then a clean shutdown. +func TestRunEngine(t *testing.T) { + doneChan, cancel, _ := setupPingEngine(t) + + // Shutdown the engine + cancel() + timeout := time.Tick(3 * time.Second) + select { + case ok := <-doneChan: + require.True(t, ok) + case <-timeout: + require.Fail(t, "Timeout when shutting down the engine") + } +} + +// Ensure that if the TLS certificate is updated on disk then new +// connections will use the new version. +func TestUpdateCert(t *testing.T) { + _, cancel, socketLocation := setupPingEngine(t) + defer cancel() + + getCurrentFingerprint := func() [sha256.Size]byte { + + conn, err := net.Dial("unix", socketLocation) + require.NoError(t, err) + defer conn.Close() + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: param.Server_WebHost.GetString(), + } + tlsConn := tls.Client(conn, tlsConfig) + err = tlsConn.Handshake() + require.NoError(t, err) + + currentCert := tlsConn.ConnectionState().PeerCertificates[0] + return sha256.Sum256(currentCert.Raw) + } + + // First, compare the current fingerprint against that on disk + currentFingerprint := getCurrentFingerprint() + + certFile := param.Server_TLSCertificate.GetString() + keyFile := param.Server_TLSKey.GetString() + getDiskFingerprint := func() [sha256.Size]byte { + diskCert, err := tls.LoadX509KeyPair(certFile, keyFile) + require.NoError(t, err) + return sha256.Sum256(diskCert.Certificate[0]) + } + + diskFingerprint := getDiskFingerprint() + assert.Equal(t, currentFingerprint, diskFingerprint) + + // Next, trigger a reload of the cert + require.NoError(t, os.Remove(certFile)) + require.NoError(t, os.Remove(keyFile)) + require.NoError(t, config.InitServer(config.OriginType)) + + newDiskFingerprint := getDiskFingerprint() + assert.NotEqual(t, diskFingerprint, newDiskFingerprint) + + log.Debugln("Will look for updated TLS certificate") + sawUpdate := false + for idx := 0; idx < 10; idx++ { + time.Sleep(50 * time.Millisecond) + log.Debugln("Checking current fingerprint") + currentFingerprint := getCurrentFingerprint() + if currentFingerprint == newDiskFingerprint { + sawUpdate = true + break + } else { + require.Equal(t, currentFingerprint, diskFingerprint) + } + } + assert.True(t, sawUpdate) +}