Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/1446: Ensure Pydantic Models Can Be Created withtyping.pyspark.DataFrame or typing.pyspark_sql.DataFrame Generic #1447

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
87 changes: 83 additions & 4 deletions pandera/typing/pyspark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Pandera type annotations for Dask."""

from typing import TYPE_CHECKING, Generic, TypeVar
"""Pandera type annotations for Pyspark Pandas."""
import functools
from typing import TYPE_CHECKING, Generic, TypeVar, Any, get_args

from pandera.engines import PYDANTIC_V2
from pandera.errors import SchemaInitError, SchemaError
from pandera.typing.common import (
DataFrameBase,
GenericDtype,
Expand All @@ -25,9 +27,86 @@
T = DataFrameModel


class _PydanticIntegrationMixIn:
"""Mixin class for pydantic integration with pyspark DataFrames"""

@classmethod
def _get_schema_model(cls, field):
if not field.sub_fields:
raise TypeError(
"Expected a typed pandera.typing.DataFrame,"
" e.g. DataFrame[Schema]"
)
schema_model = field.sub_fields[0].type_
return schema_model

if PYDANTIC_V2:

# pylint: disable=import-outside-toplevel
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
schema_model = get_args(_source_type)[0]
return core_schema.no_info_plain_validator_function(
functools.partial(
cls.pydantic_validate,
schema_model=schema_model,
),
)

else:

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate

@classmethod
def pydantic_validate(cls, obj: Any, schema_model) -> ps.DataFrame:
"""
Verify that the input can be converted into a pandas dataframe that
meets all schema requirements.

This is for pydantic >= v2
"""
try:
schema = schema_model.to_schema()
except SchemaInitError as exc:
raise ValueError(
f"Cannot use {cls.__name__} as a pydantic type as its "
"DataFrameModel cannot be converted to a DataFrameSchema.\n"
f"Please revisit the model to address the following errors:"
f"\n{exc}"
) from exc

try:
valid_data = schema.validate(obj, lazy=False)
except SchemaError as exc:
raise ValueError(str(exc)) from exc

return valid_data

@classmethod
def _pydantic_validate(cls, obj: Any, field) -> ps.DataFrame:
"""
Verify that the input can be converted into a pandas dataframe that
meets all schema requirements.

This is for pydantic < v1
"""
schema_model = cls._get_schema_model(field)
return cls.pydantic_validate(obj, schema_model)


if PYSPARK_INSTALLED:
# pylint: disable=too-few-public-methods,arguments-renamed
class DataFrame(DataFrameBase, ps.DataFrame, Generic[T]):

class DataFrame(
DataFrameBase, _PydanticIntegrationMixIn, ps.DataFrame, Generic[T]
):
"""
Representation of dask.dataframe.DataFrame, only used for type
annotation.
Expand Down
5 changes: 4 additions & 1 deletion pandera/typing/pyspark_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pandera.typing.common import DataFrameBase
from pandera.typing.pandas import DataFrameModel, _GenericAlias
from pandera.typing.pyspark import _PydanticIntegrationMixIn

try:
import pyspark.sql as ps
Expand Down Expand Up @@ -53,7 +54,9 @@

if PYSPARK_SQL_INSTALLED:
# pylint: disable=too-few-public-methods,arguments-renamed
class DataFrame(DataFrameBase, ps.DataFrame, Generic[T]):
class DataFrame(
DataFrameBase, _PydanticIntegrationMixIn, ps.DataFrame, Generic[T]
):
"""
Representation of dask.dataframe.DataFrame, only used for type
annotation.
Expand Down
2 changes: 1 addition & 1 deletion tests/pyspark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def spark() -> SparkSession:
"""
creates spark session
"""
return SparkSession.builder.getOrCreate()
yield SparkSession.builder.appName("Pandera Pyspark Testing").getOrCreate()


@pytest.fixture(scope="session")
Expand Down
156 changes: 156 additions & 0 deletions tests/pyspark/test_pyspark_pydantic_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Tests for integration of Pyspark DataFrames with Pydantic."""
# pylint:disable=redefined-outer-name,abstract-method

from enum import Enum

import pandas as pd
import pytest
from pydantic import BaseModel, ValidationError
from pyspark.testing.utils import assertDataFrameEqual
import pyspark.sql.types as T

from pandera.typing.pyspark_sql import DataFrame as PySparkSQLDataFrame
from pandera.typing.pyspark import DataFrame as PysparkPandasDataFrame, Series
from pandera.pyspark import DataFrameModel as PysparkSQLDataFrameModel
from pandera import DataFrameModel


class PysparkAPIs(Enum):
"""
Enum for the different Pyspark APIs.
"""

PANDAS = "pandas"
SQL = "SQL"


@pytest.fixture(
params=[PysparkAPIs.PANDAS, PysparkAPIs.SQL],
ids=["pyspark_pandas", "pyspark_sql"],
)
def pyspark_api(request):
"""
Fixture for the different Pyspark APIs.
"""
return request.param


@pytest.fixture
def sample_data_frame_model_class(pyspark_api):
"""
Fixture for a sample DataFrameModel class.
"""
if pyspark_api == PysparkAPIs.SQL:

class SampleSchema(PysparkSQLDataFrameModel):
"""
Sample schema model
"""

product: T.StringType()
price: T.IntegerType()

elif pyspark_api == PysparkAPIs.PANDAS:

class SampleSchema(DataFrameModel):
"""
Sample schema model
"""

product: Series[str]
price: Series[pd.Int32Dtype]

else:
raise ValueError(f"Unknown data frame library: {pyspark_api}")

return SampleSchema


@pytest.fixture
def pydantic_container(pyspark_api, sample_data_frame_model_class):
"""
Fixture for a Pydantic container with a DataFrameModel as a field.
"""
if pyspark_api == PysparkAPIs.PANDAS:

class PydanticContainer(BaseModel):
"""
Pydantic container with a DataFrameModel as a field.
"""

data: PysparkPandasDataFrame[sample_data_frame_model_class]

elif pyspark_api == PysparkAPIs.SQL:

class PydanticContainer(BaseModel):
"""
Pydantic container with a DataFrameModel as a field.
"""

data: PySparkSQLDataFrame[sample_data_frame_model_class]

else:
raise ValueError(f"Unknown data frame library: {pyspark_api}")

return PydanticContainer


@pytest.fixture
def correct_data(spark, sample_data, sample_spark_schema, pyspark_api):
"""
Correct data that should pass validation.
"""
df = spark.createDataFrame(sample_data, sample_spark_schema)
if pyspark_api == PysparkAPIs.PANDAS:
return df.pandas_api()
elif pyspark_api == PysparkAPIs.SQL:
return df
else:
raise ValueError(f"Unknown data frame library: {pyspark_api}")


@pytest.fixture
def incorrect_data(spark, pyspark_api):
"""
Incorrect data that should fail validation.
"""
data = [
(1, "Apples"),
(2, "Bananas"),
]
df = spark.createDataFrame(data, ["product", "price"])
if pyspark_api == PysparkAPIs.PANDAS:
return df.pandas_api()
elif pyspark_api == PysparkAPIs.SQL:
return df
else:
raise ValueError(f"Unknown data frame library: {pyspark_api}")


def test_pydantic_model_instantiates_with_correct_data(
correct_data, pydantic_container
):
"""
Test that a Pydantic model can be instantiated with a DataFrameModel when data is valid.
"""
my_container = pydantic_container(data=correct_data)
assertDataFrameEqual(my_container.data, correct_data)


def test_pydantic_model_throws_validation_error_with_incorrect_data(
incorrect_data, pydantic_container, pyspark_api
):
"""
Test that a Pydantic model throws a ValidationError when data is invalid.
"""
if pyspark_api == PysparkAPIs.PANDAS:
expected_error_substring = "expected series 'product' to have type str"
elif pyspark_api == PysparkAPIs.SQL:
expected_error_substring = (
"expected column 'product' to have type StringType()"
)
else:
raise ValueError(f"Unknown data frame library: {pyspark_api}")

with pytest.raises(ValidationError, match=expected_error_substring):
pydantic_container(data=incorrect_data)
Loading