Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Validate dataframe with Pydantic schema #522

Merged
merged 4 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions pandasai/helpers/df_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import List, Dict
from pydantic import ValidationError
from pydantic import BaseModel
from pandasai.helpers.df_info import DataFrameType, df_type


class DFValidationResult:
def __init__(self, passed: bool = True, errors: List[Dict] = []):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default mutable arguments in Python can lead to unexpected behavior. Here, errors: List[Dict] = [] is a mutable default argument. If this list is modified, the change will persist across function calls. Consider changing the default value to None and assigning an empty list within the function if the argument is None.

-    def __init__(self, passed: bool = True, errors: List[Dict] = []):
+    def __init__(self, passed: bool = True, errors: List[Dict] = None):
+        if errors is None:
+            errors = []

self._passed = passed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default mutable arguments are a common source of unexpected behaviors and bugs in Python. Here, errors: List[Dict] = [] is a mutable default argument. If the errors list is modified, it will persist across all future function calls that do not provide an errors argument. This can lead to unexpected behavior. Consider using a sentinel value to denote "no argument provided" and create a new list inside the function.

-    def __init__(self, passed: bool = True, errors: List[Dict] = []):
+    def __init__(self, passed: bool = True, errors: List[Dict] = None):
+        if errors is None:
+            errors = []

self._errors = errors

@property
def passed(self):
return self._passed

def errors(self) -> List[Dict]:
return self._errors

def add_error(self, error_message: str):
"""
Add an error message to the validation results.
"""
self.passed = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are trying to set the attribute passed which is not defined in the class. You probably meant to set _passed.

-        self.passed = False
+        self._passed = False

self._errors.append(error_message)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The passed attribute is defined as a read-only property on lines 12-14, but you're trying to modify it here. This will raise an AttributeError. You should modify the _passed attribute directly.

-        self.passed = False
+        self._passed = False


def __bool__(self) -> bool:
"""
Define the truthiness of ValidationResults.
"""
return self.passed


class DFValidator:
def __init__(self, df, verbose=False):
self._df = df
self._verbose = verbose

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making df and verbose keyword-only arguments for better readability and to prevent positional argument errors.

-    def __init__(self, df, verbose=False):
+    def __init__(self, *, df, verbose=False):

def _validate_batch(self, schema, df_json: List[Dict]):
"""
Args:
schema: Pydantic schema
batch_df: dataframe batch
"""
try:
# Create a Pydantic Validator to validate rows of dataframe
class PdVal(BaseModel):
df: List[schema]

PdVal(df=df_json)
return []

except ValidationError as e:
if self._verbose:
print(e)
return e.errors()

def _df_to_list_of_dict(self, df: DataFrameType, dataframe_type: str) -> List[Dict]:
"""
Create list of dict of dataframe rows on basis of dataframe type
Supports only polars and pandas dataframe
"""
if dataframe_type == "pandas":
return df.to_dict(orient="records")
elif dataframe_type == "polars":
return df.to_dicts()
else:
[]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line seems to be a no-op. It creates an empty list and then does nothing with it. If you intended to return an empty list when the dataframe type is neither "pandas" nor "polars", you should use the return statement.

-            []
+            return []


def validate(self, schema: BaseModel) -> DFValidationResult:
"""
Args:
schema: Pydantic schema to be validated for the dataframe row
"""
dataframe_type = df_type(self._df)
if dataframe_type is None:
raise ValueError("UnSupported DataFrame")

df_json: List[Dict] = self._df_to_list_of_dict(self._df, dataframe_type)

errors = self._validate_batch(schema, df_json)

