diff --git a/src/penai/llm/conversation.py b/src/penai/llm/conversation.py index 1554d6d..293d0ed 100644 --- a/src/penai/llm/conversation.py +++ b/src/penai/llm/conversation.py @@ -1,11 +1,11 @@ from collections.abc import Callable from copy import copy, deepcopy -from typing import Generic, Self, TypeVar +from typing import Generic, Self, TypeAlias, TypeVar import markdown from bs4 import BeautifulSoup -from langchain.chains.conversation.base import ConversationChain from langchain.memory import ConversationBufferMemory +from langchain_core.messages import HumanMessage from penai.llm.llm_model import RegisteredLLM @@ -32,6 +32,7 @@ def get_code_in_sections(self, heading_level: int) -> dict[str, str]: TResponse = TypeVar("TResponse", bound=Response) +QueryType: TypeAlias = str | HumanMessage class Conversation(Generic[TResponse]): @@ -42,7 +43,7 @@ def __init__( response_factory: Callable[[str], TResponse] = Response, # type: ignore ): self.memory = ConversationBufferMemory() - self.chain = ConversationChain(llm=model.create_model(), memory=self.memory) + self.llm = model.create_model() self.verbose = verbose self.response_factory = response_factory @@ -54,25 +55,19 @@ def get_full_conversation_string( [message.pretty_repr() for message in self.memory.buffer_as_messages], ) - @property - def _input_key(self) -> str: - return self.chain.input_key - - @property - def _output_key(self) -> str: - return self.chain.output_key - - def query_text(self, query: str) -> str: - response_dict = self.chain.invoke({self._input_key: query}, return_only_outputs=True) - response = response_dict[self._output_key] + def query_text(self, query: QueryType) -> str: + self.memory.chat_memory.add_user_message(query) + ai_message = self.llm.invoke(self.memory.chat_memory.messages) + self.memory.chat_memory.add_ai_message(ai_message) + response_text = ai_message.content if self.verbose: - print(response) - return response + print(response_text) + return response_text - def query(self, query: str) -> None: + def query(self, query: QueryType) -> None: self.query_text(query) - def query_response(self, query: str) -> TResponse: + def query_response(self, query: QueryType) -> TResponse: return self.response_factory(self.query_text(query)) def clone(self) -> Self: