Skip to content

Commit

Permalink
[Outlook] Feature: Office365 multi-user support
Browse files Browse the repository at this point in the history
Introduced BaseOffice365User abstract base class to standardize Office 365 user handling.
Added MultiOffice365Users to manage multiple emails from config.
Added client_emails (comma-separated) in OutlookDataSource config.
Resolved issue with fetching too many users causing SMTP server not found error.
  • Loading branch information
ilyasabdellaoui committed Sep 23, 2024
1 parent ce5a4c1 commit 897fed8
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 29 deletions.
119 changes: 108 additions & 11 deletions connectors/sources/outlook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import asyncio
import os
from abc import ABC, abstractmethod
from copy import copy
from datetime import date
from functools import cached_property, partial
from typing import List

import aiofiles
import aiohttp
Expand Down Expand Up @@ -348,13 +350,13 @@ async def get_user_accounts(self):
yield user_account


class Office365Users:
"""Fetch users from Office365 Active Directory"""
class BaseOffice365User(ABC):
"""Abstract base class for Office 365 user management"""

def __init__(self, client_id, client_secret, tenant_id):
self.tenant_id = tenant_id
self.client_id = client_id
self.client_secret = client_secret
self.tenant_id = tenant_id

@cached_property
def _get_session(self):
Expand Down Expand Up @@ -403,6 +405,21 @@ async def _fetch_token(self):
except Exception as exception:
self._check_errors(response=exception)

@abstractmethod
async def get_users(self):
pass

@abstractmethod
async def get_user_accounts(self):
pass


class Office365Users(BaseOffice365User):
"""Fetch users from Office365 Active Directory"""

def __init__(self, client_id, client_secret, tenant_id):
super().__init__(client_id, client_secret, tenant_id)

@retryable(
retries=RETRIES,
interval=RETRY_INTERVAL,
Expand Down Expand Up @@ -456,6 +473,57 @@ async def get_user_accounts(self):
yield user_account


class MultiOffice365Users(BaseOffice365User):
"""Fetch multiple Office365 users based on a list of email addresses."""

def __init__(self, client_id, client_secret, tenant_id, client_emails: List[str]):
super().__init__(client_id, client_secret, tenant_id)
self.client_emails = client_emails

async def get_users(self):
access_token = await self._fetch_token()
for email in self.client_emails:
url = f"https://graph.microsoft.com/v1.0/users/{email}"
try:
async with self._get_session.get(
url=url,
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
) as response:
json_response = await response.json()
yield json_response
except Exception:
raise

async def get_user_accounts(self):
async for user in self.get_users():
mail = user.get("mail")
if mail is None:
continue

credentials = OAuth2Credentials(
client_id=self.client_id,
tenant_id=self.tenant_id,
client_secret=self.client_secret,
identity=Identity(primary_smtp_address=mail),
)
configuration = Configuration(
credentials=credentials,
auth_type=OAUTH2,
service_endpoint=EWS_ENDPOINT,
retry_policy=FaultTolerance(max_wait=120),
)
user_account = Account(
primary_smtp_address=mail,
config=configuration,
autodiscover=False,
access_type=IMPERSONATION,
)
yield user_account


class OutlookDocFormatter:
"""Format Outlook object documents to Elasticsearch document"""

Expand Down Expand Up @@ -583,6 +651,27 @@ def attachment_doc_formatter(self, attachment, attachment_type, timezone):
}


class UserFactory:
"""Factory class for creating Office365 user instances"""

@staticmethod
def create_user(configuration: dict) -> BaseOffice365User:
if configuration.get("client_emails"):
client_emails = [email.strip() for email in configuration["client_emails"].split(",")]
return MultiOffice365Users(
client_id=configuration["client_id"],
client_secret=configuration["client_secret"],
tenant_id=configuration["tenant_id"],
client_emails=client_emails
)
else:
return Office365Users(
client_id=configuration["client_id"],
client_secret=configuration["client_secret"],
tenant_id=configuration["tenant_id"]
)


class OutlookClient:
"""Outlook client to handle API calls made to Outlook"""

Expand All @@ -605,11 +694,7 @@ def set_logger(self, logger_):
@cached_property
def _get_user_instance(self):
if self.is_cloud:
return Office365Users(
client_id=self.configuration["client_id"],
client_secret=self.configuration["client_secret"],
tenant_id=self.configuration["tenant_id"],
)
return UserFactory.create_user(self.configuration)