if len(errors) > 0:
return DFValidationResult(False, errors)
else:
return DFValidationResult(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method could be simplified by directly returning the DFValidationResult object without checking the length of errors. The DFValidationResult constructor already sets passed to True by default, so you only need to create a new DFValidationResult when there are errors.

-        if len(errors) > 0:
-            return DFValidationResult(False, errors)
-        else:
-            return DFValidationResult(True)
+        return DFValidationResult(len(errors) == 0, errors)

14 changes: 14 additions & 0 deletions pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from io import StringIO

import pandas as pd
import pydantic

from pandasai.helpers.df_validator import DFValidator

from ..smart_datalake import SmartDatalake
from ..schemas.df_config import Config
Expand Down Expand Up @@ -235,6 +238,17 @@ def _get_head_csv(self):
self._sample_head = df_head.to_csv(index=False)
return self._sample_head

def validate(self, schema: pydantic.BaseModel, verbose: bool = False):
"""
Validates Dataframe rows on the basis Pydantic schema input
(Args):
schema: Pydantic schema class
n_jobs: Parallelism for larger dataframe
verbose: Print Errors
"""
df_validator = DFValidator(self.original_import, verbose)
return df_validator.validate(schema)

@property
def datalake(self):
return self._dl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validate method has been added to the SmartDataframe class. This method uses the DFValidator class to validate the DataFrame rows against a given Pydantic schema. The method signature looks correct, and the use of the DFValidator class seems appropriate. However, the docstring mentions an argument n_jobs which is not present in the function definition. Please update the docstring to match the function signature.

-            n_jobs: Parallelism for larger dataframe

Expand Down
99 changes: 99 additions & 0 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pandas as pd
import polars as pl
from pydantic import BaseModel, Field
import pytest

from pandasai import SmartDataframe
Expand Down Expand Up @@ -658,3 +659,101 @@ def test_save_pandas_no_name(self, llm):
# Recover file for next test case
with open("pandasai.json", "w") as json_file:
json_file.write(backup_pandasai)

def test_pydantic_validate(self, llm):
# Create a sample DataFrame
df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]})

# Create an instance of YourDataFrameClass without a name
df_object = SmartDataframe(
df, description="Name", config={"llm": llm, "enable_cache": False}
)

# Pydantic Schema
class TestSchema(BaseModel):
A: int
B: int

validation_result = df_object.validate(TestSchema)

assert validation_result.passed is True

def test_pydantic_validate_false(self, llm):
# Create a sample DataFrame
df = pd.DataFrame({"A": ["Test", "Test2", "Test3", "Test4"], "B": [5, 6, 7, 8]})

# Create an instance of YourDataFrameClass without a name
df_object = SmartDataframe(
df, description="Name", config={"llm": llm, "enable_cache": False}
)

# Pydantic Schema
class TestSchema(BaseModel):
A: int
B: int

validation_result = df_object.validate(TestSchema)

assert validation_result.passed is False

def test_pydantic_validate_polars(self, llm):
# Create a sample DataFrame
df = pl.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]})

# Create an instance of YourDataFrameClass without a name
df_object = SmartDataframe(
df, description="Name", config={"llm": llm, "enable_cache": False}
)

# Pydantic Schema
class TestSchema(BaseModel):
A: int
B: int

validation_result = df_object.validate(TestSchema)
assert validation_result.passed is True

def test_pydantic_validate_false_one_record(self, llm):
# Create a sample DataFrame
df = pd.DataFrame({"A": [1, "test", 3, 4], "B": [5, 6, 7, 8]})

# Create an instance of YourDataFrameClass without a name
df_object = SmartDataframe(
df, description="Name", config={"llm": llm, "enable_cache": False}
)

# Pydantic Schema
class TestSchema(BaseModel):
A: int
B: int

validation_result = df_object.validate(TestSchema)
assert (
validation_result.passed is False and len(validation_result.errors()) == 1
)

def test_pydantic_validate_complex_Schema(self, llm):
# Create a sample DataFrame
df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [5, 6, 7, 8]})

# Create an instance of YourDataFrameClass without a name
df_object = SmartDataframe(
df, description="Name", config={"llm": llm, "enable_cache": False}
)

# Pydantic Schema
class TestSchema(BaseModel):
A: int = Field(..., gt=5)
B: int

validation_result = df_object.validate(TestSchema)

assert validation_result.passed is False

class TestSchema(BaseModel):
A: int = Field(..., lt=5)
B: int

validation_result = df_object.validate(TestSchema)

assert validation_result.passed is True