From 2ad85bceda620b9d1754e4c6fd5371ccdaf0b0c3 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Tue, 10 Sep 2024 11:51:08 +0200 Subject: [PATCH 01/24] Very first version of the AI working in jupyterlab_collaborative_chat --- .../jupyter_ai/chat_handlers/__init__.py | 5 + .../jupyter_ai/chat_handlers/ask.py | 11 +- .../jupyter_ai/chat_handlers/base.py | 51 ++++--- .../jupyter_ai/chat_handlers/clear.py | 3 +- .../jupyter_ai/chat_handlers/default.py | 22 +-- .../jupyter_ai/chat_handlers/export.py | 7 +- .../jupyter_ai/chat_handlers/fix.py | 6 +- .../jupyter_ai/chat_handlers/generate.py | 11 +- .../jupyter_ai/chat_handlers/help.py | 3 +- .../jupyter_ai/chat_handlers/learn.py | 37 ++--- packages/jupyter-ai/jupyter_ai/extension.py | 130 ++++++++++++++++++ packages/jupyter-ai/jupyter_ai/models.py | 2 +- packages/jupyter-ai/pyproject.toml | 1 + 13 files changed, 224 insertions(+), 65 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py index a8fe9eb50..a46b74e4d 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py @@ -1,3 +1,8 @@ +# The following import is to make sure jupyter_ydoc is imported before +# jupyterlab_collaborative_chat, otherwise it leads to circular import because of the +# YChat relying on YBaseDoc, and jupyter_ydoc registering YChat from the entry point. +import jupyter_ydoc + from .ask import AskChatHandler from .base import BaseChatHandler, SlashCommandRoutingType from .clear import ClearChatHandler diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 5c3026685..5bda7cef2 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -1,5 +1,6 @@ import argparse from typing import Dict, Type +from jupyterlab_collaborative_chat.ychat import YChat from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider @@ -59,13 +60,13 @@ def create_llm_chain( verbose=False, ) - async def process_message(self, message: HumanChatMessage): - args = self.parse_args(message) + async def process_message(self, message: HumanChatMessage, chat: YChat): + args = self.parse_args(message, chat) if args is None: return query = " ".join(args.query) if not query: - self.reply(f"{self.parser.format_usage()}", message) + self.reply(f"{self.parser.format_usage()}", chat, message) return self.get_llm_chain() @@ -74,7 +75,7 @@ async def process_message(self, message: HumanChatMessage): with self.pending("Searching learned documents", message): result = await self.llm_chain.acall({"question": query}) response = result["answer"] - self.reply(response, message) + self.reply(response, chat, message) except AssertionError as e: self.log.error(e) response = """Sorry, an error occurred while reading the from the learned documents. @@ -82,4 +83,4 @@ async def process_message(self, message: HumanChatMessage): `/learn -d` command and then re-submitting the `learn ` to learn the documents, and then asking the question again. """ - self.reply(response, message) + self.reply(response, chat, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index b97015518..985ff264e 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -6,6 +6,7 @@ from typing import ( TYPE_CHECKING, Awaitable, + Callable, ClassVar, Dict, List, @@ -15,6 +16,7 @@ Union, ) from uuid import uuid4 +from jupyterlab_collaborative_chat.ychat import YChat from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger @@ -132,6 +134,7 @@ def __init__( dask_client_future: Awaitable[DaskClient], help_message_template: str, chat_handlers: Dict[str, "BaseChatHandler"], + write_message: Callable[[YChat, str], None] ): self.log = log self.config_manager = config_manager @@ -157,13 +160,16 @@ def __init__( self.llm_params = None self.llm_chain = None - async def on_message(self, message: HumanChatMessage): + self.write_message = write_message + + async def on_message(self, message: HumanChatMessage, chat: YChat): """ Method which receives a human message, calls `self.get_llm_chain()`, and processes the message via `self.process_message()`, calling `self.handle_exc()` when an exception is raised. This method is called by RootChatHandler when it routes a human message to this chat handler. """ + self.log.warn(f"MESSAGE SENT {message.body}") lm_provider_klass = self.config_manager.lm_provider # ensure the current slash command is supported @@ -173,7 +179,8 @@ async def on_message(self, message: HumanChatMessage): ) if slash_command in lm_provider_klass.unsupported_slash_commands: self.reply( - "Sorry, the selected language model does not support this slash command." + "Sorry, the selected language model does not support this slash command.", + chat ) return @@ -185,6 +192,7 @@ async def on_message(self, message: HumanChatMessage): if not lm_provider.allows_concurrency: self.reply( "The currently selected language model can process only one request at a time. Please wait for me to reply before sending another question.", + chat, message, ) return @@ -192,43 +200,43 @@ async def on_message(self, message: HumanChatMessage): BaseChatHandler._requests_count += 1 if self.__class__.supports_help: - args = self.parse_args(message, silent=True) + args = self.parse_args(message, chat, silent=True) if args and args.help: - self.reply(self.parser.format_help(), message) + self.reply(self.parser.format_help(), chat, message) return try: - await self.process_message(message) + await self.process_message(message, chat) except Exception as e: try: # we try/except `handle_exc()` in case it was overriden and # raises an exception by accident. - await self.handle_exc(e, message) + await self.handle_exc(e, message, chat) except Exception as e: await self._default_handle_exc(e, message) finally: BaseChatHandler._requests_count -= 1 - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: YChat): """ Processes a human message routed to this chat handler. Chat handlers (subclasses) must implement this method. Don't forget to call - `self.reply(, message)` at the end! + `self.reply(, chat, message)` at the end! The method definition does not need to be wrapped in a try/except block; any exceptions raised here are caught by `self.handle_exc()`. """ raise NotImplementedError("Should be implemented by subclasses.") - async def handle_exc(self, e: Exception, message: HumanChatMessage): + async def handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat): """ Handles an exception raised by `self.process_message()`. A default implementation is provided, however chat handlers (subclasses) should implement this method to provide a more helpful error response. """ - await self._default_handle_exc(e, message) + await self._default_handle_exc(e, message, chat) - async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): + async def _default_handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat): """ The default definition of `handle_exc()`. This is the default used when the `handle_exc()` excepts. @@ -238,15 +246,15 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): if lm_provider and lm_provider.is_api_key_exc(e): provider_name = getattr(self.config_manager.lm_provider, "name", "") response = f"Oops! There's a problem connecting to {provider_name}. Please update your {provider_name} API key in the chat settings." - self.reply(response, message) + self.reply(response, chat, message) return formatted_e = traceback.format_exc() response = ( f"Sorry, an error occurred. Details below:\n\n```\n{formatted_e}\n```" ) - self.reply(response, message) + self.reply(response, chat, message) - def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): + def reply(self, response: str, chat: YChat, human_msg: Optional[HumanChatMessage] = None): """ Sends an agent message, usually in response to a received `HumanChatMessage`. @@ -259,12 +267,13 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): persona=self.persona, ) - for handler in self._root_chat_handlers.values(): - if not handler: - continue + self.write_message(chat, response) + # for handler in self._root_chat_handlers.values(): + # if not handler: + # continue - handler.broadcast_message(agent_msg) - break + # handler.broadcast_message(agent_msg) + # break @property def persona(self): @@ -380,14 +389,14 @@ def create_llm_chain( ): raise NotImplementedError("Should be implemented by subclasses") - def parse_args(self, message, silent=False): + def parse_args(self, message, chat, silent=False): args = message.body.split(" ") try: args = self.parser.parse_args(args[1:]) except (argparse.ArgumentError, SystemExit) as e: if not silent: response = f"{self.parser.format_usage()}" - self.reply(response, message) + self.reply(response, chat, message) return None return args diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index a05bc3e57..29e7a0596 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -1,4 +1,5 @@ from jupyter_ai.models import ClearMessage +from jupyterlab_collaborative_chat.ychat import YChat from .base import BaseChatHandler, SlashCommandRoutingType @@ -16,7 +17,7 @@ class ClearChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, _): + async def process_message(self, _, chat: YChat): # Clear chat for handler in self._root_chat_handlers.values(): if not handler: diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3e8f194b3..3cd1782fb 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,6 +1,7 @@ import time from typing import Dict, Type from uuid import uuid4 +from jupyterlab_collaborative_chat.ychat import YChat from jupyter_ai.models import ( AgentStreamChunkMessage, @@ -81,7 +82,7 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str: return stream_id - def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False): + def _send_stream_chunk(self, stream_id: str, content: str, chat: YChat, complete: bool = False): """ Sends an `agent-stream-chunk` message containing content that should be appended to an existing `agent-stream` message with ID `stream_id`. @@ -89,15 +90,16 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals stream_chunk_msg = AgentStreamChunkMessage( id=stream_id, content=content, stream_complete=complete ) + self.write_message(chat, stream_chunk_msg.content) + # for handler in self._root_chat_handlers.values(): + # if not handler: + # continue - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(stream_chunk_msg) - break + # handler.broadcast_message(stream_chunk_msg) + # break - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: YChat): + self.log.warning("PROCESS IN DEFAULT HANDLER") self.get_llm_chain() received_first_chunk = False @@ -120,10 +122,10 @@ async def process_message(self, message: HumanChatMessage): if isinstance(chunk, AIMessageChunk): self._send_stream_chunk(stream_id, chunk.content) elif isinstance(chunk, str): - self._send_stream_chunk(stream_id, chunk) + self._send_stream_chunk(stream_id, chunk, chat) else: self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") break # complete stream after all chunks have been streamed - self._send_stream_chunk(stream_id, "", complete=True) + self._send_stream_chunk(stream_id, "", chat, complete=True) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py index ed478f57e..cf84581fa 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py @@ -2,6 +2,7 @@ import os from datetime import datetime from typing import List +from jupyterlab_collaborative_chat.ychat import YChat from jupyter_ai.models import AgentChatMessage, HumanChatMessage @@ -31,11 +32,11 @@ def chat_message_to_markdown(self, message): return "" # Write the chat history to a markdown file with a timestamp - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: YChat): markdown_content = "\n\n".join( self.chat_message_to_markdown(msg) for msg in self._chat_history ) - args = self.parse_args(message) + args = self.parse_args(message, chat) chat_filename = ( # if no filename, use "chat_history" + timestamp args.path[0] if (args.path and args.path[0] != "") @@ -46,4 +47,4 @@ async def process_message(self, message: HumanChatMessage): ) # Do not use timestamp if filename is entered as argument with open(chat_file, "w") as chat_history: chat_history.write(markdown_content) - self.reply(f"File saved to `{chat_file}`") + self.reply(f"File saved to `{chat_file}`", chat) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index d6ecc6d81..44210b14a 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -1,4 +1,5 @@ from typing import Dict, Type +from jupyterlab_collaborative_chat.ychat import YChat from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider @@ -77,10 +78,11 @@ def create_llm_chain( self.llm = llm self.llm_chain = LLMChain(llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True) - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: YChat): if not (message.selection and message.selection.type == "cell-with-error"): self.reply( "`/fix` requires an active code cell with error output. Please click on a cell with error output and retry.", + chat, message, ) return @@ -101,4 +103,4 @@ async def process_message(self, message: HumanChatMessage): error_value=selection.error.value, traceback="\n".join(selection.error.traceback), ) - self.reply(response, message) + self.reply(response, chat, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 52398eabe..da26d0896 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -4,6 +4,7 @@ import traceback from pathlib import Path from typing import Dict, List, Optional, Type +from jupyterlab_collaborative_chat.ychat import YChat import nbformat from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType @@ -261,18 +262,18 @@ async def _generate_notebook(self, prompt: str): nbformat.write(notebook, final_path) return final_path - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: YChat): self.get_llm_chain() # first send a verification message to user response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." - self.reply(response, message) + self.reply(response, chat, message) final_path = await self._generate_notebook(prompt=message.body) response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it.""" - self.reply(response, message) + self.reply(response, chat, message) - async def handle_exc(self, e: Exception, message: HumanChatMessage): + async def handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat): timestamp = time.strftime("%Y-%m-%d-%H.%M.%S") default_log_dir = Path(self.output_dir) / "jupyter-ai-logs" log_dir = self.log_dir or default_log_dir @@ -282,4 +283,4 @@ async def handle_exc(self, e: Exception, message: HumanChatMessage): traceback.print_exc(file=log) response = f"An error occurred while generating the notebook. The error details have been saved to `./{log_path}`.\n\nTry running `/generate` again, as some language models require multiple attempts before a notebook is generated." - self.reply(response, message) + self.reply(response, chat, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index cd8556863..6c5a2fa9b 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -1,4 +1,5 @@ from jupyter_ai.models import HumanChatMessage +from jupyterlab_collaborative_chat.ychat import YChat from .base import BaseChatHandler, SlashCommandRoutingType @@ -15,5 +16,5 @@ class HelpChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: YChat): self.send_help_message(message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 29e147f22..1ed118c97 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -2,6 +2,7 @@ import json import os from typing import Any, Coroutine, List, Optional, Tuple +from jupyterlab_collaborative_chat.ychat import YChat from dask.distributed import Client as DaskClient from jupyter_ai.document_loaders.directory import ( @@ -127,26 +128,27 @@ def _load(self): ) self.log.error(e) - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: YChat): # If no embedding provider has been selected em_provider_cls, em_provider_args = self.get_embedding_provider() if not em_provider_cls: self.reply( - "Sorry, please select an embedding provider before using the `/learn` command." + "Sorry, please select an embedding provider before using the `/learn` command.", + chat ) return - args = self.parse_args(message) + args = self.parse_args(message, chat) if args is None: return if args.delete: self.delete() - self.reply(f"👍 I have deleted everything I previously learned.", message) + self.reply(f"👍 I have deleted everything I previously learned.", chat, message) return if args.list: - self.reply(self._build_list_response()) + self.reply(self._build_list_response(), chat) return if args.remote: @@ -157,35 +159,38 @@ async def process_message(self, message: HumanChatMessage): args.path = [arxiv_to_text(id, self.output_dir)] self.reply( f"Learning arxiv file with id **{id}**, saved in **{args.path[0]}**.", + chat, message, ) except ModuleNotFoundError as e: self.log.error(e) self.reply( - "No `arxiv` package found. " "Install with `pip install arxiv`." + "No `arxiv` package found. " "Install with `pip install arxiv`.", + chat ) return except Exception as e: self.log.error(e) self.reply( "An error occurred while processing the arXiv file. " - f"Please verify that the arxiv id {id} is correct." + f"Please verify that the arxiv id {id} is correct.", + chat ) return # Make sure the path exists. if not len(args.path) == 1: - self.reply(f"{self.parser.format_usage()}", message) + self.reply(f"{self.parser.format_usage()}", chat, message) return short_path = args.path[0] load_path = os.path.join(self.output_dir, short_path) if not os.path.exists(load_path): response = f"Sorry, that path doesn't exist: {load_path}" - self.reply(response, message) + self.reply(response, chat, message) return # delete and relearn index if embedding model was changed - await self.delete_and_relearn() + await self.delete_and_relearn(chat) with self.pending(f"Loading and splitting files for {load_path}", message): try: @@ -198,7 +203,7 @@ async def process_message(self, message: HumanChatMessage): self.save() response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. You can ask questions about these docs by prefixing your message with **/ask**.""" - self.reply(response, message) + self.reply(response, chat, message) def _build_list_response(self): if not self.metadata.dirs: @@ -250,7 +255,7 @@ def _add_dir_to_metadata(self, path: str, chunk_size: int, chunk_overlap: int): ) self.metadata.dirs = dirs - async def delete_and_relearn(self): + async def delete_and_relearn(self, chat: YChat): """Delete the vector store and relearn all indexed directories if necessary. If the embedding model is unchanged, this method does nothing.""" @@ -277,11 +282,11 @@ async def delete_and_relearn(self): documents you had previously submitted for learning. Please wait to use the **/ask** command until I am done with this task.""" - self.reply(message) + self.reply(message, chat) metadata = self.metadata self.delete() - await self.relearn(metadata) + await self.relearn(metadata, chat) self.prev_em_id = curr_em_id def delete(self): @@ -295,7 +300,7 @@ def delete(self): if os.path.isfile(path): os.remove(path) - async def relearn(self, metadata: IndexMetadata): + async def relearn(self, metadata: IndexMetadata, chat: YChat): # Index all dirs in the metadata if not metadata.dirs: return @@ -315,7 +320,7 @@ async def relearn(self, metadata: IndexMetadata): message = f"""🎉 I am done learning docs in these directories: {dir_list} I am ready to answer questions about them. You can ask me about these documents by starting your message with **/ask**.""" - self.reply(message) + self.reply(message, chat) def create( self, diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 78c4dfefa..ae9fa85e6 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -2,15 +2,27 @@ import re import time import types +import uuid from dask.distributed import Client as DaskClient from importlib_metadata import entry_points from jupyter_ai.chat_handlers.learn import Retriever +from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics import BaseProvider, JupyternautPersona from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp from tornado.web import StaticFileHandler from traitlets import Dict, Integer, List, Unicode +from jupyter_collaboration import __version__ as jupyter_collaboration_version +from jupyterlab_collaborative_chat.ychat import YChat + +from functools import partial +from jupyter_collaboration.utils import JUPYTER_COLLABORATION_EVENTS_URI +from jupyter_events import EventLogger +from jupyter_server.extension.application import ExtensionApp +from jupyter_server.utils import url_path_join + +from pycrdt import ArrayEvent from .chat_handlers import ( AskChatHandler, @@ -41,6 +53,17 @@ ) +if int(jupyter_collaboration_version[0]) >= 3: + COLLAB_VERSION = 3 +else: + COLLAB_VERSION = 2 + +BOT = { + "username": str(uuid.uuid4()), + "name": "Jupyternaut", + "display_name": "Jupyternaut" +} + DEFAULT_HELP_MESSAGE_TEMPLATE = """Hi there! I'm {persona_name}, your programming assistant. You can ask me a question using the text box below. You can also use these commands: {slash_commands_list} @@ -198,6 +221,112 @@ class AiExtension(ExtensionApp): config=True, ) + def initialize(self): + super().initialize() + self.event_logger = self.serverapp.web_app.settings["event_logger"] + self.event_logger.add_listener( + schema_id=JUPYTER_COLLABORATION_EVENTS_URI, + listener=self.connect_chat + ) + + async def connect_chat(self, logger: EventLogger, schema_id: str, data: dict) -> None: + self.log.warn(f"New DOC {data["room"]}") + if data["room"].startswith("text:chat:") \ + and data["action"] == "initialize"\ + and data["msg"] == "Room initialized": + + self.log.info(f"Collaborative chat server is listening for {data["room"]}") + chat = await self.get_chat(data["room"]) + self.log.warn(f"Chat {chat}") + callback = partial(self.on_change, chat) + chat.ymessages.observe(callback) + + async def get_chat(self, room_id: str) -> YChat: + if COLLAB_VERSION == 3: + collaboration = self.serverapp.web_app.settings["jupyter_server_ydoc"] + document = await collaboration.get_document( + room_id=room_id, + copy=False + ) + else: + collaboration = self.serverapp.web_app.settings["jupyter_collaboration"] + server = collaboration.ywebsocket_server + + room = await server.get_room(room_id) + document = room._document + return document + + def on_change(self, chat: YChat, events: ArrayEvent) -> None: + for change in events.delta: + if not "insert" in change.keys(): + continue + messages = change["insert"] + self.log.warn(f"New messages {messages}") + for message in messages: + self.log.warn(f"SENDER {message["sender"]}") + self.log.warn(f"BOT {BOT["username"]}") + + if message["sender"] == BOT["username"] or message["raw_time"]: + self.log.warn("HERE WE ARE") + continue + try: + chat_message = HumanChatMessage( + id=message["id"], + time=time.time(), + body=message["body"], + prompt="", + selection=None, + client=None, + ) + except Exception as e: + self.log.error(e) + self.log.warn(f"BUILT HUMAN MESSAGE {chat_message}") + self.serverapp.io_loop.asyncio_loop.create_task(self._route(chat_message, chat)) + + async def _route(self, message: HumanChatMessage, chat: YChat): + """Method that routes an incoming message to the appropriate handler.""" + self.log.warn(f"ROUTING {message}") + chat_handlers = self.settings["jai_chat_handlers"] + default = chat_handlers["default"] + # Split on any whitespace, either spaces or newlines + maybe_command = message.body.split(None, 1)[0] + is_command = ( + message.body.startswith("/") + and maybe_command in chat_handlers.keys() + and maybe_command != "default" + ) + command = maybe_command if is_command else "default" + + start = time.time() + if is_command: + await chat_handlers[command].on_message(message, chat) + else: + await default.on_message(message, chat) + + latency_ms = round((time.time() - start) * 1000) + command_readable = "Default" if command == "default" else command + self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.") + + def write_message(self, chat: YChat, body: str) -> None: + BOT["avatar_url"] = url_path_join( + self.settings.get("base_url", "/"), + "api/ai/static/jupyternaut.svg" + ) + bot = chat.get_user_by_name(BOT["name"]) + if not bot: + chat.set_user(BOT) + else: + BOT["username"] = bot["username"] + + chat.add_message({ + "type": "msg", + "body": body, + "id": str(uuid.uuid4()), + "time": time.time(), + "sender": BOT["username"], + "raw_time": False + }) + def initialize_settings(self): start = time.time() @@ -301,6 +430,7 @@ def initialize_settings(self): "preferred_dir": self.serverapp.contents_manager.preferred_dir, "help_message_template": self.help_message_template, "chat_handlers": chat_handlers, + "write_message": self.write_message } default_chat_handler = DefaultChatHandler(**chat_handler_kwargs) clear_chat_handler = ClearChatHandler(**chat_handler_kwargs) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index f9098a12a..75638b6b8 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -114,7 +114,7 @@ class HumanChatMessage(BaseModel): """The prompt typed into the chat input by the user.""" selection: Optional[Selection] """The selection included with the prompt, if any.""" - client: ChatClient + client: Optional[ChatClient] class ClearMessage(BaseModel): diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index e8deeb133..2353343e3 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "typing_extensions>=4.5.0", "traitlets>=5.0", "deepmerge>=1.0", + "jupyterlab-collaborative-chat", ] dynamic = ["version", "description", "authors", "urls", "keywords"] From a51840b1e61c77abed093ce92e00efd71cd11a63 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Tue, 10 Sep 2024 14:59:31 +0200 Subject: [PATCH 02/24] Allows both collaborative and regular chat to work with AI --- .../jupyter_ai/chat_handlers/ask.py | 8 +++- .../jupyter_ai/chat_handlers/base.py | 47 ++++++++++--------- .../jupyter_ai/chat_handlers/clear.py | 8 +++- .../jupyter_ai/chat_handlers/default.py | 31 +++++++----- .../jupyter_ai/chat_handlers/export.py | 8 +++- .../jupyter_ai/chat_handlers/fix.py | 8 +++- .../jupyter_ai/chat_handlers/generate.py | 9 ++-- .../jupyter_ai/chat_handlers/help.py | 8 +++- .../jupyter_ai/chat_handlers/learn.py | 12 +++-- packages/jupyter-ai/pyproject.toml | 3 +- 10 files changed, 90 insertions(+), 52 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 5bda7cef2..6c68cbf15 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -1,6 +1,5 @@ import argparse from typing import Dict, Type -from jupyterlab_collaborative_chat.ychat import YChat from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider @@ -8,6 +7,11 @@ from langchain.memory import ConversationBufferWindowMemory from langchain_core.prompts import PromptTemplate +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from .base import BaseChatHandler, SlashCommandRoutingType PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. @@ -60,7 +64,7 @@ def create_llm_chain( verbose=False, ) - async def process_message(self, message: HumanChatMessage, chat: YChat): + async def process_message(self, message: HumanChatMessage, chat: YChat | None): args = self.parse_args(message, chat) if args is None: return diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 985ff264e..84daa3dfa 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -16,7 +16,6 @@ Union, ) from uuid import uuid4 -from jupyterlab_collaborative_chat.ychat import YChat from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger @@ -32,6 +31,11 @@ from jupyter_ai_magics.providers import BaseProvider from langchain.pydantic_v1 import BaseModel +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + if TYPE_CHECKING: from jupyter_ai.handlers import RootChatHandler from jupyter_ai.history import BoundedChatHistory @@ -134,7 +138,7 @@ def __init__( dask_client_future: Awaitable[DaskClient], help_message_template: str, chat_handlers: Dict[str, "BaseChatHandler"], - write_message: Callable[[YChat, str], None] + write_message: Callable[[YChat, str], None] | None = None ): self.log = log self.config_manager = config_manager @@ -162,7 +166,7 @@ def __init__( self.write_message = write_message - async def on_message(self, message: HumanChatMessage, chat: YChat): + async def on_message(self, message: HumanChatMessage, chat: YChat| None = None): """ Method which receives a human message, calls `self.get_llm_chain()`, and processes the message via `self.process_message()`, calling @@ -217,7 +221,7 @@ async def on_message(self, message: HumanChatMessage, chat: YChat): finally: BaseChatHandler._requests_count -= 1 - async def process_message(self, message: HumanChatMessage, chat: YChat): + async def process_message(self, message: HumanChatMessage, chat: YChat | None): """ Processes a human message routed to this chat handler. Chat handlers (subclasses) must implement this method. Don't forget to call @@ -228,7 +232,7 @@ async def process_message(self, message: HumanChatMessage, chat: YChat): """ raise NotImplementedError("Should be implemented by subclasses.") - async def handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat): + async def handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat | None): """ Handles an exception raised by `self.process_message()`. A default implementation is provided, however chat handlers (subclasses) should @@ -236,7 +240,7 @@ async def handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat) """ await self._default_handle_exc(e, message, chat) - async def _default_handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat): + async def _default_handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat | None): """ The default definition of `handle_exc()`. This is the default used when the `handle_exc()` excepts. @@ -254,26 +258,27 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage, cha ) self.reply(response, chat, message) - def reply(self, response: str, chat: YChat, human_msg: Optional[HumanChatMessage] = None): + def reply(self, response: str, chat: YChat | None, human_msg: Optional[HumanChatMessage] = None): """ Sends an agent message, usually in response to a received `HumanChatMessage`. """ - agent_msg = AgentChatMessage( - id=uuid4().hex, - time=time.time(), - body=response, - reply_to=human_msg.id if human_msg else "", - persona=self.persona, - ) - - self.write_message(chat, response) - # for handler in self._root_chat_handlers.values(): - # if not handler: - # continue + if chat is not None: + self.write_message(chat, response) + else: + agent_msg = AgentChatMessage( + id=uuid4().hex, + time=time.time(), + body=response, + reply_to=human_msg.id if human_msg else "", + persona=self.persona, + ) + for handler in self._root_chat_handlers.values(): + if not handler: + continue - # handler.broadcast_message(agent_msg) - # break + handler.broadcast_message(agent_msg) + break @property def persona(self): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index 29e7a0596..e163545bc 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -1,5 +1,9 @@ from jupyter_ai.models import ClearMessage -from jupyterlab_collaborative_chat.ychat import YChat + +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat from .base import BaseChatHandler, SlashCommandRoutingType @@ -17,7 +21,7 @@ class ClearChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, _, chat: YChat): + async def process_message(self, _, chat: YChat | None): # Clear chat for handler in self._root_chat_handlers.values(): if not handler: diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3cd1782fb..480fd4565 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,7 +1,6 @@ import time from typing import Dict, Type from uuid import uuid4 -from jupyterlab_collaborative_chat.ychat import YChat from jupyter_ai.models import ( AgentStreamChunkMessage, @@ -13,6 +12,11 @@ from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from ..models import HumanChatMessage from .base import BaseChatHandler, SlashCommandRoutingType @@ -82,24 +86,25 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str: return stream_id - def _send_stream_chunk(self, stream_id: str, content: str, chat: YChat, complete: bool = False): + def _send_stream_chunk(self, stream_id: str, content: str, chat: YChat | None, complete: bool = False): """ Sends an `agent-stream-chunk` message containing content that should be appended to an existing `agent-stream` message with ID `stream_id`. """ - stream_chunk_msg = AgentStreamChunkMessage( - id=stream_id, content=content, stream_complete=complete - ) - self.write_message(chat, stream_chunk_msg.content) - # for handler in self._root_chat_handlers.values(): - # if not handler: - # continue + if chat is not None: + self.write_message(chat, content) + else: + stream_chunk_msg = AgentStreamChunkMessage( + id=stream_id, content=content, stream_complete=complete + ) + for handler in self._root_chat_handlers.values(): + if not handler: + continue - # handler.broadcast_message(stream_chunk_msg) - # break + handler.broadcast_message(stream_chunk_msg) + break - async def process_message(self, message: HumanChatMessage, chat: YChat): - self.log.warning("PROCESS IN DEFAULT HANDLER") + async def process_message(self, message: HumanChatMessage, chat: YChat | None): self.get_llm_chain() received_first_chunk = False diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py index cf84581fa..c296b62cc 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py @@ -2,10 +2,14 @@ import os from datetime import datetime from typing import List -from jupyterlab_collaborative_chat.ychat import YChat from jupyter_ai.models import AgentChatMessage, HumanChatMessage +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from .base import BaseChatHandler, SlashCommandRoutingType @@ -32,7 +36,7 @@ def chat_message_to_markdown(self, message): return "" # Write the chat history to a markdown file with a timestamp - async def process_message(self, message: HumanChatMessage, chat: YChat): + async def process_message(self, message: HumanChatMessage, chat: YChat | None): markdown_content = "\n\n".join( self.chat_message_to_markdown(msg) for msg in self._chat_history ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index 44210b14a..349f83987 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -1,11 +1,15 @@ from typing import Dict, Type -from jupyterlab_collaborative_chat.ychat import YChat from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider from langchain.chains import LLMChain from langchain.prompts import PromptTemplate +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from .base import BaseChatHandler, SlashCommandRoutingType FIX_STRING_TEMPLATE = """ @@ -78,7 +82,7 @@ def create_llm_chain( self.llm = llm self.llm_chain = LLMChain(llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True) - async def process_message(self, message: HumanChatMessage, chat: YChat): + async def process_message(self, message: HumanChatMessage, chat: YChat | None): if not (message.selection and message.selection.type == "cell-with-error"): self.reply( "`/fix` requires an active code cell with error output. Please click on a cell with error output and retry.", diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index da26d0896..50d85c3ae 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -4,7 +4,6 @@ import traceback from pathlib import Path from typing import Dict, List, Optional, Type -from jupyterlab_collaborative_chat.ychat import YChat import nbformat from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType @@ -17,6 +16,10 @@ from langchain.schema.output_parser import BaseOutputParser from langchain_core.prompts import PromptTemplate +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat class OutlineSection(BaseModel): title: str @@ -262,7 +265,7 @@ async def _generate_notebook(self, prompt: str): nbformat.write(notebook, final_path) return final_path - async def process_message(self, message: HumanChatMessage, chat: YChat): + async def process_message(self, message: HumanChatMessage, chat: YChat | None): self.get_llm_chain() # first send a verification message to user @@ -273,7 +276,7 @@ async def process_message(self, message: HumanChatMessage, chat: YChat): response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it.""" self.reply(response, chat, message) - async def handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat): + async def handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat | None): timestamp = time.strftime("%Y-%m-%d-%H.%M.%S") default_log_dir = Path(self.output_dir) / "jupyter-ai-logs" log_dir = self.log_dir or default_log_dir diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index 6c5a2fa9b..fcce75e09 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -1,5 +1,9 @@ from jupyter_ai.models import HumanChatMessage -from jupyterlab_collaborative_chat.ychat import YChat + +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat from .base import BaseChatHandler, SlashCommandRoutingType @@ -16,5 +20,5 @@ class HelpChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, message: HumanChatMessage, chat: YChat): + async def process_message(self, message: HumanChatMessage, chat: YChat | None): self.send_help_message(message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 1ed118c97..3b30e5f23 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -2,7 +2,6 @@ import json import os from typing import Any, Coroutine, List, Optional, Tuple -from jupyterlab_collaborative_chat.ychat import YChat from dask.distributed import Client as DaskClient from jupyter_ai.document_loaders.directory import ( @@ -30,6 +29,11 @@ ) from langchain_community.vectorstores import FAISS +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from .base import BaseChatHandler, SlashCommandRoutingType INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), "jupyter_ai", "indices") @@ -128,7 +132,7 @@ def _load(self): ) self.log.error(e) - async def process_message(self, message: HumanChatMessage, chat: YChat): + async def process_message(self, message: HumanChatMessage, chat: YChat | None): # If no embedding provider has been selected em_provider_cls, em_provider_args = self.get_embedding_provider() if not em_provider_cls: @@ -255,7 +259,7 @@ def _add_dir_to_metadata(self, path: str, chunk_size: int, chunk_overlap: int): ) self.metadata.dirs = dirs - async def delete_and_relearn(self, chat: YChat): + async def delete_and_relearn(self, chat: YChat | None): """Delete the vector store and relearn all indexed directories if necessary. If the embedding model is unchanged, this method does nothing.""" @@ -300,7 +304,7 @@ def delete(self): if os.path.isfile(path): os.remove(path) - async def relearn(self, metadata: IndexMetadata, chat: YChat): + async def relearn(self, metadata: IndexMetadata, chat: YChat | None): # Index all dirs in the metadata if not metadata.dirs: return diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 2353343e3..9b1d3d996 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -33,7 +33,6 @@ dependencies = [ "typing_extensions>=4.5.0", "traitlets>=5.0", "deepmerge>=1.0", - "jupyterlab-collaborative-chat", ] dynamic = ["version", "description", "authors", "urls", "keywords"] @@ -57,6 +56,8 @@ dev = ["jupyter_ai_magics[dev]"] all = ["jupyter_ai_magics[all]", "pypdf", "arxiv"] +collaborative = ["jupyterlab-collaborative-chat"] + [tool.hatch.version] source = "nodejs" From ac00c973b58898d846ad82bc740d10a0237f03d9 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Tue, 10 Sep 2024 15:24:36 +0200 Subject: [PATCH 03/24] handle the help message in the chat too --- .../jupyter_ai/chat_handlers/base.py | 19 ++++++++++++++----- .../jupyter_ai/chat_handlers/help.py | 2 +- packages/jupyter-ai/jupyter_ai/extension.py | 10 +--------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 84daa3dfa..77e45ea9f 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -173,7 +173,6 @@ async def on_message(self, message: HumanChatMessage, chat: YChat| None = None): `self.handle_exc()` when an exception is raised. This method is called by RootChatHandler when it routes a human message to this chat handler. """ - self.log.warn(f"MESSAGE SENT {message.body}") lm_provider_klass = self.config_manager.lm_provider # ensure the current slash command is supported @@ -424,7 +423,7 @@ def output_dir(self) -> str: else: return self.root_dir - def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> None: + def send_help_message(self, chat: YChat | None, human_msg: Optional[HumanChatMessage] = None) -> None: """Sends a help message to all connected clients.""" lm_provider = self.config_manager.lm_provider unsupported_slash_commands = ( @@ -454,6 +453,16 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non persona=self.persona, ) - self._chat_history.append(help_message) - for websocket in self._root_chat_handlers.values(): - websocket.write_message(help_message.json()) + if chat is not None: + self.write_message(chat, help_message_body) + else: + help_message = AgentChatMessage( + id=uuid4().hex, + time=time.time(), + body=help_message_body, + reply_to=human_msg.id if human_msg else "", + persona=self.persona, + ) + self._chat_history.append(help_message) + for websocket in self._root_chat_handlers.values(): + websocket.write_message(help_message.json()) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index fcce75e09..0cfa60999 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -21,4 +21,4 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def process_message(self, message: HumanChatMessage, chat: YChat | None): - self.send_help_message(message) + self.send_help_message(chat, message) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index ae9fa85e6..128263726 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -230,14 +230,12 @@ def initialize(self): ) async def connect_chat(self, logger: EventLogger, schema_id: str, data: dict) -> None: - self.log.warn(f"New DOC {data["room"]}") if data["room"].startswith("text:chat:") \ and data["action"] == "initialize"\ and data["msg"] == "Room initialized": self.log.info(f"Collaborative chat server is listening for {data["room"]}") chat = await self.get_chat(data["room"]) - self.log.warn(f"Chat {chat}") callback = partial(self.on_change, chat) chat.ymessages.observe(callback) @@ -261,13 +259,9 @@ def on_change(self, chat: YChat, events: ArrayEvent) -> None: if not "insert" in change.keys(): continue messages = change["insert"] - self.log.warn(f"New messages {messages}") for message in messages: - self.log.warn(f"SENDER {message["sender"]}") - self.log.warn(f"BOT {BOT["username"]}") if message["sender"] == BOT["username"] or message["raw_time"]: - self.log.warn("HERE WE ARE") continue try: chat_message = HumanChatMessage( @@ -280,12 +274,10 @@ def on_change(self, chat: YChat, events: ArrayEvent) -> None: ) except Exception as e: self.log.error(e) - self.log.warn(f"BUILT HUMAN MESSAGE {chat_message}") self.serverapp.io_loop.asyncio_loop.create_task(self._route(chat_message, chat)) async def _route(self, message: HumanChatMessage, chat: YChat): """Method that routes an incoming message to the appropriate handler.""" - self.log.warn(f"ROUTING {message}") chat_handlers = self.settings["jai_chat_handlers"] default = chat_handlers["default"] # Split on any whitespace, either spaces or newlines @@ -523,7 +515,7 @@ def _show_help_message(self): default_chat_handler: DefaultChatHandler = self.settings["jai_chat_handlers"][ "default" ] - default_chat_handler.send_help_message() + default_chat_handler.send_help_message(None) async def _get_dask_client(self): return DaskClient(processes=False, asynchronous=True) From 689b7b3668b7241b874b23169b931eebdb9152c0 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet <32258950+brichet@users.noreply.github.com> Date: Tue, 10 Sep 2024 17:37:50 +0200 Subject: [PATCH 04/24] Autocompletion (#2) * Fix handler methods' parameters * Add slash commands (autocompletion) to the chat input --- .../slash_command.py | 7 +- .../jupyter_ai_test/test_slash_commands.py | 9 +- .../jupyter_ai/chat_handlers/base.py | 2 +- packages/jupyter-ai/package.json | 1 + packages/jupyter-ai/src/index.ts | 25 +- .../jupyter-ai/src/slash-autocompletion.tsx | 93 ++++++ yarn.lock | 286 +++++++++++++++++- 7 files changed, 415 insertions(+), 8 deletions(-) create mode 100644 packages/jupyter-ai/src/slash-autocompletion.tsx diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/slash_command.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/slash_command.py index f82bd5531..d5ea9b720 100644 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/slash_command.py +++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/slash_command.py @@ -1,6 +1,9 @@ from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.models import HumanChatMessage - +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat class TestSlashCommand(BaseChatHandler): """ @@ -25,5 +28,5 @@ class TestSlashCommand(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: YChat): self.reply("This is the `/test` slash command.") diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py b/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py index f82bd5531..14b97f2f0 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py @@ -1,6 +1,9 @@ from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.models import HumanChatMessage - +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat class TestSlashCommand(BaseChatHandler): """ @@ -25,5 +28,5 @@ class TestSlashCommand(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, message: HumanChatMessage): - self.reply("This is the `/test` slash command.") + async def process_message(self, message: HumanChatMessage, chat: YChat): + self.reply("This is the `/test` slash command.", chat) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 77e45ea9f..5309e4bf3 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -216,7 +216,7 @@ async def on_message(self, message: HumanChatMessage, chat: YChat| None = None): # raises an exception by accident. await self.handle_exc(e, message, chat) except Exception as e: - await self._default_handle_exc(e, message) + await self._default_handle_exc(e, message, chat) finally: BaseChatHandler._requests_count -= 1 diff --git a/packages/jupyter-ai/package.json b/packages/jupyter-ai/package.json index bd9a20a71..a1f8029da 100644 --- a/packages/jupyter-ai/package.json +++ b/packages/jupyter-ai/package.json @@ -61,6 +61,7 @@ "dependencies": { "@emotion/react": "^11.10.5", "@emotion/styled": "^11.10.5", + "@jupyter/chat": "^0.3.1", "@jupyter/collaboration": "^1", "@jupyterlab/application": "^4.2.0", "@jupyterlab/apputils": "^4.2.0", diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index d6e52b576..a029c3237 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -1,3 +1,4 @@ +import { IAutocompletionRegistry } from '@jupyter/chat'; import { JupyterFrontEnd, JupyterFrontEndPlugin, @@ -21,6 +22,7 @@ import { statusItemPlugin } from './status'; import { IJaiCompletionProvider, IJaiCore, IJaiMessageFooter } from './tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ActiveCellManager } from './contexts/active-cell-context'; +import { autocompletion } from './slash-autocompletion'; import { Signal } from '@lumino/signaling'; import { menuPlugin } from './plugins/menu-plugin'; @@ -126,4 +128,25 @@ const plugin: JupyterFrontEndPlugin = { } }; -export default [plugin, statusItemPlugin, completionPlugin, menuPlugin]; +/** + * Add slash commands to collaborative chat. + */ +const collaborative_autocompletion: JupyterFrontEndPlugin = { + id: '@jupyter-ai/core:autocompletion', + autoStart: true, + requires: [IAutocompletionRegistry], + activate: async ( + app: JupyterFrontEnd, + autocompletionRegistry: IAutocompletionRegistry + ) => { + autocompletionRegistry.add('ai', autocompletion); + } +}; + +export default [ + plugin, + statusItemPlugin, + completionPlugin, + menuPlugin, + collaborative_autocompletion +]; diff --git a/packages/jupyter-ai/src/slash-autocompletion.tsx b/packages/jupyter-ai/src/slash-autocompletion.tsx new file mode 100644 index 000000000..50aad1a0c --- /dev/null +++ b/packages/jupyter-ai/src/slash-autocompletion.tsx @@ -0,0 +1,93 @@ +import { + AutocompleteCommand, + IAutocompletionCommandsProps +} from '@jupyter/chat'; +import { + Download, + FindInPage, + Help, + MoreHoriz, + MenuBook, + School, + HideSource, + AutoFixNormal +} from '@mui/icons-material'; +import { Box, Typography } from '@mui/material'; +import React from 'react'; +import { AiService } from './handler'; + +type SlashCommandOption = AutocompleteCommand & { + id: string; + description: string; +}; + +/** + * List of icons per slash command, shown in the autocomplete popup. + * + * This list of icons should eventually be made configurable. However, it is + * unclear whether custom icons should be defined within a Lumino plugin (in the + * frontend) or served from a static server route (in the backend). + */ +const DEFAULT_SLASH_COMMAND_ICONS: Record = { + ask: , + clear: , + export: , + fix: , + generate: , + help: , + learn: , + unknown: +}; + +/** + * Renders an option shown in the slash command autocomplete. + */ +function renderSlashCommandOption( + optionProps: React.HTMLAttributes, + option: SlashCommandOption +): JSX.Element { + const icon = + option.id in DEFAULT_SLASH_COMMAND_ICONS + ? DEFAULT_SLASH_COMMAND_ICONS[option.id] + : DEFAULT_SLASH_COMMAND_ICONS.unknown; + + return ( +
  • + {icon} + + + {option.label} + + + {' — ' + option.description} + + +
  • + ); +} + +/** + * The autocompletion command properties to add to the registry. + */ +export const autocompletion: IAutocompletionCommandsProps = { + opener: '/', + commands: async () => { + const slashCommands = (await AiService.listSlashCommands()).slash_commands; + return slashCommands.map(slashCommand => ({ + id: slashCommand.slash_id, + label: '/' + slashCommand.slash_id + ' ', + description: slashCommand.description + })); + }, + props: { + renderOption: renderSlashCommandOption + } +}; diff --git a/yarn.lock b/yarn.lock index 76ddf97ca..49d9b3304 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2220,6 +2220,7 @@ __metadata: "@babel/preset-env": ^7.0.0 "@emotion/react": ^11.10.5 "@emotion/styled": ^11.10.5 + "@jupyter/chat": ^0.3.1 "@jupyter/collaboration": ^1 "@jupyterlab/application": ^4.2.0 "@jupyterlab/apputils": ^4.2.0 @@ -2285,6 +2286,30 @@ __metadata: languageName: unknown linkType: soft +"@jupyter/chat@npm:^0.3.1": + version: 0.3.1 + resolution: "@jupyter/chat@npm:0.3.1" + dependencies: + "@emotion/react": ^11.10.5 + "@emotion/styled": ^11.10.5 + "@jupyter/react-components": ^0.15.2 + "@jupyterlab/application": ^4.2.0 + "@jupyterlab/apputils": ^4.3.0 + "@jupyterlab/notebook": ^4.2.0 + "@jupyterlab/rendermime": ^4.2.0 + "@jupyterlab/ui-components": ^4.2.0 + "@lumino/commands": ^2.0.0 + "@lumino/disposable": ^2.0.0 + "@lumino/signaling": ^2.0.0 + "@mui/icons-material": ^5.11.0 + "@mui/material": ^5.11.0 + clsx: ^2.1.0 + react: ^18.2.0 + react-dom: ^18.2.0 + checksum: 92d1f6d6d3083be2a8c2309fc37eb3bb39bc184e1c41d18f79d51f87e54d7badba1495c5943916e7a21fc65414d6100860b044a14554add8192b545fca2748dd + languageName: node + linkType: hard + "@jupyter/collaboration@npm:^1": version: 1.2.1 resolution: "@jupyter/collaboration@npm:1.2.1" @@ -2323,7 +2348,7 @@ __metadata: languageName: node linkType: hard -"@jupyter/react-components@npm:^0.15.3": +"@jupyter/react-components@npm:^0.15.2, @jupyter/react-components@npm:^0.15.3": version: 0.15.3 resolution: "@jupyter/react-components@npm:0.15.3" dependencies: @@ -2431,6 +2456,35 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/apputils@npm:^4.3.0": + version: 4.3.5 + resolution: "@jupyterlab/apputils@npm:4.3.5" + dependencies: + "@jupyterlab/coreutils": ^6.2.5 + "@jupyterlab/observables": ^5.2.5 + "@jupyterlab/rendermime-interfaces": ^3.10.5 + "@jupyterlab/services": ^7.2.5 + "@jupyterlab/settingregistry": ^4.2.5 + "@jupyterlab/statedb": ^4.2.5 + "@jupyterlab/statusbar": ^4.2.5 + "@jupyterlab/translation": ^4.2.5 + "@jupyterlab/ui-components": ^4.2.5 + "@lumino/algorithm": ^2.0.1 + "@lumino/commands": ^2.3.0 + "@lumino/coreutils": ^2.1.2 + "@lumino/disposable": ^2.1.2 + "@lumino/domutils": ^2.0.1 + "@lumino/messaging": ^2.0.1 + "@lumino/signaling": ^2.1.2 + "@lumino/virtualdom": ^2.0.1 + "@lumino/widgets": ^2.3.2 + "@types/react": ^18.0.26 + react: ^18.2.0 + sanitize-html: ~2.12.1 + checksum: a2307657bfab1aff687eccfdb7a2c378a40989beea618ad6e5a811dbd250753588ea704a11250ddef42a551c8360717c1fe4c8827c5e2c3bfff1e84fc7fdc836 + languageName: node + linkType: hard + "@jupyterlab/attachments@npm:^4.2.2": version: 4.2.2 resolution: "@jupyterlab/attachments@npm:4.2.2" @@ -2630,6 +2684,20 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/coreutils@npm:^6.2.5": + version: 6.2.5 + resolution: "@jupyterlab/coreutils@npm:6.2.5" + dependencies: + "@lumino/coreutils": ^2.1.2 + "@lumino/disposable": ^2.1.2 + "@lumino/signaling": ^2.1.2 + minimist: ~1.2.0 + path-browserify: ^1.0.0 + url-parse: ~1.5.4 + checksum: 3b6a10b117ee82a437b6535801fe012bb5af7769a850be95c8ffa666ee2d6f7c29041ba546c9cfca0ab32b65f91c661570541f4f785f48af9022d08407c0a3e5 + languageName: node + linkType: hard + "@jupyterlab/docmanager@npm:^4.2.2": version: 4.2.2 resolution: "@jupyterlab/docmanager@npm:4.2.2" @@ -2786,6 +2854,15 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/nbformat@npm:^4.2.5": + version: 4.2.5 + resolution: "@jupyterlab/nbformat@npm:4.2.5" + dependencies: + "@lumino/coreutils": ^2.1.2 + checksum: b3ad2026969bfa59f8cfb7b1a991419f96f7e6dc8c4acf4ac166c210d7ab99631350c785e9b04350095488965d2824492c8adbff24a2e26db615457545426b3c + languageName: node + linkType: hard + "@jupyterlab/notebook@npm:^4.2.0, @jupyterlab/notebook@npm:^4.2.2": version: 4.2.2 resolution: "@jupyterlab/notebook@npm:4.2.2" @@ -2837,6 +2914,19 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/observables@npm:^5.2.5": + version: 5.2.5 + resolution: "@jupyterlab/observables@npm:5.2.5" + dependencies: + "@lumino/algorithm": ^2.0.1 + "@lumino/coreutils": ^2.1.2 + "@lumino/disposable": ^2.1.2 + "@lumino/messaging": ^2.0.1 + "@lumino/signaling": ^2.1.2 + checksum: 21fd2828463c08a770714692ff44aeca500f8ea8f3a743ad203a61fbf04cfa81921a47b432d8e65f4935fb45c08fce2b8858cb7e2198cc9bf0fa51f482ec37bd + languageName: node + linkType: hard + "@jupyterlab/outputarea@npm:^4.2.2": version: 4.2.2 resolution: "@jupyterlab/outputarea@npm:4.2.2" @@ -2869,6 +2959,16 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/rendermime-interfaces@npm:^3.10.5": + version: 3.10.5 + resolution: "@jupyterlab/rendermime-interfaces@npm:3.10.5" + dependencies: + "@lumino/coreutils": ^1.11.0 || ^2.1.2 + "@lumino/widgets": ^1.37.2 || ^2.3.2 + checksum: acfb10315a3ed4d0b0ef664437b33f8938968c61993351fd4067b0eaf6cb6ccd4c5caf50ae050d184a34b35b88d844eee6689d00244e54a02b228c02eab544b4 + languageName: node + linkType: hard + "@jupyterlab/rendermime@npm:^4.2.0, @jupyterlab/rendermime@npm:^4.2.2": version: 4.2.2 resolution: "@jupyterlab/rendermime@npm:4.2.2" @@ -2908,6 +3008,25 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/services@npm:^7.2.5": + version: 7.2.5 + resolution: "@jupyterlab/services@npm:7.2.5" + dependencies: + "@jupyter/ydoc": ^2.0.1 + "@jupyterlab/coreutils": ^6.2.5 + "@jupyterlab/nbformat": ^4.2.5 + "@jupyterlab/settingregistry": ^4.2.5 + "@jupyterlab/statedb": ^4.2.5 + "@lumino/coreutils": ^2.1.2 + "@lumino/disposable": ^2.1.2 + "@lumino/polling": ^2.1.2 + "@lumino/properties": ^2.0.1 + "@lumino/signaling": ^2.1.2 + ws: ^8.11.0 + checksum: 72d7578a86af1277b574095423fafb4176bc66373662fdc0e243a7d20e4baf8f291377b6c80300841dba6486767f16664f0e893174c2761658aedb74024e1db6 + languageName: node + linkType: hard + "@jupyterlab/settingregistry@npm:^4.2.0, @jupyterlab/settingregistry@npm:^4.2.2": version: 4.2.2 resolution: "@jupyterlab/settingregistry@npm:4.2.2" @@ -2927,6 +3046,25 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/settingregistry@npm:^4.2.5": + version: 4.2.5 + resolution: "@jupyterlab/settingregistry@npm:4.2.5" + dependencies: + "@jupyterlab/nbformat": ^4.2.5 + "@jupyterlab/statedb": ^4.2.5 + "@lumino/commands": ^2.3.0 + "@lumino/coreutils": ^2.1.2 + "@lumino/disposable": ^2.1.2 + "@lumino/signaling": ^2.1.2 + "@rjsf/utils": ^5.13.4 + ajv: ^8.12.0 + json5: ^2.2.3 + peerDependencies: + react: ">=16" + checksum: 2403e3198f2937fb9e4c12f96121e8bfc4f2a9ed47a9ad64182c88c8c19d59fcdf7443d0bf7d04527e89ac06378ceb39d6b4196c7f575c2a21fea23283ad3892 + languageName: node + linkType: hard + "@jupyterlab/statedb@npm:^4.2.2": version: 4.2.2 resolution: "@jupyterlab/statedb@npm:4.2.2" @@ -2940,6 +3078,19 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/statedb@npm:^4.2.5": + version: 4.2.5 + resolution: "@jupyterlab/statedb@npm:4.2.5" + dependencies: + "@lumino/commands": ^2.3.0 + "@lumino/coreutils": ^2.1.2 + "@lumino/disposable": ^2.1.2 + "@lumino/properties": ^2.0.1 + "@lumino/signaling": ^2.1.2 + checksum: 236e7628070971af167eb4fdeac96a0090b2256cfa14b6a75aee5ef23b156cd57a8b25518125fbdc58dea09490f8f473740bc4b454d8ad7c23949f64a61b757e + languageName: node + linkType: hard + "@jupyterlab/statusbar@npm:^4.2.2": version: 4.2.2 resolution: "@jupyterlab/statusbar@npm:4.2.2" @@ -2956,6 +3107,22 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/statusbar@npm:^4.2.5": + version: 4.2.5 + resolution: "@jupyterlab/statusbar@npm:4.2.5" + dependencies: + "@jupyterlab/ui-components": ^4.2.5 + "@lumino/algorithm": ^2.0.1 + "@lumino/coreutils": ^2.1.2 + "@lumino/disposable": ^2.1.2 + "@lumino/messaging": ^2.0.1 + "@lumino/signaling": ^2.1.2 + "@lumino/widgets": ^2.3.2 + react: ^18.2.0 + checksum: fa429b88a5bcd6889b9ac32b5f2500cb10a968cc636ca8dede17972535cc47454cb7fc96518fc8def76935f826b66b071752d0fd26afdacba579f6f3785e97b2 + languageName: node + linkType: hard + "@jupyterlab/testing@npm:^4.2.2": version: 4.2.2 resolution: "@jupyterlab/testing@npm:4.2.2" @@ -3027,6 +3194,19 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/translation@npm:^4.2.5": + version: 4.2.5 + resolution: "@jupyterlab/translation@npm:4.2.5" + dependencies: + "@jupyterlab/coreutils": ^6.2.5 + "@jupyterlab/rendermime-interfaces": ^3.10.5 + "@jupyterlab/services": ^7.2.5 + "@jupyterlab/statedb": ^4.2.5 + "@lumino/coreutils": ^2.1.2 + checksum: 8983efad2b0d54381cb94799a10eab30f284a87103f93e844bd87106e2df3c304e260b9c95540317819cc2b2520c74ad78cb724816c81e0c315fdb43d0bdaab3 + languageName: node + linkType: hard + "@jupyterlab/ui-components@npm:^4.0.0, @jupyterlab/ui-components@npm:^4.2.0, @jupyterlab/ui-components@npm:^4.2.2": version: 4.2.2 resolution: "@jupyterlab/ui-components@npm:4.2.2" @@ -3058,6 +3238,37 @@ __metadata: languageName: node linkType: hard +"@jupyterlab/ui-components@npm:^4.2.5": + version: 4.2.5 + resolution: "@jupyterlab/ui-components@npm:4.2.5" + dependencies: + "@jupyter/react-components": ^0.15.3 + "@jupyter/web-components": ^0.15.3 + "@jupyterlab/coreutils": ^6.2.5 + "@jupyterlab/observables": ^5.2.5 + "@jupyterlab/rendermime-interfaces": ^3.10.5 + "@jupyterlab/translation": ^4.2.5 + "@lumino/algorithm": ^2.0.1 + "@lumino/commands": ^2.3.0 + "@lumino/coreutils": ^2.1.2 + "@lumino/disposable": ^2.1.2 + "@lumino/messaging": ^2.0.1 + "@lumino/polling": ^2.1.2 + "@lumino/properties": ^2.0.1 + "@lumino/signaling": ^2.1.2 + "@lumino/virtualdom": ^2.0.1 + "@lumino/widgets": ^2.3.2 + "@rjsf/core": ^5.13.4 + "@rjsf/utils": ^5.13.4 + react: ^18.2.0 + react-dom: ^18.2.0 + typestyle: ^2.0.4 + peerDependencies: + react: ^18.2.0 + checksum: 9d2b887910a3b0d41645388c5ac6183d6fd2f3af4567de9b077b2492b1a9380f98c4598a4ae6d1c3186624ed4f956bedf8ba37adb5f772c96555761384a93e1e + languageName: node + linkType: hard + "@lerna/child-process@npm:6.6.2": version: 6.6.2 resolution: "@lerna/child-process@npm:6.6.2" @@ -3324,6 +3535,13 @@ __metadata: languageName: node linkType: hard +"@lumino/algorithm@npm:^2.0.2": + version: 2.0.2 + resolution: "@lumino/algorithm@npm:2.0.2" + checksum: 34b25684b845f1bdbf78ca45ebd99a97b67b2992440c9643aafe5fc5a99fae1ddafa9e5890b246b233dc3a12d9f66aa84afe4a2aac44cf31071348ed217740db + languageName: node + linkType: hard + "@lumino/application@npm:^2.3.1": version: 2.3.1 resolution: "@lumino/application@npm:2.3.1" @@ -3344,6 +3562,21 @@ __metadata: languageName: node linkType: hard +"@lumino/commands@npm:^2.0.0": + version: 2.3.1 + resolution: "@lumino/commands@npm:2.3.1" + dependencies: + "@lumino/algorithm": ^2.0.2 + "@lumino/coreutils": ^2.2.0 + "@lumino/disposable": ^2.1.3 + "@lumino/domutils": ^2.0.2 + "@lumino/keyboard": ^2.0.2 + "@lumino/signaling": ^2.1.3 + "@lumino/virtualdom": ^2.0.2 + checksum: 83bc6d66de37e58582b00f70ce66e797c9fcf84e36041c6881631ed0d281305e2a49927f5b2fe6c5c965733f3cd6fb4a233c7b7967fc050497024a941659bd65 + languageName: node + linkType: hard + "@lumino/commands@npm:^2.3.0": version: 2.3.0 resolution: "@lumino/commands@npm:2.3.0" @@ -3366,6 +3599,15 @@ __metadata: languageName: node linkType: hard +"@lumino/coreutils@npm:^2.2.0": + version: 2.2.0 + resolution: "@lumino/coreutils@npm:2.2.0" + dependencies: + "@lumino/algorithm": ^2.0.2 + checksum: 345fcd5d7493d745831dd944edfbd8eda06cc59a117e71023fc97ce53badd697be2bd51671f071f5ff0064f75f104575d9695f116a07517bafbedd38e5c7a785 + languageName: node + linkType: hard + "@lumino/disposable@npm:^1.10.0 || ^2.0.0, @lumino/disposable@npm:^2.1.0, @lumino/disposable@npm:^2.1.2": version: 2.1.2 resolution: "@lumino/disposable@npm:2.1.2" @@ -3375,6 +3617,15 @@ __metadata: languageName: node linkType: hard +"@lumino/disposable@npm:^2.0.0, @lumino/disposable@npm:^2.1.3": + version: 2.1.3 + resolution: "@lumino/disposable@npm:2.1.3" + dependencies: + "@lumino/signaling": ^2.1.3 + checksum: b9a346fa2752b3cd1b053cb637ee173501d33082a73423429070e8acc508b034ea0babdae0549b923cbdd287ee1fc7f6159f0539c9fff7574393a214eef07c57 + languageName: node + linkType: hard + "@lumino/domutils@npm:^2.0.1": version: 2.0.1 resolution: "@lumino/domutils@npm:2.0.1" @@ -3382,6 +3633,13 @@ __metadata: languageName: node linkType: hard +"@lumino/domutils@npm:^2.0.2": + version: 2.0.2 + resolution: "@lumino/domutils@npm:2.0.2" + checksum: 037b8d0b62af43887fd7edd506fa551e2af104a4b46d62e6fef256e16754dba40d351513beb5083834d468b2c7806aae0fe205fd6aac8ef24759451ee998bbd9 + languageName: node + linkType: hard + "@lumino/dragdrop@npm:^2.1.4": version: 2.1.4 resolution: "@lumino/dragdrop@npm:2.1.4" @@ -3399,6 +3657,13 @@ __metadata: languageName: node linkType: hard +"@lumino/keyboard@npm:^2.0.2": + version: 2.0.2 + resolution: "@lumino/keyboard@npm:2.0.2" + checksum: 198e8c17825c9a831fa0770f58a71574b936acb0f0bbbe7f8feb73d89686dda7ff41fcb02d12b401f5d462b45fe0bba24f7f38befb7cefe0826576559f0bee6d + languageName: node + linkType: hard + "@lumino/messaging@npm:^2.0.1": version: 2.0.1 resolution: "@lumino/messaging@npm:2.0.1" @@ -3437,6 +3702,16 @@ __metadata: languageName: node linkType: hard +"@lumino/signaling@npm:^2.0.0, @lumino/signaling@npm:^2.1.3": + version: 2.1.3 + resolution: "@lumino/signaling@npm:2.1.3" + dependencies: + "@lumino/algorithm": ^2.0.2 + "@lumino/coreutils": ^2.2.0 + checksum: ce59383bd75fe30df5800e0442dbc4193cc6778e2530b9be0f484d159f1d8668be5c6ee92cee9df36d5a0c3dbd9126d0479a82581dee1df889d5c9f922d3328d + languageName: node + linkType: hard + "@lumino/virtualdom@npm:^2.0.0, @lumino/virtualdom@npm:^2.0.1": version: 2.0.1 resolution: "@lumino/virtualdom@npm:2.0.1" @@ -3446,6 +3721,15 @@ __metadata: languageName: node linkType: hard +"@lumino/virtualdom@npm:^2.0.2": + version: 2.0.2 + resolution: "@lumino/virtualdom@npm:2.0.2" + dependencies: + "@lumino/algorithm": ^2.0.2 + checksum: 0e1220d5b3b2441e7668f3542a6341e015bdbea0c8bd6d4be962009386c034336540732596d5dedcd54ca57fbde61c2942549129a3e1b0fccb1aa143685fcd15 + languageName: node + linkType: hard + "@lumino/widgets@npm:^1.37.2 || ^2.3.2, @lumino/widgets@npm:^2.1.0, @lumino/widgets@npm:^2.3.2": version: 2.3.2 resolution: "@lumino/widgets@npm:2.3.2" From b7e25f49df3410a688fca4478edc8283de215638 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet <32258950+brichet@users.noreply.github.com> Date: Thu, 19 Sep 2024 16:07:03 +0200 Subject: [PATCH 05/24] Stream messages (#3) * Allow for stream messages * update jupyter collaborative chat dependency --- .../jupyter_ai/chat_handlers/default.py | 43 ++++++++++--------- packages/jupyter-ai/jupyter_ai/extension.py | 19 ++++++-- packages/jupyter-ai/package.json | 2 +- packages/jupyter-ai/pyproject.toml | 3 +- yarn.lock | 10 ++--- 5 files changed, 45 insertions(+), 32 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 480fd4565..7e90fa2f0 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -11,6 +11,7 @@ from langchain_core.messages import AIMessageChunk from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.outputs.generation import GenerationChunk try: from jupyterlab_collaborative_chat.ychat import YChat @@ -62,27 +63,29 @@ def create_llm_chain( ) self.llm_chain = runnable - def _start_stream(self, human_msg: HumanChatMessage) -> str: + def _start_stream(self, human_msg: HumanChatMessage, chat: YChat | None) -> str: """ Sends an `agent-stream` message to indicate the start of a response stream. Returns the ID of the message, denoted as the `stream_id`. """ - stream_id = uuid4().hex - stream_msg = AgentStreamMessage( - id=stream_id, - time=time.time(), - body="", - reply_to=human_msg.id, - persona=self.persona, - complete=False, - ) - - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(stream_msg) - break + if chat is not None: + stream_id = self.write_message(chat, "") + else: + stream_id = uuid4().hex + stream_msg = AgentStreamMessage( + id=stream_id, + time=time.time(), + body="", + reply_to=human_msg.id, + persona=self.persona, + complete=False, + ) + for handler in self._root_chat_handlers.values(): + if not handler: + continue + + handler.broadcast_message(stream_msg) + break return stream_id @@ -92,7 +95,7 @@ def _send_stream_chunk(self, stream_id: str, content: str, chat: YChat | None, c appended to an existing `agent-stream` message with ID `stream_id`. """ if chat is not None: - self.write_message(chat, content) + self.write_message(chat, content, stream_id) else: stream_chunk_msg = AgentStreamChunkMessage( id=stream_id, content=content, stream_complete=complete @@ -121,11 +124,11 @@ async def process_message(self, message: HumanChatMessage, chat: YChat | None): # when receiving the first chunk, close the pending message and # start the stream. self.close_pending(pending_message) - stream_id = self._start_stream(human_msg=message) + stream_id = self._start_stream(message, chat) received_first_chunk = True if isinstance(chunk, AIMessageChunk): - self._send_stream_chunk(stream_id, chunk.content) + self._send_stream_chunk(stream_id, chunk.content, chat) elif isinstance(chunk, str): self._send_stream_chunk(stream_id, chunk, chat) else: diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 128263726..201291019 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -229,6 +229,10 @@ def initialize(self): listener=self.connect_chat ) + # Keep the message indexes to avoid extra computation looking for a message when + # updating it. + self.messages_indexes = {} + async def connect_chat(self, logger: EventLogger, schema_id: str, data: dict) -> None: if data["room"].startswith("text:chat:") \ and data["action"] == "initialize"\ @@ -299,7 +303,7 @@ async def _route(self, message: HumanChatMessage, chat: YChat): command_readable = "Default" if command == "default" else command self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.") - def write_message(self, chat: YChat, body: str) -> None: + def write_message(self, chat: YChat, body: str, id: str | None = None) -> str: BOT["avatar_url"] = url_path_join( self.settings.get("base_url", "/"), "api/ai/static/jupyternaut.svg" @@ -310,14 +314,21 @@ def write_message(self, chat: YChat, body: str) -> None: else: BOT["username"] = bot["username"] - chat.add_message({ + index = self.messages_indexes[id] if id else None + id = id if id else str(uuid.uuid4()) + new_index = chat.set_message({ "type": "msg", "body": body, - "id": str(uuid.uuid4()), + "id": id if id else str(uuid.uuid4()), "time": time.time(), "sender": BOT["username"], "raw_time": False - }) + }, index, True) + + if new_index != index: + self.messages_indexes[id] = new_index + + return id def initialize_settings(self): start = time.time() diff --git a/packages/jupyter-ai/package.json b/packages/jupyter-ai/package.json index a1f8029da..cc404aaa3 100644 --- a/packages/jupyter-ai/package.json +++ b/packages/jupyter-ai/package.json @@ -61,7 +61,7 @@ "dependencies": { "@emotion/react": "^11.10.5", "@emotion/styled": "^11.10.5", - "@jupyter/chat": "^0.3.1", + "@jupyter/chat": "^0.4.0", "@jupyter/collaboration": "^1", "@jupyterlab/application": "^4.2.0", "@jupyterlab/apputils": "^4.2.0", diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 9b1d3d996..2353343e3 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "typing_extensions>=4.5.0", "traitlets>=5.0", "deepmerge>=1.0", + "jupyterlab-collaborative-chat", ] dynamic = ["version", "description", "authors", "urls", "keywords"] @@ -56,8 +57,6 @@ dev = ["jupyter_ai_magics[dev]"] all = ["jupyter_ai_magics[all]", "pypdf", "arxiv"] -collaborative = ["jupyterlab-collaborative-chat"] - [tool.hatch.version] source = "nodejs" diff --git a/yarn.lock b/yarn.lock index 49d9b3304..5d45cf685 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2220,7 +2220,7 @@ __metadata: "@babel/preset-env": ^7.0.0 "@emotion/react": ^11.10.5 "@emotion/styled": ^11.10.5 - "@jupyter/chat": ^0.3.1 + "@jupyter/chat": ^0.4.0 "@jupyter/collaboration": ^1 "@jupyterlab/application": ^4.2.0 "@jupyterlab/apputils": ^4.2.0 @@ -2286,9 +2286,9 @@ __metadata: languageName: unknown linkType: soft -"@jupyter/chat@npm:^0.3.1": - version: 0.3.1 - resolution: "@jupyter/chat@npm:0.3.1" +"@jupyter/chat@npm:^0.4.0": + version: 0.4.0 + resolution: "@jupyter/chat@npm:0.4.0" dependencies: "@emotion/react": ^11.10.5 "@emotion/styled": ^11.10.5 @@ -2306,7 +2306,7 @@ __metadata: clsx: ^2.1.0 react: ^18.2.0 react-dom: ^18.2.0 - checksum: 92d1f6d6d3083be2a8c2309fc37eb3bb39bc184e1c41d18f79d51f87e54d7badba1495c5943916e7a21fc65414d6100860b044a14554add8192b545fca2748dd + checksum: 6e309c8e70cf480103eb26f3109da417c58d2e6844d5e56e63feabf71926f9dba6f9bc85caff765dfc574a8fd7ed803a8c03e5d812c28568dcf6ec918bbd2e66 languageName: node linkType: hard From 809569159f0f52811b1c7890f11771d1d139c6ca Mon Sep 17 00:00:00 2001 From: Nicolas Brichet <32258950+brichet@users.noreply.github.com> Date: Tue, 15 Oct 2024 09:25:05 +0200 Subject: [PATCH 06/24] AI settings (#4) * Add a menu option to open the AI settings * Remove the input option from the setting widget --- packages/jupyter-ai/schema/plugin.json | 22 ++++++ .../src/components/chat-settings.tsx | 69 +++++++++++-------- packages/jupyter-ai/src/index.ts | 50 +++++++++++++- .../src/widgets/settings-widget.tsx | 26 +++++++ 4 files changed, 135 insertions(+), 32 deletions(-) create mode 100644 packages/jupyter-ai/src/widgets/settings-widget.tsx diff --git a/packages/jupyter-ai/schema/plugin.json b/packages/jupyter-ai/schema/plugin.json index 78804b5c6..73dc1b60a 100644 --- a/packages/jupyter-ai/schema/plugin.json +++ b/packages/jupyter-ai/schema/plugin.json @@ -12,6 +12,28 @@ "preventDefault": false } ], + "jupyter.lab.menus": { + "main": [ + { + "id": "jp-mainmenu-settings", + "items": [ + { + "type": "separator", + "rank": 110 + + }, + { + "command": "jupyter-ai:open-settings", + "rank": 110 + }, + { + "type": "separator", + "rank": 110 + } + ] + } + ] + }, "additionalProperties": false, "type": "object" } diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index a1ad0a9b6..b9e9d8bd1 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -34,6 +34,9 @@ type ChatSettingsProps = { rmRegistry: IRenderMimeRegistry; completionProvider: IJaiCompletionProvider | null; openInlineCompleterSettings: () => void; + // The temporary input options, should be removed when the collaborative chat is + // the only chat. + inputOptions?: boolean; }; /** @@ -511,36 +514,42 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { onSuccess={server.refetchApiKeys} /> - {/* Input */} -

    Input

    - - - When writing a message, press Enter to: - - { - setSendWse(e.target.value === 'newline'); - }} - > - } - label="Send the message" - /> - } - label={ - <> - Start a new line (use Shift+Enter to send) - - } - /> - - + {/* Input - to remove when the collaborative chat is the only chat */} + {(props.inputOptions ?? true) && ( + <> +

    Input

    + + + When writing a message, press Enter to: + + { + setSendWse(e.target.value === 'newline'); + }} + > + } + label="Send the message" + /> + } + label={ + <> + Start a new line (use Shift+Enter to + send) + + } + /> + + + + )} +