Skip to content

Commit

Permalink
Backport PR #713 on branch 1.x (Improve support for custom providers) (
Browse files Browse the repository at this point in the history
…#715)

* improve support for custom providers

adds 3 new provider class attributes:
1. `manages_history`
2. `unsupported_slash_commands`
3. `persona`

* pre-commit

* add comment about jupyternaut icon in frontend

* remove 'avatar_path' from 'Persona', drop 'PersonaDescription'

* pre-commit
  • Loading branch information
dlqqq authored Apr 4, 2024
1 parent cedd4b7 commit 9f6e863
Show file tree
Hide file tree
Showing 13 changed files with 195 additions and 22 deletions.
4 changes: 4 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from .exception import store_exception
from .magics import AiMagics

# expose JupyternautPersona on the package root
# required by `jupyter-ai`.
from .models.persona import JupyternautPersona, Persona

# expose model providers on the package root
from .providers import (
AI21Provider,
Expand Down
26 changes: 26 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from langchain.pydantic_v1 import BaseModel


class Persona(BaseModel):
"""
Model of an **agent persona**, a struct that includes the name & avatar
shown on agent replies in the chat UI.
Each persona is specific to a single provider, set on the `persona` field.
"""

name: str = ...
"""
Name of the persona, e.g. "Jupyternaut". This is used to render the name
shown on agent replies in the chat UI.
"""

avatar_route: str = ...
"""
The server route that should be used the avatar of this persona. This is
used to render the avatar shown on agent replies in the chat UI.
"""


JUPYTERNAUT_AVATAR_ROUTE = "api/ai/static/jupyternaut.svg"
JupyternautPersona = Persona(name="Jupyternaut", avatar_route=JUPYTERNAUT_AVATAR_ROUTE)
25 changes: 25 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
except:
from pydantic.main import ModelMetaclass

from .models.persona import Persona

CHAT_SYSTEM_PROMPT = """
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
Expand Down Expand Up @@ -188,6 +189,30 @@ class Config:
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""

manages_history: ClassVar[bool] = False
"""Whether this provider manages its own conversation history upstream. If
set to `True`, Jupyter AI will not pass the chat history to this provider
when invoked."""

persona: ClassVar[Optional[Persona]] = None
"""
The **persona** of this provider, a struct that defines the name and avatar
shown on agent replies in the chat UI. When set to `None`, `jupyter-ai` will
choose a default persona when rendering agent messages by this provider.
Because this field is set to `None` by default, `jupyter-ai` will render a
default persona for all providers that are included natively with the
`jupyter-ai` package. This field is reserved for Jupyter AI modules that
serve a custom provider and want to distinguish it in the chat UI.
"""

unsupported_slash_commands: ClassVar[set] = {}
"""
A set of slash commands unsupported by this provider. Unsupported slash
commands are not shown in the help message, and cannot be used while this
provider is selected.
"""

#
# instance attrs
#
Expand Down
21 changes: 20 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.pydantic_v1 import BaseModel

Expand Down Expand Up @@ -94,10 +95,21 @@ async def on_message(self, message: HumanChatMessage):
`self.handle_exc()` when an exception is raised. This method is called
by RootChatHandler when it routes a human message to this chat handler.
"""
lm_provider_klass = self.config_manager.lm_provider

# ensure the current slash command is supported
if self.routing_type.routing_method == "slash_command":
slash_command = (
"/" + self.routing_type.slash_id if self.routing_type.slash_id else ""
)
if slash_command in lm_provider_klass.unsupported_slash_commands:
self.reply(
"Sorry, the selected language model does not support this slash command."
)
return

# check whether the configured LLM can support a request at this time.
if self.uses_llm and BaseChatHandler._requests_count > 0:
lm_provider_klass = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
lm_provider = lm_provider_klass(**lm_provider_params)

Expand Down Expand Up @@ -159,11 +171,18 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage):
self.reply(response, message)

def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
"""
Sends an agent message, usually in response to a received
`HumanChatMessage`.
"""
persona = self.config_manager.persona

agent_msg = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=response,
reply_to=human_msg.id if human_msg else "",
persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
)

for handler in self._root_chat_handlers.values():
Expand Down
14 changes: 9 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationChain
from langchain.chains import ConversationChain, LLMChain
from langchain.memory import ConversationBufferWindowMemory

from .base import BaseChatHandler, SlashCommandRoutingType
Expand Down Expand Up @@ -30,14 +30,18 @@ def create_llm_chain(
llm = provider(**unified_parameters)

prompt_template = llm.get_chat_prompt_template()
self.llm = llm
self.memory = ConversationBufferWindowMemory(
return_messages=llm.is_chat_provider, k=2
)

self.llm = llm
self.llm_chain = ConversationChain(
llm=llm, prompt=prompt_template, verbose=True, memory=self.memory
)
if llm.manages_history:
self.llm_chain = LLMChain(llm=llm, prompt=prompt_template, verbose=True)

else:
self.llm_chain = ConversationChain(
llm=llm, prompt=prompt_template, verbose=True, memory=self.memory
)

def clear_memory(self):
# clear chain memory
Expand Down
24 changes: 19 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from uuid import uuid4

from jupyter_ai.models import AgentChatMessage, HumanChatMessage
from jupyter_ai_magics import Persona

from .base import BaseChatHandler, SlashCommandRoutingType

HELP_MESSAGE = """Hi there! I'm Jupyternaut, your programming assistant.
HELP_MESSAGE = """Hi there! I'm {persona_name}, your programming assistant.
You can ask me a question using the text box below. You can also use these commands:
{commands}
Expand All @@ -15,23 +16,36 @@
"""


def _format_help_message(chat_handlers: Dict[str, BaseChatHandler]):
def _format_help_message(
chat_handlers: Dict[str, BaseChatHandler],
persona: Persona,
unsupported_slash_commands: set,
):
if unsupported_slash_commands:
keys = set(chat_handlers.keys()) - unsupported_slash_commands
chat_handlers = {key: chat_handlers[key] for key in keys}

commands = "\n".join(
[
f"* `{command_name}` — {handler.help}"
for command_name, handler in chat_handlers.items()
if command_name != "default"
]
)
return HELP_MESSAGE.format(commands=commands)
return HELP_MESSAGE.format(commands=commands, persona_name=persona.name)


def HelpMessage(chat_handlers: Dict[str, BaseChatHandler]):
def build_help_message(
chat_handlers: Dict[str, BaseChatHandler],
persona: Persona,
unsupported_slash_commands: set,
):
return AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=_format_help_message(chat_handlers),
body=_format_help_message(chat_handlers, persona, unsupported_slash_commands),
reply_to="",
persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
)


Expand Down
12 changes: 12 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest
from jupyter_ai_magics import JupyternautPersona, Persona
from jupyter_ai_magics.utils import (
AnyProvider,
EmProvidersDict,
Expand Down Expand Up @@ -452,3 +453,14 @@ def em_provider_params(self):
"model_id": em_lid,
**authn_fields,
}

@property
def persona(self) -> Persona:
"""
The current agent persona, set by the selected LM provider. If the
selected LM provider is `None`, this property returns
`JupyternautPersona` by default.
"""
lm_provider = self.lm_provider
persona = getattr(lm_provider, "persona", None) or JupyternautPersona
return persona
47 changes: 42 additions & 5 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
import os
import re
import time

from dask.distributed import Client as DaskClient
from importlib_metadata import entry_points
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics import JupyternautPersona
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
from tornado.web import StaticFileHandler
from traitlets import Dict, List, Unicode

from .chat_handlers import (
Expand All @@ -18,7 +20,7 @@
HelpChatHandler,
LearnChatHandler,
)
from .chat_handlers.help import HelpMessage
from .chat_handlers.help import build_help_message
from .config_manager import ConfigManager
from .handlers import (
ApiKeysHandler,
Expand All @@ -29,6 +31,11 @@
RootChatHandler,
)

JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route
JUPYTERNAUT_AVATAR_PATH = str(
os.path.join(os.path.dirname(__file__), "static", "jupyternaut.svg")
)


class AiExtension(ExtensionApp):
name = "jupyter_ai"
Expand All @@ -39,6 +46,14 @@ class AiExtension(ExtensionApp):
(r"api/ai/chats/history?", ChatHistoryHandler),
(r"api/ai/providers?", ModelProviderHandler),
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
# serve the default persona avatar at this path.
# the `()` at the end of the URL denotes an empty regex capture group,
# required by Tornado.
(
rf"{JUPYTERNAUT_AVATAR_ROUTE}()",
StaticFileHandler,
{"path": JUPYTERNAUT_AVATAR_PATH},
),
]

allowed_providers = List(
Expand Down Expand Up @@ -296,14 +311,36 @@ def initialize_settings(self):
# Make help always appear as the last command
jai_chat_handlers["/help"] = help_chat_handler

self.settings["chat_history"].append(
HelpMessage(chat_handlers=jai_chat_handlers)
)
# bind chat handlers to settings
self.settings["jai_chat_handlers"] = jai_chat_handlers

# show help message at server start
self._show_help_message()

latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.")

def _show_help_message(self):
"""
Method that ensures a dynamically-generated help message is included in
the chat history shown to users.
"""
chat_handlers = self.settings["jai_chat_handlers"]
config_manager: ConfigManager = self.settings["jai_config_manager"]
lm_provider = config_manager.lm_provider

if not lm_provider:
return

persona = config_manager.persona
unsupported_slash_commands = (
lm_provider.unsupported_slash_commands if lm_provider else set()
)
help_message = build_help_message(
chat_handlers, persona, unsupported_slash_commands
)
self.settings["chat_history"].append(help_message)

async def _get_dask_client(self):
return DaskClient(processes=False, asynchronous=True)

Expand Down
13 changes: 12 additions & 1 deletion packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Literal, Optional, Union

from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import AuthStrategy, Field
from langchain.pydantic_v1 import BaseModel, validator

Expand Down Expand Up @@ -34,8 +35,18 @@ class AgentChatMessage(BaseModel):
id: str
time: float
body: str
# message ID of the HumanChatMessage it is replying to

reply_to: str
"""
Message ID of the HumanChatMessage being replied to. This is set to an empty
string if not applicable.
"""

persona: Persona
"""
The persona of the selected provider. If the selected provider is `None`,
this defaults to a description of `JupyternautPersona`.
"""


class HumanChatMessage(BaseModel):
Expand Down
9 changes: 9 additions & 0 deletions packages/jupyter-ai/jupyter_ai/static/jupyternaut.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 9f6e863

Please sign in to comment.