Skip to content

Commit

Permalink
feat: Add Explain functionality in the agent
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Sep 22, 2023
1 parent 70244c3 commit f715035
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 71 deletions.
21 changes: 17 additions & 4 deletions examples/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,24 @@
salaries_df = pd.DataFrame(salaries_data)


llm = OpenAI("OPEN_API")
llm = OpenAI("sk-lyDyNVyBwnykr1lJ4Yc7T3BlbkFJtJNyJlKTAvUa2E2D5Wdb44")
agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10)

# Chat with the agent
response = agent.chat("Who gets paid the most?")
print(response)
questions = agent.clarification_questions()
print(questions)
response = agent.chat("Which department he belongs to?")


# Get Clarification Questions
response = agent.clarification_questions()

if response:
for question in response.questions:
print(question)
else:
print(response.message)


# Explain how the chat response is generated
response = agent.explain()
print(response)
74 changes: 55 additions & 19 deletions pandasai/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
from typing import Union, List, Optional
from pandasai.agent.response import ClarificationResponse
from pandasai.helpers.df_info import DataFrameType
from pandasai.helpers.logger import Logger
from pandasai.helpers.memory import Memory
from pandasai.prompts.base import Prompt
from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt
from pandasai.prompts.explain_prompt import ExplainPrompt
from pandasai.schemas.df_config import Config

