Skip to content

Commit

Permalink
Merge pull request #4 from alan-turing-institute/3-respond-to-all-req…
Browse files Browse the repository at this point in the history
…uests

Respond to all requests
  • Loading branch information
jemrobinson authored Sep 28, 2023
2 parents da5afee + 0eaab01 commit 59fc514
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 36 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

# OS files
.DS_Store

Expand Down
37 changes: 36 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,42 @@ The name is a slightly tortured acronym for: LD**A**P **pr**oxy for Open**I**D *
Start the `Apricot` server on port 8080 by running:

```bash
python run.py --client-id "<your client ID>" --client-secret "<your client secret>" --backend <your backend> --port 8080
python run.py --client-id "<your client ID>" --client-secret "<your client secret>" --backend "<your backend>" --port 8080 --domain "<your domain name>"
```

This will create an LDAP tree that looks like this:

```ldif
dn: DC=<your domain>
objectClass: dcObject
dn: OU=users,DC=<your domain>
objectClass: organizationalUnit
ou: users
dn: OU=groups,DC=<your domain>
objectClass: organizationalUnit
ou: groups
```

Each user will have an entry like

```ldif
dn: CN=<user name>,OU=users,DC=<your domain>
objectClass: organizationalPerson
objectClass: person
objectClass: top
objectClass: user
<user data fields here>
```

Each group will have an entry like

```ldif
dn: CN=<group name>,OU=groups,DC=<your domain>
objectClass: group
objectClass: top
<group data fields here>
```

