From 9944ddac3d372621c7f7f8bc0c6d9701a0fbea89 Mon Sep 17 00:00:00 2001 From: abhishek9686 Date: Sun, 8 Dec 2024 21:56:08 +0400 Subject: [PATCH] add user acl rules on target node --- logic/acls.go | 140 +++++++++++++++++++++++++++++++-------------- logic/extpeers.go | 18 +++--- logic/nodes.go | 21 +++++++ logic/peers.go | 2 +- logic/user_mgmt.go | 19 ++++++ 5 files changed, 147 insertions(+), 53 deletions(-) diff --git a/logic/acls.go b/logic/acls.go index d0ca5012e..7168058c7 100644 --- a/logic/acls.go +++ b/logic/acls.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "net" "sort" "sync" "time" @@ -503,6 +502,18 @@ func listDevicePolicies(netID models.NetworkID) []models.Acl { return deviceAcls } +// listUserPolicies - lists all user policies in a network +func listUserPolicies(netID models.NetworkID) []models.Acl { + allAcls := ListAcls() + deviceAcls := []models.Acl{} + for _, acl := range allAcls { + if acl.NetworkID == netID && acl.RuleType == models.UserPolicy { + deviceAcls = append(deviceAcls, acl) + } + } + return deviceAcls +} + // ListAcls - lists all acl policies func ListAclsByNetwork(netID models.NetworkID) ([]models.Acl, error) { @@ -770,27 +781,83 @@ func RemoveDeviceTagFromAclPolicies(tagID models.TagID, netID models.NetworkID) return nil } -func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) { +func getUserAclRulesForNode(targetnode *models.Node, + rules map[string]models.AclRule) map[string]models.AclRule { + userNodes := GetStaticUserNodesByNetwork(models.NetworkID(targetnode.Network)) + userGrpMap := GetUserGrpMap() + allowedUsers := make(map[string]models.Acl) + acls := listUserPolicies(models.NetworkID(targetnode.Network)) + for nodeTag := range targetnode.Tags { + for _, acl := range acls { + if !acl.Enabled { + continue + } + dstTags := convAclTagToValueMap(acl.Dst) + if _, ok := dstTags[nodeTag.String()]; ok { + // get all src tags + for _, srcAcl := range acl.Src { + if srcAcl.ID == models.UserAclID { + allowedUsers[srcAcl.Value] = acl + } else if srcAcl.ID == models.UserGroupAclID { + // fetch all users in the group + if usersMap, ok := userGrpMap[models.UserGroupID(srcAcl.Value)]; ok { + for userName := range usersMap { + allowedUsers[userName] = acl + } + } + } + } - rules = make(map[string]models.AclRule) - defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy) - if err == nil && defaultPolicy.Enabled { - rules[defaultPolicy.ID] = models.AclRule{ - IPList: []net.IPNet{node.NetworkRange}, - IP6List: []net.IPNet{node.NetworkRange6}, - AllowedProtocol: models.ALL, - Direction: models.TrafficDirectionBi, + } + } + } + for _, userNode := range userNodes { + if !userNode.StaticNode.Enabled { + continue + } + acl, ok := allowedUsers[userNode.StaticNode.OwnerID] + if !ok { + continue + } + if !acl.Enabled { + continue + } + + r := models.AclRule{ + ID: acl.ID, + AllowedProtocol: acl.Proto, + AllowedPorts: acl.Port, + Direction: acl.AllowedDirection, Allowed: true, } - return + // Get peers in the tags and add allowed rules + if userNode.StaticNode.Address != "" { + r.IPList = append(r.IPList, userNode.StaticNode.AddressIPNet4()) + } + if userNode.StaticNode.Address6 != "" { + r.IP6List = append(r.IP6List, userNode.StaticNode.AddressIPNet6()) + } + if aclRule, ok := rules[acl.ID]; ok { + aclRule.IPList = append(aclRule.IPList, r.IPList...) + aclRule.IP6List = append(aclRule.IP6List, r.IP6List...) + rules[acl.ID] = aclRule + } else { + rules[acl.ID] = r + } } + return rules +} - taggedNodes := GetTagMapWithNodesByNetwork(models.NetworkID(node.Network)) - acls := listDevicePolicies(models.NetworkID(node.Network)) - //allowedNodeUniqueMap := make(map[string]struct{}) - for nodeTag := range node.Tags { +func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRule) { + defer func() { + rules = getUserAclRulesForNode(targetnode, rules) + }() + rules = make(map[string]models.AclRule) + taggedNodes := GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network)) + acls := listDevicePolicies(models.NetworkID(targetnode.Network)) + for nodeTag := range targetnode.Tags { for _, acl := range acls { - if acl.Default || !acl.Enabled { + if !acl.Enabled { continue } srcTags := convAclTagToValueMap(acl.Src) @@ -805,31 +872,6 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) { if acl.AllowedDirection == models.TrafficDirectionBi { var existsInSrcTag bool var existsInDstTag bool - // if contains all resources, return entire cidr - if _, ok := srcTags["*"]; ok { - return map[string]models.AclRule{ - acl.ID: { - IPList: []net.IPNet{node.NetworkRange}, - IP6List: []net.IPNet{node.NetworkRange6}, - AllowedProtocol: models.ALL, - AllowedPorts: acl.Port, - Direction: acl.AllowedDirection, - Allowed: true, - }, - } - } - if _, ok := dstTags["*"]; ok { - return map[string]models.AclRule{ - acl.ID: { - IPList: []net.IPNet{node.NetworkRange}, - IP6List: []net.IPNet{node.NetworkRange6}, - AllowedProtocol: models.ALL, - AllowedPorts: acl.Port, - Direction: acl.AllowedDirection, - Allowed: true, - }, - } - } if _, ok := srcTags[nodeTag.String()]; ok { existsInSrcTag = true @@ -838,7 +880,7 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) { existsInDstTag = true } - if existsInSrcTag { + if existsInSrcTag && !existsInDstTag { // get all dst tags for dst := range dstTags { if dst == nodeTag.String() { @@ -847,6 +889,9 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) { // Get peers in the tags and add allowed rules nodes := taggedNodes[models.TagID(dst)] for _, node := range nodes { + if node.ID == targetnode.ID { + continue + } if node.Address.IP != nil { aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4()) } @@ -862,7 +907,7 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) { } } } - if existsInDstTag { + if existsInDstTag && !existsInSrcTag { // get all src tags for src := range srcTags { if src == nodeTag.String() { @@ -871,6 +916,9 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) { // Get peers in the tags and add allowed rules nodes := taggedNodes[models.TagID(src)] for _, node := range nodes { + if node.ID == targetnode.ID { + continue + } if node.Address.IP != nil { aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4()) } @@ -889,6 +937,9 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) { if existsInDstTag && existsInSrcTag { nodes := taggedNodes[nodeTag] for _, node := range nodes { + if node.ID == targetnode.ID { + continue + } if node.Address.IP != nil { aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4()) } @@ -913,6 +964,9 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) { // Get peers in the tags and add allowed rules nodes := taggedNodes[models.TagID(src)] for _, node := range nodes { + if node.ID == targetnode.ID { + continue + } if node.Address.IP != nil { aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4()) } diff --git a/logic/extpeers.go b/logic/extpeers.go index d8a9aad29..aead8c1dc 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -737,15 +737,15 @@ func GetExtPeers(node, peer *models.Node) ([]wgtypes.PeerConfig, []models.IDandA if !IsClientNodeAllowed(&extPeer, peer.ID.String()) { continue } - if extPeer.RemoteAccessClientID == "" { - if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), *peer); !ok { - continue - } - } else { - if ok, _ := IsUserAllowedToCommunicate(extPeer.OwnerID, *peer); !ok { - continue - } - } + // if extPeer.RemoteAccessClientID == "" { + // if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), *peer); !ok { + // continue + // } + // } else { + // if ok, _ := IsUserAllowedToCommunicate(extPeer.OwnerID, *peer); !ok { + // continue + // } + // } pubkey, err := wgtypes.ParseKey(extPeer.PublicKey) if err != nil { diff --git a/logic/nodes.go b/logic/nodes.go index 63bbe5dce..9cb73c47c 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -790,6 +790,27 @@ func AddTagMapWithStaticNodes(netID models.NetworkID, return tagNodesMap } +func AddTagMapWithStaticNodesWithUsers(netID models.NetworkID, + tagNodesMap map[models.TagID][]models.Node) map[models.TagID][]models.Node { + extclients, err := GetNetworkExtClients(netID.String()) + if err != nil { + return tagNodesMap + } + for _, extclient := range extclients { + if extclient.Tags == nil { + continue + } + for tagID := range extclient.Tags { + tagNodesMap[tagID] = append(tagNodesMap[tagID], models.Node{ + IsStatic: true, + StaticNode: extclient, + }) + } + + } + return tagNodesMap +} + func GetNodesWithTag(tagID models.TagID) map[string]models.Node { nMap := make(map[string]models.Node) tag, err := GetTag(tagID) diff --git a/logic/peers.go b/logic/peers.go index af5d66a56..098467ced 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -169,7 +169,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N hostPeerUpdate.FwUpdate.AllowAll = false } hostPeerUpdate.FwUpdate.AclRules = GetAclRulesForNode(&node) - if host.Name == "Test-Server" { + if host.Name == "lon-1" { fmt.Println("##### DEF POL ", defaultDevicePolicy.Enabled, defaultUserPolicy.Enabled) fmt.Printf("ACL Rules: %+v\n", hostPeerUpdate.FwUpdate.AclRules) } diff --git a/logic/user_mgmt.go b/logic/user_mgmt.go index 56395c78c..f6eccac20 100644 --- a/logic/user_mgmt.go +++ b/logic/user_mgmt.go @@ -98,6 +98,25 @@ func ListPlatformRoles() ([]models.UserRolePermissionTemplate, error) { return userRoles, nil } +func GetUserGrpMap() map[models.UserGroupID]map[string]struct{} { + grpUsersMap := make(map[models.UserGroupID]map[string]struct{}) + users, _ := GetUsersDB() + for _, user := range users { + for gID := range user.UserGroups { + if grpUsers, ok := grpUsersMap[gID]; ok { + grpUsers[user.UserName] = struct{}{} + grpUsersMap[gID] = grpUsers + } else { + grpUsersMap[gID] = make(map[string]struct{}) + grpUsersMap[gID][user.UserName] = struct{}{} + } + } + + } + + return grpUsersMap +} + func userRolesInit() { d, _ := json.Marshal(SuperAdminPermissionTemplate) database.Insert(SuperAdminPermissionTemplate.ID.String(), string(d), database.USER_PERMISSIONS_TABLE_NAME)