diff --git a/backend/src/backend/auth/auth_helper.py b/backend/src/backend/auth/auth_helper.py index 57fb1e969..041e4be4a 100644 --- a/backend/src/backend/auth/auth_helper.py +++ b/backend/src/backend/auth/auth_helper.py @@ -101,7 +101,6 @@ async def _authorized_callback_route(self, request: Request) -> Response: def get_authenticated_user( request_with_session: Request, ) -> Optional[AuthenticatedUser]: - timer = PerfTimer() # We may already have created and stored the AuthenticatedUser object on the request @@ -173,6 +172,9 @@ def get_authenticated_user( # print("-------------------------------------------------") smda_token = token_dict.get("access_token") if token_dict else None + token_dict = cca.acquire_token_silent(scopes=config.GRAPH_SCOPES, account=accounts[0]) + graph_token = token_dict.get("access_token") if token_dict else None + # print(f" get tokens {timer.lap_ms():.1f}ms") _save_token_cache_in_session(request_with_session, token_cache) @@ -187,10 +189,13 @@ def get_authenticated_user( authenticated_user = AuthenticatedUser( user_id=user_id, username=user_name, - sumo_access_token=sumo_token, - smda_access_token=smda_token, - pdm_access_token=None, - ssdl_access_token=None, + access_tokens={ + "graph_access_token": graph_token, + "sumo_access_token": sumo_token, + "smda_access_token": smda_token, + "pdm_access_token": None, + "ssdl_access_token": None, + }, ) request_with_session.state.authenticated_user_obj = authenticated_user @@ -203,7 +208,6 @@ def get_authenticated_user( def _create_msal_confidential_client_app( token_cache: msal.TokenCache, ) -> msal.ConfidentialClientApplication: - authority = f"https://login.microsoftonline.com/{config.TENANT_ID}" return msal.ConfidentialClientApplication( client_id=config.CLIENT_ID, @@ -217,7 +221,6 @@ def _create_msal_confidential_client_app( # Note that this function will NOT return the token itself, but rather a dict # that typically has an "access_token" key def _get_token_dict_from_session_token_cache(request_with_session: Request, scopes: List[str]) -> Optional[dict]: - token_cache = _load_token_cache_from_session(request_with_session) cca = _create_msal_confidential_client_app(token_cache) diff --git a/backend/src/backend/primary/routers/general.py b/backend/src/backend/primary/routers/general.py index 304c0d3ee..29e9c7f9f 100644 --- a/backend/src/backend/primary/routers/general.py +++ b/backend/src/backend/primary/routers/general.py @@ -1,19 +1,24 @@ +import asyncio import datetime import logging +import httpx import starsessions from starlette.responses import StreamingResponse -from fastapi import APIRouter, HTTPException, Request, status, Depends +from fastapi import APIRouter, HTTPException, Request, status, Depends, Query from pydantic import BaseModel from src.backend.auth.auth_helper import AuthHelper, AuthenticatedUser from src.backend.primary.user_session_proxy import proxy_to_user_session +from src.services.graph_access.graph_access import GraphApiAccess LOGGER = logging.getLogger(__name__) class UserInfo(BaseModel): username: str + display_name: str | None + avatar_b64str: str | None has_sumo_access: bool has_smda_access: bool @@ -34,7 +39,12 @@ def alive_protected() -> str: @router.get("/logged_in_user", response_model=UserInfo) -async def logged_in_user(request: Request) -> UserInfo: +async def logged_in_user( + request: Request, + includeGraphApiInfo: bool = Query( + False, description="Set to true to include user avatar and display name from Microsoft Graph API" + ), +) -> UserInfo: print("entering logged_in_user route") await starsessions.load_session(request) @@ -47,10 +57,29 @@ async def logged_in_user(request: Request) -> UserInfo: user_info = UserInfo( username=authenticated_user.get_username(), + avatar_b64str=None, + display_name=None, has_sumo_access=authenticated_user.has_sumo_access_token(), has_smda_access=authenticated_user.has_smda_access_token(), ) + if authenticated_user.has_graph_access_token() and includeGraphApiInfo: + graph_api_access = GraphApiAccess(authenticated_user.get_graph_access_token()) + try: + avatar_b64str_future = asyncio.create_task(graph_api_access.get_user_profile_photo()) + graph_user_info_future = asyncio.create_task(graph_api_access.get_user_info()) + + avatar_b64str = await avatar_b64str_future + graph_user_info = await graph_user_info_future + + user_info.avatar_b64str = avatar_b64str + if graph_user_info is not None: + user_info.display_name = graph_user_info.get("displayName", None) + except httpx.HTTPError as e: + print("Error while fetching user avatar and info from Microsoft Graph API (HTTP error):\n", e) + except httpx.InvalidURL as e: + print("Error while fetching user avatar and info from Microsoft Graph API (Invalid URL):\n", e) + return user_info diff --git a/backend/src/services/graph_access/graph_access.py b/backend/src/services/graph_access/graph_access.py new file mode 100644 index 000000000..fe3b92c6e --- /dev/null +++ b/backend/src/services/graph_access/graph_access.py @@ -0,0 +1,39 @@ +import base64 +from typing import Mapping + +# Using the same http client as sumo +import httpx + + +class GraphApiAccess: + def __init__(self, access_token: str): + self._access_token = access_token + + def _make_headers(self) -> Mapping[str, str]: + return {"Authorization": f"Bearer {self._access_token}"} + + async def _request(self, url: str) -> httpx.Response: + async with httpx.AsyncClient() as client: + response = await client.get( + url, + headers=self._make_headers(), + ) + return response + + async def get_user_profile_photo(self) -> str | None: + print("entering get_user_profile_photo") + response = await self._request("https://graph.microsoft.com/v1.0/me/photo/$value") + + if response.status_code == 200: + return base64.b64encode(response.content).decode("utf-8") + else: + return None + + async def get_user_info(self) -> Mapping[str, str] | None: + print("entering get_user_info") + response = await self._request("https://graph.microsoft.com/v1.0/me") + + if response.status_code == 200: + return response.json() + else: + return None diff --git a/backend/src/services/utils/authenticated_user.py b/backend/src/services/utils/authenticated_user.py index 97f3ed7df..3b28aa515 100644 --- a/backend/src/services/utils/authenticated_user.py +++ b/backend/src/services/utils/authenticated_user.py @@ -1,6 +1,14 @@ # pylint: disable=bare-except -from typing import Any, Optional +from typing import Any, Optional, TypedDict + + +class AccessTokens(TypedDict): + graph_access_token: Optional[str] + sumo_access_token: Optional[str] + smda_access_token: Optional[str] + pdm_access_token: Optional[str] + ssdl_access_token: Optional[str] class AuthenticatedUser: @@ -8,17 +16,15 @@ def __init__( self, user_id: str, username: str, - sumo_access_token: Optional[str], - smda_access_token: Optional[str], - pdm_access_token: Optional[str], - ssdl_access_token: Optional[str], + access_tokens: AccessTokens, ) -> None: self._user_id = user_id self._username = username - self._sumo_access_token = sumo_access_token - self._smda_access_token = smda_access_token - self._pdm_access_token = pdm_access_token - self._ssdl_access_token = ssdl_access_token + self._graph_access_token = access_tokens.get("graph_access_token") + self._sumo_access_token = access_tokens.get("sumo_access_token") + self._smda_access_token = access_tokens.get("smda_access_token") + self._pdm_access_token = access_tokens.get("pdm_access_token") + self._ssdl_access_token = access_tokens.get("ssdl_access_token") def __hash__(self) -> int: return hash(self._user_id) @@ -29,6 +35,19 @@ def __eq__(self, other: Any) -> bool: def get_username(self) -> str: return self._username + def get_graph_access_token(self) -> str: + if isinstance(self._graph_access_token, str) and self._graph_access_token: + return self._graph_access_token + + raise ValueError("User has no graph access token") + + def has_graph_access_token(self) -> bool: + try: + self.get_graph_access_token() + return True + except ValueError: + return False + def get_sumo_access_token(self) -> str: if isinstance(self._sumo_access_token, str) and len(self._sumo_access_token) > 0: return self._sumo_access_token diff --git a/frontend/src/api/models/UserInfo.ts b/frontend/src/api/models/UserInfo.ts index 1d80b9237..efffcd292 100644 --- a/frontend/src/api/models/UserInfo.ts +++ b/frontend/src/api/models/UserInfo.ts @@ -4,6 +4,8 @@ export type UserInfo = { username: string; + display_name: (string | null); + avatar_b64str: (string | null); has_sumo_access: boolean; has_smda_access: boolean; }; diff --git a/frontend/src/api/services/DefaultService.ts b/frontend/src/api/services/DefaultService.ts index 6cc8012e0..71efe87fa 100644 --- a/frontend/src/api/services/DefaultService.ts +++ b/frontend/src/api/services/DefaultService.ts @@ -69,13 +69,22 @@ export class DefaultService { /** * Logged In User + * @param includeGraphApiInfo Set to true to include user avatar and display name from Microsoft Graph API * @returns UserInfo Successful Response * @throws ApiError */ - public loggedInUser(): CancelablePromise { + public loggedInUser( + includeGraphApiInfo: boolean = false, + ): CancelablePromise { return this.httpRequest.request({ method: 'GET', url: '/logged_in_user', + query: { + 'includeGraphApiInfo': includeGraphApiInfo, + }, + errors: { + 422: `Validation Error`, + }, }); } diff --git a/frontend/src/framework/internal/components/LoginButton/loginButton.tsx b/frontend/src/framework/internal/components/LoginButton/loginButton.tsx index 6e514c6dc..55d994d32 100644 --- a/frontend/src/framework/internal/components/LoginButton/loginButton.tsx +++ b/frontend/src/framework/internal/components/LoginButton/loginButton.tsx @@ -9,6 +9,19 @@ import { getTextWidth } from "@lib/utils/textSize"; import { Dropdown, MenuButton } from "@mui/base"; import { AccountCircle, Login, Logout } from "@mui/icons-material"; +function makeInitials(name: string): string | null { + const regExp = new RegExp(/([^()]+)(\([\w ]+\))/); + const match = regExp.exec(name); + + if (match) { + const names = match[1].trim().split(" "); + if (names.length > 1) { + return names[0].charAt(0) + names[names.length - 1].charAt(0); + } + } + return null; +} + export type LoginButtonProps = { className?: string; showText?: boolean; @@ -29,7 +42,27 @@ export const LoginButton: React.FC = (props) => { function makeIcon() { if (authState === AuthState.LoggedIn) { - return ; + if (userInfo?.avatar_b64str) { + return ( + Avatar + ); + } + if (userInfo?.display_name) { + const initials = makeInitials(userInfo.display_name); + if (initials) { + return ( +
+ {initials} +
+ ); + } + } + return ; + } else if (authState === AuthState.NotLoggedIn) { return ; } else { @@ -39,7 +72,7 @@ export const LoginButton: React.FC = (props) => { function makeText() { if (authState === AuthState.LoggedIn) { - return userInfo?.username || "Unknown user"; + return userInfo?.display_name || userInfo?.username || "Unknown user"; } else if (authState === AuthState.NotLoggedIn) { return "Sign in"; } else { @@ -71,7 +104,7 @@ export const LoginButton: React.FC = (props) => { > {makeIcon()} diff --git a/frontend/src/framework/internal/providers/AuthProvider.tsx b/frontend/src/framework/internal/providers/AuthProvider.tsx index 120dce7aa..454d0c622 100644 --- a/frontend/src/framework/internal/providers/AuthProvider.tsx +++ b/frontend/src/framework/internal/providers/AuthProvider.tsx @@ -43,7 +43,7 @@ export const AuthProvider: React.FC<{ children: React.ReactElement }> = (props) } apiService.default - .loggedInUser() + .loggedInUser(true) .then((user) => { if (user) { setAuthState(AuthState.LoggedIn);