Skip to content

Commit

Permalink
feat: validate dataframe with Pydantic schema (#522)
Browse files Browse the repository at this point in the history
* feat[DFValidator]: Add return types

* fix pre-commit errors

* add return value

* refactor: minor changes in validation

---------

Co-authored-by: arslan-agory <[email protected]>
Co-authored-by: Gabriele Venturi <[email protected]>
  • Loading branch information
3 people authored Sep 4, 2023
1 parent d93171c commit c986050
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 2 deletions.
126 changes: 126 additions & 0 deletions pandasai/helpers/df_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import List, Dict
from pydantic import ValidationError
from pydantic import BaseModel
from pandasai.helpers.df_info import DataFrameType, df_type


class DfValidationResult:
"""
Validation results for a dataframe.
Attributes:
passed: Whether the validation passed or not.
errors: List of errors if the validation failed.
"""

_passed: bool
_errors: List[Dict]

def __init__(self, passed: bool = True, errors: List[Dict] = None):
"""
Args:
passed: Whether the validation passed or not.
errors: List of errors if the validation failed.
"""
if errors is None:
errors = []
self._passed = passed
self._errors = errors

@property
def passed(self):
return self._passed

def errors(self) -> List[Dict]:
return self._errors

def add_error(self, error_message: str):
"""
Add an error message to the validation results.
Args:
error_message: Error message to add.
"""
self._passed = False
self._errors.append(error_message)

def __bool__(self) -> bool:
"""
Define the truthiness of ValidationResults.
"""
return self.passed


class DfValidator:
"""
Validate a dataframe using a Pydantic schema.
Attributes:
df: dataframe to be validated
"""

_df: DataFrameType

def __init__(self, df: DataFrameType):
"""
Args:
df: dataframe to be validated
"""
self._df = df

def _validate_batch(self, schema, df_json: List[Dict]):
"""
Args:
schema: Pydantic schema
batch_df: dataframe batch
Returns:
list of errors
"""
try:
# Create a Pydantic Validator to validate rows of dataframe
class PdVal(BaseModel):
df: List[schema]

PdVal(df=df_json)
return []

except ValidationError as e:
return e.errors()

def _df_to_list_of_dict(self, df: DataFrameType, dataframe_type: str) -> List[Dict]:
"""
Create list of dict of dataframe rows on basis of dataframe type
Supports only polars and pandas dataframe
Args:
df: dataframe to be converted
dataframe_type: type of dataframe
Returns:
list of dict of dataframe rows
"""
if dataframe_type == "pandas":
return df.to_dict(orient="records")
elif dataframe_type == "polars":
return df.to_dicts()
else:
return []

def validate(self, schema: BaseModel) -> DfValidationResult:
"""
Args:
schema: Pydantic schema to be validated for the dataframe row
Returns:
Validation results
"""
dataframe_type = df_type(self._df)
if dataframe_type is None:
raise ValueError("Unsupported DataFrame")

df_json: List[Dict] = self._df_to_list_of_dict(self._df, dataframe_type)

errors = self._validate_batch(schema, df_json)

return DfValidationResult(len(errors) == 0, errors)
13 changes: 13 additions & 0 deletions pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from io import StringIO

import pandas as pd
import pydantic

from pandasai.helpers.df_validator import DfValidator

from ..smart_datalake import SmartDatalake
from ..schemas.df_config import Config
Expand Down Expand Up @@ -235,6 +238,16 @@ def _get_head_csv(self):
self._sample_head = df_head.to_csv(index=False)
return self._sample_head

def validate(self, schema: pydantic.BaseModel):
"""
Validates Dataframe rows on the basis Pydantic schema input
(Args):
schema: Pydantic schema class
verbose: Print Errors
"""
df_validator = DfValidator(self.original_import)
return df_validator.validate(schema)

@property
def datalake(self):
return self._dl
Expand Down
103 changes: 101 additions & 2 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pandas as pd
import polars as pl
from pydantic import BaseModel, Field
import pytest

from pandasai import SmartDataframe
Expand Down Expand Up @@ -667,7 +668,7 @@ def test_save_pandas_dataframe_duplicate_name(self, llm):
# Create a sample DataFrame
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})

# Create instances of YourDataFrameClass
# Create instances of SmartDataframe
df_object1 = SmartDataframe(
df,
name="df_duplicate",
Expand Down Expand Up @@ -699,7 +700,7 @@ def test_save_pandas_no_name(self, llm):
# Create a sample DataFrame
df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]})

# Create an instance of YourDataFrameClass without a name
# Create an instance of SmartDataframe without a name
df_object = SmartDataframe(
df, description="No Name", config={"llm": llm, "enable_cache": False}
)
Expand All @@ -723,3 +724,101 @@ def test_save_pandas_no_name(self, llm):
# Recover file for next test case
with open("pandasai.json", "w") as json_file:
json_file.write(backup_pandasai)

def test_pydantic_validate(self, llm):
# Create a sample DataFrame
df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]})

# Create an instance of SmartDataframe without a name
df_object = SmartDataframe(
df, description="Name", config={"llm": llm, "enable_cache": False}
)

# Pydantic Schema
class TestSchema(BaseModel):
A: int
B: int

validation_result = df_object.validate(TestSchema)

assert validation_result.passed is True

def test_pydantic_validate_false(self, llm):
# Create a sample DataFrame
df = pd.DataFrame({"A": ["Test", "Test2", "Test3", "Test4"], "B": [5, 6, 7, 8]})

# Create an instance of SmartDataframe without a name
df_object = SmartDataframe(
df, description="Name", config={"llm": llm, "enable_cache": False}
)

# Pydantic Schema
class TestSchema(BaseModel):
A: int
B: int

validation_result = df_object.validate(TestSchema)

assert validation_result.passed is False

def test_pydantic_validate_polars(self, llm):
# Create a sample DataFrame
df = pl.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]})

# Create an instance of SmartDataframe without a name
df_object = SmartDataframe(
df, description="Name", config={"llm": llm, "enable_cache": False}
)

# Pydantic Schema
class TestSchema(BaseModel):
A: int
B: int

validation_result = df_object.validate(TestSchema)
assert validation_result.passed is True

def test_pydantic_validate_false_one_record(self, llm):
# Create a sample DataFrame
df = pd.DataFrame({"A": [1, "test", 3, 4], "B": [5, 6, 7, 8]})

# Create an instance of SmartDataframe without a name
df_object = SmartDataframe(
df, description="Name", config={"llm": llm, "enable_cache": False}
)

# Pydantic Schema
class TestSchema(BaseModel):
A: int
B: int

validation_result = df_object.validate(TestSchema)
assert (
validation_result.passed is False and len(validation_result.errors()) == 1
)

def test_pydantic_validate_complex_schema(self, llm):
# Create a sample DataFrame
df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]})

# Create an instance of SmartDataframe without a name
df_object = SmartDataframe(
df, description="Name", config={"llm": llm, "enable_cache": False}
)

# Pydantic Schema
class TestSchema(BaseModel):
A: int = Field(..., gt=5)
B: int

validation_result = df_object.validate(TestSchema)

assert validation_result.passed is False

class TestSchema(BaseModel):
A: int = Field(..., lt=5)
B: int

validation_result = df_object.validate(TestSchema)

assert validation_result.passed is True

0 comments on commit c986050

Please sign in to comment.