diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index bff66e142..902f2c841 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -51,7 +51,7 @@ class GeneratePythonCodePrompt(Prompt): # Analyze the data # 1. Prepare: Preprocessing and cleaning data if necessary # 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) -# 3. Analyze: Conducting the actual analysis (if the user asks to create a chart save it to an image in exports/charts/temp_chart.png and do not show the chart.) +# 3. Analyze: Conducting the actual analysis (if the user asks to create a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.) # 4. Output: return a dictionary of: # - type (possible values "text", "number", "dataframe", "plot") # - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 8a6d612b5..bcc53e9b2 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -262,6 +262,7 @@ def chat(self, query: str): default_values = { # TODO: find a better way to determine the engine, "engine": self._dfs[0].engine, + "save_charts_path": self._config.save_charts_path.rstrip("/"), } generate_python_code_instruction = self._get_prompt( "generate_python_code", diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index e5597ba14..889723ddf 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -23,6 +23,7 @@ def test_str_with_args(self): prompt = GeneratePythonCodePrompt() prompt.set_var("dfs", dfs) prompt.set_var("conversation", "Question") + prompt.set_var("save_charts_path", "exports/charts") assert ( prompt.to_string() == """ @@ -51,6 +52,61 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: # Code goes here (do not add comments) +# Declare a result variable +result = analyze_data(dfs) +``` + +Using the provided dataframes (`dfs`), update the python code based on the last user question: +Question + +Updated code: +""" # noqa: E501 + ) + + def test_str_with_custom_save_charts_path(self): + """Test that the __str__ method is implemented""" + + llm = FakeLLM("plt.show()") + dfs = [ + SmartDataframe( + pd.DataFrame({"a": [1], "b": [4]}), + config={"llm": llm}, + ) + ] + + prompt = GeneratePythonCodePrompt() + prompt.set_var("dfs", dfs) + prompt.set_var("conversation", "Question") + prompt.set_var("save_charts_path", "custom_path") + + assert ( + prompt.to_string() + == """ +You are provided with the following pandas DataFrames with the following metadata: + +Dataframe dfs[0], with 1 rows and 2 columns. +This is the metadata of the dataframe dfs[0]: +a,b +1,4 + + +This is the initial python code to be updated: +```python +# TODO import all the dependencies required +import pandas as pd + +# Analyze the data +# 1. Prepare: Preprocessing and cleaning data if necessary +# 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) +# 3. Analyze: Conducting the actual analysis (if the user asks to create a chart save it to an image in custom_path/temp_chart.png and do not show the chart.) +# 4. Output: return a dictionary of: +# - type (possible values "text", "number", "dataframe", "plot") +# - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) +# Example output: { "type": "text", "value": "The average loan amount is $15,000." } +def analyze_data(dfs: list[pd.DataFrame]) -> dict: + # Code goes here (do not add comments) + + # Declare a result variable result = analyze_data(dfs) ``` diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index 37fa88f88..e5bf6d823 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -240,6 +240,71 @@ def test_getters_are_accessible(self, smart_dataframe: SmartDataframe, llm): == "def analyze_data(dfs):\n return {'type': 'number', 'value': 1}" ) + def test_save_chart_non_default_dir( + self, smart_dataframe: SmartDataframe, llm, sample_df + ): + """ + Test chat with `SmartDataframe` with custom `save_charts_path`. + + Script: + 1) Ask `SmartDataframe` to build a chart and save it in + a custom directory; + 2) Check if substring representing the directory present in + `llm.last_prompt`. + 3) Check if the code has had a call of `plt.savefig()` passing + the custom directory. + + Notes: + 1) Mock `import_dependency()` util-function to avoid the + actual calls to `matplotlib.pyplot`. + 2) The `analyze_data()` function in the code fixture must have + `"type": None` in the result dict. Otherwise, if it had + `"type": "plot"` (like it has in practice), `_format_results()` + method from `SmartDatalake` object would try to read the image + with `matplotlib.image.imread()` and this test would fail. + Those calls to `matplotlib.image` are unmockable because of + imports inside the function scope, not in the top of a module. + @TODO: figure out if we can just move the imports beyond to + make it possible to mock out `matplotlib.image` + """ + llm._output = """ +import pandas as pd +import matplotlib.pyplot as plt +def analyze_data(dfs: list[pd.DataFrame]) -> dict: + df = dfs[0].nlargest(5, 'happiness_index') + + plt.figure(figsize=(8, 6)) + plt.pie(df['happiness_index'], labels=df['country'], autopct='%1.1f%%') + plt.title('Happiness Index for the 5 Happiest Countries') + plt.savefig('custom-dir/output_charts/temp_chart.png') + plt.close() + + return {"type": None, "value": "custom-dir/output_charts/temp_chart.png"} +result = analyze_data(dfs) +""" + with patch( + "pandasai.helpers.code_manager.import_dependency" + ) as import_dependency_mock: + smart_dataframe = SmartDataframe( + sample_df, + config={ + "llm": llm, + "enable_cache": False, + "save_charts": True, + "save_charts_path": "custom-dir/output_charts/", + }, + ) + + smart_dataframe.chat("Plot pie-chart the 5 happiest countries") + + assert "custom-dir/output_charts/temp_chart.png" in llm.last_prompt + plt_mock = getattr(import_dependency_mock.return_value, "matplotlib.pyplot") + assert plt_mock.savefig.called + assert ( + plt_mock.savefig.call_args.args[0] + == "custom-dir/output_charts/temp_chart.png" + ) + def test_add_middlewares(self, smart_dataframe: SmartDataframe, custom_middleware): middleware = custom_middleware() smart_dataframe.add_middlewares(middleware)