return ExchangeUsers(
ad_server=self.configuration["active_directory_server"],
Expand Down Expand Up @@ -666,9 +751,12 @@ async def get_tasks(self, account):
yield task

async def get_contacts(self, account):
folder = account.root / "Top of Information Store" / "Contacts"
for contact in await asyncio.to_thread(folder.all().only, *CONTACT_FIELDS):
yield contact
try:
folder = account.root / "Top of Information Store" / "Contacts"
for contact in await asyncio.to_thread(folder.all().only, *CONTACT_FIELDS):
yield contact
except Exception:
raise


class OutlookDataSource(BaseDataSource):
Expand Down Expand Up @@ -735,6 +823,13 @@ def get_default_configuration(cls):
"sensitive": True,
"type": "str",
},
"client_emails": {
"depends_on": [{"field": "data_source", "value": OUTLOOK_CLOUD}],
"label": "Client Email Addresses (comma-separated)",
"order": 5,
"required": False,
"type": "str",
},
"exchange_server": {
"depends_on": [{"field": "data_source", "value": OUTLOOK_SERVER}],
"label": "Exchange Server",
Expand Down Expand Up @@ -1072,9 +1167,11 @@ async def get_docs(self, filtering=None):
dictionary: dictionary containing meta-data of the files.
"""
async for account in self.client._get_user_instance.get_user_accounts():
self._logger.debug(f"Processing account: {account}")
timezone = account.default_timezone or DEFAULT_TIMEZONE

async for mail in self._fetch_mails(account=account, timezone=timezone):
self._logger.debug(f"Fetched mail: {mail}")
yield mail

async for contact in self._fetch_contacts(
Expand Down
99 changes: 81 additions & 18 deletions tests/sources/test_outlook.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ async def create_outlook_source(
tenant_id="foo",
client_id="bar",
client_secret="faa",
client_emails=None,
exchange_server="127.0.0.1",
active_directory_server="127.0.0.1",
username="fee",
Expand All @@ -383,12 +384,16 @@ async def create_outlook_source(
ssl_ca="",
use_text_extraction_service=False,
):
if client_emails is None:
client_emails = ""

async with create_source(
OutlookDataSource,
data_source=data_source,
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
client_emails=client_emails,
exchange_server=exchange_server,
active_directory_server=active_directory_server,
username=username,
Expand All @@ -415,26 +420,36 @@ def get_stream_reader():
return async_mock


def side_effect_function(url, headers):
def side_effect_function(client_emails=None):
"""Dynamically changing return values for API calls
Args:
url, ssl: Params required for get call
client_emails: Optional string of comma-separated email addresses
"""
if url == "https://graph.microsoft.com/v1.0/users?$top=999":
return get_json_mock(
mock_response={
"@odata.nextLink": "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token",
"value": [{"mail": "[email protected]"}],
},
status=200,
)
elif (
url
== "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token"
):
return get_json_mock(
mock_response={"value": [{"mail": "[email protected]"}]}, status=200
)
def inner(url, headers):
if client_emails:
emails = [email.strip() for email in client_emails.split(",")]
for email in emails:
if url == f"https://graph.microsoft.com/v1.0/users/{email}":
users_response = {"value": [{"mail": email}]}
return get_json_mock(mock_response=users_response, status=200)
elif url == "https://graph.microsoft.com/v1.0/users?$top=999":
return get_json_mock(
mock_response={
"@odata.nextLink": "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token",
"value": [{"mail": "[email protected]"}],
},
status=200,
)
elif (
url
== "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token"
):
return get_json_mock(
mock_response={"value": [{"mail": "[email protected]"}]}, status=200
)

return inner


@pytest.mark.asyncio
Expand All @@ -459,6 +474,7 @@ def side_effect_function(url, headers):
"tenant_id": "foo",
"client_id": "bar",
"client_secret": "",
"client_emails": None,
}
),
],
Expand Down Expand Up @@ -497,6 +513,17 @@ async def test_validate_configuration_with_invalid_dependency_fields_raises_erro
"tenant_id": "foo",
"client_id": "bar",
"client_secret": "foo.bar",
"client_emails": None
}
),
(
# Outlook Cloud with non-blank dependent fields & client_emails provided
{
"data_source": OUTLOOK_CLOUD,
"tenant_id": "foo",
"client_id": "bar",
"client_secret": "foo.bar",
"client_emails": "[email protected]"
}
),
],
Expand Down Expand Up @@ -552,7 +579,7 @@ async def test_ping_for_cloud():
):
with mock.patch(
"aiohttp.ClientSession.get",
side_effect=side_effect_function,
side_effect=side_effect_function(),
):
await source.ping()

Expand Down Expand Up @@ -597,13 +624,49 @@ async def test_get_users_for_cloud():
):
with mock.patch(
"aiohttp.ClientSession.get",
side_effect=side_effect_function,
side_effect=side_effect_function(),
):
async for response in source.client._get_user_instance.get_users():
user_mails = [user["mail"] for user in response["value"]]
users.extend(user_mails)
assert users == ["[email protected]", "[email protected]"]

client_emails = "[email protected]"
async with create_outlook_source(client_emails=client_emails) as source:
users = []
with mock.patch(
"aiohttp.ClientSession.post",
return_value=get_json_mock(
mock_response={"access_token": "fake-token"}, status=200
),
):
with mock.patch(
"aiohttp.ClientSession.get",
side_effect=side_effect_function(client_emails),
):
async for response in source.client._get_user_instance.get_users():
user_mails = [user["mail"] for user in response["value"]]
users.extend(user_mails)
assert users == ["[email protected]"]

client_emails = "[email protected], [email protected]"
async with create_outlook_source(client_emails=client_emails) as source:
users = []
with mock.patch(
"aiohttp.ClientSession.post",
return_value=get_json_mock(
mock_response={"access_token": "fake-token"}, status=200
),
):
with mock.patch(
"aiohttp.ClientSession.get",
side_effect=side_effect_function(client_emails),
):
async for response in source.client._get_user_instance.get_users():
user_mails = [user["mail"] for user in response["value"]]
users.extend(user_mails)
assert set(users) == {"[email protected]", "[email protected]"}


@pytest.mark.asyncio
@patch("connectors.sources.outlook.Connection")
Expand Down

0 comments on commit 897fed8

Please sign in to comment.