Skip to content

Commit

Permalink
add user acl rules on target node
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishek9686 committed Dec 8, 2024
1 parent 602a8ec commit 9944dda
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 53 deletions.
140 changes: 97 additions & 43 deletions logic/acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"sort"
"sync"
"time"
Expand Down Expand Up @@ -503,6 +502,18 @@ func listDevicePolicies(netID models.NetworkID) []models.Acl {
return deviceAcls
}

// listUserPolicies - lists all user policies in a network
func listUserPolicies(netID models.NetworkID) []models.Acl {
allAcls := ListAcls()
deviceAcls := []models.Acl{}
for _, acl := range allAcls {
if acl.NetworkID == netID && acl.RuleType == models.UserPolicy {
deviceAcls = append(deviceAcls, acl)
}
}
return deviceAcls
}

// ListAcls - lists all acl policies
func ListAclsByNetwork(netID models.NetworkID) ([]models.Acl, error) {

Expand Down Expand Up @@ -770,27 +781,83 @@ func RemoveDeviceTagFromAclPolicies(tagID models.TagID, netID models.NetworkID)
return nil
}

func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
func getUserAclRulesForNode(targetnode *models.Node,
rules map[string]models.AclRule) map[string]models.AclRule {
userNodes := GetStaticUserNodesByNetwork(models.NetworkID(targetnode.Network))
userGrpMap := GetUserGrpMap()
allowedUsers := make(map[string]models.Acl)
acls := listUserPolicies(models.NetworkID(targetnode.Network))
for nodeTag := range targetnode.Tags {
for _, acl := range acls {
if !acl.Enabled {
continue
}
dstTags := convAclTagToValueMap(acl.Dst)
if _, ok := dstTags[nodeTag.String()]; ok {
// get all src tags
for _, srcAcl := range acl.Src {
if srcAcl.ID == models.UserAclID {
allowedUsers[srcAcl.Value] = acl
} else if srcAcl.ID == models.UserGroupAclID {
// fetch all users in the group
if usersMap, ok := userGrpMap[models.UserGroupID(srcAcl.Value)]; ok {
for userName := range usersMap {
allowedUsers[userName] = acl
}
}
}
}

rules = make(map[string]models.AclRule)
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
if err == nil && defaultPolicy.Enabled {
rules[defaultPolicy.ID] = models.AclRule{
IPList: []net.IPNet{node.NetworkRange},
IP6List: []net.IPNet{node.NetworkRange6},
AllowedProtocol: models.ALL,
Direction: models.TrafficDirectionBi,
}
}
}
for _, userNode := range userNodes {
if !userNode.StaticNode.Enabled {
continue
}
acl, ok := allowedUsers[userNode.StaticNode.OwnerID]
if !ok {
continue
}
if !acl.Enabled {
continue
}

r := models.AclRule{
ID: acl.ID,
AllowedProtocol: acl.Proto,
AllowedPorts: acl.Port,
Direction: acl.AllowedDirection,
Allowed: true,
}
return
// Get peers in the tags and add allowed rules
if userNode.StaticNode.Address != "" {
r.IPList = append(r.IPList, userNode.StaticNode.AddressIPNet4())
}
if userNode.StaticNode.Address6 != "" {
r.IP6List = append(r.IP6List, userNode.StaticNode.AddressIPNet6())
}
if aclRule, ok := rules[acl.ID]; ok {
aclRule.IPList = append(aclRule.IPList, r.IPList...)
aclRule.IP6List = append(aclRule.IP6List, r.IP6List...)
rules[acl.ID] = aclRule
} else {
rules[acl.ID] = r
}
}
return rules
}

taggedNodes := GetTagMapWithNodesByNetwork(models.NetworkID(node.Network))
acls := listDevicePolicies(models.NetworkID(node.Network))
//allowedNodeUniqueMap := make(map[string]struct{})
for nodeTag := range node.Tags {
func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRule) {
defer func() {
rules = getUserAclRulesForNode(targetnode, rules)
}()
rules = make(map[string]models.AclRule)
taggedNodes := GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network))
acls := listDevicePolicies(models.NetworkID(targetnode.Network))
for nodeTag := range targetnode.Tags {
for _, acl := range acls {
if acl.Default || !acl.Enabled {
if !acl.Enabled {
continue
}
srcTags := convAclTagToValueMap(acl.Src)
Expand All @@ -805,31 +872,6 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
if acl.AllowedDirection == models.TrafficDirectionBi {
var existsInSrcTag bool
var existsInDstTag bool
// if contains all resources, return entire cidr
if _, ok := srcTags["*"]; ok {
return map[string]models.AclRule{
acl.ID: {
IPList: []net.IPNet{node.NetworkRange},
IP6List: []net.IPNet{node.NetworkRange6},
AllowedProtocol: models.ALL,
AllowedPorts: acl.Port,
Direction: acl.AllowedDirection,
Allowed: true,
},
}
}
if _, ok := dstTags["*"]; ok {
return map[string]models.AclRule{
acl.ID: {
IPList: []net.IPNet{node.NetworkRange},
IP6List: []net.IPNet{node.NetworkRange6},
AllowedProtocol: models.ALL,
AllowedPorts: acl.Port,
Direction: acl.AllowedDirection,
Allowed: true,
},
}
}

if _, ok := srcTags[nodeTag.String()]; ok {
existsInSrcTag = true
Expand All @@ -838,7 +880,7 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
existsInDstTag = true
}

if existsInSrcTag {
if existsInSrcTag && !existsInDstTag {
// get all dst tags
for dst := range dstTags {
if dst == nodeTag.String() {
Expand All @@ -847,6 +889,9 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
// Get peers in the tags and add allowed rules
nodes := taggedNodes[models.TagID(dst)]
for _, node := range nodes {
if node.ID == targetnode.ID {
continue
}
if node.Address.IP != nil {
aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
}
Expand All @@ -862,7 +907,7 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
}
}
}
if existsInDstTag {
if existsInDstTag && !existsInSrcTag {
// get all src tags
for src := range srcTags {
if src == nodeTag.String() {
Expand All @@ -871,6 +916,9 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
// Get peers in the tags and add allowed rules
nodes := taggedNodes[models.TagID(src)]
for _, node := range nodes {
if node.ID == targetnode.ID {
continue
}
if node.Address.IP != nil {
aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
}
Expand All @@ -889,6 +937,9 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
if existsInDstTag && existsInSrcTag {
nodes := taggedNodes[nodeTag]
for _, node := range nodes {
if node.ID == targetnode.ID {
continue
}
if node.Address.IP != nil {
aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
}
Expand All @@ -913,6 +964,9 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
// Get peers in the tags and add allowed rules
nodes := taggedNodes[models.TagID(src)]
for _, node := range nodes {
if node.ID == targetnode.ID {
continue
}
if node.Address.IP != nil {
aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
}
Expand Down
18 changes: 9 additions & 9 deletions logic/extpeers.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,15 +737,15 @@ func GetExtPeers(node, peer *models.Node) ([]wgtypes.PeerConfig, []models.IDandA
if !IsClientNodeAllowed(&extPeer, peer.ID.String()) {
continue
}
if extPeer.RemoteAccessClientID == "" {
if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), *peer); !ok {
continue
}
} else {
if ok, _ := IsUserAllowedToCommunicate(extPeer.OwnerID, *peer); !ok {
continue
}
}
// if extPeer.RemoteAccessClientID == "" {
// if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), *peer); !ok {
// continue
// }
// } else {
// if ok, _ := IsUserAllowedToCommunicate(extPeer.OwnerID, *peer); !ok {
// continue
// }
// }

pubkey, err := wgtypes.ParseKey(extPeer.PublicKey)
if err != nil {
Expand Down
21 changes: 21 additions & 0 deletions logic/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,27 @@ func AddTagMapWithStaticNodes(netID models.NetworkID,
return tagNodesMap
}

func AddTagMapWithStaticNodesWithUsers(netID models.NetworkID,
tagNodesMap map[models.TagID][]models.Node) map[models.TagID][]models.Node {
extclients, err := GetNetworkExtClients(netID.String())
if err != nil {
return tagNodesMap
}
for _, extclient := range extclients {
if extclient.Tags == nil {
continue
}
for tagID := range extclient.Tags {
tagNodesMap[tagID] = append(tagNodesMap[tagID], models.Node{
IsStatic: true,
StaticNode: extclient,
})
}

}
return tagNodesMap
}

func GetNodesWithTag(tagID models.TagID) map[string]models.Node {
nMap := make(map[string]models.Node)
tag, err := GetTag(tagID)
Expand Down
2 changes: 1 addition & 1 deletion logic/peers.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
hostPeerUpdate.FwUpdate.AllowAll = false
}
hostPeerUpdate.FwUpdate.AclRules = GetAclRulesForNode(&node)
if host.Name == "Test-Server" {
if host.Name == "lon-1" {
fmt.Println("##### DEF POL ", defaultDevicePolicy.Enabled, defaultUserPolicy.Enabled)
fmt.Printf("ACL Rules: %+v\n", hostPeerUpdate.FwUpdate.AclRules)
}
Expand Down
19 changes: 19 additions & 0 deletions logic/user_mgmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,25 @@ func ListPlatformRoles() ([]models.UserRolePermissionTemplate, error) {
return userRoles, nil
}

func GetUserGrpMap() map[models.UserGroupID]map[string]struct{} {
grpUsersMap := make(map[models.UserGroupID]map[string]struct{})
users, _ := GetUsersDB()
for _, user := range users {
for gID := range user.UserGroups {
if grpUsers, ok := grpUsersMap[gID]; ok {
grpUsers[user.UserName] = struct{}{}
grpUsersMap[gID] = grpUsers
} else {
grpUsersMap[gID] = make(map[string]struct{})
grpUsersMap[gID][user.UserName] = struct{}{}
}
}

}

return grpUsersMap
}

func userRolesInit() {
d, _ := json.Marshal(SuperAdminPermissionTemplate)
database.Insert(SuperAdminPermissionTemplate.ID.String(), string(d), database.USER_PERMISSIONS_TABLE_NAME)
Expand Down

0 comments on commit 9944dda

Please sign in to comment.