Skip to content

Commit

Permalink
feat: output_type parameter (#519)
Browse files Browse the repository at this point in the history
* (feat): update prompt template in `GeneratePythonCodePrompt`, add
  `output_type_hint` variable to be interpolated
* (feat): update `.chat()` method for `SmartDataframe` and
  `SmartDatalake`, add optional `output_type` parameter
* (feat): add `get_output_type_hint()` in `GeneratePythonCodePrompt`
  class
* (feat): add "output_type_hint" to `default_values` when forming
  prompt's template context
* (tests): update tests in `TestGeneratePythonCodePrompt`
* (tests): add tests for checking `output_type` interpotaion to a
  prompt
  • Loading branch information
nautics889 committed Sep 15, 2023
1 parent 0eca724 commit cb0e433
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 24 deletions.
8 changes: 4 additions & 4 deletions pandasai/llm/huggingface_text_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class HuggingFaceTextGen(LLM):
top_k: Optional[int] = None
top_p: Optional[float] = 0.8
typical_p: Optional[float] = 0.8
temperature: float = 1E-3 # must be strictly positive
temperature: float = 1e-3 # must be strictly positive
repetition_penalty: Optional[float] = None
truncate: Optional[int] = None
stop_sequences: List[str] = None
Expand All @@ -29,7 +29,7 @@ def __init__(self, inference_server_url: str, **kwargs):
try:
import text_generation

for (key, val) in kwargs.items():
for key, val in kwargs.items():
if key in self.__annotations__:
setattr(self, key, val)

Expand Down Expand Up @@ -76,8 +76,8 @@ def call(self, instruction: Prompt, suffix: str = "") -> str:
for stop_seq in self.stop_sequences:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
:res.generated_text.index(stop_seq)
]
: res.generated_text.index(stop_seq)
]
return res.generated_text

