Skip to content

Commit

Permalink
[v3-dev] Dedicate one LangChain history object per chat (#1151)
Browse files Browse the repository at this point in the history
* dedicate a separate LangChain history object per chat

* pre-commit

* fix mypy
  • Loading branch information
dlqqq authored Dec 11, 2024
1 parent e52aedf commit 3581a8f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
3 changes: 3 additions & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ def get_llm_chat_memory(
last_human_msg: HumanChatMessage,
**kwargs,
) -> "BaseChatMessageHistory":
if self.ychat:
return self.llm_chat_memory

return WrappedBoundedChatHistory(
history=self.llm_chat_memory,
last_human_msg=last_human_msg,
Expand Down
16 changes: 13 additions & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
RootChatHandler,
SlashCommandsInfoHandler,
)
from .history import BoundedChatHistory
from .history import BoundedChatHistory, YChatHistory

from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip
__version__ as jupyter_collaboration_version,
Expand Down Expand Up @@ -418,9 +418,13 @@ def initialize_settings(self):
# list of chat messages to broadcast to new clients
# this is only used to render the UI, and is not the conversational
# memory object used by the LM chain.
#
# TODO: remove this in v3. this list is only used by the REST API to get
# history in v2 chat.
self.settings["chat_history"] = []

# conversational memory object used by LM chain
# TODO: remove this in v3. this is the history implementation that
# provides memory to the chat model in v2.
self.settings["llm_chat_memory"] = BoundedChatHistory(
k=self.default_max_chat_history
)
Expand Down Expand Up @@ -515,13 +519,19 @@ def _init_chat_handlers(
eps = entry_points()
chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")
chat_handlers: Dict[str, BaseChatHandler] = {}

if ychat:
llm_chat_memory = YChatHistory(ychat, k=self.default_max_chat_history)
else:
llm_chat_memory = self.settings["llm_chat_memory"]

chat_handler_kwargs = {
"log": self.log,
"config_manager": self.settings["jai_config_manager"],
"model_parameters": self.settings["model_parameters"],
"root_chat_handlers": self.settings["jai_root_chat_handlers"],
"chat_history": self.settings["chat_history"],
"llm_chat_memory": self.settings["llm_chat_memory"],
"llm_chat_memory": llm_chat_memory,
"root_dir": self.serverapp.root_dir,
"dask_client_future": self.settings["dask_client_future"],
"preferred_dir": self.serverapp.contents_manager.preferred_dir,
Expand Down
48 changes: 47 additions & 1 deletion packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,61 @@
import time
from typing import List, Optional, Sequence, Set, Union

from jupyterlab_chat.ychat import YChat
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr

from .constants import BOT
from .models import HumanChatMessage

HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id"


class YChatHistory(BaseChatMessageHistory):
"""
An implementation of `BaseChatMessageHistory` that returns the preceding `k`
exchanges (`k * 2` messages) from the given YChat model.
If `k` is set to `None`, then this class returns all preceding messages.
"""

def __init__(self, ychat: YChat, k: Optional[int]):
self.ychat = ychat
self.k = k

@property
def messages(self) -> List[BaseMessage]: # type:ignore[override]
"""
Returns the last `2 * k` messages preceding the latest message. If
`k` is set to `None`, return all preceding messages.
"""
# TODO: consider bounding history based on message size (e.g. total
# char/token count) instead of message count.
all_messages = self.ychat.get_messages()

# gather last k * 2 messages and return
# we exclude the last message since that is the HumanChatMessage just
# submitted by a user.
messages: List[BaseMessage] = []
start_idx = 0 if self.k is None else -2 * self.k - 1
for message in all_messages[start_idx:-1]:
if message["sender"] == BOT["username"]:
messages.append(AIMessage(content=message["body"]))
else:
messages.append(HumanMessage(content=message["body"]))

return messages

def add_message(self, message: BaseMessage) -> None:
# do nothing when other LangChain objects call this method, since
# message history is maintained by the `YChat` shared document.
return

def clear(self):
raise NotImplementedError()


class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
"""
An in-memory implementation of `BaseChatMessageHistory` that stores up to
Expand Down

0 comments on commit 3581a8f

Please sign in to comment.