From cc865fffc0e556005a6ab596717a77230ba82ee7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Oct 2023 20:08:11 -0400 Subject: [PATCH] Convert user_get_threepids response to attrs. (#16468) This improves type annotations by not having a dictionary of Any values. --- changelog.d/16468.misc | 1 + synapse/handlers/account_validity.py | 4 ++-- synapse/handlers/admin.py | 4 +++- synapse/handlers/deactivate_account.py | 4 ++-- synapse/module_api/__init__.py | 2 +- synapse/rest/admin/users.py | 3 +-- synapse/rest/client/account.py | 4 +++- .../storage/databases/main/registration.py | 19 ++++++++++++++----- tests/module_api/test_api.py | 8 ++++---- 9 files changed, 31 insertions(+), 18 deletions(-) create mode 100644 changelog.d/16468.misc diff --git a/changelog.d/16468.misc b/changelog.d/16468.misc new file mode 100644 index 000000000000..93ceaeafc9b9 --- /dev/null +++ b/changelog.d/16468.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index f1a7a05df6bc..6c2a49a3b91f 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -212,8 +212,8 @@ async def _get_email_addresses_for_user(self, user_id: str) -> List[str]: addresses = [] for threepid in threepids: - if threepid["medium"] == "email": - addresses.append(threepid["address"]) + if threepid.medium == "email": + addresses.append(threepid.address) return addresses diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 97fd1fd42772..2c2baeac675e 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -16,6 +16,8 @@ import logging from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set +import attr + from synapse.api.constants import Direction, Membership from synapse.events import EventBase from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo @@ -93,7 +95,7 @@ async def get_user(self, user: UserID) -> Optional[JsonMapping]: ] user_info_dict["displayname"] = profile.display_name user_info_dict["avatar_url"] = profile.avatar_url - user_info_dict["threepids"] = threepids + user_info_dict["threepids"] = [attr.asdict(t) for t in threepids] user_info_dict["external_ids"] = external_ids user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 67adeae6a7a6..6a8f8f2fd18a 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -117,9 +117,9 @@ async def deactivate_account( # Remove any local threepid associations for this account. local_threepids = await self.store.user_get_threepids(user_id) - for threepid in local_threepids: + for local_threepid in local_threepids: await self._auth_handler.delete_local_threepid( - user_id, threepid["medium"], threepid["address"] + user_id, local_threepid.medium, local_threepid.address ) # delete any devices belonging to the user, which will also diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 65e2aca4560a..0786d2063565 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -678,7 +678,7 @@ async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]: "msisdn" for phone numbers, and an "address" key which value is the threepid's address. """ - return await self._store.user_get_threepids(user_id) + return [attr.asdict(t) for t in await self._store.user_get_threepids(user_id)] def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]": """Check if user exists. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index cd995e8dbb80..7fe16130e764 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -329,9 +329,8 @@ async def on_PUT( if threepids is not None: # get changed threepids (added and removed) - # convert List[Dict[str, Any]] into Set[Tuple[str, str]] cur_threepids = { - (threepid["medium"], threepid["address"]) + (threepid.medium, threepid.address) for threepid in await self.store.user_get_threepids(user_id) } add_threepids = new_threepids - cur_threepids diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index e74a87af4d33..641390cb304d 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -24,6 +24,8 @@ from pydantic.v1 import StrictBool, StrictStr, constr else: from pydantic import StrictBool, StrictStr, constr + +import attr from typing_extensions import Literal from twisted.web.server import Request @@ -595,7 +597,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: threepids = await self.datastore.user_get_threepids(requester.user.to_string()) - return 200, {"threepids": threepids} + return 200, {"threepids": [attr.asdict(t) for t in threepids]} # NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because # the endpoint is deprecated. (If you really want to, you could do this by reusing diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 64a2c31a5d60..9e8643ae4d87 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -143,6 +143,14 @@ class LoginTokenLookupResult: """The session ID advertised by the SSO Identity Provider.""" +@attr.s(frozen=True, slots=True, auto_attribs=True) +class ThreepidResult: + medium: str + address: str + validated_at: int + added_at: int + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -988,13 +996,14 @@ async def user_add_threepid( {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]: - return await self.db_pool.simple_select_list( + async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]: + results = await self.db_pool.simple_select_list( "user_threepids", - {"user_id": user_id}, - ["medium", "address", "validated_at", "added_at"], - "user_get_threepids", + keyvalues={"user_id": user_id}, + retcols=["medium", "address", "validated_at", "added_at"], + desc="user_get_threepids", ) + return [ThreepidResult(**r) for r in results] async def user_delete_threepid( self, user_id: str, medium: str, address: str diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 172fc3a736df..1dabf52156d4 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -94,12 +94,12 @@ def test_can_register_user(self) -> None: self.assertEqual(len(emails), 1) email = emails[0] - self.assertEqual(email["medium"], "email") - self.assertEqual(email["address"], "bob@bobinator.bob") + self.assertEqual(email.medium, "email") + self.assertEqual(email.address, "bob@bobinator.bob") # Should these be 0? - self.assertEqual(email["validated_at"], 0) - self.assertEqual(email["added_at"], 0) + self.assertEqual(email.validated_at, 0) + self.assertEqual(email.added_at, 0) # Check that the displayname was assigned displayname = self.get_success(