@property
Expand Down
19 changes: 17 additions & 2 deletions pandasai/prompts/generate_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def analyze_data(dfs: list[{engine_df_name}]) -> dict:
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.)
4. Output: return a dictionary of:
- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
{output_type_hint}
Example output: {{ "type": "text", "value": "The average loan amount is $15,000." }}
\"\"\"
```
Expand All @@ -70,10 +69,26 @@ def analyze_data(dfs: list[{engine_df_name}]) -> dict:
Updated code:
""" # noqa: E501
_output_type_map = {
"number": """- type (must be "number")
- value (must be a number)""",
"dataframe": """- type (must be "dataframe")
- value (must be a pandas dataframe)""",
"plot": """- type (must be "plot")
- value (must be a string containing the path of the plot image)""",
"string": """- type (must be "string")
- value (must be a conversational answer, as a string)""",
}
_output_type_default = """- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)""" # noqa E501

def __init__(self):
default_import = "import pandas as pd"
engine_df_name = "pd.DataFrame"

self.set_var("default_import", default_import)
self.set_var("engine_df_name", engine_df_name)

@classmethod
def get_output_type_hint(cls, output_type):
return cls._output_type_map.get(output_type, cls._output_type_default)
5 changes: 3 additions & 2 deletions pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,17 +315,18 @@ def add_middlewares(self, *middlewares: Optional[Middleware]):
"""
self.lake.add_middlewares(*middlewares)

def chat(self, query: str):
def chat(self, query: str, output_type: Optional[str] = None):
"""
Run a query on the dataframe.
Args:
query (str): Query to run on the dataframe
output_type (Optional[str]):
Raises:
ValueError: If the query is empty
"""
return self.lake.chat(query)
return self.lake.chat(query, output_type)

def column_hash(self) -> str:
"""
Expand Down
5 changes: 4 additions & 1 deletion pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def _get_cache_key(self) -> str:

return cache_key

def chat(self, query: str):
def chat(self, query: str, output_type: Optional[str] = None):
"""
Run a query on the dataframe.
Args:
query (str): Query to run on the dataframe
output_type (Optional[str]):
Raises:
ValueError: If the query is empty
Expand All @@ -280,10 +281,12 @@ def chat(self, query: str):
self.logger.log("Using cached response")
code = self._cache.get(self._get_cache_key())
else:
prompt_cls = GeneratePythonCodePrompt
default_values = {
# TODO: find a better way to determine the engine,
"engine": self._dfs[0].engine,
"save_charts_path": self._config.save_charts_path.rstrip("/"),
"output_type_hint": prompt_cls.get_output_type_hint(output_type),
}
generate_python_code_instruction = self._get_prompt(
"generate_python_code",
Expand Down
43 changes: 28 additions & 15 deletions tests/prompts/test_generate_python_code_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Unit tests for the generate python code prompt class"""

import sys

import pandas as pd
from pandasai import SmartDataframe
Expand All @@ -24,9 +24,12 @@ def test_str_with_args(self):
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", "exports/charts")
assert (
prompt.to_string()
== """
output_type_hint = """- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
""" # noqa E501
prompt.set_var("output_type_hint", output_type_hint)

expected_prompt_content = '''
You are provided with the following pandas DataFrames:
<dataframe>
Expand All @@ -46,23 +49,27 @@ def test_str_with_args(self):
import pandas as pd
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
\"\"\"
"""
Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in exports/charts/temp_chart.png and do not show the chart.)
4. Output: return a dictionary of:
- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Example output: { "type": "text", "value": "The average loan amount is $15,000." }
\"\"\"
"""
```
Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
Updated code:
""" # noqa: E501
)
''' # noqa E501
actual_prompt_content = prompt.to_string()
if sys.platform.startswith("win"):
actual_prompt_content = expected_prompt_content.replace("\r\n", "\n")
assert actual_prompt_content == expected_prompt_content

def test_str_with_custom_save_charts_path(self):
"""Test that the __str__ method is implemented"""
Expand All @@ -79,10 +86,13 @@ def test_str_with_custom_save_charts_path(self):
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", "custom_path")
# noqa E501
output_type_hint = """- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
""" # noqa E501
prompt.set_var("output_type_hint", output_type_hint)

assert (
prompt.to_string()
== """
expected_prompt_content = '''
You are provided with the following pandas DataFrames:
<dataframe>
Expand All @@ -102,7 +112,7 @@ def test_str_with_custom_save_charts_path(self):
import pandas as pd
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
\"\"\"
"""
Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
Expand All @@ -111,11 +121,14 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Example output: { "type": "text", "value": "The average loan amount is $15,000." }
\"\"\"
"""
```
Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
Updated code:
""" # noqa: E501
)
''' # noqa E501
actual_prompt_content = prompt.to_string()
if sys.platform.startswith("win"):
actual_prompt_content = expected_prompt_content.replace("\r\n", "\n")
assert actual_prompt_content == expected_prompt_content
47 changes: 47 additions & 0 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,53 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
last_prompt = df.last_prompt.replace("\r\n", "\n")
assert last_prompt == expected_prompt

def test_run_passing_output_type(self, llm):
df = pd.DataFrame({"country": []})
df = SmartDataframe(df, config={"llm": llm, "enable_cache": False})
df.enforce_privacy = True

expected_prompt = '''
You are provided with the following pandas DataFrames:
<dataframe>
Dataframe dfs[0], with 0 rows and 1 columns.
This is the metadata of the dataframe dfs[0]:
country
</dataframe>
<conversation>
User 1: How many countries are in the dataframe?
</conversation>
This is the initial python code to be updated:
```python
# TODO import all the dependencies required
import pandas as pd
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
"""
Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in exports/charts/temp_chart.png and do not show the chart.)
4. Output: return a dictionary of:
- type (must be "number")
- value (must be a number)
Example output: { "type": "text", "value": "The average loan amount is $15,000." }
"""
```
Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
Updated code:
''' # noqa: E501

df.chat("How many countries are in the dataframe?", output_type="number")
last_prompt = df.last_prompt
if sys.platform.startswith("win"):
last_prompt = df.last_prompt.replace("\r\n", "\n")
assert last_prompt == expected_prompt

def test_to_dict(self, smart_dataframe: SmartDataframe):
expected_keys = ("country", "gdp", "happiness_index")

Expand Down

0 comments on commit cb0e433

Please sign in to comment.