Skip to content

Commit

Permalink
fix: directory for saving charts issue (#513) (#514)
Browse files Browse the repository at this point in the history
* fix: directory for saving charts issue (#513)

* (fix): update template (remove hardcode of directory) in
 `GeneratePythonCodePrompt` class
* (fix): add passing "save_charts_path" from the config object to
  `get_prompt()`

* tests: custom path for charts (#513)

* (tests): add test method for a case passing custom path with
  `save_charts_path` parameter

* test: add test for the custom chart prompt

---------

Co-authored-by: Gabriele Venturi <[email protected]>
  • Loading branch information
nautics889 and gventuri authored Sep 4, 2023
1 parent 167a15a commit d93171c
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pandasai/prompts/generate_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class GeneratePythonCodePrompt(Prompt):
# Analyze the data
# 1. Prepare: Preprocessing and cleaning data if necessary
# 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
# 3. Analyze: Conducting the actual analysis (if the user asks to create a chart save it to an image in exports/charts/temp_chart.png and do not show the chart.)
# 3. Analyze: Conducting the actual analysis (if the user asks to create a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.)
# 4. Output: return a dictionary of:
# - type (possible values "text", "number", "dataframe", "plot")
# - value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Expand Down
1 change: 1 addition & 0 deletions pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def chat(self, query: str):
default_values = {
# TODO: find a better way to determine the engine,
"engine": self._dfs[0].engine,
"save_charts_path": self._config.save_charts_path.rstrip("/"),
}
generate_python_code_instruction = self._get_prompt(
"generate_python_code",
Expand Down
56 changes: 56 additions & 0 deletions tests/prompts/test_generate_python_code_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_str_with_args(self):
prompt = GeneratePythonCodePrompt()
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", "exports/charts")
assert (
prompt.to_string()
== """
Expand Down Expand Up @@ -51,6 +52,61 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
# Code goes here (do not add comments)
# Declare a result variable
result = analyze_data(dfs)
```
Using the provided dataframes (`dfs`), update the python code based on the last user question:
Question
Updated code:
""" # noqa: E501
)

def test_str_with_custom_save_charts_path(self):
"""Test that the __str__ method is implemented"""

llm = FakeLLM("plt.show()")
dfs = [
SmartDataframe(
pd.DataFrame({"a": [1], "b": [4]}),
config={"llm": llm},
)
]

prompt = GeneratePythonCodePrompt()
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", "custom_path")

assert (
prompt.to_string()
== """
You are provided with the following pandas DataFrames with the following metadata:
Dataframe dfs[0], with 1 rows and 2 columns.
This is the metadata of the dataframe dfs[0]:
a,b
1,4
This is the initial python code to be updated:
```python
# TODO import all the dependencies required
import pandas as pd
# Analyze the data
# 1. Prepare: Preprocessing and cleaning data if necessary
# 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
# 3. Analyze: Conducting the actual analysis (if the user asks to create a chart save it to an image in custom_path/temp_chart.png and do not show the chart.)
# 4. Output: return a dictionary of:
# - type (possible values "text", "number", "dataframe", "plot")
# - value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
# Example output: { "type": "text", "value": "The average loan amount is $15,000." }
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
# Code goes here (do not add comments)
# Declare a result variable
result = analyze_data(dfs)
```
Expand Down
65 changes: 65 additions & 0 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,71 @@ def test_getters_are_accessible(self, smart_dataframe: SmartDataframe, llm):
== "def analyze_data(dfs):\n return {'type': 'number', 'value': 1}"
)

def test_save_chart_non_default_dir(
self, smart_dataframe: SmartDataframe, llm, sample_df
):
"""
Test chat with `SmartDataframe` with custom `save_charts_path`.
Script:
1) Ask `SmartDataframe` to build a chart and save it in
a custom directory;
2) Check if substring representing the directory present in
`llm.last_prompt`.
3) Check if the code has had a call of `plt.savefig()` passing
the custom directory.
Notes:
1) Mock `import_dependency()` util-function to avoid the
actual calls to `matplotlib.pyplot`.
2) The `analyze_data()` function in the code fixture must have
`"type": None` in the result dict. Otherwise, if it had
`"type": "plot"` (like it has in practice), `_format_results()`
method from `SmartDatalake` object would try to read the image
with `matplotlib.image.imread()` and this test would fail.
Those calls to `matplotlib.image` are unmockable because of
imports inside the function scope, not in the top of a module.
@TODO: figure out if we can just move the imports beyond to
make it possible to mock out `matplotlib.image`
"""
llm._output = """
import pandas as pd
import matplotlib.pyplot as plt
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
df = dfs[0].nlargest(5, 'happiness_index')
plt.figure(figsize=(8, 6))
plt.pie(df['happiness_index'], labels=df['country'], autopct='%1.1f%%')
plt.title('Happiness Index for the 5 Happiest Countries')
plt.savefig('custom-dir/output_charts/temp_chart.png')
plt.close()
return {"type": None, "value": "custom-dir/output_charts/temp_chart.png"}
result = analyze_data(dfs)
"""
with patch(
"pandasai.helpers.code_manager.import_dependency"
) as import_dependency_mock:
smart_dataframe = SmartDataframe(
sample_df,
config={
"llm": llm,
"enable_cache": False,
"save_charts": True,
"save_charts_path": "custom-dir/output_charts/",
},
)

smart_dataframe.chat("Plot pie-chart the 5 happiest countries")

assert "custom-dir/output_charts/temp_chart.png" in llm.last_prompt
plt_mock = getattr(import_dependency_mock.return_value, "matplotlib.pyplot")
assert plt_mock.savefig.called
assert (
plt_mock.savefig.call_args.args[0]
== "custom-dir/output_charts/temp_chart.png"
)

def test_add_middlewares(self, smart_dataframe: SmartDataframe, custom_middleware):
middleware = custom_middleware()
smart_dataframe.add_middlewares(middleware)
Expand Down

0 comments on commit d93171c

Please sign in to comment.