diff --git a/python/jupyterlab-chat/jupyterlab_chat/tests/test_ychat.py b/python/jupyterlab-chat/jupyterlab_chat/tests/test_ychat.py index b812549e..8e8ee0e2 100644 --- a/python/jupyterlab-chat/jupyterlab_chat/tests/test_ychat.py +++ b/python/jupyterlab-chat/jupyterlab_chat/tests/test_ychat.py @@ -167,6 +167,70 @@ def test_update_message_should_append_content(): assert message_dict["sender"] == msg.sender +def test_update_message_includes_mentions(): + chat = YChat() + chat.set_user(USER) + chat.set_user(USER2) + chat.set_user(USER3) + + # Add a message with one mention + new_msg = create_new_message(f"@{USER2.mention_name} Hello!") + msg_id = chat.add_message(new_msg) + msg = chat.get_message(msg_id) + assert msg + assert set(msg.mentions) == set([USER2.username]) + + # Update the message to mention a different user + msg.body = f"@{USER3.mention_name} Goodbye!" + chat.update_message(msg, find_mentions=True) + updated_msg = chat.get_message(msg_id) + assert updated_msg + assert set(updated_msg.mentions) == set([USER3.username]) + + +def test_update_message_append_includes_mentions(): + chat = YChat() + chat.set_user(USER) + chat.set_user(USER2) + chat.set_user(USER3) + + # Add a message with one mention + new_msg = create_new_message(f"@{USER2.mention_name} Hello!") + msg_id = chat.add_message(new_msg) + msg = chat.get_message(msg_id) + assert msg + assert set(msg.mentions) == set([USER2.username]) + + # Append content with another mention + msg.body = f" and @{USER3.mention_name}!" + chat.update_message(msg, append=True, find_mentions=True) + updated_msg = chat.get_message(msg_id) + assert updated_msg + # Should now mention both users + assert set(updated_msg.mentions) == set([USER2.username, USER3.username]) + + +def test_update_message_append_no_duplicate_mentions(): + chat = YChat() + chat.set_user(USER) + chat.set_user(USER2) + + # Add a message with a mention + new_msg = create_new_message(f"@{USER2.mention_name} Hello!") + msg_id = chat.add_message(new_msg) + msg = chat.get_message(msg_id) + assert msg + assert set(msg.mentions) == set([USER2.username]) + + # Append content that mentions the same user again + msg.body = f" @{USER2.mention_name} again!" + chat.update_message(msg, append=True, find_mentions=True) + updated_msg = chat.get_message(msg_id) + assert updated_msg + # Should only have one mention despite appearing twice + assert set(updated_msg.mentions) == set([USER2.username]) + + def test_indexes_by_id(): chat = YChat() msg = create_new_message() diff --git a/python/jupyterlab-chat/jupyterlab_chat/ychat.py b/python/jupyterlab-chat/jupyterlab_chat/ychat.py index a89e4fe0..55ffc0a7 100644 --- a/python/jupyterlab-chat/jupyterlab_chat/ychat.py +++ b/python/jupyterlab-chat/jupyterlab_chat/ychat.py @@ -118,6 +118,20 @@ def get_messages(self) -> list[Message]: message_dicts = self._get_messages() return [Message(**message_dict) for message_dict in message_dicts] + def _find_mentions(self, body: str) -> list[str]: + """ + Extract mentioned usernames from a message body. + Finds all @mentions in the body and returns the corresponding usernames. + """ + mention_pattern = re.compile(r"@([\w-]+):?") + mentioned_names: Set[str] = set(re.findall(mention_pattern, body)) + users = self.get_users() + mentioned_usernames = [] + for username, user in users.items(): + if user.mention_name in mentioned_names and user.username not in mentioned_usernames: + mentioned_usernames.append(username) + return mentioned_usernames + def _get_messages(self) -> list[dict]: """ Returns the messages of the document as dict. @@ -137,14 +151,7 @@ def add_message(self, new_message: NewMessage) -> str: ) # find all mentioned users and add them as message mentions - mention_pattern = re.compile("@([\w-]+):?") - mentioned_names: Set[str] = set(re.findall(mention_pattern, message.body)) - users = self.get_users() - mentioned_usernames = [] - for username, user in users.items(): - if user.mention_name in mentioned_names and user.username not in mentioned_usernames: - mentioned_usernames.append(username) - message.mentions = mentioned_usernames + message.mentions = self._find_mentions(message.body) with self._ydoc.transaction(): index = len(self._ymessages) - next((i for i, v in enumerate(self._get_messages()[::-1]) if v["time"] < timestamp), len(self._ymessages)) @@ -155,10 +162,11 @@ def add_message(self, new_message: NewMessage) -> str: return uid - def update_message(self, message: Message, append: bool = False): + def update_message(self, message: Message, append: bool = False, find_mentions: bool = False): """ Update a message of the document. - If append is True, the content will be append to the previous content. + If append is True, the content will be appended to the previous content. + If find_mentions is True, mentions will be extracted and notifications triggered (use for streaming completion). """ with self._ydoc.transaction(): index = self._indexes_by_id[message.id] @@ -166,6 +174,11 @@ def update_message(self, message: Message, append: bool = False): message.time = initial_message["time"] # type:ignore[index] if append: message.body = initial_message["body"] + message.body # type:ignore[index] + + # Extract and update mentions from the message body + if find_mentions: + message.mentions = self._find_mentions(message.body) + self._ymessages[index] = asdict(message, dict_factory=message_asdict_factory) def get_attachments(self) -> dict[str, Union[FileAttachment, NotebookAttachment]]: