diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 7894dcfa5..f5cfef982 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -9,12 +9,13 @@ from jupyter_ai.config_manager import ConfigManager, Logger from jupyter_ai.models import AgentChatMessage, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider +from traitlets.config import Configurable if TYPE_CHECKING: from jupyter_ai.handlers import RootChatHandler -class BaseChatHandler: +class BaseChatHandler(Configurable): """Base ChatHandler class containing shared methods and attributes used by multiple chat handler classes.""" @@ -23,7 +24,10 @@ def __init__( log: Logger, config_manager: ConfigManager, root_chat_handlers: Dict[str, "RootChatHandler"], + *args, + **kwargs, ): + super().__init__(*args, **kwargs) self.log = log self.config_manager = config_manager self._root_chat_handlers = root_chat_handlers @@ -94,6 +98,28 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): handler.broadcast_message(agent_msg) break + @property + def lm_id(self): + """Retrieves the language model ID from the config manager.""" + lm_provider = self.config_manager.lm_provider + lm_provider_params = self.config_manager.lm_provider_params + + if lm_provider: + return lm_provider.id + ":" + lm_provider_params["model_id"] + else: + return None + + @property + def em_id(self): + """Retrieves the embedding model ID from the config manager.""" + em_provider = self.config_manager.em_provider + em_provider_params = self.config_manager.em_provider_params + + if em_provider: + return em_provider.id + ":" + em_provider_params["model_id"] + else: + return None + def get_llm_chain(self): lm_provider = self.config_manager.lm_provider lm_provider_params = self.config_manager.lm_provider_params diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3638b6cfb..5f798b5ef 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,6 +1,7 @@ from typing import Dict, List, Type from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage +from jupyter_ai.prompt_templates import ChatPromptTemplates from jupyter_ai_magics.providers import BaseProvider from langchain.chains import ConversationChain from langchain.memory import ConversationBufferWindowMemory @@ -14,22 +15,6 @@ from .base import BaseChatHandler -SYSTEM_PROMPT = """ -You are Jupyternaut, a conversational assistant living in JupyterLab to help users. -You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. -You are talkative and you provide lots of specific details from the foundation model's context. -You may use Markdown to format your response. -Code blocks must be formatted in Markdown. -Math should be rendered with inline TeX markup, surrounded by $. -If you do not know the answer to a question, answer truthfully by responding that you do not know. -The following is a friendly conversation between you and a human. -""".strip() - -DEFAULT_TEMPLATE = """Current conversation: -{history} -Human: {input} -AI:""" - class DefaultChatHandler(BaseChatHandler): def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): @@ -37,6 +22,10 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): self.memory = ConversationBufferWindowMemory(return_messages=True, k=2) self.chat_history = chat_history + @property + def templates(self): + return ChatPromptTemplates(self.lm_id, config=self.config) + def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): @@ -45,9 +34,9 @@ def create_llm_chain( if llm.is_chat_provider: prompt_template = ChatPromptTemplate.from_messages( [ - SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format( - provider_name=llm.name, local_model_id=llm.model_id - ), + SystemMessagePromptTemplate.from_template( + self.templates.system + ).format(provider_name=llm.name, local_model_id=llm.model_id), MessagesPlaceholder(variable_name="history"), HumanMessagePromptTemplate.from_template("{input}"), ] @@ -56,11 +45,11 @@ def create_llm_chain( else: prompt_template = PromptTemplate( input_variables=["history", "input"], - template=SYSTEM_PROMPT.format( + template=self.templates.system.format( provider_name=llm.name, local_model_id=llm.model_id ) + "\n\n" - + DEFAULT_TEMPLATE, + + self.templates.default, ) self.memory = ConversationBufferWindowMemory(k=2) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index a2ecd5245..e1a232c0e 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -145,6 +145,7 @@ def initialize_settings(self): # initialize chat handlers chat_handler_kwargs = { "log": self.log, + "config": self.config, # traitlets config "config_manager": self.settings["jai_config_manager"], "root_chat_handlers": self.settings["jai_root_chat_handlers"], } diff --git a/packages/jupyter-ai/jupyter_ai/prompt_templates.py b/packages/jupyter-ai/jupyter_ai/prompt_templates.py new file mode 100644 index 000000000..0899373f9 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/prompt_templates.py @@ -0,0 +1,75 @@ +from traitlets import Dict, Unicode +from traitlets.config import Configurable + +SYSTEM_TEMPLATE = """ +You are Jupyternaut, a conversational assistant living in JupyterLab to help users. +You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. +You are talkative and you provide lots of specific details from the foundation model's context. +You may use Markdown to format your response. +Code blocks must be formatted in Markdown. +Math should be rendered with inline TeX markup, surrounded by $. +If you do not know the answer to a question, answer truthfully by responding that you do not know. +The following is a friendly conversation between you and a human. +""".strip() + +HISTORY_TEMPLATE = """ +Current conversation: +{history} +Human: {input} +AI: +""".strip() + + +class ChatPromptTemplates(Configurable): + system_template = Unicode( + default_value=SYSTEM_TEMPLATE, + help="The system prompt template.", + allow_none=False, + config=True, + ) + + system_overrides = Dict( + key_trait=Unicode(), + value_trait=Unicode(), + default_value={}, + help="Defines model-specific overrides of the system prompt template.", + allow_none=False, + config=True, + ) + + history_template = Unicode( + default_value=HISTORY_TEMPLATE, + help="The history prompt template.", + allow_none=False, + config=True, + ) + + history_overrides = Dict( + key_trait=Unicode(), + value_trait=Unicode(), + default_value={}, + help="Defines model-specific overrides of the history prompt template.", + allow_none=False, + config=True, + ) + + lm_id: str = None + + def __init__(self, lm_id, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def system(self) -> str: + return self.system_overrides.get(self.lm_id, self.system_template) + + @property + def history(self) -> str: + return self.history_overrides.get(self.lm_id, self.history_template) + + +class AskPromptTemplates(Configurable): + ... + + +class GeneratePromptTemplates(Configurable): + ...