diff --git a/cmd/fed_serve_cache_test.go b/cmd/fed_serve_cache_test.go index 22aa46d10..22a7b29c5 100644 --- a/cmd/fed_serve_cache_test.go +++ b/cmd/fed_serve_cache_test.go @@ -109,6 +109,7 @@ func TestFedServeCache(t *testing.T) { viper.Set("Registry.RequireOriginApproval", false) viper.Set("Registry.RequireCacheApproval", false) viper.Set("Origin.EnablePublicReads", false) + viper.Set("Director.DbLocation", filepath.Join(t.TempDir(), "director.sqlite")) require.NoError(t, err) diff --git a/cmd/fed_serve_test.go b/cmd/fed_serve_test.go index a5f39c646..e27f91e4f 100644 --- a/cmd/fed_serve_test.go +++ b/cmd/fed_serve_test.go @@ -93,6 +93,8 @@ func TestFedServePosixOrigin(t *testing.T) { viper.Set("Registry.DbLocation", filepath.Join(t.TempDir(), "ns-registry.sqlite")) viper.Set("Registry.RequireOriginApproval", false) viper.Set("Registry.RequireCacheApproval", false) + viper.Set("Director.DbLocation", filepath.Join(t.TempDir(), "director.sqlite")) + defer cancel() _, fedCancel, err := launchers.LaunchModules(ctx, modules) diff --git a/cmd/plugin_test.go b/cmd/plugin_test.go index 0fa2c32e4..be5d98401 100644 --- a/cmd/plugin_test.go +++ b/cmd/plugin_test.go @@ -191,6 +191,7 @@ func (f *FedTest) Spinup() { viper.Set("Origin.Port", 0) viper.Set("Server.WebPort", 0) viper.Set("Origin.RunLocation", tmpPath) + viper.Set("Director.DbLocation", filepath.Join(f.T.TempDir(), "director.sqlite")) err = config.InitServer(ctx, modules) require.NoError(f.T, err) diff --git a/config/config.go b/config/config.go index 5314e12b3..255131041 100644 --- a/config/config.go +++ b/config/config.go @@ -1030,6 +1030,7 @@ func InitServer(ctx context.Context, currentServers server_structs.ServerType) e viper.SetDefault("Origin.Multiuser", true) viper.SetDefault(param.Origin_DbLocation.GetName(), "/var/lib/pelican/origin.sqlite") + viper.SetDefault(param.Director_DbLocation.GetName(), "/var/lib/pelican/director.sqlite") viper.SetDefault("Director.GeoIPLocation", "/var/cache/pelican/maxmind/GeoLite2-City.mmdb") viper.SetDefault("Registry.DbLocation", "/var/lib/pelican/registry.sqlite") // The lotman db will actually take this path and create the lot at /path/.lot/lotman_cpp.sqlite @@ -1040,6 +1041,7 @@ func InitServer(ctx context.Context, currentServers server_structs.ServerType) e viper.SetDefault(param.Origin_GlobusConfigLocation.GetName(), filepath.Join("/run", "pelican", "xrootd", "origin", "globus")) } else { viper.SetDefault(param.Origin_DbLocation.GetName(), filepath.Join(configDir, "origin.sqlite")) + viper.SetDefault(param.Director_DbLocation.GetName(), filepath.Join(configDir, "director.sqlite")) viper.SetDefault("Director.GeoIPLocation", filepath.Join(configDir, "maxmind", "GeoLite2-City.mmdb")) viper.SetDefault("Registry.DbLocation", filepath.Join(configDir, "ns-registry.sqlite")) // Lotdb will live at /.lot/lotman_cpp.sqlite diff --git a/director/director_api.go b/director/director_api.go index f91c16ec8..a427c92c2 100644 --- a/director/director_api.go +++ b/director/director_api.go @@ -66,7 +66,7 @@ func listAdvertisement(serverTypes []server_structs.ServerType) []*server_struct func checkFilter(serverName string) (bool, filterType) { filteredServersMutex.RLock() defer filteredServersMutex.RUnlock() - + log.Debugf("Checking for a downtime filter applied to server %s", serverName) status, exists := filteredServers[serverName] // No filter entry if !exists { @@ -223,17 +223,29 @@ func hookServerAdsCache() { }) } -// Populate internal filteredServers map by Director.FilteredServers +// Populate internal filteredServers map using Director.FilteredServers param and director db func ConfigFilterdServers() { filteredServersMutex.Lock() defer filteredServersMutex.Unlock() - if !param.Director_FilteredServers.IsSet() { - return + if param.Director_FilteredServers.IsSet() { + for _, sn := range param.Director_FilteredServers.GetStringSlice() { + filteredServers[sn] = permFiltered + } + log.Debugln("Loaded server downtime configuration from the Director.FilteredServers parameter:", filteredServers) } - for _, sn := range param.Director_FilteredServers.GetStringSlice() { - filteredServers[sn] = permFiltered + if param.Director_DbLocation.GetString() != "" { + persistedServerDowntimes, err := getAllServerDowntimes() + if err != nil { + log.Error("Failed to read persisted server downtimes from director db:", err) + return + } + for _, serverDowntime := range persistedServerDowntimes { + filteredServers[serverDowntime.Name] = serverDowntime.FilterType + } + log.Debugln("Loaded filtered servers config from director db:", filteredServers) + // if a filtered server config rule is set in both Director.FilteredServers param and director db, the latter one will eventually be used } } diff --git a/director/director_db.go b/director/director_db.go new file mode 100644 index 000000000..47a22ce69 --- /dev/null +++ b/director/director_db.go @@ -0,0 +1,157 @@ +/*************************************************************** + * + * Copyright (C) 2024, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ +package director + +import ( + "embed" + "time" + + "github.com/google/uuid" + "github.com/pkg/errors" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/pelicanplatform/pelican/param" + "github.com/pelicanplatform/pelican/server_utils" +) + +type ServerDowntime struct { + UUID string `gorm:"primaryKey"` + Name string `gorm:"not null;unique"` + FilterType filterType `gorm:"type:text;not null"` + // We don't use gorm default gorm.Model to change ID type to string + CreatedAt time.Time + UpdatedAt time.Time +} + +var db *gorm.DB + +//go:embed migrations/*.sql +var embedMigrations embed.FS + +// Initialize the Director's sqlite database, which is used to persist information about server downtimes +func InitializeDB() error { + dbPath := param.Director_DbLocation.GetString() + tdb, err := server_utils.InitSQLiteDB(dbPath) + if err != nil { + return errors.Wrap(err, "failed to initialize the Director's sqlite database") + } + db = tdb + sqldb, err := db.DB() + if err != nil { + return errors.Wrapf(err, "failed to get sql.DB from gorm DB: %s", dbPath) + } + // Run database migrations + if err := server_utils.MigrateDB(sqldb, embedMigrations); err != nil { + return errors.Wrap(err, "failed to migrate the Director's sqlite database using embedded migration files") + } + return nil +} + +// Shut down the Director's sqlite database +func shutdownDirectorDB() error { + return server_utils.ShutdownDB(db) +} + +// Create a new db entry representing the downtime info of a server +func createServerDowntime(serverName string, filterType filterType) error { + id, err := uuid.NewV7() + if err != nil { + return errors.Wrap(err, "unable to create new UUID for new entry in server status table") + } + serverDowntime := ServerDowntime{ + UUID: id.String(), + Name: serverName, + FilterType: filterType, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := db.Create(serverDowntime).Error; err != nil { + return errors.Wrap(err, "unable to create server downtime table") + } + return nil +} + +// Retrieve the downtime info of a given server (filter applied to the server) +func getServerDowntime(serverName string) (filterType, error) { + var serverDowntime ServerDowntime + err := db.First(&serverDowntime, "name = ?", serverName).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return "", errors.Wrapf(err, "%s is not found in the Director db", serverName) + } + return "", errors.Wrapf(err, "unable to get the downtime of %s", serverName) + } + return filterType(serverDowntime.FilterType), nil +} + +// Retrieve the downtime info of all servers saved in the Director's sqlite database +func getAllServerDowntimes() ([]ServerDowntime, error) { + var statuses []ServerDowntime + result := db.Find(&statuses) + + if result.Error != nil { + return nil, errors.Wrap(result.Error, "unable to get the downtime of all servers") + } + return statuses, nil +} + +// Set the downtime info (filterType) of a given server +func setServerDowntime(serverName string, filterType filterType) error { + var serverDowntime ServerDowntime + // slience the logger for this query because there's definitely an ErrRecordNotFound when a new downtime info entry inserted + err := db.Session(&gorm.Session{Logger: db.Logger.LogMode(logger.Silent)}).First(&serverDowntime, "name = ?", serverName).Error + + // If the server doesn't exist in director db, create a new entry for it + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return createServerDowntime(serverName, filterType) + } + + return errors.Wrapf(err, "unable to retrieve downtime status for server %s", serverName) + } + + serverDowntime.FilterType = filterType + serverDowntime.UpdatedAt = time.Now() + + if err := db.Save(&serverDowntime).Error; err != nil { + return errors.Wrap(err, "unable to update") + } + return nil +} + +// Define a function type for setServerDowntime +type setServerDowntimeFunc func(string, filterType) error + +// Make the function a variable so it can be mocked in tests +var setServerDowntimeFn setServerDowntimeFunc = setServerDowntime + +// Delete the downtime info of a given server from the Director's sqlite database +func deleteServerDowntime(serverName string) error { + if err := db.Where("name = ?", serverName).Delete(&ServerDowntime{}).Error; err != nil { + return errors.Wrap(err, "failed to delete an entry in Server Status table") + } + return nil +} + +// Define a function type for deleteServerDowntime +type deleteServerDowntimeFunc func(string) error + +// Make the function a variable so it can be mocked in tests +var deleteServerDowntimeFn deleteServerDowntimeFunc = deleteServerDowntime diff --git a/director/director_db_test.go b/director/director_db_test.go new file mode 100644 index 000000000..094ce59e8 --- /dev/null +++ b/director/director_db_test.go @@ -0,0 +1,82 @@ +package director + +import ( + "testing" + + "github.com/glebarez/sqlite" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + + "github.com/pelicanplatform/pelican/server_utils" +) + +var ( + mockSS []ServerDowntime = []ServerDowntime{ + {UUID: uuid.NewString(), Name: "/4a334d532d69:8443", FilterType: tempAllowed}, + {UUID: uuid.NewString(), Name: "/my-origin.com/foo/Bar", FilterType: permFiltered}, + {UUID: uuid.NewString(), Name: "/my-cache.com/chtc", FilterType: permFiltered}, + } +) + +func SetupMockDirectorDB(t *testing.T) { + mockDB, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + db = mockDB + require.NoError(t, err, "Error setting up mock origin DB") + err = db.AutoMigrate(&ServerDowntime{}) + require.NoError(t, err, "Failed to migrate DB for Globus table") +} + +func TeardownMockDirectorDB(t *testing.T) { + err := shutdownDirectorDB() + require.NoError(t, err, "Error tearing down mock director DB") +} + +func insertMockDBData(ss []ServerDowntime) error { + return db.Create(&ss).Error +} + +func TestDirectorDBBasics(t *testing.T) { + server_utils.ResetTestState() + SetupMockDirectorDB(t) + t.Cleanup(func() { + TeardownMockDirectorDB(t) + }) + err := insertMockDBData(mockSS) + require.NoError(t, err) + + t.Run("get-downtime", func(t *testing.T) { + filterType, err := getServerDowntime(mockSS[1].Name) + assert.Equal(t, filterType, permFiltered) + require.NoError(t, err) + }) + + t.Run("get-all-downtime", func(t *testing.T) { + statuses, err := getAllServerDowntimes() + require.NoError(t, err) + assert.Len(t, statuses, len(mockSS)) + }) + + t.Run("set-downtime", func(t *testing.T) { + err = setServerDowntime(mockSS[1].Name, tempAllowed) + require.NoError(t, err) + filterType, err := getServerDowntime(mockSS[1].Name) + assert.Equal(t, filterType, tempAllowed) + require.NoError(t, err) + }) + + t.Run("duplicate-name-insert", func(t *testing.T) { + err := createServerDowntime(mockSS[1].Name, tempAllowed) + require.Error(t, err) + assert.Contains(t, err.Error(), "UNIQUE constraint failed") + }) + + t.Run("delete-downtime-entry-from-directory-db", func(t *testing.T) { + err = deleteServerDowntime(mockSS[0].Name) + require.NoError(t, err, "Error deleting server status") + + _, err = getServerDowntime(mockSS[0].Name) + assert.Error(t, err, "Expected error retrieving deleted server status") + }) +} diff --git a/director/director_test.go b/director/director_test.go index 418deadfd..8ea01a203 100644 --- a/director/director_test.go +++ b/director/director_test.go @@ -1735,7 +1735,9 @@ func TestHandleFilterServer(t *testing.T) { filteredServersMutex.Lock() defer filteredServersMutex.Unlock() filteredServers = map[string]filterType{} + TeardownMockDirectorDB(t) }) + SetupMockDirectorDB(t) router := gin.Default() router.GET("/servers/filter/*name", handleFilterServer) @@ -1751,26 +1753,43 @@ func TestHandleFilterServer(t *testing.T) { // Check the response require.Equal(t, 200, w.Code) + // Check the in-memory cache storage filteredServersMutex.RLock() defer filteredServersMutex.RUnlock() assert.Equal(t, tempFiltered, filteredServers["mock-dne"]) + + // Check the Director database + filterType, err := getServerDowntime("mock-dne") + assert.Equal(t, tempFiltered, filterType) + require.NoError(t, err) }) t.Run("filter-server-w-permFiltered", func(t *testing.T) { // Create a request to the endpoint w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/servers/filter/mock-pf", nil) + + // Tweak the downtime status (filter type) to permFiltered filteredServersMutex.Lock() filteredServers["mock-pf"] = permFiltered filteredServersMutex.Unlock() + err := setServerDowntime("mock-pf", permFiltered) + require.NoError(t, err) + router.ServeHTTP(w, req) // Check the response require.Equal(t, 400, w.Code) + // Check the in-memory cache storage filteredServersMutex.RLock() defer filteredServersMutex.RUnlock() assert.Equal(t, permFiltered, filteredServers["mock-pf"]) + // Check the Director database + filterType, err := getServerDowntime("mock-pf") + assert.Equal(t, permFiltered, filterType) + require.NoError(t, err) + resB, err := io.ReadAll(w.Body) require.NoError(t, err) assert.Contains(t, string(resB), "Can't filter a server that already has been fitlered") @@ -1782,15 +1801,24 @@ func TestHandleFilterServer(t *testing.T) { filteredServersMutex.Lock() filteredServers["mock-tf"] = tempFiltered filteredServersMutex.Unlock() + err := setServerDowntime("mock-tf", tempFiltered) + require.NoError(t, err) + router.ServeHTTP(w, req) // Check the response require.Equal(t, 400, w.Code) + // Check the in-memory cache storage filteredServersMutex.RLock() defer filteredServersMutex.RUnlock() assert.Equal(t, tempFiltered, filteredServers["mock-tf"]) + // Check the Director database + filterType, err := getServerDowntime("mock-tf") + assert.Equal(t, tempFiltered, filterType) + require.NoError(t, err) + resB, err := io.ReadAll(w.Body) require.NoError(t, err) assert.Contains(t, string(resB), "Can't filter a server that already has been fitlered") @@ -1807,9 +1835,15 @@ func TestHandleFilterServer(t *testing.T) { // Check the response require.Equal(t, 200, w.Code) + // Check the in-memory cache storage filteredServersMutex.RLock() defer filteredServersMutex.RUnlock() assert.Equal(t, permFiltered, filteredServers["mock-ta"]) + + // Check the Director database + filterType, err := getServerDowntime("mock-ta") + assert.Equal(t, permFiltered, filterType) + require.NoError(t, err) }) t.Run("filter-with-invalid-name", func(t *testing.T) { // Create a request to the endpoint @@ -1825,12 +1859,61 @@ func TestHandleFilterServer(t *testing.T) { }) } +func TestHandleFilterServerDataIntegrity(t *testing.T) { + t.Cleanup(func() { + filteredServersMutex.Lock() + defer filteredServersMutex.Unlock() + filteredServers = map[string]filterType{} + TeardownMockDirectorDB(t) + }) + SetupMockDirectorDB(t) + + router := gin.Default() + router.GET("/servers/filter/*name", handleFilterServer) + + t.Run("db-error-when-setting-downtime", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/servers/filter/mock-error", nil) + + // Set up original filter type + filteredServersMutex.Lock() + filteredServers["mock-error"] = tempAllowed + filteredServersMutex.Unlock() + + // Mock setServerDowntime to return error + origSetServerDowntime := setServerDowntime + defer func() { setServerDowntimeFn = origSetServerDowntime }() + setServerDowntimeFn = func(serverName string, ft filterType) error { + return fmt.Errorf("mock db error") + } + + router.ServeHTTP(w, req) + + // Check response code + require.Equal(t, http.StatusInternalServerError, w.Code) + + // Verify the server was reverted to original filter type + filteredServersMutex.RLock() + actualType, exists := filteredServers["mock-error"] + filteredServersMutex.RUnlock() + assert.True(t, exists, "Server should exist in filteredServers") + assert.Equal(t, tempAllowed, actualType, "Filter type should be reverted to original value") + + // Check error message + resB, err := io.ReadAll(w.Body) + require.NoError(t, err) + assert.Contains(t, string(resB), "Failed to persist server downtime due to database error") + }) +} + func TestHandleAllowServer(t *testing.T) { t.Cleanup(func() { filteredServersMutex.Lock() defer filteredServersMutex.Unlock() filteredServers = map[string]filterType{} + TeardownMockDirectorDB(t) }) + SetupMockDirectorDB(t) router := gin.Default() router.GET("/servers/allow/*name", handleAllowServer) @@ -1838,9 +1921,15 @@ func TestHandleAllowServer(t *testing.T) { // Create a request to the endpoint w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/servers/allow/mock-dne", nil) + + // Server is not in the in-memory downtime map filteredServersMutex.Lock() delete(filteredServers, "mock-dne") filteredServersMutex.Unlock() + // Server is also not in the Director db downtime table + err := deleteServerDowntime("mock-dne") + require.NoError(t, err) + // Note: Both map deletion and db deletion do not trigger an error if there’s nothing to delete router.ServeHTTP(w, req) // Check the response @@ -1861,9 +1950,15 @@ func TestHandleAllowServer(t *testing.T) { // Check the response require.Equal(t, 200, w.Code) + // Check the in-memory cache storage filteredServersMutex.RLock() defer filteredServersMutex.RUnlock() assert.Equal(t, tempAllowed, filteredServers["mock-pf"]) + + // Check the Director database + filterType, err := getServerDowntime("mock-pf") + assert.Equal(t, tempAllowed, filterType) + require.NoError(t, err) }) t.Run("allow-server-w-tempFiltered", func(t *testing.T) { // Create a request to the endpoint @@ -1877,9 +1972,15 @@ func TestHandleAllowServer(t *testing.T) { // Check the response require.Equal(t, 200, w.Code) + // Check the in-memory cache storage filteredServersMutex.RLock() defer filteredServersMutex.RUnlock() assert.Empty(t, filteredServers["mock-tf"]) + + // Check the Director database + filterType, err := getServerDowntime("mock-tf") + assert.Equal(t, "", string(filterType)) + assert.Contains(t, err.Error(), "is not found in the Director db") }) t.Run("allow-server-w-tempAllowed", func(t *testing.T) { // Create a request to the endpoint @@ -1888,11 +1989,15 @@ func TestHandleAllowServer(t *testing.T) { filteredServersMutex.Lock() filteredServers["mock-ta"] = tempAllowed filteredServersMutex.Unlock() + err := setServerDowntime("mock-ta", tempAllowed) + require.NoError(t, err) + router.ServeHTTP(w, req) // Check the response require.Equal(t, 400, w.Code) + // Check the in-memory cache storage filteredServersMutex.RLock() defer filteredServersMutex.RUnlock() assert.Equal(t, tempAllowed, filteredServers["mock-ta"]) @@ -1900,6 +2005,11 @@ func TestHandleAllowServer(t *testing.T) { resB, err := io.ReadAll(w.Body) require.NoError(t, err) assert.Contains(t, string(resB), "Can't allow server mock-ta that is not being filtered") + + // Check the Director database + filterType, err := getServerDowntime("mock-ta") + assert.Equal(t, tempAllowed, filterType) + require.NoError(t, err) }) t.Run("allow-with-invalid-name", func(t *testing.T) { // Create a request to the endpoint @@ -1915,6 +2025,126 @@ func TestHandleAllowServer(t *testing.T) { }) } +func TestHandleAllowServerDataIntegrity(t *testing.T) { + t.Cleanup(func() { + filteredServersMutex.Lock() + defer filteredServersMutex.Unlock() + filteredServers = map[string]filterType{} + TeardownMockDirectorDB(t) + }) + SetupMockDirectorDB(t) + + router := gin.Default() + router.GET("/servers/allow/*name", handleAllowServer) + + // Sub-test 1: When server is permanently filtered + t.Run("permFiltered-db-error", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/servers/allow/mock-error", nil) + + // Set up initial state as permFiltered + filteredServersMutex.Lock() + filteredServers["mock-error"] = permFiltered + filteredServersMutex.Unlock() + + // Mock setServerDowntimeFn to return error + origSetServerDowntime := setServerDowntimeFn + defer func() { setServerDowntimeFn = origSetServerDowntime }() + setServerDowntimeFn = func(serverName string, ft filterType) error { + return fmt.Errorf("mock db error") + } + + router.ServeHTTP(w, req) + + // Check response code + require.Equal(t, http.StatusInternalServerError, w.Code) + + // Verify the server maintained original filter type + filteredServersMutex.RLock() + actualType, exists := filteredServers["mock-error"] + filteredServersMutex.RUnlock() + assert.True(t, exists, "Server should still exist in filteredServers") + assert.Equal(t, permFiltered, actualType, "Filter type should remain as permFiltered") + + // Check error message + resB, err := io.ReadAll(w.Body) + require.NoError(t, err) + assert.Contains(t, string(resB), "Failed to remove the downtime of server mock-error in director db") + }) + + // Sub-test 2: When server is temporarily filtered and deletion fails + t.Run("tempFiltered-deletion-error", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/servers/allow/mock-error", nil) + + // Set up initial state as tempFiltered + filteredServersMutex.Lock() + filteredServers["mock-error"] = tempFiltered + filteredServersMutex.Unlock() + + // Mock deleteServerDowntime to return error + origDeleteServerDowntime := deleteServerDowntimeFn + defer func() { deleteServerDowntimeFn = origDeleteServerDowntime }() + deleteServerDowntimeFn = func(serverName string) error { + return fmt.Errorf("mock deletion error") + } + + router.ServeHTTP(w, req) + + // Check response code + require.Equal(t, http.StatusInternalServerError, w.Code) + + // Verify the server maintained original filter type + filteredServersMutex.RLock() + actualType, exists := filteredServers["mock-error"] + filteredServersMutex.RUnlock() + assert.True(t, exists, "Server should still exist in filteredServers") + assert.Equal(t, tempFiltered, actualType, "Filter type should remain as tempFiltered") + + // Check error message + resB, err := io.ReadAll(w.Body) + require.NoError(t, err) + assert.Contains(t, string(resB), "Failed to remove the downtime of server mock-error in director db") + }) + + // Sub-test 3: When server is already tempAllowed and db error occurs + t.Run("tempAllowed-db-error", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/servers/allow/mock-error", nil) + + // Set up initial state as tempAllowed + filteredServersMutex.Lock() + filteredServers["mock-error"] = tempAllowed + filteredServersMutex.Unlock() + err := setServerDowntime("mock-error", tempAllowed) + require.NoError(t, err) + + // Mock setServerDowntimeFn to return error + origSetServerDowntime := setServerDowntimeFn + defer func() { setServerDowntimeFn = origSetServerDowntime }() + setServerDowntimeFn = func(serverName string, ft filterType) error { + return fmt.Errorf("mock db error") + } + + router.ServeHTTP(w, req) + + // Should return 400 as the server is already tempAllowed + require.Equal(t, http.StatusBadRequest, w.Code) + + // Verify the server maintained original filter type + filteredServersMutex.RLock() + actualType, exists := filteredServers["mock-error"] + filteredServersMutex.RUnlock() + assert.True(t, exists, "Server should still exist in filteredServers") + assert.Equal(t, tempAllowed, actualType, "Filter type should remain as tempAllowed") + + // Check error message + resB, err := io.ReadAll(w.Body) + require.NoError(t, err) + assert.Contains(t, string(resB), "Can't allow server") + }) +} + func TestGetRedirectUrl(t *testing.T) { adFromTopo := server_structs.ServerAd{ URL: url.URL{ diff --git a/director/director_ui.go b/director/director_ui.go index fe70b0753..2930a217a 100644 --- a/director/director_ui.go +++ b/director/director_ui.go @@ -248,12 +248,33 @@ func handleFilterServer(ctx *gin.Context) { filteredServersMutex.Lock() defer filteredServersMutex.Unlock() + // Backup the original filter type to revert in case of failure + originalFilterType, hasOriginalFilter := filteredServers[sn] + + // Decide new filter type and update map // If we previously temporarily allowed a server, we switch to permFiltered (reset) + newFilterType := tempFiltered if filterType == tempAllowed { - filteredServers[sn] = permFiltered - } else { - filteredServers[sn] = tempFiltered + newFilterType = permFiltered } + filteredServers[sn] = newFilterType + + // Attempt to persist change in the database + if err := setServerDowntimeFn(sn, newFilterType); err != nil { + // Revert the change in filteredServers if SetServerDowntime fails + if hasOriginalFilter { + filteredServers[sn] = originalFilterType + } else { + delete(filteredServers, sn) + } + + ctx.JSON(http.StatusInternalServerError, server_structs.SimpleApiResp{ + Status: server_structs.RespFailed, + Msg: "Failed to persist server downtime due to database error", + }) + return + } + ctx.JSON(http.StatusOK, server_structs.SimpleApiResp{Status: server_structs.RespOK, Msg: "success"}) } @@ -281,13 +302,48 @@ func handleAllowServer(ctx *gin.Context) { filteredServersMutex.Lock() defer filteredServersMutex.Unlock() + // Backup the original filter (downtime) type to revert in case of failure + originalFilterType, hasOriginalFilter := filteredServers[sn] + + // Perform actions based on the current filter type if ft == tempFiltered { - // For temporarily filtered server, allowing them by removing the server from the map + // Temporarily filtered server: allow it by removing from map delete(filteredServers, sn) + + if err := deleteServerDowntimeFn(sn); err != nil { + // Revert the change in filteredServers if DeleteServerDowntime fails + if hasOriginalFilter { + filteredServers[sn] = originalFilterType + } else { + delete(filteredServers, sn) + } + + ctx.JSON(http.StatusInternalServerError, server_structs.SimpleApiResp{ + Status: server_structs.RespFailed, + Msg: fmt.Sprintf("Failed to remove the downtime of server %s in director db", sn), + }) + return + } } else if ft == permFiltered { - // For servers to filter from the config, temporarily allow the server + // Permanently filtered server: temporarily allow it filteredServers[sn] = tempAllowed + + if err := setServerDowntimeFn(sn, tempAllowed); err != nil { + // Revert the change in filteredServers if SetServerDowntime fails + if hasOriginalFilter { + filteredServers[sn] = originalFilterType + } else { + delete(filteredServers, sn) + } + + ctx.JSON(http.StatusInternalServerError, server_structs.SimpleApiResp{ + Status: server_structs.RespFailed, + Msg: fmt.Sprintf("Failed to remove the downtime of server %s in director db", sn), + }) + return + } } else if ft == topoFiltered { + // Server is disabled by OSG Topology ctx.JSON(http.StatusBadRequest, server_structs.SimpleApiResp{ Status: server_structs.RespFailed, Msg: fmt.Sprintf("Can't allow server %s that is disabled by the OSG Topology. Contact OSG admin at support@osg-htc.org to enable the server.", sn), diff --git a/director/maxmind.go b/director/maxmind.go index 0b985fcd9..340334be2 100644 --- a/director/maxmind.go +++ b/director/maxmind.go @@ -144,7 +144,7 @@ func periodicMaxMindReload(ctx context.Context) { } } -func InitializeDB(ctx context.Context) { +func InitializeGeoIPDB(ctx context.Context) { go periodicMaxMindReload(ctx) localFile := param.Director_GeoIPLocation.GetString() localReader, err := geoip2.Open(localFile) diff --git a/director/migrations/20241017135850_create_db_tables.sql b/director/migrations/20241017135850_create_db_tables.sql new file mode 100644 index 000000000..1460aac84 --- /dev/null +++ b/director/migrations/20241017135850_create_db_tables.sql @@ -0,0 +1,14 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE server_downtimes ( + uuid TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + filter_type TEXT NOT NULL, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL +); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +-- +goose StatementEnd diff --git a/docs/parameters.yaml b/docs/parameters.yaml index 2fae2123e..47e263d9c 100644 --- a/docs/parameters.yaml +++ b/docs/parameters.yaml @@ -1252,6 +1252,14 @@ components: ["cache"] ############################ # Director-level configs # ############################ +name: Director.DbLocation +description: |+ + A filepath to the intended location of the director's database, where server downtime info is stored. +type: filename +root_default: /var/lib/pelican/director.sqlite +default: $ConfigBase/director.sqlite +components: ["director"] +--- name: Director.DefaultResponse description: |+ The default response type of a redirect for a director instance. Can be either "cache" or "origin". If a director diff --git a/fed_test_utils/fed.go b/fed_test_utils/fed.go index 2dc26416b..300ffe27d 100644 --- a/fed_test_utils/fed.go +++ b/fed_test_utils/fed.go @@ -147,6 +147,7 @@ func NewFedTest(t *testing.T, originConfig string) (ft *FedTest) { viper.Set("Registry.RequireOriginApproval", false) viper.Set("Registry.RequireCacheApproval", false) viper.Set("Director.CacheSortMethod", "distance") + viper.Set("Director.DbLocation", filepath.Join(t.TempDir(), "director.sqlite")) err = config.InitServer(ctx, modules) require.NoError(t, err) diff --git a/github_scripts/get_put_test.sh b/github_scripts/get_put_test.sh index 873646783..1db9c7a72 100755 --- a/github_scripts/get_put_test.sh +++ b/github_scripts/get_put_test.sh @@ -34,7 +34,8 @@ export PELICAN_ORIGIN_ENABLEDIRECTREADS=true export PELICAN_SERVER_ENABLEUI=false export PELICAN_ORIGIN_RUNLOCATION=$PWD/xrootdRunLocation export PELICAN_CONFIGDIR=$PWD/get_put_tmp/config -export PELICAN_REGISTRY_DBLOCATION=$PWD/get_put_tmp/config/test.sql +export PELICAN_REGISTRY_DBLOCATION=$PWD/get_put_tmp/config/test-registry.sql +export PELICAN_DIRECTOR_DBLOCATION=$PWD/get_put_tmp/config/test-director.sql export PELICAN_OIDC_CLIENTID="sometexthere" export PELICAN_ORIGIN_FEDERATIONPREFIX="/test" export PELICAN_ORIGIN_STORAGEPREFIX="$PWD/get_put_tmp/origin" diff --git a/github_scripts/stat_test.sh b/github_scripts/stat_test.sh index d1e0ab8e5..ea7ec67be 100755 --- a/github_scripts/stat_test.sh +++ b/github_scripts/stat_test.sh @@ -32,7 +32,8 @@ export PELICAN_SERVER_ENABLEUI=false export PELICAN_ORIGIN_RUNLOCATION=/tmp/pelican-test/stat_test/xrootdRunLocation export PELICAN_CONFIGDIR=/tmp/pelican-test/stat_test -export PELICAN_REGISTRY_DBLOCATION=/tmp/pelican-test/stat_test/test.sql +export PELICAN_REGISTRY_DBLOCATION=/tmp/pelican-test/stat_test/test-registry.sql +export PELICAN_DIRECTOR_DBLOCATION=/tmp/pelican-test/stat_test/test-director.sql export PELICAN_OIDC_CLIENTID="sometexthere" export PELICAN_OIDC_CLIENTSECRETFILE=/tmp/pelican-test/stat_test/oidc-secret echo "Placeholder OIDC secret" > /tmp/pelican-test/stat_test/oidc-secret @@ -63,6 +64,7 @@ cleanup() { unset PELICAN_FEDERATION_REGISTRYURL unset PELICAN_TLSSKIPVERIFY unset PELICAN_REGISTRY_DBLOCATION + unset PELICAN_DIRECTOR_DBLOCATION unset PELICAN_SERVER_ENABLEUI unset PELICAN_OIDC_CLIENTID unset PELICAN_OIDC_CLIENTSECRETFILE diff --git a/github_scripts/x509_test.sh b/github_scripts/x509_test.sh index ebe533672..f0b5ff18a 100755 --- a/github_scripts/x509_test.sh +++ b/github_scripts/x509_test.sh @@ -38,7 +38,8 @@ export PELICAN_SERVER_ENABLEUI=false export PELICAN_ORIGIN_RUNLOCATION=$PWD/xrootdRunLocation export PELICAN_CACHE_RUNLOCATION=$PWD/xrootdCacheRunLocation export PELICAN_CONFIGDIR=$PWD/x509/config -export PELICAN_REGISTRY_DBLOCATION=$PWD/x509/config/test.sql +export PELICAN_REGISTRY_DBLOCATION=$PWD/x509/config/test-registry.sql +export PELICAN_DIRECTOR_DBLOCATION=$PWD/x509/config/test-director.sql export PELICAN_OIDC_CLIENTID="sometexthere" export PELICAN_ORIGIN_EXPORTVOLUMES="$PWD/x509/origin:/test $PWD/x509/defer:/defer/" export PELICAN_DIRECTOR_X509CLIENTAUTHENTICATIONPREFIXES="/defer" diff --git a/launchers/director_serve.go b/launchers/director_serve.go index 44c72b462..4fde00fc8 100644 --- a/launchers/director_serve.go +++ b/launchers/director_serve.go @@ -37,8 +37,11 @@ import ( func DirectorServe(ctx context.Context, engine *gin.Engine, egrp *errgroup.Group) error { log.Info("Initializing Director GeoIP database...") - director.InitializeDB(ctx) + director.InitializeGeoIPDB(ctx) + if err := director.InitializeDB(); err != nil { + return errors.Wrap(err, "failed to initialize director sqlite database") + } director.ConfigFilterdServers() director.LaunchTTLCache(ctx, egrp) diff --git a/param/parameters.go b/param/parameters.go index 06cbcf3fa..f915935ef 100644 --- a/param/parameters.go +++ b/param/parameters.go @@ -152,6 +152,7 @@ var ( Cache_Url = StringParam{"Cache.Url"} Cache_XRootDPrefix = StringParam{"Cache.XRootDPrefix"} Director_CacheSortMethod = StringParam{"Director.CacheSortMethod"} + Director_DbLocation = StringParam{"Director.DbLocation"} Director_DefaultResponse = StringParam{"Director.DefaultResponse"} Director_GeoIPLocation = StringParam{"Director.GeoIPLocation"} Director_MaxMindKeyFile = StringParam{"Director.MaxMindKeyFile"} diff --git a/param/parameters_struct.go b/param/parameters_struct.go index 212dc4688..5e077de6d 100644 --- a/param/parameters_struct.go +++ b/param/parameters_struct.go @@ -69,6 +69,7 @@ type Config struct { CachesPullFromCaches bool `mapstructure:"cachespullfromcaches"` CheckCachePresence bool `mapstructure:"checkcachepresence"` CheckOriginPresence bool `mapstructure:"checkoriginpresence"` + DbLocation string `mapstructure:"dblocation"` DefaultResponse string `mapstructure:"defaultresponse"` EnableBroker bool `mapstructure:"enablebroker"` EnableOIDC bool `mapstructure:"enableoidc"` @@ -372,6 +373,7 @@ type configWithType struct { CachesPullFromCaches struct { Type string; Value bool } CheckCachePresence struct { Type string; Value bool } CheckOriginPresence struct { Type string; Value bool } + DbLocation struct { Type string; Value string } DefaultResponse struct { Type string; Value string } EnableBroker struct { Type string; Value bool } EnableOIDC struct { Type string; Value bool }