diff --git a/pandera/typing/pyspark.py b/pandera/typing/pyspark.py index 5f0934cec..9c6cba103 100644 --- a/pandera/typing/pyspark.py +++ b/pandera/typing/pyspark.py @@ -1,7 +1,9 @@ -"""Pandera type annotations for Dask.""" - -from typing import TYPE_CHECKING, Generic, TypeVar +"""Pandera type annotations for Pyspark Pandas.""" +import functools +from typing import TYPE_CHECKING, Generic, TypeVar, Any, get_args +from pandera.engines import PYDANTIC_V2 +from pandera.errors import SchemaInitError, SchemaError from pandera.typing.common import ( DataFrameBase, GenericDtype, @@ -25,9 +27,86 @@ T = DataFrameModel +class _PydanticIntegrationMixIn: + """Mixin class for pydantic integration with pyspark DataFrames""" + + @classmethod + def _get_schema_model(cls, field): + if not field.sub_fields: + raise TypeError( + "Expected a typed pandera.typing.DataFrame," + " e.g. DataFrame[Schema]" + ) + schema_model = field.sub_fields[0].type_ + return schema_model + + if PYDANTIC_V2: + + # pylint: disable=import-outside-toplevel + from pydantic import GetCoreSchemaHandler + from pydantic_core import core_schema + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + schema_model = get_args(_source_type)[0] + return core_schema.no_info_plain_validator_function( + functools.partial( + cls.pydantic_validate, + schema_model=schema_model, + ), + ) + + else: + + @classmethod + def __get_validators__(cls): + yield cls._pydantic_validate + + @classmethod + def pydantic_validate(cls, obj: Any, schema_model) -> ps.DataFrame: + """ + Verify that the input can be converted into a pandas dataframe that + meets all schema requirements. + + This is for pydantic >= v2 + """ + try: + schema = schema_model.to_schema() + except SchemaInitError as exc: + raise ValueError( + f"Cannot use {cls.__name__} as a pydantic type as its " + "DataFrameModel cannot be converted to a DataFrameSchema.\n" + f"Please revisit the model to address the following errors:" + f"\n{exc}" + ) from exc + + try: + valid_data = schema.validate(obj, lazy=False) + except SchemaError as exc: + raise ValueError(str(exc)) from exc + + return valid_data + + @classmethod + def _pydantic_validate(cls, obj: Any, field) -> ps.DataFrame: + """ + Verify that the input can be converted into a pandas dataframe that + meets all schema requirements. + + This is for pydantic < v1 + """ + schema_model = cls._get_schema_model(field) + return cls.pydantic_validate(obj, schema_model) + + if PYSPARK_INSTALLED: # pylint: disable=too-few-public-methods,arguments-renamed - class DataFrame(DataFrameBase, ps.DataFrame, Generic[T]): + + class DataFrame( + DataFrameBase, _PydanticIntegrationMixIn, ps.DataFrame, Generic[T] + ): """ Representation of dask.dataframe.DataFrame, only used for type annotation. diff --git a/pandera/typing/pyspark_sql.py b/pandera/typing/pyspark_sql.py index 91cbcea35..42dab388d 100644 --- a/pandera/typing/pyspark_sql.py +++ b/pandera/typing/pyspark_sql.py @@ -4,6 +4,7 @@ from pandera.typing.common import DataFrameBase from pandera.typing.pandas import DataFrameModel, _GenericAlias +from pandera.typing.pyspark import _PydanticIntegrationMixIn try: import pyspark.sql as ps @@ -53,7 +54,9 @@ if PYSPARK_SQL_INSTALLED: # pylint: disable=too-few-public-methods,arguments-renamed - class DataFrame(DataFrameBase, ps.DataFrame, Generic[T]): + class DataFrame( + DataFrameBase, _PydanticIntegrationMixIn, ps.DataFrame, Generic[T] + ): """ Representation of dask.dataframe.DataFrame, only used for type annotation. diff --git a/tests/pyspark/conftest.py b/tests/pyspark/conftest.py index 28f0a1e5a..43f0ba330 100644 --- a/tests/pyspark/conftest.py +++ b/tests/pyspark/conftest.py @@ -15,7 +15,7 @@ def spark() -> SparkSession: """ creates spark session """ - return SparkSession.builder.getOrCreate() + yield SparkSession.builder.appName("Pandera Pyspark Testing").getOrCreate() @pytest.fixture(scope="session") diff --git a/tests/pyspark/test_pyspark_pydantic_integration.py b/tests/pyspark/test_pyspark_pydantic_integration.py new file mode 100644 index 000000000..0f6f6ab08 --- /dev/null +++ b/tests/pyspark/test_pyspark_pydantic_integration.py @@ -0,0 +1,156 @@ +"""Tests for integration of Pyspark DataFrames with Pydantic.""" +# pylint:disable=redefined-outer-name,abstract-method + +from enum import Enum + +import pandas as pd +import pytest +from pydantic import BaseModel, ValidationError +from pyspark.testing.utils import assertDataFrameEqual +import pyspark.sql.types as T + +from pandera.typing.pyspark_sql import DataFrame as PySparkSQLDataFrame +from pandera.typing.pyspark import DataFrame as PysparkPandasDataFrame, Series +from pandera.pyspark import DataFrameModel as PysparkSQLDataFrameModel +from pandera import DataFrameModel + + +class PysparkAPIs(Enum): + """ + Enum for the different Pyspark APIs. + """ + + PANDAS = "pandas" + SQL = "SQL" + + +@pytest.fixture( + params=[PysparkAPIs.PANDAS, PysparkAPIs.SQL], + ids=["pyspark_pandas", "pyspark_sql"], +) +def pyspark_api(request): + """ + Fixture for the different Pyspark APIs. + """ + return request.param + + +@pytest.fixture +def sample_data_frame_model_class(pyspark_api): + """ + Fixture for a sample DataFrameModel class. + """ + if pyspark_api == PysparkAPIs.SQL: + + class SampleSchema(PysparkSQLDataFrameModel): + """ + Sample schema model + """ + + product: T.StringType() + price: T.IntegerType() + + elif pyspark_api == PysparkAPIs.PANDAS: + + class SampleSchema(DataFrameModel): + """ + Sample schema model + """ + + product: Series[str] + price: Series[pd.Int32Dtype] + + else: + raise ValueError(f"Unknown data frame library: {pyspark_api}") + + return SampleSchema + + +@pytest.fixture +def pydantic_container(pyspark_api, sample_data_frame_model_class): + """ + Fixture for a Pydantic container with a DataFrameModel as a field. + """ + if pyspark_api == PysparkAPIs.PANDAS: + + class PydanticContainer(BaseModel): + """ + Pydantic container with a DataFrameModel as a field. + """ + + data: PysparkPandasDataFrame[sample_data_frame_model_class] + + elif pyspark_api == PysparkAPIs.SQL: + + class PydanticContainer(BaseModel): + """ + Pydantic container with a DataFrameModel as a field. + """ + + data: PySparkSQLDataFrame[sample_data_frame_model_class] + + else: + raise ValueError(f"Unknown data frame library: {pyspark_api}") + + return PydanticContainer + + +@pytest.fixture +def correct_data(spark, sample_data, sample_spark_schema, pyspark_api): + """ + Correct data that should pass validation. + """ + df = spark.createDataFrame(sample_data, sample_spark_schema) + if pyspark_api == PysparkAPIs.PANDAS: + return df.pandas_api() + elif pyspark_api == PysparkAPIs.SQL: + return df + else: + raise ValueError(f"Unknown data frame library: {pyspark_api}") + + +@pytest.fixture +def incorrect_data(spark, pyspark_api): + """ + Incorrect data that should fail validation. + """ + data = [ + (1, "Apples"), + (2, "Bananas"), + ] + df = spark.createDataFrame(data, ["product", "price"]) + if pyspark_api == PysparkAPIs.PANDAS: + return df.pandas_api() + elif pyspark_api == PysparkAPIs.SQL: + return df + else: + raise ValueError(f"Unknown data frame library: {pyspark_api}") + + +def test_pydantic_model_instantiates_with_correct_data( + correct_data, pydantic_container +): + """ + Test that a Pydantic model can be instantiated with a DataFrameModel when data is valid. + """ + my_container = pydantic_container(data=correct_data) + assertDataFrameEqual(my_container.data, correct_data) + + +def test_pydantic_model_throws_validation_error_with_incorrect_data( + incorrect_data, pydantic_container, pyspark_api +): + """ + Test that a Pydantic model throws a ValidationError when data is invalid. + """ + if pyspark_api == PysparkAPIs.PANDAS: + expected_error_substring = "expected series 'product' to have type str" + elif pyspark_api == PysparkAPIs.SQL: + expected_error_substring = ( + "expected column 'product' to have type StringType()" + ) + else: + raise ValueError(f"Unknown data frame library: {pyspark_api}") + + with pytest.raises(ValidationError, match=expected_error_substring): + pydantic_container(data=incorrect_data)