diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index e1494ba59..68638e8a2 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -441,7 +441,7 @@ def execute_direct_sql_query(self, sql_query): if not self._is_sql_query_safe(sql_query): raise MaliciousQueryError("Malicious query is generated in code") - return pd.read_sql(sql_query, self._connection) + return pd.read_sql(text(sql_query), self._connection) @property def cs_table_name(self): diff --git a/pandasai/helpers/output_validator.py b/pandasai/helpers/output_validator.py index e26bcf2ff..56a3a495d 100644 --- a/pandasai/helpers/output_validator.py +++ b/pandasai/helpers/output_validator.py @@ -56,7 +56,7 @@ def validate_value(self, expected_type: str) -> bool: elif expected_type == "string": return isinstance(self, str) elif expected_type == "dataframe": - return isinstance(self, (pd.DataFrame, pd.Series)) + return isinstance(self, (pd.DataFrame, pd.Series, dict)) elif expected_type == "plot": if not isinstance(self, (str, dict)): return False @@ -82,7 +82,7 @@ def validate_result(result: dict) -> bool: elif result["type"] == "string": return isinstance(result["value"], str) elif result["type"] == "dataframe": - return isinstance(result["value"], (pd.DataFrame, pd.Series)) + return isinstance(result["value"], (pd.DataFrame, pd.Series, dict)) elif result["type"] == "plot": if "plotly" in repr(type(result["value"])): return True diff --git a/pandasai/responses/response_parser.py b/pandasai/responses/response_parser.py index fd202784d..4254c77ec 100644 --- a/pandasai/responses/response_parser.py +++ b/pandasai/responses/response_parser.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any +import pandas as pd from PIL import Image from pandasai.exceptions import MethodNotImplementedError @@ -51,9 +52,20 @@ def parse(self, result: dict) -> Any: if result["type"] == "plot": return self.format_plot(result) + elif result["type"] == "dataframe": + return self.format_dataframe(result) else: return result["value"] + def format_dataframe(self, result: dict) -> Any: + if isinstance(result["value"], dict): + print("Df conversiont") + df = pd.Dataframe(result["value"]) + print("Df conversiont Done") + result["value"] = df + + return result["value"] + def format_plot(self, result: dict) -> Any: """ Display matplotlib plot against a user query.