diff --git a/Casbin/Rbac/DefaultRoleManager.cs b/Casbin/Rbac/DefaultRoleManager.cs index 1e3452a8..657d72fb 100644 --- a/Casbin/Rbac/DefaultRoleManager.cs +++ b/Casbin/Rbac/DefaultRoleManager.cs @@ -230,15 +230,6 @@ private bool HasLinkInDomain(string name1, string name2, string domain) public virtual void AddLink(string name1, string name2, string domain = null) { domain ??= _defaultDomain; - if (HasDomainPattern) - { - foreach (string matchDomain in GetPatternDomains(domain)) - { - AddLinkInDomain(name1, name2, matchDomain); - } - _cachedAllDomains = null; - return; - } _cachedAllDomains = null; AddLinkInDomain(name1, name2, domain); } @@ -249,10 +240,25 @@ private void AddLinkInDomain(string name1, string name2, string domain) ? _defaultRoles : _allDomains.GetOrAdd(domain, new ConcurrentDictionary()); + bool role1IsNew = roles.ContainsKey(name1) is false; + bool role2IsNew = roles.ContainsKey(name2) is false; + Role role1 = roles.GetOrAdd(name1, new Role(name1, domain)); Role role2 = roles.GetOrAdd(name2, new Role(name2, domain)); role1.AddRole(role2); + if (HasDomainPattern) + { + if (role1IsNew) + { + AddLinksFromMatchingDomains(role1); + } + if (role2IsNew) + { + AddLinksToMatchingDomains(role2); + } + } + if (HasPattern is false) { return; @@ -276,6 +282,87 @@ private void AddLinkInDomain(string name1, string name2, string domain) } } + private void AddLinksFromMatchingDomains(Role role) + { + if (HasDomainPattern is false) + { + return; + } + + IEnumerable matchingDomains = GetMatchingDomains(role.Domain); + + foreach (string domain in matchingDomains) + { + if (domain == role.Domain || _allDomains.TryGetValue(domain, out var matchingDomain) is false) + { + continue; + } + + if (HasPattern is false) + { + if (matchingDomain.TryGetValue(role.Name, out var matchingRole) && role != matchingRole) + { + matchingRole.AddRole(role); + }; + continue; + } + + var roleNames = matchingDomain.Keys; + foreach (string roleName in roleNames) + { + if (MatchingFunc(roleName, role.Name) is false) + { + continue; + } + if (matchingDomain.TryGetValue(roleName, out var matchingRole) && role != matchingRole) + { + matchingRole.AddRole(role); + } + } + } + } + + private void AddLinksToMatchingDomains(Role role) + { + if (HasDomainPattern is false) + { + return; + } + + IEnumerable matchingDomains = GetPatternDomains(role.Domain); + + foreach (string domain in matchingDomains) + { + if (domain == role.Domain || _allDomains.TryGetValue(domain, out var matchingDomain) is false) + { + continue; + } + + if (HasPattern is false) + { + if (matchingDomain.TryGetValue(role.Name, out var matchingRole) && role != matchingRole) + { + role.AddRole(matchingRole); + }; + continue; + } + + var roleNames = matchingDomain.Keys; + foreach (string roleName in roleNames) + { + if (MatchingFunc(role.Name, roleName) is false) + { + continue; + } + if (matchingDomain.TryGetValue(roleName, out var matchingRole) && matchingRole != role) + { + role.AddRole(matchingRole); + } + } + + } + } + public virtual void DeleteLink(string name1, string name2, string domain = null) { domain ??= _defaultDomain; @@ -321,5 +408,17 @@ private IEnumerable GetPatternDomains(string domain) } return matchDomains; } + + private IEnumerable GetMatchingDomains(string domainPattern) + { + List matchDomains = new() { domainPattern }; + _cachedAllDomains ??= _allDomains.Keys; + if (HasDomainPattern) + { + matchDomains.AddRange(_cachedAllDomains.Where(key => + DomainMatchingFunc(key, domainPattern) && key != domainPattern)); + } + return matchDomains; + } } }