Skip to content

Commit

Permalink
Base chat handler refactor for custom slash commands (jupyterlab#398)
Browse files Browse the repository at this point in the history
* Adds attributes, starts adding to subclasses

* Consistent syntax

* Help for all handlers

* Fix slash ID error

* Iterate through entry points

* Fix typo in call to select()

* Moves config to magics, modifies extensions to attempt to load classes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Moves config to proper location, improves error logging

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP: Updates per feedback, adds custom handler

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Removes redundant code, style fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Removes unnecessary custom message

* Instantiates class

* Validates slash ID

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Consistent arguments to chat handlers

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactors to avoid intentionally unused params

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Updates docs, removes custom handler from source and config

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Renames process_message to match base class

* Adds needed parameter that had been deleted

* Joins lines in contributor doc

* Removes natural language routing type, which is not yet used

* Update docs/source/developers/index.md

Co-authored-by: Piyush Jain <[email protected]>

* Update docs/source/developers/index.md

Co-authored-by: Piyush Jain <[email protected]>

* Update docs/source/developers/index.md

Co-authored-by: Piyush Jain <[email protected]>

* Revises per @3coins, avoids Latinism

* Removes Configurable, since we do not yet have configurable traits

* Uses Literal for validation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Piyush Jain <[email protected]>
  • Loading branch information
3 people authored and Marchlak committed Oct 28, 2024
1 parent 2680b0f commit 2104714
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 44 deletions.
42 changes: 42 additions & 0 deletions docs/source/developers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,45 @@ class MyProvider(BaseProvider, FakeListLLM):
```

Please note that this will only work with Jupyter AI magics (the `%ai` and `%%ai` magic commands). Custom prompt templates are not used in the chat interface yet.

## Custom slash commands in the chat UI

You can add a custom slash command to the chat interface by
creating a new class that inherits from `BaseChatHandler`. Set
its `id`, `name`, `help` message for display in the user interface,
and `routing_type`. Each custom slash command must have a unique
slash command. Slash commands can only contain ASCII letters, numerals,
and underscores. Each slash command must be unique; custom slash
commands cannot replace built-in slash commands.

Add your custom handler in Python code:

```python
from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType
from jupyter_ai.models import HumanChatMessage

class CustomChatHandler(BaseChatHandler):
id = "custom"
name = "Custom"
help = "A chat handler that does something custom"
routing_type = SlashCommandRoutingType(slash_id="custom")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, message: HumanChatMessage):
# Put your custom logic here
self.reply("<your-response>", message)
```

Jupyter AI uses entry points to support custom slash commands.
In the `pyproject.toml` file, add your custom handler to the
`[project.entry-points."jupyter_ai.chat_handlers"]` section:

```
[project.entry-points."jupyter_ai.chat_handlers"]
custom = "custom_package:CustomChatHandler"
```

Then, install your package so that Jupyter AI adds custom chat handlers
to the existing chat handlers.
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .ask import AskChatHandler
from .base import BaseChatHandler
from .base import BaseChatHandler, SlashCommandRoutingType
from .clear import ClearChatHandler
from .default import DefaultChatHandler
from .generate import GenerateChatHandler
Expand Down
7 changes: 6 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate

from .base import BaseChatHandler
from .base import BaseChatHandler, SlashCommandRoutingType

PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Expand All @@ -26,6 +26,11 @@ class AskChatHandler(BaseChatHandler):
to the LLM to generate the final reply.
"""

id = "ask"
name = "Ask with Local Data"
help = "Asks a question with retrieval augmented generation (RAG)"
routing_type = SlashCommandRoutingType(slash_id="ask")

def __init__(self, retriever, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
54 changes: 50 additions & 4 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,81 @@
import argparse
import os
import time
import traceback

# necessary to prevent circular import
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from typing import (
TYPE_CHECKING,
Awaitable,
ClassVar,
Dict,
List,
Literal,
Optional,
Type,
)
from uuid import uuid4

from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.models import AgentChatMessage, HumanChatMessage
from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider

# necessary to prevent circular import
from pydantic import BaseModel

if TYPE_CHECKING:
from jupyter_ai.handlers import RootChatHandler


# Chat handler type, with specific attributes for each
class HandlerRoutingType(BaseModel):
routing_method: ClassVar[str] = Literal["slash_command"]
"""The routing method that sends commands to this handler."""


class SlashCommandRoutingType(HandlerRoutingType):
routing_method = "slash_command"

slash_id: Optional[str]
"""Slash ID for routing a chat command to this handler. Only one handler
may declare a particular slash ID. Must contain only alphanumerics and
underscores."""


class BaseChatHandler:
"""Base ChatHandler class containing shared methods and attributes used by
multiple chat handler classes."""

# Class attributes
id: ClassVar[str] = ...
"""ID for this chat handler; should be unique"""

name: ClassVar[str] = ...
"""User-facing name of this handler"""

help: ClassVar[str] = ...
"""What this chat handler does, which third-party models it contacts,
the data it returns to the user, and so on, for display in the UI."""

routing_type: HandlerRoutingType = ...

def __init__(
self,
log: Logger,
config_manager: ConfigManager,
root_chat_handlers: Dict[str, "RootChatHandler"],
model_parameters: Dict[str, Dict],
chat_history: List[ChatMessage],
root_dir: str,
dask_client_future: Awaitable[DaskClient],
):
self.log = log
self.config_manager = config_manager
self._root_chat_handlers = root_chat_handlers
self.model_parameters = model_parameters
self._chat_history = chat_history
self.parser = argparse.ArgumentParser()
self.root_dir = os.path.abspath(os.path.expanduser(root_dir))
self.dask_client_future = dask_client_future
self.llm = None
self.llm_params = None
self.llm_chain = None
Expand Down
10 changes: 7 additions & 3 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

from jupyter_ai.models import ChatMessage, ClearMessage

from .base import BaseChatHandler
from .base import BaseChatHandler, SlashCommandRoutingType


class ClearChatHandler(BaseChatHandler):
def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
id = "clear"
name = "Clear chat messages"
help = "Clears the displayed chat message history only; does not clear the context sent to chat providers"
routing_type = SlashCommandRoutingType(slash_id="clear")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._chat_history = chat_history

async def process_message(self, _):
self._chat_history.clear()
Expand Down
14 changes: 9 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SystemMessagePromptTemplate,
)

from .base import BaseChatHandler
from .base import BaseChatHandler, SlashCommandRoutingType

SYSTEM_PROMPT = """
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
Expand All @@ -32,10 +32,14 @@


class DefaultChatHandler(BaseChatHandler):
def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
id = "default"
name = "Default"
help = "Responds to prompts that are not otherwise handled by a chat handler"
routing_type = SlashCommandRoutingType(slash_id=None)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.memory = ConversationBufferWindowMemory(return_messages=True, k=2)
self.chat_history = chat_history

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand Down Expand Up @@ -80,8 +84,8 @@ def clear_memory(self):
self.reply(reply_message)

# clear transcript for new chat clients
if self.chat_history:
self.chat_history.clear()
if self._chat_history:
self._chat_history.clear()

async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
Expand Down
10 changes: 6 additions & 4 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, List, Optional, Type

import nbformat
from jupyter_ai.chat_handlers import BaseChatHandler
from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
Expand Down Expand Up @@ -216,11 +216,13 @@ def create_notebook(outline):


class GenerateChatHandler(BaseChatHandler):
"""Generates a Jupyter notebook given a description."""
id = "generate"
name = "Generate Notebook"
help = "Generates a Jupyter notebook, including name, outline, and section contents"
routing_type = SlashCommandRoutingType(slash_id="generate")

def __init__(self, root_dir: str, *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.root_dir = os.path.abspath(os.path.expanduser(root_dir))
self.llm = None

def create_llm_chain(
Expand Down
7 changes: 6 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from jupyter_ai.models import AgentChatMessage, HumanChatMessage

from .base import BaseChatHandler
from .base import BaseChatHandler, SlashCommandRoutingType

HELP_MESSAGE = """Hi there! I'm Jupyternaut, your programming assistant.
You can ask me a question using the text box below. You can also use these commands:
Expand All @@ -29,6 +29,11 @@ def HelpMessage():


class HelpChatHandler(BaseChatHandler):
id = "help"
name = "Help"
help = "Displays a help message in the chat message area"
routing_type = SlashCommandRoutingType(slash_id="help")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
13 changes: 7 additions & 6 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@
)
from langchain.vectorstores import FAISS

from .base import BaseChatHandler
from .base import BaseChatHandler, SlashCommandRoutingType

INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), "jupyter_ai", "indices")
METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, "metadata.json")


class LearnChatHandler(BaseChatHandler):
def __init__(
self, root_dir: str, dask_client_future: Awaitable[DaskClient], *args, **kwargs
):
id = "learn"
name = "Learn Local Data"
help = "Pass a list of files and directories. Once converted to vector format, you can ask about them with /ask."
routing_type = SlashCommandRoutingType(slash_id="learn")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.root_dir = root_dir
self.dask_client_future = dask_client_future
self.parser.prog = "/learn"
self.parser.add_argument("-a", "--all-files", action="store_true")
self.parser.add_argument("-v", "--verbose", action="store_true")
Expand Down
Loading

0 comments on commit 2104714

Please sign in to comment.