diff --git a/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py b/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py index 00e75764e..0b01e82fd 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py @@ -79,14 +79,15 @@ def execute(self, input_data: Any, **kwargs) -> Any: traceback_errors = traceback.format_exc() - input_data = self.on_failure(input, traceback_errors) + input_data = self.on_failure(input_data, traceback_errors) retry_count += 1 def _get_type(self, input: dict) -> bool: return ( "plot" - if input["type"] in ["bar", "line", "histogram", "pie", "scatter"] + if input["type"] + in ["bar", "line", "histogram", "pie", "scatter", "boxplot"] else input["type"] ) @@ -99,7 +100,7 @@ def _generate_code(self, type, query): """ elif type == "dataframe": return """ -result = {{"type": "dataframe","value": data}} +result = {"type": "dataframe","value": data} """ else: code = self.generate_matplotlib_code(query) @@ -119,8 +120,8 @@ def _generate_code_for_number(self, query: dict) -> str: def generate_matplotlib_code(self, query: dict) -> str: chart_type = query["type"] - x_label = query["options"].get("xLabel", None) - y_label = query["options"].get("yLabel", None) + x_label = query.get("options", {}).get("xLabel", None) + y_label = query.get("options", {}).get("yLabel", None) title = query["options"].get("title", None) legend_display = {"display": True} legend_position = "best" diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py index 9d9c60189..65e34db94 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py @@ -58,6 +58,7 @@ def __init__( on_code_generation=on_code_generation, on_prompt_generation=on_prompt_generation, ) + self.query_exec_tracker = query_exec_tracker self._context = context self._logger = logger diff --git a/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py b/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py index 5f8f68254..80237c944 100644 --- a/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py +++ b/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py @@ -31,13 +31,16 @@ def __init__(self, **kwargs): def validate(self, output: str) -> bool: try: json_data = json.loads(output.replace("# SAMPLE SCHEMA", "")) + context = self.props["context"] if isinstance(json_data, dict): json_data = [json_data] if isinstance(json_data, list): for record in json_data: if not all(key in record for key in ("name", "table")): return False - return True + + return len(context.dfs) == len(json_data) + except json.JSONDecodeError: pass return False diff --git a/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py b/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py index c50d26752..fb1f261b6 100644 --- a/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py +++ b/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py @@ -307,11 +307,10 @@ def test_generate_matplolib_boxplot_chart_code( logic_unit = code_gen.execute(json_str, context=context, logger=logger) assert isinstance(logic_unit, LogicUnitOutput) - print(logic_unit.output) assert ( logic_unit.output == """ - +import matplotlib.pyplot as plt import pandas as pd sql_query="SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country"