From c98605056a6c811670146994c6f2494a80d3829e Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Tue, 5 Sep 2023 04:00:15 +0500 Subject: [PATCH] feat: validate dataframe with Pydantic schema (#522) * feat[DFValidator]: Add return types * fix pre-commit errors * add return value * refactor: minor changes in validation --------- Co-authored-by: arslan-agory Co-authored-by: Gabriele Venturi --- pandasai/helpers/df_validator.py | 126 +++++++++++++++++++++++++++ pandasai/smart_dataframe/__init__.py | 13 +++ tests/test_smartdataframe.py | 103 +++++++++++++++++++++- 3 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 pandasai/helpers/df_validator.py diff --git a/pandasai/helpers/df_validator.py b/pandasai/helpers/df_validator.py new file mode 100644 index 000000000..03c9bf999 --- /dev/null +++ b/pandasai/helpers/df_validator.py @@ -0,0 +1,126 @@ +from typing import List, Dict +from pydantic import ValidationError +from pydantic import BaseModel +from pandasai.helpers.df_info import DataFrameType, df_type + + +class DfValidationResult: + """ + Validation results for a dataframe. + + Attributes: + passed: Whether the validation passed or not. + errors: List of errors if the validation failed. + """ + + _passed: bool + _errors: List[Dict] + + def __init__(self, passed: bool = True, errors: List[Dict] = None): + """ + Args: + passed: Whether the validation passed or not. + errors: List of errors if the validation failed. + """ + if errors is None: + errors = [] + self._passed = passed + self._errors = errors + + @property + def passed(self): + return self._passed + + def errors(self) -> List[Dict]: + return self._errors + + def add_error(self, error_message: str): + """ + Add an error message to the validation results. + + Args: + error_message: Error message to add. + """ + self._passed = False + self._errors.append(error_message) + + def __bool__(self) -> bool: + """ + Define the truthiness of ValidationResults. + """ + return self.passed + + +class DfValidator: + """ + Validate a dataframe using a Pydantic schema. + + Attributes: + df: dataframe to be validated + """ + + _df: DataFrameType + + def __init__(self, df: DataFrameType): + """ + Args: + df: dataframe to be validated + """ + self._df = df + + def _validate_batch(self, schema, df_json: List[Dict]): + """ + Args: + schema: Pydantic schema + batch_df: dataframe batch + + Returns: + list of errors + """ + try: + # Create a Pydantic Validator to validate rows of dataframe + class PdVal(BaseModel): + df: List[schema] + + PdVal(df=df_json) + return [] + + except ValidationError as e: + return e.errors() + + def _df_to_list_of_dict(self, df: DataFrameType, dataframe_type: str) -> List[Dict]: + """ + Create list of dict of dataframe rows on basis of dataframe type + Supports only polars and pandas dataframe + + Args: + df: dataframe to be converted + dataframe_type: type of dataframe + + Returns: + list of dict of dataframe rows + """ + if dataframe_type == "pandas": + return df.to_dict(orient="records") + elif dataframe_type == "polars": + return df.to_dicts() + else: + return [] + + def validate(self, schema: BaseModel) -> DfValidationResult: + """ + Args: + schema: Pydantic schema to be validated for the dataframe row + + Returns: + Validation results + """ + dataframe_type = df_type(self._df) + if dataframe_type is None: + raise ValueError("Unsupported DataFrame") + + df_json: List[Dict] = self._df_to_list_of_dict(self._df, dataframe_type) + + errors = self._validate_batch(schema, df_json) + + return DfValidationResult(len(errors) == 0, errors) diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 611a8a03e..13168ad18 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -22,6 +22,9 @@ from io import StringIO import pandas as pd +import pydantic + +from pandasai.helpers.df_validator import DfValidator from ..smart_datalake import SmartDatalake from ..schemas.df_config import Config @@ -235,6 +238,16 @@ def _get_head_csv(self): self._sample_head = df_head.to_csv(index=False) return self._sample_head + def validate(self, schema: pydantic.BaseModel): + """ + Validates Dataframe rows on the basis Pydantic schema input + (Args): + schema: Pydantic schema class + verbose: Print Errors + """ + df_validator = DfValidator(self.original_import) + return df_validator.validate(schema) + @property def datalake(self): return self._dl diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index e5bf6d823..ba4863bc8 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -8,6 +8,7 @@ import pandas as pd import polars as pl +from pydantic import BaseModel, Field import pytest from pandasai import SmartDataframe @@ -667,7 +668,7 @@ def test_save_pandas_dataframe_duplicate_name(self, llm): # Create a sample DataFrame df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - # Create instances of YourDataFrameClass + # Create instances of SmartDataframe df_object1 = SmartDataframe( df, name="df_duplicate", @@ -699,7 +700,7 @@ def test_save_pandas_no_name(self, llm): # Create a sample DataFrame df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]}) - # Create an instance of YourDataFrameClass without a name + # Create an instance of SmartDataframe without a name df_object = SmartDataframe( df, description="No Name", config={"llm": llm, "enable_cache": False} ) @@ -723,3 +724,101 @@ def test_save_pandas_no_name(self, llm): # Recover file for next test case with open("pandasai.json", "w") as json_file: json_file.write(backup_pandasai) + + def test_pydantic_validate(self, llm): + # Create a sample DataFrame + df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]}) + + # Create an instance of SmartDataframe without a name + df_object = SmartDataframe( + df, description="Name", config={"llm": llm, "enable_cache": False} + ) + + # Pydantic Schema + class TestSchema(BaseModel): + A: int + B: int + + validation_result = df_object.validate(TestSchema) + + assert validation_result.passed is True + + def test_pydantic_validate_false(self, llm): + # Create a sample DataFrame + df = pd.DataFrame({"A": ["Test", "Test2", "Test3", "Test4"], "B": [5, 6, 7, 8]}) + + # Create an instance of SmartDataframe without a name + df_object = SmartDataframe( + df, description="Name", config={"llm": llm, "enable_cache": False} + ) + + # Pydantic Schema + class TestSchema(BaseModel): + A: int + B: int + + validation_result = df_object.validate(TestSchema) + + assert validation_result.passed is False + + def test_pydantic_validate_polars(self, llm): + # Create a sample DataFrame + df = pl.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]}) + + # Create an instance of SmartDataframe without a name + df_object = SmartDataframe( + df, description="Name", config={"llm": llm, "enable_cache": False} + ) + + # Pydantic Schema + class TestSchema(BaseModel): + A: int + B: int + + validation_result = df_object.validate(TestSchema) + assert validation_result.passed is True + + def test_pydantic_validate_false_one_record(self, llm): + # Create a sample DataFrame + df = pd.DataFrame({"A": [1, "test", 3, 4], "B": [5, 6, 7, 8]}) + + # Create an instance of SmartDataframe without a name + df_object = SmartDataframe( + df, description="Name", config={"llm": llm, "enable_cache": False} + ) + + # Pydantic Schema + class TestSchema(BaseModel): + A: int + B: int + + validation_result = df_object.validate(TestSchema) + assert ( + validation_result.passed is False and len(validation_result.errors()) == 1 + ) + + def test_pydantic_validate_complex_schema(self, llm): + # Create a sample DataFrame + df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]}) + + # Create an instance of SmartDataframe without a name + df_object = SmartDataframe( + df, description="Name", config={"llm": llm, "enable_cache": False} + ) + + # Pydantic Schema + class TestSchema(BaseModel): + A: int = Field(..., gt=5) + B: int + + validation_result = df_object.validate(TestSchema) + + assert validation_result.passed is False + + class TestSchema(BaseModel): + A: int = Field(..., lt=5) + B: int + + validation_result = df_object.validate(TestSchema) + + assert validation_result.passed is True