diff --git a/pyproject.toml b/pyproject.toml index 6c996d0..7442bbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ dev = [ # for tests "pydantic >= 1.7.4, < 2.0", - "matrix-synapse == 1.98.0", + "matrix-synapse == 1.103.0", "tox", "twisted", "aiounittest", diff --git a/room_access_rules/__init__.py b/room_access_rules/__init__.py index 5c15500..b6cf9a7 100644 --- a/room_access_rules/__init__.py +++ b/room_access_rules/__init__.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import email.utils import logging from typing import Any, Dict, List, Optional, Tuple @@ -21,14 +22,15 @@ from synapse.events import EventBase from synapse.module_api import ModuleApi, UserID from synapse.module_api.errors import ConfigError, SynapseError -from synapse.types import Requester, StateMap +from synapse.storage.database import LoggingTransaction +from synapse.types import JsonMapping, Requester, ScheduledTask, StateMap, TaskStatus logger = logging.getLogger(__name__) ACCESS_RULES_TYPE = "im.vector.room.access_rules" -LOCATION_LIVE_SHARING_EVENT_TYPE = "m.beacon_info" -LOCATION_LIVE_SHARING_MSC_EVENT_TYPE = "org.matrix.msc3672.beacon_info" +LOCATION_LIVE_SHARE_TYPE = "m.beacon_info" +LOCATION_LIVE_SHARE_MSC_TYPE = "org.matrix.msc3672.beacon_info" class AccessRules: @@ -61,6 +63,7 @@ class AccessRules: class RoomAccessRulesConfig: id_server: str domains_forbidden_when_restricted: List[str] = [] + fix_existing_rooms_power_levels: bool = False class RoomAccessRules(object): @@ -86,6 +89,29 @@ def __init__( check_visibility_can_be_modified=self.check_visibility_can_be_modified, ) + self.task_scheduler = api._hs.get_task_scheduler() + self.store = api._hs.get_datastores().main + + self.task_scheduler.register_action( + self.fix_existing_rooms_power_levels, + "fix_existing_rooms_power_levels", + ) + + # This will schedule a resumable long running task to fix power levels of existing rooms. + # Only schedules if we are the main process, and if we can't find an existing task in the queue. + if config.fix_existing_rooms_power_levels and api.worker_name is None: + + async def schedule_task_if_needed() -> None: + existing_tasks = await self.task_scheduler.get_tasks( + actions=["fix_existing_rooms_power_levels"] + ) + if not existing_tasks: + await self.task_scheduler.schedule_task( + "fix_existing_rooms_power_levels" + ) + + api.delayed_background_call(0, schedule_task_if_needed) + @staticmethod def parse_config(config_dict: Dict[str, Any]) -> RoomAccessRulesConfig: """Parses and validates the options specified in the homeserver config. @@ -106,6 +132,137 @@ def parse_config(config_dict: Dict[str, Any]) -> RoomAccessRulesConfig: return config + async def fix_existing_rooms_power_levels( + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + def get_room_ids_from( + txn: LoggingTransaction, + limit: Optional[int] = None, + from_id: Optional[str] = None, + ) -> List[str]: + limit_statement = "" + if limit is not None: + limit_statement = f"LIMIT {limit}" + + where_statement = "" + if from_id: + where_statement = f"WHERE room_id > '{from_id}'" + + txn.execute( + f"SELECT * FROM rooms {where_statement} ORDER BY room_id {limit_statement}" + ) + rows = txn.fetchall() + return [r[0] for r in rows] + + last_room_id = None + # Let's resume the task from the last processed room + if task.result: + last_room_id = task.result.get("last_room_id") + + # Let's iterate on all rooms by pack of 100 + has_next = True + while has_next: + room_ids = await self.module_api.run_db_interaction( + "get_room_ids_from", + get_room_ids_from, + limit=100, + from_id=last_room_id, + ) + + # No more rooms around, let's stop + if len(room_ids) == 0: + has_next = False + + for room_id in room_ids: + await self.fix_room_power_levels(room_id) + last_room_id = room_id + + # Update task result so it is resumed from the last + # fully processed batch of rooms + # We don't do it for each room for perf reason + await self.task_scheduler.update_task( + task.id, + status=TaskStatus.ACTIVE, + result={"last_room_id": last_room_id}, + ) + + return TaskStatus.COMPLETE, None, None + + async def fix_room_power_levels(self, room_id: str) -> None: + logger.info(f"Fixing power levels of room {room_id}") + + # Fetch local users joined to the room + local_joined_users = set() + for user_id, membership in await self.store.get_local_users_related_to_room( + room_id + ): + if membership == "join": + local_joined_users.add(user_id) + + # Fetch current power levels event and check if we have a local admin + res = await self.module_api.get_room_state( + room_id, [("m.room.power_levels", "")] + ) + power_levels_event = res.get(("m.room.power_levels", "")) + if power_levels_event and power_levels_event.content: + content = copy.deepcopy(power_levels_event.content) + + admin_user = None + content.setdefault("users", {}) + for u in content["users"]: + if u in local_joined_users and content["users"][u] == 100: + admin_user = u + break + + if not admin_user: + return + + # We have an admin on this server !! + # Let's patch the power levels with it + changed = False + + # Set location live share needed pl to default events pl + default_events_pl = content.get("events_default", 0) + if content["events"].get(LOCATION_LIVE_SHARE_TYPE, None) is None: + content["events"][LOCATION_LIVE_SHARE_TYPE] = default_events_pl + changed = True + if content["events"].get(LOCATION_LIVE_SHARE_MSC_TYPE, None) is None: + content["events"][LOCATION_LIVE_SHARE_MSC_TYPE] = default_events_pl + changed = True + + res = await self.module_api.get_room_state( + room_id, [("im.vector.room.access_rules", "")] + ) + + is_dm = False + + access_rules_event = res.get(("im.vector.room.access_rules", "")) + if access_rules_event: + if access_rules_event.content.get("rule", None) == "direct": + is_dm = True + + if is_dm: + # it's a DM, let's try to fix it by putting everyone admins + res = await self.module_api.get_room_state( + room_id, [("m.room.member", None)] + ) + for _, member in res: + if content["users"].get(member) != 100: + content["users"][member] = 100 + changed = True + + # Send the updated pl event to the room with a local admin + if changed: + await self.module_api.create_and_send_event_into_room( + { + "room_id": room_id, + "type": "m.room.power_levels", + "state_key": "", + "sender": admin_user, + "content": content, + } + ) + async def on_create_room( self, requester: Requester, @@ -254,8 +411,8 @@ def _get_default_power_levels(user_id: str) -> Dict[str, Any]: EventTypes.ServerACL: 100, EventTypes.RoomEncryption: 100, # We want normal users to be able to use live location sharing by default - LOCATION_LIVE_SHARING_EVENT_TYPE: 0, - LOCATION_LIVE_SHARING_MSC_EVENT_TYPE: 0, + LOCATION_LIVE_SHARE_TYPE: 0, + LOCATION_LIVE_SHARE_MSC_TYPE: 0, }, "events_default": 0, "state_default": 100, # Admins should be the only ones to perform other tasks diff --git a/tests/__init__.py b/tests/__init__.py index 8a5919c..92fb28b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -60,6 +60,14 @@ def membership(self): return self.content["membership"] +class MockHomeserver: + def get_datastores(self): + return Mock(spec=["main"]) + + def get_task_scheduler(self): + return Mock(spec=["register_action"]) + + def new_access_rules_event(sender: str, room_id: str, rule: str) -> MockEvent: return MockEvent( sender=sender, @@ -78,6 +86,7 @@ def create_module( module_api = Mock(spec=ModuleApi) module_api.http_client = MockHttpClient() module_api.public_room_list_manager = MockPublicRoomListManager() + module_api._hs = MockHomeserver() if config_override is None: config_override = {}