Skip to content

Commit

Permalink
refactor: memory to return conversation according to size
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Sep 22, 2023
1 parent d1b8e61 commit 49d8720
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 21 deletions.
16 changes: 5 additions & 11 deletions pandasai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union, List, Optional
from pandasai.helpers.df_info import DataFrameType
from pandasai.helpers.logger import Logger
from pandasai.helpers.memory import Memory
from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt
from pandasai.prompts.explain_prompt import ExplainPrompt
from pandasai.schemas.df_config import Config
Expand Down Expand Up @@ -33,22 +34,15 @@ def __init__(
if not isinstance(dfs, list):
dfs = [dfs]

self._lake = SmartDatalake(dfs, config, logger)
self._lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size))
self._logger = self._lake.logger
self._memory_size = memory_size

def chat(self, query: str, output_type: Optional[str] = None):
"""
Simulate a chat interaction with the assistant on Dataframe.
"""
try:
result = self._lake.chat(
query,
output_type=output_type,
start_conversation=self._lake._memory.get_conversation(
self._memory_size
),
)
result = self._lake.chat(query, output_type=output_type)
return result
except Exception as exception:
return (
Expand All @@ -63,7 +57,7 @@ def clarification_questions(self) -> List[str]:
"""
try:
prompt = ClarificationQuestionPrompt(
self._lake.dfs, self._lake._memory.get_conversation(self._memory_size)
self._lake.dfs, self._lake._memory.get_conversation()
)

result = self._lake.llm.call(prompt)
Expand All @@ -89,7 +83,7 @@ def explain(self) -> str:
"""
try:
prompt = ExplainPrompt(
self._lake._memory.get_conversation(self._memory_size),
self._lake._memory.get_conversation(),
self._lake.last_code_executed,
)
response = self._lake.llm.call(prompt)
Expand Down
11 changes: 9 additions & 2 deletions pandasai/helpers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ class Memory:
"""Memory class to store the conversations"""

_messages: list
_memory_size: int

def __init__(self):
def __init__(self, memory_size: int = 1):
self._messages = []
self._memory_size = memory_size

def add(self, message: str, is_user: bool):
self._messages.append({"message": message, "is_user": is_user})
Expand All @@ -21,7 +23,12 @@ def all(self) -> list:
def last(self) -> dict:
return self._messages[-1]

def get_conversation(self, limit: int = 1) -> str:
def get_conversation(self, limit: int = None) -> str:
"""
Returns the conversation messages based on limit parameter
or default memory size
"""
limit = self._memory_size if limit is None else limit
return "\n".join(
[
f"{f'User {i+1}' if message['is_user'] else f'Assistant {i}'}: "
Expand Down
9 changes: 1 addition & 8 deletions pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,7 @@ def _get_cache_key(self) -> str:

return cache_key

def chat(
self,
query: str,
output_type: Optional[str] = None,
start_conversation: Optional[str] = None,
):
def chat(self, query: str, output_type: Optional[str] = None):
"""
Run a query on the dataframe.
Expand Down Expand Up @@ -310,8 +305,6 @@ def chat(
"save_charts_path": self._config.save_charts_path.rstrip("/"),
"output_type_hint": output_type_helper.template_hint,
}
if start_conversation is not None:
default_values["conversation"] = start_conversation

generate_python_code_instruction = self._get_prompt(
"generate_python_code",
Expand Down

0 comments on commit 49d8720

Please sign in to comment.