from pandasai.smart_datalake import SmartDatalake
Expand All @@ -23,7 +26,7 @@ def __init__(
dfs: Union[DataFrameType, List[DataFrameType]],
config: Optional[Union[Config, dict]] = None,
logger: Logger = None,
memory_size=1,
memory_size: int = 1,
):
"""
Args:
Expand All @@ -36,6 +39,7 @@ def __init__(

self._lake = SmartDatalake(dfs, config, logger)
self.logger = self._lake.logger
# For the conversation multiple the memory size by 2
self._memory = Memory(memory_size * 2)

def _get_conversation(self):
Expand All @@ -51,17 +55,26 @@ def _get_conversation(self):
]
)

def chat(self, query: str):
def chat(self, query: str, output_type: Optional[str] = None):
"""
Simulate a chat interaction with the assistant on Dataframe.
"""
self._memory.add(query, True)
conversation = self._get_conversation()
result = self._lake.chat(query, start_conversation=conversation)
self._memory.add(result, False)
return result
try:
self._memory.add(query, True)
conversation = self._get_conversation()
result = self._lake.chat(
query, output_type=output_type, start_conversation=conversation
)
self._memory.add(result, False)
return result
except Exception as exception:
return (
"Unfortunately, I was not able to get your answers, "
"because of the following error:\n"
f"\n{exception}\n"
)

def _get_clarification_prompt(self):
def _get_clarification_prompt(self) -> Prompt:
"""
Create a clarification prompt with relevant variables.
"""
Expand All @@ -70,31 +83,54 @@ def _get_clarification_prompt(self):
prompt.set_var("conversation", self._get_conversation())
return prompt

def clarification_questions(self):
def clarification_questions(self) -> ClarificationResponse:
"""
Generate and return up to three clarification questions based on a given prompt.
Generate clarification questions based on the data
"""
try:
prompt = self._get_clarification_prompt()
result = self._lake.llm.generate_code(prompt)
questions = json.loads(result)
result = self._lake.llm.call(prompt)
self.logger.log(
f"""Clarification Questions: {result}
"""
)
questions: list[str] = json.loads(result)
return ClarificationResponse(
success=True, questions=questions[:3], message=result
)
except Exception as exception:
return (
return ClarificationResponse(
False,
[],
"Unfortunately, I was not able to get your clarification questions, "
"because of the following error:\n"
f"\n{exception}\n"
f"\n{exception}\n",
)

return questions[:3]

def start_new_conversation(self):
def start_new_conversation(self) -> True:
"""
Clears the previous conversation
"""

self._memory.clear()
return True

def explain(self):
def explain(self) -> str:
"""
Returns the explanation of the code how it reached to the solution
"""
pass
try:
prompt = ExplainPrompt()
prompt.set_var("code", self._lake.last_code_executed)
response = self._lake.llm.call(prompt)
self.logger.log(
f"""Explaination: {response}
"""
)
return response
except Exception as exception:
return (
"Unfortunately, I was not able to explain, "
"because of the following error:\n"
f"\n{exception}\n"
)
38 changes: 38 additions & 0 deletions pandasai/agent/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import List


class ClarificationResponse:
"""
Clarification Response
"""

def __init__(
self, success: bool = True, questions: List[str] = None, message: str = ""
):
"""
Args:
success: Whether the response generated or not.
questions: List of questions
"""
self._success: bool = success
self._questions: List[str] = questions
self._message: str = message

@property
def questions(self) -> List[str]:
return self._questions

@property
def message(self) -> List[str]:
return self._message

@property
def success(self) -> bool:
return self._success

def __bool__(self) -> bool:
"""
Define the success of response.
"""
return self._success
4 changes: 3 additions & 1 deletion pandasai/helpers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ class Memory:
_messages: list
_max_messages: int

def __init__(self, max_messages=sys.maxsize):
def __init__(self, max_messages: int = sys.maxsize):
self._messages = []
self._max_messages = max_messages

def add(self, message: str, is_user: bool):
self._messages.append({"message": message, "is_user": is_user})

# Delete two entry because of the conversation
if len(self._messages) > self._max_messages:
del self._messages[:2]

Expand Down
48 changes: 12 additions & 36 deletions pandasai/prompts/explain_prompt.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,23 @@
""" Prompt to get clarification questions
You are provided with the following pandas DataFrames:
<dataframe>
{dataframe}
</dataframe>
<conversation>
{conversation}
</conversation>
Based on the conversation, are there any clarification questions that a senior data scientist would ask? These are questions for non technical people, only ask for questions they could ask given low tech expertise and no knowledge about how the dataframes are structured.
Return the JSON array of the clarification questions. If there is no clarification question, return an empty array.
Json:
""" # noqa: E501
""" Prompt to explain solution generated
Based on the last conversation you generated the code.
Can you explain briefly for non technical person on how you came up with code
without explaining pandas library?
"""


from .base import Prompt


class ClarificationQuestionPrompt(Prompt):
class ExplainPrompt(Prompt):
"""Prompt to get clarification questions"""

text: str = """
You are provided with the following pandas DataFrames:
<dataframe>
{dataframes}
</dataframe>
<conversation>
{conversation}
</conversation>
Based on the conversation, are there any clarification questions
that a senior data scientist would ask? These are questions for non technical people,
only ask for questions they could ask given low tech expertise and
no knowledge about how the dataframes are structured.
Return the JSON array of the clarification questions.
Based on the last conversation you generated the code.
If there is no clarification question, return an empty array.
<Code>
{code}
</Code
Json:
Can you explain briefly for non technical person on how you came up with code
without explaining pandas library?
"""
2 changes: 1 addition & 1 deletion pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def last_code_generated(self):

@last_code_generated.setter
def last_code_generated(self, last_code_generated: str):
self._code_manager._last_code_generated = last_code_generated
self._last_code_generated = last_code_generated

@property
def last_code_executed(self):
Expand Down
75 changes: 65 additions & 10 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pandasai.agent import Agent
import pandas as pd
import pytest
from pandasai.agent.response import ClarificationResponse
from pandasai.llm.fake import FakeLLM

from pandasai.smart_datalake import SmartDatalake
Expand Down Expand Up @@ -171,27 +172,81 @@ def test_start_new_conversation(self, sample_df, config):

def test_clarification_questions(self, sample_df, config):
agent = Agent(sample_df, config, memory_size=10)
agent._lake.llm.generate_code = Mock()
agent._lake.llm.call = Mock()
clarification_response = (
'["What is happiest index for you?", "What is unit of measure for gdp?"]'
)
agent._lake.llm.generate_code.return_value = clarification_response
agent._lake.llm.call.return_value = clarification_response

questions = agent.clarification_questions()
assert len(questions) == 2
assert questions[0] == "What is happiest index for you?"
assert questions[1] == "What is unit of measure for gdp?"
response = agent.clarification_questions()
assert len(response.questions) == 2
assert response.questions[0] == "What is happiest index for you?"
assert response.questions[1] == "What is unit of measure for gdp?"

def test_clarification_questions_failure(self, sample_df, config):
agent = Agent(sample_df, config, memory_size=10)
agent._lake.llm.call = Mock()

agent._lake.llm.call.return_value = Exception("This is a mock exception")

response = agent.clarification_questions()
assert response.success is False

def test_clarification_questions_fail_non_json(self, sample_df, config):
agent = Agent(sample_df, config, memory_size=10)
agent._lake.llm.call = Mock()

agent._lake.llm.call.return_value = "This is not json response"

response = agent.clarification_questions()
assert response.success is False

def test_clarification_questions_max_3(self, sample_df, config):
agent = Agent(sample_df, config, memory_size=10)
agent._lake.llm.generate_code = Mock()
agent._lake.llm.call = Mock()
clarification_response = (
'["What is happiest index for you", '
'"What is unit of measure for gdp", '
'"How many countries are involved in the survey", '
'"How do you want this data to be represented"]'
)
agent._lake.llm.generate_code.return_value = clarification_response
agent._lake.llm.call.return_value = clarification_response

questions = agent.clarification_questions()
assert len(questions) == 3
response = agent.clarification_questions()
assert isinstance(response, ClarificationResponse)
assert response.success is True
assert len(response.questions) == 3
assert response.message == (
'["What is happiest index for you", '
'"What is unit of measure for gdp", '
'"How many countries are involved in the survey", '
'"How do you want this data to be represented"]'
)

def test_explain(self, sample_df, config):
agent = Agent(sample_df, config, memory_size=10)
agent._lake.llm.call = Mock()
clarification_response = """
Combine the Data: To find out who gets paid the most,
I needed to match the names of people with the amounts of money they earn.
It's like making sure the right names are next to the right amounts.
I used a method to do this, like connecting pieces of a puzzle.
Find the Top Earner: After combining the data, I looked through it to find
the person with the most money.
It's like finding the person who has the most marbles in a game
"""
agent._lake.llm.call.return_value = clarification_response

response = agent.explain()

assert response == (
"""
Combine the Data: To find out who gets paid the most,
I needed to match the names of people with the amounts of money they earn.
It's like making sure the right names are next to the right amounts.
I used a method to do this, like connecting pieces of a puzzle.
Find the Top Earner: After combining the data, I looked through it to find
the person with the most money.
It's like finding the person who has the most marbles in a game
"""
)

0 comments on commit f715035

Please sign in to comment.