Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dropping invalid rows for pyspark backend #1639

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
152 changes: 107 additions & 45 deletions pandera/backends/pyspark/builtin_checks.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""PySpark implementation of built-in checks"""

import re
from typing import Any, Iterable, TypeVar

from typing import Any, Iterable, TypeVar, Union
import pyspark.sql as ps
import pyspark.sql.types as pst
from pyspark.sql.functions import col
from pyspark.sql.functions import col, when

import pandera.strategies as st
from pandera.api.extensions import register_builtin_check
from pandera.api.pyspark.types import PysparkDataframeColumnObject
from pandera.backends.pyspark.decorators import register_input_datatypes
from pandera.backends.pyspark.utils import convert_to_list
from pandera.backends.pyspark.decorators import (
builtin_check_validation_mode,
register_input_datatypes,
)
from pandera.backends.pyspark.utils import (
convert_to_list,
)

T = TypeVar("T")
ALL_NUMERIC_TYPE = [
Expand All @@ -37,14 +42,21 @@
ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE
)
)
def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool:
@builtin_check_validation_mode()
def equal_to(
data: PysparkDataframeColumnObject,
value: Any,
should_validate_full_table: bool,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of passing this in as an argument, you can use pandera.config.get_config_context to get the full_table_validation configuration value. This is so that the API for each check is consistent across the different backends.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cosmicBboy. Will make the recommended change.

Also, is there a way that you can suggest to keep the PANDERA_FULL_TABLE_VALIDATION config value to be False when the backend is pyspark and True when the backend is pandas? Did not find a good way to do this, hence asking for a suggestion 😅.

Copy link
Collaborator

@cosmicBboy cosmicBboy May 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the config_context context manager in the validate methods for each backend to control this behavior: https://github.com/unionai-oss/pandera/blob/main/pandera/config.py#L71

for example this is used in the polars backend:

with config_context(validation_depth=get_validation_depth(check_obj)):

) -> Union[bool, ps.Column]:
"""Ensure all elements of a data container equal a certain value.

:param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check
:param value: values in this DataFrame data structure must be
equal to this value.
"""
cond = col(data.column_name) == value
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -58,14 +70,21 @@ def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool:
ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE
)
)
def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool:
@builtin_check_validation_mode()
def not_equal_to(
data: PysparkDataframeColumnObject,
value: Any,
should_validate_full_table: bool,
) -> Union[bool, ps.Column]:
"""Ensure no elements of a data container equals a certain value.

:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "column_name".
:param value: This value must not occur in the checked
"""
cond = col(data.column_name) != value
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -76,7 +95,12 @@ def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool:
@register_input_datatypes(
acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE)
)
def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool:
@builtin_check_validation_mode()
def greater_than(
data: PysparkDataframeColumnObject,
min_value: Any,
should_validate_full_table: bool,
) -> Union[bool, ps.Column]:
"""
Ensure values of a data container are strictly greater than a minimum
value.
Expand All @@ -85,6 +109,8 @@ def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool:
:param min_value: Lower bound to be exceeded.
"""
cond = col(data.column_name) > min_value
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -96,16 +122,21 @@ def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool:
@register_input_datatypes(
acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE)
)
@builtin_check_validation_mode()
def greater_than_or_equal_to(
data: PysparkDataframeColumnObject, min_value: Any
) -> bool:
data: PysparkDataframeColumnObject,
min_value: Any,
should_validate_full_table: bool,
) -> Union[bool, ps.Column]:
"""Ensure all values are greater or equal a certain value.
:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "column_name".
:param min_value: Allowed minimum value for values of a series. Must be
a type comparable to the dtype of the column datatype of pyspark
"""
cond = col(data.column_name) >= min_value
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -117,7 +148,12 @@ def greater_than_or_equal_to(
@register_input_datatypes(
acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE)
)
def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool:
@builtin_check_validation_mode()
def less_than(
data: PysparkDataframeColumnObject,
max_value: Any,
should_validate_full_table: bool,
) -> Union[bool, ps.Column]:
"""Ensure values of a series are strictly below a maximum value.

:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
Expand All @@ -129,6 +165,8 @@ def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool:
if max_value is None: # pragma: no cover
raise ValueError("max_value must not be None")
cond = col(data.column_name) < max_value
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -140,9 +178,12 @@ def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool:
@register_input_datatypes(
acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE)
)
@builtin_check_validation_mode()
def less_than_or_equal_to(
data: PysparkDataframeColumnObject, max_value: Any
) -> bool:
data: PysparkDataframeColumnObject,
max_value: Any,
should_validate_full_table: bool,
) -> Union[bool, ps.Column]:
"""Ensure values of a series are strictly below a maximum value.

:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
Expand All @@ -154,6 +195,8 @@ def less_than_or_equal_to(
if max_value is None: # pragma: no cover
raise ValueError("max_value must not be None")
cond = col(data.column_name) <= max_value
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -165,13 +208,15 @@ def less_than_or_equal_to(
@register_input_datatypes(
acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE)
)
@builtin_check_validation_mode()
def in_range(
data: PysparkDataframeColumnObject,
min_value: T,
max_value: T,
include_min: bool = True,
include_max: bool = True,
):
should_validate_full_table: bool = False,
) -> Union[bool, ps.Column]:
"""Ensure all values of a column are within an interval.

Both endpoints must be a type comparable to the dtype of the
Expand Down Expand Up @@ -201,7 +246,10 @@ def in_range(
if include_max
else col(data.column_name) < max_value
)
return data.dataframe.filter(~(cond_right & cond_left)).limit(1).count() == 0 # type: ignore
cond = cond_right & cond_left
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0 # type: ignore


@register_builtin_check(
Expand All @@ -213,7 +261,12 @@ def in_range(
ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE
)
)
def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool:
@builtin_check_validation_mode()
def isin(
data: PysparkDataframeColumnObject,
allowed_values: Iterable,
should_validate_full_table: bool,
) -> Union[bool, ps.Column]:
"""Ensure only allowed values occur within a series.

Remember it can be a compute intensive check on large dataset. So, use it with caution.
Expand All @@ -229,14 +282,10 @@ def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool:
to access the dataframe is "dataframe" and column name using "column_name".
:param allowed_values: The set of allowed values. May be any iterable.
"""
return (
data.dataframe.filter(
~col(data.column_name).isin(list(allowed_values))
)
.limit(1)
.count()
== 0
)
cond = col(data.column_name).isin(list(allowed_values))
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0


@register_builtin_check(
Expand All @@ -248,9 +297,12 @@ def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool:
ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE
)
)
@builtin_check_validation_mode()
def notin(
data: PysparkDataframeColumnObject, forbidden_values: Iterable
) -> bool:
data: PysparkDataframeColumnObject,
forbidden_values: Iterable,
should_validate_full_table: bool,
) -> Union[bool, ps.Column]:
"""Ensure some defined values don't occur within a series.

Remember it can be a compute intensive check on large dataset. So, use it with caution.
Expand All @@ -265,24 +317,23 @@ def notin(
:param forbidden_values: The set of values which should not occur. May
be any iterable.
"""
return (
data.dataframe.filter(
col(data.column_name).isin(list(forbidden_values))
)
.limit(1)
.count()
== 0
)
cond = col(data.column_name).isin(list(forbidden_values))
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(cond).limit(1).count() == 0


@register_builtin_check(
strategy=st.str_contains_strategy,
error="str_contains('{pattern}')",
)
@register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE))
@builtin_check_validation_mode()
def str_contains(
data: PysparkDataframeColumnObject, pattern: re.Pattern
) -> bool:
data: PysparkDataframeColumnObject,
pattern: re.Pattern,
should_validate_full_table: bool,
) -> Union[bool, ps.Column]:
"""Ensure that a pattern can be found within each row.

Remember it can be a compute intensive check on large dataset. So, use it with caution.
Expand All @@ -291,20 +342,22 @@ def str_contains(
to access the dataframe is "dataframe" and column name using "column_name".
:param pattern: Regular expression pattern to use for searching
"""

return (
data.dataframe.filter(~col(data.column_name).rlike(pattern.pattern))
.limit(1)
.count()
== 0
)
cond = col(data.column_name).rlike(pattern.pattern)
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0


@register_builtin_check(
error="str_startswith('{string}')",
)
@register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE))
def str_startswith(data: PysparkDataframeColumnObject, string: str) -> bool:
@builtin_check_validation_mode()
def str_startswith(
data: PysparkDataframeColumnObject,
string: str,
should_validate_full_table: bool,
) -> bool:
"""Ensure that all values start with a certain string.

Remember it can be a compute intensive check on large dataset. So, use it with caution.
Expand All @@ -314,14 +367,21 @@ def str_startswith(data: PysparkDataframeColumnObject, string: str) -> bool:
:param string: String all values should start with
"""
cond = col(data.column_name).startswith(string)
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0


@register_builtin_check(
strategy=st.str_endswith_strategy, error="str_endswith('{string}')"
)
@register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE))
def str_endswith(data: PysparkDataframeColumnObject, string: str) -> bool:
@builtin_check_validation_mode()
def str_endswith(
data: PysparkDataframeColumnObject,
string: str,
should_validate_full_table: bool,
) -> bool:
"""Ensure that all values end with a certain string.

Remember it can be a compute intensive check on large dataset. So, use it with caution.
Expand All @@ -331,4 +391,6 @@ def str_endswith(data: PysparkDataframeColumnObject, string: str) -> bool:
:param string: String all values should end with
"""
cond = col(data.column_name).endswith(string)
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0
22 changes: 22 additions & 0 deletions pandera/backends/pyspark/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyspark.sql import DataFrame

from pandera.api.pyspark.types import PysparkDefaultTypes
from pandera.backends.pyspark.utils import get_full_table_validation
from pandera.config import ValidationDepth, get_config_context
from pandera.errors import SchemaError
from pandera.validation_depth import ValidationScope
Expand Down Expand Up @@ -192,3 +193,24 @@ def cached_check_obj():
return wrapper

return _wrapper


def builtin_check_validation_mode():
"""
Evaluates whether the full table validation is enabled or not for a builtin check and passes it to the function.
"""

def _wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Skip if not enabled
should_validate_full_table = get_full_table_validation()
return func(
*args,
**kwargs,
should_validate_full_table=should_validate_full_table,
)

return wrapper

return _wrapper
13 changes: 13 additions & 0 deletions pandera/backends/pyspark/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""pyspark backend utilities."""

from pandera.config import get_config_context


def convert_to_list(*args):
"""Converts arguments to a list"""
Expand All @@ -11,3 +13,14 @@ def convert_to_list(*args):
converted_list.append(arg)

return converted_list


def get_full_table_validation():
"""
Get the full table validation configuration.
- By default, full table validation is disabled for pyspark dataframes for performance reasons.
"""
config = get_config_context()
if config.full_table_validation is not None:
return config.full_table_validation
return False
Loading
Loading