diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py b/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py index 9cf5c6a26..f82bd5531 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py @@ -1,6 +1,5 @@ from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.models import HumanChatMessage -from jupyterlab_chat.ychat import YChat class TestSlashCommand(BaseChatHandler): @@ -26,5 +25,5 @@ class TestSlashCommand(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, message: HumanChatMessage, chat: YChat): - self.reply("This is the `/test` slash command.", chat) + async def process_message(self, message: HumanChatMessage): + self.reply("This is the `/test` slash command.") diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index be65e745a..b5c4fa38b 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -1,9 +1,8 @@ import argparse -from typing import Dict, Optional, Type +from typing import Dict, Type from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider -from jupyterlab_chat.ychat import YChat from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferWindowMemory from langchain_core.prompts import PromptTemplate @@ -60,19 +59,19 @@ def create_llm_chain( verbose=False, ) - async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): - args = self.parse_args(message, chat) + async def process_message(self, message: HumanChatMessage): + args = self.parse_args(message) if args is None: return query = " ".join(args.query) if not query: - self.reply(f"{self.parser.format_usage()}", chat, message) + self.reply(f"{self.parser.format_usage()}", message) return self.get_llm_chain() try: - with self.pending("Searching learned documents", message, chat=chat): + with self.pending("Searching learned documents", message): assert self.llm_chain # TODO: migrate this class to use a LCEL `Runnable` instead of # `Chain`, then remove the below ignore comment. @@ -80,7 +79,7 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] {"question": query} ) response = result["answer"] - self.reply(response, chat, message) + self.reply(response, message) except AssertionError as e: self.log.error(e) response = """Sorry, an error occurred while reading the from the learned documents. @@ -88,4 +87,4 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] `/learn -d` command and then re-submitting the `learn ` to learn the documents, and then asking the question again. """ - self.reply(response, chat, message) + self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 11347f1a1..233099151 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -8,7 +8,6 @@ TYPE_CHECKING, Any, Awaitable, - Callable, ClassVar, Dict, List, @@ -24,6 +23,7 @@ from dask.distributed import Client as DaskClient from jupyter_ai.callback_handlers import MetadataCallbackHandler from jupyter_ai.config_manager import ConfigManager, Logger +from jupyter_ai.constants import BOT from jupyter_ai.history import WrappedBoundedChatHistory from jupyter_ai.models import ( AgentChatMessage, @@ -158,7 +158,7 @@ def __init__( chat_handlers: Dict[str, "BaseChatHandler"], context_providers: Dict[str, "BaseCommandContextProvider"], message_interrupted: Dict[str, asyncio.Event], - write_message: Callable[[YChat, str, Optional[str]], str], + ychat: Optional[YChat], ): self.log = log self.config_manager = config_manager @@ -181,14 +181,20 @@ def __init__( self.chat_handlers = chat_handlers self.context_providers = context_providers self.message_interrupted = message_interrupted + self.ychat = ychat + self.indexes_by_id: Dict[str, str] = {} + """ + Indexes of messages in the YChat document by message ID. + + TODO: Remove this once `jupyterlab-chat` can update messages by ID + without an index. + """ self.llm: Optional[BaseProvider] = None self.llm_params: Optional[dict] = None self.llm_chain: Optional[Runnable] = None - self.write_message = write_message - - async def on_message(self, message: HumanChatMessage, chat: Optional[YChat] = None): + async def on_message(self, message: HumanChatMessage): """ Method which receives a human message, calls `self.get_llm_chain()`, and processes the message via `self.process_message()`, calling @@ -204,7 +210,6 @@ async def on_message(self, message: HumanChatMessage, chat: Optional[YChat] = No if slash_command in lm_provider_klass.unsupported_slash_commands: self.reply( "Sorry, the selected language model does not support this slash command.", - chat, ) return @@ -216,7 +221,6 @@ async def on_message(self, message: HumanChatMessage, chat: Optional[YChat] = No if not lm_provider.allows_concurrency: self.reply( "The currently selected language model can process only one request at a time. Please wait for me to reply before sending another question.", - chat, message, ) return @@ -224,24 +228,24 @@ async def on_message(self, message: HumanChatMessage, chat: Optional[YChat] = No BaseChatHandler._requests_count += 1 if self.__class__.supports_help: - args = self.parse_args(message, chat, silent=True) + args = self.parse_args(message, silent=True) if args and args.help: - self.reply(self.parser.format_help(), chat, message) + self.reply(self.parser.format_help(), message) return try: - await self.process_message(message, chat) + await self.process_message(message) except Exception as e: try: # we try/except `handle_exc()` in case it was overriden and # raises an exception by accident. - await self.handle_exc(e, message, chat) + await self.handle_exc(e, message) except Exception as e: - await self._default_handle_exc(e, message, chat) + await self._default_handle_exc(e, message) finally: BaseChatHandler._requests_count -= 1 - async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): + async def process_message(self, message: HumanChatMessage): """ Processes a human message routed to this chat handler. Chat handlers (subclasses) must implement this method. Don't forget to call @@ -252,19 +256,15 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] """ raise NotImplementedError("Should be implemented by subclasses.") - async def handle_exc( - self, e: Exception, message: HumanChatMessage, chat: Optional[YChat] - ): + async def handle_exc(self, e: Exception, message: HumanChatMessage): """ Handles an exception raised by `self.process_message()`. A default implementation is provided, however chat handlers (subclasses) should implement this method to provide a more helpful error response. """ - await self._default_handle_exc(e, message, chat) + await self._default_handle_exc(e, message) - async def _default_handle_exc( - self, e: Exception, message: HumanChatMessage, chat: Optional[YChat] - ): + async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): """ The default definition of `handle_exc()`. This is the default used when the `handle_exc()` excepts. @@ -274,19 +274,49 @@ async def _default_handle_exc( if lm_provider and lm_provider.is_api_key_exc(e): provider_name = getattr(self.config_manager.lm_provider, "name", "") response = f"Oops! There's a problem connecting to {provider_name}. Please update your {provider_name} API key in the chat settings." - self.reply(response, chat, message) + self.reply(response, message) return formatted_e = traceback.format_exc() response = ( f"Sorry, an error occurred. Details below:\n\n```\n{formatted_e}\n```" ) - self.reply(response, chat, message) + self.reply(response, message) + + def write_message(self, body: str, id: Optional[str] = None) -> None: + """[Jupyter Chat only] Writes a message to the YChat shared document + that this chat handler is assigned to.""" + if not self.ychat: + return + + bot = self.ychat.get_user(BOT["username"]) + if not bot: + self.ychat.set_user(BOT) + + index = self.indexes_by_id.get(id, None) + id = id if id else str(uuid4()) + new_index = self.ychat.set_message( + { + "type": "msg", + "body": body, + "id": id if id else str(uuid4()), + "time": time.time(), + "sender": BOT["username"], + "raw_time": False, + }, + index=index, + append=True, + ) + + self.indexes_by_id[id] = new_index + return id def broadcast_message(self, message: Message): """ Broadcasts a message to all WebSocket connections. If there are no WebSocket connections and the message is a chat message, this method directly appends to `self.chat_history`. + + TODO: Remove this after Jupyter Chat migration is complete. """ broadcast = False for websocket in self._root_chat_handlers.values(): @@ -305,15 +335,14 @@ def broadcast_message(self, message: Message): def reply( self, response: str, - chat: Optional[YChat], human_msg: Optional[HumanChatMessage] = None, ): """ Sends an agent message, usually in response to a received `HumanChatMessage`. """ - if chat is not None: - self.write_message(chat, response, None) + if self.ychat is not None: + self.write_message(response, None) else: agent_msg = AgentChatMessage( id=uuid4().hex, @@ -333,7 +362,6 @@ def start_pending( text: str, human_msg: Optional[HumanChatMessage] = None, *, - chat: Optional[YChat] = None, ellipsis: bool = True, ) -> PendingMessage: """ @@ -352,13 +380,13 @@ def start_pending( ellipsis=ellipsis, ) - if chat is not None and chat.awareness is not None: - chat.awareness.set_local_state_field("isWriting", True) + if self.ychat is not None and self.ychat.awareness is not None: + self.ychat.awareness.set_local_state_field("isWriting", True) else: self.broadcast_message(pending_msg) return pending_msg - def close_pending(self, pending_msg: PendingMessage, chat: Optional[YChat] = None): + def close_pending(self, pending_msg: PendingMessage): """ Closes a pending message. """ @@ -369,8 +397,8 @@ def close_pending(self, pending_msg: PendingMessage, chat: Optional[YChat] = Non id=pending_msg.id, ) - if chat is not None and chat.awareness is not None: - chat.awareness.set_local_state_field("isWriting", False) + if self.ychat is not None and self.ychat.awareness is not None: + self.ychat.awareness.set_local_state_field("isWriting", False) else: self.broadcast_message(close_pending_msg) pending_msg.closed = True @@ -381,7 +409,6 @@ def pending( text: str, human_msg: Optional[HumanChatMessage] = None, *, - chat: Optional[YChat] = None, ellipsis: bool = True, ): """ @@ -391,14 +418,12 @@ def pending( TODO: Simplify it by only modifying the awareness as soon as jupyterlab chat is the only used chat. """ - pending_msg = self.start_pending( - text, human_msg=human_msg, chat=chat, ellipsis=ellipsis - ) + pending_msg = self.start_pending(text, human_msg=human_msg, ellipsis=ellipsis) try: yield pending_msg finally: if not pending_msg.closed: - self.close_pending(pending_msg, chat=chat) + self.close_pending(pending_msg) def get_llm_chain(self): lm_provider = self.config_manager.lm_provider @@ -440,14 +465,14 @@ def create_llm_chain( ): raise NotImplementedError("Should be implemented by subclasses") - def parse_args(self, message, chat, silent=False): + def parse_args(self, message, silent=False): args = message.body.split(" ") try: args = self.parser.parse_args(args[1:]) except (argparse.ArgumentError, SystemExit) as e: if not silent: response = f"{self.parser.format_usage()}" - self.reply(response, chat, message) + self.reply(response, message) return None return args @@ -470,9 +495,7 @@ def output_dir(self) -> str: else: return self.root_dir - def send_help_message( - self, chat: Optional[YChat], human_msg: Optional[HumanChatMessage] = None - ) -> None: + def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> None: """Sends a help message to all connected clients.""" lm_provider = self.config_manager.lm_provider unsupported_slash_commands = ( @@ -504,8 +527,8 @@ def send_help_message( context_commands_list=context_commands_list, ) - if chat is not None: - self.write_message(chat, help_message_body, None) + if self.ychat is not None: + self.write_message(help_message_body, None) else: help_message = AgentChatMessage( id=uuid4().hex, @@ -516,13 +539,13 @@ def send_help_message( ) self.broadcast_message(help_message) - def _start_stream(self, human_msg: HumanChatMessage, chat: Optional[YChat]) -> str: + def _start_stream(self, human_msg: HumanChatMessage) -> str: """ Sends an `agent-stream` message to indicate the start of a response stream. Returns the ID of the message, denoted as the `stream_id`. """ - if chat is not None: - stream_id = self.write_message(chat, "", None) + if self.ychat is not None: + stream_id = self.write_message("", None) else: stream_id = uuid4().hex stream_msg = AgentStreamMessage( @@ -542,7 +565,6 @@ def _send_stream_chunk( self, stream_id: str, content: str, - chat: Optional[YChat], complete: bool = False, metadata: Optional[Dict[str, Any]] = None, ) -> None: @@ -550,8 +572,8 @@ def _send_stream_chunk( Sends an `agent-stream-chunk` message containing content that should be appended to an existing `agent-stream` message with ID `stream_id`. """ - if chat is not None: - self.write_message(chat, content, stream_id) + if self.ychat is not None: + self.write_message(content, stream_id) else: if not metadata: metadata = {} @@ -568,7 +590,6 @@ async def stream_reply( self, input: Input, human_msg: HumanChatMessage, - chat: Optional[YChat], pending_msg="Generating response", config: Optional[RunnableConfig] = None, ): @@ -603,7 +624,7 @@ async def stream_reply( merged_config: RunnableConfig = merge_runnable_configs(base_config, config) # start with a pending message - with self.pending(pending_msg, human_msg, chat=chat) as pending_message: + with self.pending(pending_msg, human_msg) as pending_message: # stream response in chunks. this works even if a provider does not # implement streaming, as `astream()` defaults to yielding `_call()` # when `_stream()` is not implemented on the LLM class. @@ -613,8 +634,8 @@ async def stream_reply( if not received_first_chunk: # when receiving the first chunk, close the pending message and # start the stream. - self.close_pending(pending_message, chat=chat) - stream_id = self._start_stream(human_msg=human_msg, chat=chat) + self.close_pending(pending_message) + stream_id = self._start_stream(human_msg=human_msg) received_first_chunk = True self.message_interrupted[stream_id] = asyncio.Event() @@ -637,9 +658,9 @@ async def stream_reply( break if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str): - self._send_stream_chunk(stream_id, chunk.content, chat=chat) + self._send_stream_chunk(stream_id, chunk.content) elif isinstance(chunk, str): - self._send_stream_chunk(stream_id, chunk, chat=chat) + self._send_stream_chunk(stream_id, chunk) else: self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") break @@ -651,7 +672,6 @@ async def stream_reply( self._send_stream_chunk( stream_id, stream_tombstone, - chat=chat, complete=True, metadata=metadata_handler.jai_metadata, ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index 0ca46dfb1..d5b0ab6c7 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -1,7 +1,4 @@ -from typing import Optional - from jupyter_ai.models import ClearRequest -from jupyterlab_chat.ychat import YChat from .base import BaseChatHandler, SlashCommandRoutingType @@ -19,11 +16,11 @@ class ClearChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, _, chat: Optional[YChat]): + async def process_message(self, _): # Clear chat by triggering `RootChatHandler.on_clear_request()`. for handler in self._root_chat_handlers.values(): if not handler: continue - handler.on_clear_request(ClearRequest(target=None)) + handler.on_clear_request(ClearRequest()) break diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 692eb67d1..266ad73ad 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,9 +1,8 @@ import asyncio -from typing import Dict, Optional, Type +from typing import Dict, Type from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider -from jupyterlab_chat.ychat import YChat from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory @@ -54,7 +53,7 @@ def create_llm_chain( ) self.llm_chain = runnable - async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): + async def process_message(self, message: HumanChatMessage): self.get_llm_chain() assert self.llm_chain @@ -64,12 +63,12 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] try: context_prompt = await self.make_context_prompt(message) except ContextProviderException as e: - self.reply(str(e), chat, message) + self.reply(str(e), message) return inputs["context"] = context_prompt inputs["input"] = self.replace_prompt(inputs["input"]) - await self.stream_reply(inputs, message, chat=chat) + await self.stream_reply(inputs, message) async def make_context_prompt(self, human_msg: HumanChatMessage) -> str: return "\n\n".join( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py index 1f419f5f1..7323d81c1 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py @@ -1,10 +1,9 @@ import argparse import os from datetime import datetime -from typing import List, Optional +from typing import List from jupyter_ai.models import AgentChatMessage, AgentStreamMessage, HumanChatMessage -from jupyterlab_chat.ychat import YChat from .base import BaseChatHandler, SlashCommandRoutingType @@ -32,11 +31,11 @@ def chat_message_to_markdown(self, message): return "" # Write the chat history to a markdown file with a timestamp - async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): + async def process_message(self, message: HumanChatMessage): markdown_content = "\n\n".join( self.chat_message_to_markdown(msg) for msg in self._chat_history ) - args = self.parse_args(message, chat) + args = self.parse_args(message) chat_filename = ( # if no filename, use "chat_history" + timestamp args.path[0] if (args.path and args.path[0] != "") @@ -47,4 +46,4 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] ) # Do not use timestamp if filename is entered as argument with open(chat_file, "w") as chat_history: chat_history.write(markdown_content) - self.reply(f"File saved to `{chat_file}`", chat) + self.reply(f"File saved to `{chat_file}`") diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index 33d4eeefd..390b93cf6 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -1,8 +1,7 @@ -from typing import Dict, Optional, Type +from typing import Dict, Type from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider -from jupyterlab_chat.ychat import YChat from langchain.prompts import PromptTemplate from .base import BaseChatHandler, SlashCommandRoutingType @@ -80,11 +79,10 @@ def create_llm_chain( runnable = prompt_template | llm # type:ignore self.llm_chain = runnable - async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): + async def process_message(self, message: HumanChatMessage): if not (message.selection and message.selection.type == "cell-with-error"): self.reply( "`/fix` requires an active code cell with error output. Please click on a cell with error output and retry.", - chat, message, ) return @@ -105,6 +103,4 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] "error_name": selection.error.name, "error_value": selection.error.value, } - await self.stream_reply( - inputs, message, pending_msg="Analyzing error", chat=chat - ) + await self.stream_reply(inputs, message, pending_msg="Analyzing error") diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index aec98193a..a69b5ed28 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -9,7 +9,6 @@ from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider -from jupyterlab_chat.ychat import YChat from langchain.chains import LLMChain from langchain.llms import BaseLLM from langchain.output_parsers import PydanticOutputParser @@ -263,20 +262,18 @@ async def _generate_notebook(self, prompt: str): nbformat.write(notebook, final_path) return final_path - async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): + async def process_message(self, message: HumanChatMessage): self.get_llm_chain() # first send a verification message to user response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." - self.reply(response, chat, message) + self.reply(response, message) final_path = await self._generate_notebook(prompt=message.body) response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it.""" - self.reply(response, chat, message) + self.reply(response, message) - async def handle_exc( - self, e: Exception, message: HumanChatMessage, chat: Optional[YChat] - ): + async def handle_exc(self, e: Exception, message: HumanChatMessage): timestamp = time.strftime("%Y-%m-%d-%H.%M.%S") default_log_dir = Path(self.output_dir) / "jupyter-ai-logs" log_dir = self.log_dir or default_log_dir @@ -286,4 +283,4 @@ async def handle_exc( traceback.print_exc(file=log) response = f"An error occurred while generating the notebook. The error details have been saved to `./{log_path}`.\n\nTry running `/generate` again, as some language models require multiple attempts before a notebook is generated." - self.reply(response, chat, message) + self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index 4f54b2850..cd8556863 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -1,7 +1,4 @@ -from typing import Optional - from jupyter_ai.models import HumanChatMessage -from jupyterlab_chat.ychat import YChat from .base import BaseChatHandler, SlashCommandRoutingType @@ -18,5 +15,5 @@ class HelpChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): - self.send_help_message(chat, message) + async def process_message(self, message: HumanChatMessage): + self.send_help_message(message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 55e6d8b15..c350dd1b8 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -21,7 +21,6 @@ ) from jupyter_core.paths import jupyter_data_dir from jupyter_core.utils import ensure_dir_exists -from jupyterlab_chat.ychat import YChat from langchain.schema import BaseRetriever, Document from langchain.text_splitter import ( LatexTextSplitter, @@ -129,29 +128,26 @@ def _load(self): ) self.log.error(e) - async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): + async def process_message(self, message: HumanChatMessage): # If no embedding provider has been selected em_provider_cls, em_provider_args = self.get_embedding_provider() if not em_provider_cls: self.reply( "Sorry, please select an embedding provider before using the `/learn` command.", - chat, ) return - args = self.parse_args(message, chat) + args = self.parse_args(message) if args is None: return if args.delete: self.delete() - self.reply( - f"👍 I have deleted everything I previously learned.", chat, message - ) + self.reply(f"👍 I have deleted everything I previously learned.", message) return if args.list: - self.reply(self._build_list_response(), chat) + self.reply(self._build_list_response()) return if args.remote: @@ -162,7 +158,6 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] args.path = [arxiv_to_text(id, self.output_dir)] self.reply( f"Learning arxiv file with id **{id}**, saved in **{args.path[0]}**.", - chat, message, ) except ModuleNotFoundError as e: @@ -170,7 +165,6 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] self.reply( "No `arxiv` package found. " "Install with `pip install arxiv`.", - chat, ) return except Exception as e: @@ -178,7 +172,6 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] self.reply( "An error occurred while processing the arXiv file. " f"Please verify that the arxiv id {id} is correct.", - chat, ) return @@ -194,7 +187,7 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] "- Learn on files in the root directory: `/learn *`\n" "- Learn all python files under the root directory recursively: `/learn **/*.py`" ) - self.reply(f"{self.parser.format_usage()}\n\n {no_path_arg_message}", chat) + self.reply(f"{self.parser.format_usage()}\n\n {no_path_arg_message}") return short_path = args.path[0] load_path = os.path.join(self.output_dir, short_path) @@ -204,15 +197,13 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] next(iglob(load_path)) except StopIteration: response = f"Sorry, that path doesn't exist: {load_path}" - self.reply(response, chat, message) + self.reply(response, message) return # delete and relearn index if embedding model was changed - await self.delete_and_relearn(chat) + await self.delete_and_relearn() - with self.pending( - f"Loading and splitting files for {load_path}", message, chat=chat - ): + with self.pending(f"Loading and splitting files for {load_path}", message): try: await self.learn_dir( load_path, args.chunk_size, args.chunk_overlap, args.all_files @@ -228,7 +219,7 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat] You can ask questions about these docs by prefixing your message with **/ask**.""" % ( load_path.replace("*", r"\*") ) - self.reply(response, chat, message) + self.reply(response, message) def _build_list_response(self): if not self.metadata.dirs: @@ -282,7 +273,7 @@ def _add_dir_to_metadata(self, path: str, chunk_size: int, chunk_overlap: int): ) self.metadata.dirs = dirs - async def delete_and_relearn(self, chat: Optional[YChat] = None): + async def delete_and_relearn(self): """Delete the vector store and relearn all indexed directories if necessary. If the embedding model is unchanged, this method does nothing.""" @@ -309,11 +300,11 @@ async def delete_and_relearn(self, chat: Optional[YChat] = None): documents you had previously submitted for learning. Please wait to use the **/ask** command until I am done with this task.""" - self.reply(message, chat) + self.reply(message) metadata = self.metadata self.delete() - await self.relearn(metadata, chat) + await self.relearn(metadata) self.prev_em_id = curr_em_id def delete(self): @@ -327,7 +318,7 @@ def delete(self): if os.path.isfile(path): os.remove(path) - async def relearn(self, metadata: IndexMetadata, chat: Optional[YChat]): + async def relearn(self, metadata: IndexMetadata): # Index all dirs in the metadata if not metadata.dirs: return @@ -347,7 +338,7 @@ async def relearn(self, metadata: IndexMetadata, chat: Optional[YChat]): message = f"""🎉 I am done learning docs in these directories: {dir_list} I am ready to answer questions about them. You can ask me about these documents by starting your message with **/ask**.""" - self.reply(message, chat) + self.reply(message) def create( self, diff --git a/packages/jupyter-ai/jupyter_ai/constants.py b/packages/jupyter-ai/jupyter_ai/constants.py new file mode 100644 index 000000000..ab212fb23 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/constants.py @@ -0,0 +1,8 @@ +# The BOT currently has a fixed username, because this username is used has key in chats, +# it needs to constant. Do we need to change it ? +BOT = { + "username": "5f6a7570-7974-6572-6e61-75742d626f74", + "name": "Jupyternaut", + "display_name": "Jupyternaut", + "initials": "J", +} diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index c4c03775c..31a2e4fe0 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -4,8 +4,9 @@ import types import uuid from functools import partial -from typing import Optional +from typing import Dict, Optional +import traitlets from dask.distributed import Client as DaskClient from importlib_metadata import entry_points from jupyter_ai.chat_handlers.learn import Retriever @@ -18,10 +19,11 @@ from jupyterlab_chat.ychat import YChat from pycrdt import ArrayEvent from tornado.web import StaticFileHandler -from traitlets import Dict, Integer, List, Unicode +from traitlets import Integer, List, Unicode from .chat_handlers import ( AskChatHandler, + BaseChatHandler, ClearChatHandler, DefaultChatHandler, ExportChatHandler, @@ -32,6 +34,7 @@ ) from .completions.handlers import DefaultInlineCompletionHandler from .config_manager import ConfigManager +from .constants import BOT from .context_providers import BaseCommandContextProvider, FileContextProvider from .handlers import ( ApiKeysHandler, @@ -66,15 +69,6 @@ JUPYTER_COLLABORATION_EVENTS_URI, ) -# The BOT currently has a fixed username, because this username is used has key in chats, -# it needs to constant. Do we need to change it ? -BOT = { - "username": "5f6a7570-7974-6572-6e61-75742d626f74", - "name": "Jupyternaut", - "display_name": "Jupyternaut", - "initials": "J", -} - DEFAULT_HELP_MESSAGE_TEMPLATE = """Hi there! I'm {persona_name}, your programming assistant. You can ask me a question using the text box below. You can also use these commands: {slash_commands_list} @@ -153,9 +147,9 @@ class AiExtension(ExtensionApp): config=True, ) - model_parameters = Dict( + model_parameters = traitlets.Dict( key_trait=Unicode(), - value_trait=Dict(), + value_trait=traitlets.Dict(), default_value={}, help="""Key-value pairs for model id and corresponding parameters that are passed to the provider class. The values are unpacked and passed to @@ -193,7 +187,7 @@ class AiExtension(ExtensionApp): config=True, ) - default_api_keys = Dict( + default_api_keys = traitlets.Dict( key_trait=Unicode(), value_trait=Unicode(), default_value=None, @@ -238,41 +232,72 @@ class AiExtension(ExtensionApp): def initialize(self): super().initialize() + + self.chat_handlers_by_room: Dict[str, Dict[str, BaseChatHandler]] = {} + """ + Nested dictionary that returns the dedicated chat handler instance that + should be used, given the room ID and command ID respectively. + + Example: `self.chat_handlers_by_room[]` yields the set of chat + handlers dedicated to the room identified by ``. + """ + + self.ychats_by_room: Dict[str, YChat] = {} + """Cache of YChat instances, indexed by room ID.""" + self.event_logger = self.serverapp.web_app.settings["event_logger"] self.event_logger.add_listener( schema_id=JUPYTER_COLLABORATION_EVENTS_URI, listener=self.connect_chat ) - # Keep the message indexes to avoid extra computation looking for a message when - # updating it. - self.messages_indexes = {} - async def connect_chat( self, logger: EventLogger, schema_id: str, data: dict ) -> None: - if ( + # ignore events that are not chat room initialization events + if not ( data["room"].startswith("text:chat:") and data["action"] == "initialize" and data["msg"] == "Room initialized" ): + return - self.log.info(f"Collaborative chat server is listening for {data['room']}") - chat = await self.get_chat(data["room"]) + # log room ID + room_id = data["room"] + self.log.info(f"Connecting to a chat room with room ID: {room_id}.") - if chat is None: - return + # get YChat document associated with the room + ychat = await self.get_chat(room_id) + if ychat is None: + return - # Add the bot user to the chat document awareness. - BOT["avatar_url"] = url_path_join( - self.settings.get("base_url", "/"), "api/ai/static/jupyternaut.svg" - ) - if chat.awareness is not None: - chat.awareness.set_local_state_field("user", BOT) + # Add the bot user to the chat document awareness. + BOT["avatar_url"] = url_path_join( + self.settings.get("base_url", "/"), "api/ai/static/jupyternaut.svg" + ) + if ychat.awareness is not None: + ychat.awareness.set_local_state_field("user", BOT) + + # initialize chat handlers for new chat + self.chat_handlers_by_room[room_id] = self._init_chat_handlers(ychat) - callback = partial(self.on_change, chat) - chat.ymessages.observe(callback) + callback = partial(self.on_change, room_id) + ychat.ymessages.observe(callback) async def get_chat(self, room_id: str) -> Optional[YChat]: + """ + Retrieves the YChat instance associated with a room ID. This method + is cached, i.e. successive calls with the same room ID quickly return a + cached value. + + TODO: Determine if get_chat() should ever fail under normal usage + scenarios. If not, we should just raise an exception if chat is `None`, + and indicate the return type as just `YChat` instead of + `Optional[YChat]`. This will simplify the code by removing redundant + null checks. + """ + if room_id in self.ychats_by_room: + return self.ychats_by_room[room_id] + if JCOLLAB_VERSION >= 3: collaboration = self.serverapp.web_app.settings["jupyter_server_ydoc"] document = await collaboration.get_document(room_id=room_id, copy=False) @@ -282,9 +307,11 @@ async def get_chat(self, room_id: str) -> Optional[YChat]: room = await server.get_room(room_id) document = room._document + + self.ychats_by_room[room_id] = document return document - def on_change(self, chat: YChat, events: ArrayEvent) -> None: + def on_change(self, room_id: str, events: ArrayEvent) -> None: for change in events.delta: # type:ignore[attr-defined] if not "insert" in change.keys(): continue @@ -303,12 +330,15 @@ def on_change(self, chat: YChat, events: ArrayEvent) -> None: ) if self.serverapp is not None: self.serverapp.io_loop.asyncio_loop.create_task( # type:ignore[attr-defined] - self._route(chat_message, chat) + self.route_human_message(room_id, chat_message) ) - async def _route(self, message: HumanChatMessage, chat: YChat): - """Method that routes an incoming message to the appropriate handler.""" - chat_handlers = self.settings["jai_chat_handlers"] + async def route_human_message(self, room_id: str, message: HumanChatMessage): + """ + Method that routes an incoming `HumanChatMessage` to the appropriate + chat handler. + """ + chat_handlers = self.chat_handlers_by_room[room_id] default = chat_handlers["default"] # Split on any whitespace, either spaces or newlines maybe_command = message.body.split(None, 1)[0] @@ -321,39 +351,14 @@ async def _route(self, message: HumanChatMessage, chat: YChat): start = time.time() if is_command: - await chat_handlers[command].on_message(message, chat) + await chat_handlers[command].on_message(message) else: - await default.on_message(message, chat) + await default.on_message(message) latency_ms = round((time.time() - start) * 1000) command_readable = "Default" if command == "default" else command self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.") - def write_message(self, chat: YChat, body: str, id: Optional[str] = None) -> str: - bot = chat.get_user(BOT["username"]) - if not bot: - chat.set_user(BOT) - - index = self.messages_indexes[id] if id else None - id = id if id else str(uuid.uuid4()) - new_index = chat.set_message( - { - "type": "msg", - "body": body, - "id": id if id else str(uuid.uuid4()), - "time": time.time(), - "sender": BOT["username"], - "raw_time": False, - }, - index, - True, - ) - - if new_index != index: - self.messages_indexes[id] = new_index - - return id - def initialize_settings(self): start = time.time() @@ -448,7 +453,7 @@ def initialize_settings(self): self.settings["jai_message_interrupted"] = {} # initialize chat handlers - self._init_chat_handlers() + self.settings["jai_chat_handlers"] = self._init_chat_handlers() # initialize context providers self._init_context_provders() @@ -499,7 +504,15 @@ async def _stop_extension(self): await dask_client.close() self.log.debug("Closed Dask client.") - def _init_chat_handlers(self): + def _init_chat_handlers( + self, ychat: Optional[YChat] = None + ) -> Dict[str, BaseChatHandler]: + """ + Initializes a set of chat handlers. May accept a YChat instance for + collaborative chats. + + TODO: Make `ychat` required once Jupyter Chat migration is complete. + """ eps = entry_points() chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers") chat_handlers = {} @@ -515,9 +528,9 @@ def _init_chat_handlers(self): "preferred_dir": self.serverapp.contents_manager.preferred_dir, "help_message_template": self.help_message_template, "chat_handlers": chat_handlers, - "write_message": self.write_message, "context_providers": self.settings["jai_context_providers"], "message_interrupted": self.settings["jai_message_interrupted"], + "ychat": ychat, } default_chat_handler = DefaultChatHandler(**chat_handler_kwargs) clear_chat_handler = ClearChatHandler(**chat_handler_kwargs) @@ -590,8 +603,7 @@ def _init_chat_handlers(self): # Make help always appear as the last command chat_handlers["/help"] = HelpChatHandler(**chat_handler_kwargs) - # bind chat handlers to settings - self.settings["jai_chat_handlers"] = chat_handlers + return chat_handlers def _init_context_provders(self): eps = entry_points() diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index 9ce0867d9..4d851b2b2 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -19,7 +19,6 @@ Persona, ) from jupyter_ai_magics import BaseProvider -from jupyterlab_chat.ychat import YChat from langchain_community.llms import FakeListLLM from tornado.httputil import HTTPServerRequest from tornado.web import Application @@ -65,11 +64,6 @@ def broadcast_message(message: Message) -> None: root_handler = mock.create_autospec(RootChatHandler) root_handler.broadcast_message = broadcast_message - def write_message( - self, chat: YChat, body: str, id: Optional[str] = None - ) -> str: - return id or "" - super().__init__( log=logging.getLogger(__name__), config_manager=config_manager, @@ -84,7 +78,7 @@ def write_message( chat_handlers={}, context_providers={}, message_interrupted={}, - write_message=write_message, + ychat=None, ) @@ -127,7 +121,7 @@ async def test_default_closes_pending_on_success(human_chat_message): "should_raise": False, }, ) - await handler.process_message(human_chat_message, None) + await handler.process_message(human_chat_message) # >=2 because there are additional stream messages that follow assert len(handler.messages) >= 2 @@ -144,7 +138,7 @@ async def test_default_closes_pending_on_error(human_chat_message): }, ) with pytest.raises(TestException): - await handler.process_message(human_chat_message, None) + await handler.process_message(human_chat_message) assert len(handler.messages) == 2 assert isinstance(handler.messages[0], PendingMessage)