Skip to content

Commit

Permalink
fix multimethod bug in pyspark (#1260)
Browse files Browse the repository at this point in the history
* Add tests for pyspark dataframemodel custom checks

Signed-off-by: mfkaptan <[email protected]>

* Add a second overloaded preprocess method

Making `key` argument optional did not work because of an existing
bug in multimethod: coady/multimethod#90
Added a second overloaded `preprocess` method that doesn't have `key`
arg so multimethod lib can dispatch correctly.
Same with apply function.

Signed-off-by: mfkaptan <[email protected]>

* Fix isinstance call with adding type of the BooleanType

Signed-off-by: mfkaptan <[email protected]>

---------

Signed-off-by: mfkaptan <[email protected]>
Co-authored-by: mfkaptan <[email protected]>
  • Loading branch information
cosmicBboy and mfkaptan-motius authored Aug 11, 2023
1 parent 57d8269 commit ef4b5aa
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pandera/api/pyspark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ def is_table(obj):

def is_bool(x):
"""Verifies whether an object is a boolean type."""
return isinstance(x, (bool, pst.BooleanType()))
return isinstance(x, (bool, type(pst.BooleanType())))
31 changes: 26 additions & 5 deletions pandera/backends/pyspark/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,22 @@ def _format_groupby_input(
) -> Dict[str, DataFrame]: # pragma: no cover
raise NotImplementedError

@overload # type: ignore [no-redef]
def preprocess(
self,
check_obj: DataFrame,
key: str, # type: ignore [valid-type]
) -> DataFrame:
return check_obj

# Workaround for multimethod not supporting Optional arguments
# such as `key: Optional[str]` (fails in multimethod)
# https://github.com/coady/multimethod/issues/90
# FIXME when the multimethod supports Optional args # pylint: disable=fixme
@overload # type: ignore [no-redef]
def preprocess(
self,
check_obj: DataFrame, # type: ignore [valid-type]
key: str,
) -> DataFrame:
return check_obj

Expand Down Expand Up @@ -104,11 +115,21 @@ def __call__(
check_obj: DataFrame,
key: Optional[str] = None,
) -> CheckResult:
check_obj = self.preprocess(check_obj, key)
if key is None:
# pylint:disable=no-value-for-parameter
check_obj = self.preprocess(check_obj)
else:
check_obj = self.preprocess(check_obj, key)

try:
check_output = self.apply( # pylint:disable=too-many-function-args
check_obj, key, self.check._check_kwargs
)
if key is None:
check_output = self.apply(check_obj)
else:
check_output = (
self.apply( # pylint:disable=too-many-function-args
check_obj, key, self.check._check_kwargs
)
)

except DispatchError as exc: # pragma: no cover
if exc.__cause__ is not None:
Expand Down
31 changes: 31 additions & 0 deletions tests/pyspark/test_pyspark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

import pandera
import pandera.api.extensions as pax
import pandera.pyspark as pa
from pandera.config import PanderaConfig, ValidationDepth
from pandera.pyspark import DataFrameModel, DataFrameSchema, Field
Expand Down Expand Up @@ -324,3 +325,33 @@ class Schema(DataFrameModel): # pylint:disable=missing-class-docstring
match="'a' can only be assigned a 'Field'",
):
Schema.to_schema()


def test_registered_dataframemodel_checks(spark) -> None:
"""Check that custom registered checks work"""

@pax.register_check_method(
supported_types=DataFrame,
)
def always_true_check(df: DataFrame):
# pylint: disable=unused-argument
return True

class ExampleDFModel(
DataFrameModel
): # pylint:disable=missing-class-docstring
name: str
age: int

class Config:
coerce = True
always_true_check = ()

example_data_cols = ("name", "age")
example_data = [("foo", 42), ("bar", 24)]

df = spark.createDataFrame(example_data, example_data_cols)

out = ExampleDFModel.validate(df, lazy=False)

assert not out.pandera.errors

0 comments on commit ef4b5aa

Please sign in to comment.