diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index a3de23e70..6706f4ed6 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -2,6 +2,7 @@ from typing import Union, List, Optional from pandasai.helpers.df_info import DataFrameType from pandasai.helpers.logger import Logger +from pandasai.helpers.memory import Memory from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt from pandasai.prompts.explain_prompt import ExplainPrompt from pandasai.schemas.df_config import Config @@ -33,22 +34,15 @@ def __init__( if not isinstance(dfs, list): dfs = [dfs] - self._lake = SmartDatalake(dfs, config, logger) + self._lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size)) self._logger = self._lake.logger - self._memory_size = memory_size def chat(self, query: str, output_type: Optional[str] = None): """ Simulate a chat interaction with the assistant on Dataframe. """ try: - result = self._lake.chat( - query, - output_type=output_type, - start_conversation=self._lake._memory.get_conversation( - self._memory_size - ), - ) + result = self._lake.chat(query, output_type=output_type) return result except Exception as exception: return ( @@ -63,7 +57,7 @@ def clarification_questions(self) -> List[str]: """ try: prompt = ClarificationQuestionPrompt( - self._lake.dfs, self._lake._memory.get_conversation(self._memory_size) + self._lake.dfs, self._lake._memory.get_conversation() ) result = self._lake.llm.call(prompt) @@ -89,7 +83,7 @@ def explain(self) -> str: """ try: prompt = ExplainPrompt( - self._lake._memory.get_conversation(self._memory_size), + self._lake._memory.get_conversation(), self._lake.last_code_executed, ) response = self._lake.llm.call(prompt) diff --git a/pandasai/helpers/memory.py b/pandasai/helpers/memory.py index 5c7e01c8e..072542d3e 100644 --- a/pandasai/helpers/memory.py +++ b/pandasai/helpers/memory.py @@ -5,9 +5,11 @@ class Memory: """Memory class to store the conversations""" _messages: list + _memory_size: int - def __init__(self): + def __init__(self, memory_size: int = 1): self._messages = [] + self._memory_size = memory_size def add(self, message: str, is_user: bool): self._messages.append({"message": message, "is_user": is_user}) @@ -21,7 +23,12 @@ def all(self) -> list: def last(self) -> dict: return self._messages[-1] - def get_conversation(self, limit: int = 1) -> str: + def get_conversation(self, limit: int = None) -> str: + """ + Returns the conversation messages based on limit parameter + or default memory size + """ + limit = self._memory_size if limit is None else limit return "\n".join( [ f"{f'User {i+1}' if message['is_user'] else f'Assistant {i}'}: " diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index ce74d7809..7b5040d0a 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -255,12 +255,7 @@ def _get_cache_key(self) -> str: return cache_key - def chat( - self, - query: str, - output_type: Optional[str] = None, - start_conversation: Optional[str] = None, - ): + def chat(self, query: str, output_type: Optional[str] = None): """ Run a query on the dataframe. @@ -310,8 +305,6 @@ def chat( "save_charts_path": self._config.save_charts_path.rstrip("/"), "output_type_hint": output_type_helper.template_hint, } - if start_conversation is not None: - default_values["conversation"] = start_conversation generate_python_code_instruction = self._get_prompt( "generate_python_code",