Skip to content

Commit

Permalink
Drop use of depcreated ConversationChain, manage conversation memory …
Browse files Browse the repository at this point in the history
…ourselves

such that we can use HumanMessage as input, enabling convenient multi-modal inputs
  • Loading branch information
opcode81 committed Jun 10, 2024
1 parent cf4237c commit d51dbf6
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions src/penai/llm/conversation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections.abc import Callable
from copy import copy, deepcopy
from typing import Generic, Self, TypeVar
from typing import Generic, Self, TypeAlias, TypeVar

import markdown
from bs4 import BeautifulSoup
from langchain.chains.conversation.base import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import HumanMessage

from penai.llm.llm_model import RegisteredLLM

Expand All @@ -32,6 +32,7 @@ def get_code_in_sections(self, heading_level: int) -> dict[str, str]:


TResponse = TypeVar("TResponse", bound=Response)
QueryType: TypeAlias = str | HumanMessage


class Conversation(Generic[TResponse]):
Expand All @@ -42,7 +43,7 @@ def __init__(
response_factory: Callable[[str], TResponse] = Response, # type: ignore
):
self.memory = ConversationBufferMemory()
self.chain = ConversationChain(llm=model.create_model(), memory=self.memory)
self.llm = model.create_model()
self.verbose = verbose
self.response_factory = response_factory

Expand All @@ -54,25 +55,19 @@ def get_full_conversation_string(
[message.pretty_repr() for message in self.memory.buffer_as_messages],
)

@property
def _input_key(self) -> str:
return self.chain.input_key

@property
def _output_key(self) -> str:
return self.chain.output_key

def query_text(self, query: str) -> str:
response_dict = self.chain.invoke({self._input_key: query}, return_only_outputs=True)
response = response_dict[self._output_key]
def query_text(self, query: QueryType) -> str:
self.memory.chat_memory.add_user_message(query)
ai_message = self.llm.invoke(self.memory.chat_memory.messages)
self.memory.chat_memory.add_ai_message(ai_message)
response_text = ai_message.content
if self.verbose:
print(response)
return response
print(response_text)
return response_text

def query(self, query: str) -> None:
def query(self, query: QueryType) -> None:
self.query_text(query)

def query_response(self, query: str) -> TResponse:
def query_response(self, query: QueryType) -> TResponse:
return self.response_factory(self.query_text(query))

def clone(self) -> Self:
Expand Down

0 comments on commit d51dbf6

Please sign in to comment.