From e8d3e80b081b774fdfc3c2b3012d7119ad61d8bd Mon Sep 17 00:00:00 2001 From: Max Ma Date: Fri, 29 Nov 2024 11:35:34 +0100 Subject: [PATCH] initialize cache in startup --- controllers/acls.go | 2 +- logic/acls.go | 18 +++++++++--------- main.go | 9 +++++++++ migrate/migrate.go | 2 -- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/controllers/acls.go b/controllers/acls.go index 727811fb5..99d67876d 100644 --- a/controllers/acls.go +++ b/controllers/acls.go @@ -91,7 +91,7 @@ func getAcls(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - acls, err := logic.ListAcls(models.NetworkID(netID)) + acls, err := logic.ListAclsByNetwork(models.NetworkID(netID)) if err != nil { logger.Log(0, r.Header.Get("user"), "failed to get all network acl entries: ", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) diff --git a/logic/acls.go b/logic/acls.go index 9a4dcb4aa..f48f0ee27 100644 --- a/logic/acls.go +++ b/logic/acls.go @@ -23,7 +23,7 @@ func CreateDefaultAclNetworkPolicies(netID models.NetworkID) { if netID.String() == "" { return } - _, _ = ListAcls(netID) + _, _ = ListAclsByNetwork(netID) if !IsAclExists(fmt.Sprintf("%s.%s", netID, "all-nodes")) { defaultDeviceAcl := models.Acl{ ID: fmt.Sprintf("%s.%s", netID, "all-nodes"), @@ -106,7 +106,7 @@ func CreateDefaultAclNetworkPolicies(netID models.NetworkID) { // DeleteDefaultNetworkPolicies - deletes all default network acl policies func DeleteDefaultNetworkPolicies(netId models.NetworkID) { - acls, _ := ListAcls(netId) + acls, _ := ListAclsByNetwork(netId) for _, acl := range acls { if acl.NetworkID == netId && acl.Default { DeleteAcl(acl) @@ -347,7 +347,7 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo return acl, nil } // check if there are any custom all policies - policies, _ := ListAcls(netID) + policies, _ := ListAclsByNetwork(netID) for _, policy := range policies { if !policy.Enabled { continue @@ -367,7 +367,7 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo return acl, nil } -func listAcls() (acls []models.Acl) { +func ListAcls() (acls []models.Acl) { if servercfg.CacheEnabled() && len(aclCacheMap) > 0 { return listAclFromCache() } @@ -393,7 +393,7 @@ func listAcls() (acls []models.Acl) { // ListUserPolicies - lists all acl policies enforced on an user func ListUserPolicies(u models.User) []models.Acl { - allAcls := listAcls() + allAcls := ListAcls() userAcls := []models.Acl{} for _, acl := range allAcls { @@ -418,7 +418,7 @@ func ListUserPolicies(u models.User) []models.Acl { // listPoliciesOfUser - lists all user acl policies applied to user in an network func listPoliciesOfUser(user models.User, netID models.NetworkID) []models.Acl { - allAcls := listAcls() + allAcls := ListAcls() userAcls := []models.Acl{} for _, acl := range allAcls { if acl.NetworkID == netID && acl.RuleType == models.UserPolicy { @@ -447,7 +447,7 @@ func listPoliciesOfUser(user models.User, netID models.NetworkID) []models.Acl { // listDevicePolicies - lists all device policies in a network func listDevicePolicies(netID models.NetworkID) []models.Acl { - allAcls := listAcls() + allAcls := ListAcls() deviceAcls := []models.Acl{} for _, acl := range allAcls { if acl.NetworkID == netID && acl.RuleType == models.DevicePolicy { @@ -458,9 +458,9 @@ func listDevicePolicies(netID models.NetworkID) []models.Acl { } // ListAcls - lists all acl policies -func ListAcls(netID models.NetworkID) ([]models.Acl, error) { +func ListAclsByNetwork(netID models.NetworkID) ([]models.Acl, error) { - allAcls := listAcls() + allAcls := ListAcls() netAcls := []models.Acl{} for _, acl := range allAcls { if acl.NetworkID == netID { diff --git a/main.go b/main.go index 10bb52b8b..33eebb71c 100644 --- a/main.go +++ b/main.go @@ -99,6 +99,15 @@ func initialize() { // Client Mode Prereq Check logger.FatalLog("Error connecting to database: ", err.Error()) } logger.Log(0, "database successfully connected") + + //initialize cache + _, _ = logic.GetNetworks() + _, _ = logic.GetAllNodes() + _, _ = logic.GetAllHosts() + _, _ = logic.GetAllExtClients() + _ = logic.ListAcls() + _, _ = logic.GetAllEnrollmentKeys() + migrate.Run() logic.SetJWTSecret() diff --git a/migrate/migrate.go b/migrate/migrate.go index b4a866ab7..19e9232aa 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -20,8 +20,6 @@ import ( // Run - runs all migrations func Run() { - _, _ = logic.GetAllNodes() - _, _ = logic.GetAllHosts() updateEnrollmentKeys() assignSuperAdmin() createDefaultTagsAndPolicies()