## OpenID Connect
Expand Down
6 changes: 5 additions & 1 deletion apricot/apricot_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
backend: OAuthBackend,
client_id: str,
client_secret: str,
domain: str,
port: int,
**kwargs: Any,
) -> None:
Expand All @@ -25,7 +26,10 @@ def __init__(
# Initialize the appropriate OAuth client
try:
oauth_client = OAuthClientMap[backend](
client_id=client_id, client_secret=client_secret, **kwargs
client_id=client_id,
client_secret=client_secret,
domain=domain,
**kwargs,
)
except Exception as exc:
msg = f"Could not construct an OAuth client for the '{backend}' backend."
Expand Down
7 changes: 6 additions & 1 deletion apricot/ldap/oauth_ldap_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def __str__(self) -> str:
output = bytes(self.toWire()).decode("utf-8")
for child in self._children.values():
try:
output += f"\n- {child!s}"
# Indent children by two spaces
indent = " "
output += (
f"{indent}{str(child).strip()}".replace("\n", f"\n{indent}")
+ "\n\n"
)
except TypeError:
pass
return output
Expand Down
4 changes: 2 additions & 2 deletions apricot/ldap/oauth_ldap_server_factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from ldaptor.protocols.ldap.ldapserver import LDAPServer
from twisted.internet.interfaces import IAddress
from twisted.internet.protocol import Protocol, ServerFactory

from apricot.oauth import OAuthClient

from .oauth_ldap_tree import OAuthLDAPTree
from .read_only_ldap_server import ReadOnlyLDAPServer


class OAuthLDAPServerFactory(ServerFactory):
protocol = LDAPServer
protocol = ReadOnlyLDAPServer

def __init__(self, oauth_client: OAuthClient):
"""
Expand Down
3 changes: 1 addition & 2 deletions apricot/ldap/oauth_ldap_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ def __init__(self, oauth_client: OAuthClient) -> None:
self.oauth_client = oauth_client

# Create a root node for the tree
root_dn = "DC=" + self.oauth_client.domain().replace(".", ",DC=")
self.root = self.build_root(
dn=root_dn, attributes={"objectClass": ["dcObject"]}
dn=self.oauth_client.root_dn, attributes={"objectClass": ["dcObject"]}
)
# Add OUs for users and groups
groups_ou = self.root.add_child(
Expand Down
111 changes: 111 additions & 0 deletions apricot/ldap/read_only_ldap_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Callable

from ldaptor.interfaces import ILDAPEntry
from ldaptor.protocols.ldap.ldaperrors import LDAPProtocolError
from ldaptor.protocols.ldap.ldapserver import LDAPServer
from ldaptor.protocols.pureldap import (
LDAPBindRequest,
LDAPControl,
LDAPSearchResultDone,
LDAPSearchResultEntry,
)
from twisted.internet import defer


class ReadOnlyLDAPServer(LDAPServer):
def getRootDSE( # noqa: N802
self,
request: LDAPBindRequest,
reply: Callable[[LDAPSearchResultEntry], None] | None,
) -> LDAPSearchResultDone:
"""Handle an LDAP Root RSE request"""
return super().getRootDSE(request, reply)

def handle_LDAPAddRequest( # noqa: N802
self,
request: LDAPBindRequest,
controls: LDAPControl | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""Refuse to handle an LDAP add request"""
id((request, controls, reply)) # ignore unused arguments
msg = "ReadOnlyLDAPServer will not handle LDAP add requests"
raise LDAPProtocolError(msg)

def handle_LDAPBindRequest( # noqa: N802
self,
request: LDAPBindRequest,
controls: LDAPControl | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""Handle an LDAP bind request"""
return super().handle_LDAPBindRequest(request, controls, reply)

def handle_LDAPCompareRequest( # noqa: N802
self,
request: LDAPBindRequest,
controls: LDAPControl | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""Handle an LDAP compare request"""
return super().handle_LDAPCompareRequest(request, controls, reply)

def handle_LDAPDelRequest( # noqa: N802
self,
request: LDAPBindRequest,
controls: LDAPControl | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""Refuse to handle an LDAP delete request"""
id((request, controls, reply)) # ignore unused arguments
msg = "ReadOnlyLDAPServer will not handle LDAP delete requests"
raise LDAPProtocolError(msg)

def handle_LDAPExtendedRequest( # noqa: N802
self,
request: LDAPBindRequest,
controls: LDAPControl | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""Handle an LDAP extended request"""
return super().handle_LDAPExtendedRequest(request, controls, reply)

def handle_LDAPModifyDNRequest( # noqa: N802
self,
request: LDAPBindRequest,
controls: LDAPControl | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""Refuse to handle an LDAP modify DN request"""
id((request, controls, reply)) # ignore unused arguments
msg = "ReadOnlyLDAPServer will not handle LDAP modify DN requests"
raise LDAPProtocolError(msg)

def handle_LDAPModifyRequest( # noqa: N802
self,
request: LDAPBindRequest,
controls: LDAPControl | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""Refuse to handle an LDAP modify request"""
id((request, controls, reply)) # ignore unused arguments
msg = "ReadOnlyLDAPServer will not handle LDAP modify requests"
raise LDAPProtocolError(msg)

def handle_LDAPUnbindRequest( # noqa: N802
self,
request: LDAPBindRequest,
controls: LDAPControl | None,
reply: Callable[..., None] | None,
) -> None:
"""Handle an LDAP unbind request"""
super().handle_LDAPUnbindRequest(request, controls, reply)

def handle_LDAPSearchRequest( # noqa: N802
self,
request: LDAPBindRequest,
controls: LDAPControl | None,
reply: Callable[[LDAPSearchResultEntry], None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""Handle an LDAP search request"""
return super().handle_LDAPSearchRequest(request, controls, reply)
18 changes: 2 additions & 16 deletions apricot/oauth/microsoft_entra_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,19 @@ class MicrosoftEntraClient(OAuthClient):

def __init__(
self,
client_id: str,
client_secret: str,
entra_tenant_id: str,
**kwargs: Any,
):
del kwargs # consume any unused arguments
redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL
scopes = ["https://graph.microsoft.com/.default"] # this is the default scope
token_url = (
f"https://login.microsoftonline.com/{entra_tenant_id}/oauth2/v2.0/token"
)
self.tenant_id = entra_tenant_id
super().__init__(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
scopes=scopes,
token_url=token_url,
redirect_uri=redirect_uri, scopes=scopes, token_url=token_url, **kwargs
)

def domain(self) -> str:
users = self.users()
domains = {str(user["domain"][0]) for user in users}
if len(domains) > 1:
domains = {domain for domain in domains if "onmicrosoft.com" not in domain}
return sorted(domains)[0]

def extract_token(self, json_response: JSONDict) -> str:
return str(json_response["access_token"])

Expand Down Expand Up @@ -68,7 +54,7 @@ def users(self) -> list[LDAPAttributeDict]:
f"https://graph.microsoft.com/v1.0/users/{user_dict['id']}/memberOf"
)
attributes["memberOf"] = [
group["displayName"]
f"CN={group['displayName']},OU=groups,{self.root_dn}"
for group in group_memberships["value"]
if group["displayName"]
]
Expand Down
16 changes: 9 additions & 7 deletions apricot/oauth/oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@ def __init__(
self,
client_id: str,
client_secret: str,
domain: str,
redirect_uri: str,
scopes: list[str],
token_url: str,
) -> None:
# Set attributes
self.client_secret = client_secret
self.domain = domain
self.token_url = token_url
# Allow token scope to not match requested scope. (Other auth libraries allow
# this, but Requests-OAuthlib raises exception on scope mismatch by default.)
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" # noqa: S105
Expand All @@ -40,9 +45,6 @@ def __init__(
client_id=client_id, scope=scopes, redirect_uri=redirect_uri
)
)
# Store client secret and token URL
self.client_secret = client_secret
self.token_url = token_url
# Request a new bearer token
json_response = self.session_application.fetch_token(
token_url=self.token_url,
Expand All @@ -51,10 +53,6 @@ def __init__(
)
self.bearer_token = self.extract_token(json_response)

@abstractmethod
def domain(self) -> str:
pass

@abstractmethod
def extract_token(self, json_response: JSONDict) -> str:
pass
Expand All @@ -67,6 +65,10 @@ def groups(self) -> list[LDAPAttributeDict]:
def users(self) -> list[LDAPAttributeDict]:
pass

@property
def root_dn(self) -> str:
return "DC=" + self.domain.replace(".", ",DC=")

def query(self, url: str) -> dict[str, Any]:
result = self.session_application.request(
method="GET",
Expand Down
15 changes: 9 additions & 6 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
prog="Apricot",
description="Apricot is a proxy for delegating LDAP requests to an OpenID Connect backend.",
)
parser.add_argument("-b", "--backend", type=OAuthBackend, help="Which OAuth backend to use")
parser.add_argument("-p", "--port", type=int, default=8080, help="Port to run on")
parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID")
parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret")
parser.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant id")

# Common options needed for all backends
parser.add_argument("-b", "--backend", type=OAuthBackend, help="Which OAuth backend to use.")
parser.add_argument("-d", "--domain", type=str, help="Which domain users belong to.")
parser.add_argument("-p", "--port", type=int, default=8080, help="Port to run on.")
parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.")
parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.")
# Options for Microsoft Entra backend
parser.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False)
# Parse arguments
args = parser.parse_args()

# Create the Apricot server
Expand Down

0 comments on commit 59fc514

Please sign in to comment.