Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dropping invalid rows for pyspark backend #1639

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
133 changes: 97 additions & 36 deletions pandera/backends/pyspark/builtin_checks.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
"""PySpark implementation of built-in checks"""

from typing import Any, Iterable, TypeVar
import re
from typing import Any, Iterable, TypeVar, Union

Check warning on line 4 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L3-L4

Added lines #L3 - L4 were not covered by tests

import pyspark.sql as ps

Check warning on line 6 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L6

Added line #L6 was not covered by tests
import pyspark.sql.types as pst
from pyspark.sql.functions import col
from pyspark.sql.functions import col, when

Check warning on line 8 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L8

Added line #L8 was not covered by tests

import pandera.strategies as st
from pandera.api.extensions import register_builtin_check
from pandera.api.pyspark.types import PysparkDataframeColumnObject
from pandera.backends.pyspark.decorators import register_input_datatypes
from pandera.backends.pyspark.utils import convert_to_list
from pandera.backends.pyspark.decorators import (

Check warning on line 13 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L13

Added line #L13 was not covered by tests
register_input_datatypes,
)
from pandera.backends.pyspark.utils import (

Check warning on line 16 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L16

Added line #L16 was not covered by tests
convert_to_list,
get_full_table_validation,
)

T = TypeVar("T")
ALL_NUMERIC_TYPE = [
Expand All @@ -22,6 +29,7 @@
pst.FloatType,
]
ALL_DATE_TYPE = [pst.DateType, pst.TimestampType]
# TODO: Fix the boolean typo in a new PR or in a different commit if that is acceptable
BOLEAN_TYPE = pst.BooleanType
BINARY_TYPE = pst.BinaryType
STRING_TYPE = pst.StringType
Expand All @@ -36,14 +44,20 @@
ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE
)
)
def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool:
def equal_to(

Check warning on line 47 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L47

Added line #L47 was not covered by tests
data: PysparkDataframeColumnObject,
value: Any,
) -> Union[bool, ps.Column]:
"""Ensure all elements of a data container equal a certain value.

:param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check
:param value: values in this DataFrame data structure must be
equal to this value.
"""
cond = col(data.column_name) == value
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))

Check warning on line 60 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L58-L60

Added lines #L58 - L60 were not covered by tests
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -57,14 +71,20 @@
ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE
)
)
def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool:
def not_equal_to(

Check warning on line 74 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L74

Added line #L74 was not covered by tests
data: PysparkDataframeColumnObject,
value: Any,
) -> Union[bool, ps.Column]:
"""Ensure no elements of a data container equals a certain value.

:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "column_name".
:param value: This value must not occur in the checked
"""
cond = col(data.column_name) != value
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))

Check warning on line 87 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L85-L87

Added lines #L85 - L87 were not covered by tests
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -75,7 +95,10 @@
@register_input_datatypes(
acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE)
)
def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool:
def greater_than(

Check warning on line 98 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L98

Added line #L98 was not covered by tests
data: PysparkDataframeColumnObject,
min_value: Any,
) -> Union[bool, ps.Column]:
"""
Ensure values of a data container are strictly greater than a minimum
value.
Expand All @@ -84,6 +107,9 @@
:param min_value: Lower bound to be exceeded.
"""
cond = col(data.column_name) > min_value
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))

Check warning on line 112 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L110-L112

Added lines #L110 - L112 were not covered by tests
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -96,15 +122,19 @@
acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE)
)
def greater_than_or_equal_to(
data: PysparkDataframeColumnObject, min_value: Any
) -> bool:
data: PysparkDataframeColumnObject,
min_value: Any,
) -> Union[bool, ps.Column]:
"""Ensure all values are greater or equal a certain value.
:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "column_name".
:param min_value: Allowed minimum value for values of a series. Must be
a type comparable to the dtype of the column datatype of pyspark
"""
cond = col(data.column_name) >= min_value
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))

Check warning on line 137 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L135-L137

Added lines #L135 - L137 were not covered by tests
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -116,7 +146,10 @@
@register_input_datatypes(
acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE)
)
def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool:
def less_than(

Check warning on line 149 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L149

Added line #L149 was not covered by tests
data: PysparkDataframeColumnObject,
max_value: Any,
) -> Union[bool, ps.Column]:
"""Ensure values of a series are strictly below a maximum value.

:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
Expand All @@ -128,6 +161,9 @@
if max_value is None: # pragma: no cover
raise ValueError("max_value must not be None")
cond = col(data.column_name) < max_value
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))

Check warning on line 166 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L164-L166

Added lines #L164 - L166 were not covered by tests
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -140,8 +176,9 @@
acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE)
)
def less_than_or_equal_to(
data: PysparkDataframeColumnObject, max_value: Any
) -> bool:
data: PysparkDataframeColumnObject,
max_value: Any,
) -> Union[bool, ps.Column]:
"""Ensure values of a series are strictly below a maximum value.

:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
Expand All @@ -153,6 +190,9 @@
if max_value is None: # pragma: no cover
raise ValueError("max_value must not be None")
cond = col(data.column_name) <= max_value
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))

Check warning on line 195 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L193-L195

Added lines #L193 - L195 were not covered by tests
return data.dataframe.filter(~cond).limit(1).count() == 0


Expand All @@ -170,7 +210,7 @@
max_value: T,
include_min: bool = True,
include_max: bool = True,
):
) -> Union[bool, ps.Column]:
"""Ensure all values of a column are within an interval.

Both endpoints must be a type comparable to the dtype of the
Expand Down Expand Up @@ -200,7 +240,11 @@
if include_max
else col(data.column_name) < max_value
)
return data.dataframe.filter(~(cond_right & cond_left)).limit(1).count() == 0 # type: ignore
cond = cond_right & cond_left
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0 # type: ignore

Check warning on line 247 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L243-L247

Added lines #L243 - L247 were not covered by tests


@register_builtin_check(
Expand All @@ -212,7 +256,10 @@
ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE
)
)
def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool:
def isin(

Check warning on line 259 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L259

Added line #L259 was not covered by tests
data: PysparkDataframeColumnObject,
allowed_values: Iterable,
) -> Union[bool, ps.Column]:
"""Ensure only allowed values occur within a series.

Remember it can be a compute intensive check on large dataset. So, use it with caution.
Expand All @@ -228,14 +275,11 @@
to access the dataframe is "dataframe" and column name using "column_name".
:param allowed_values: The set of allowed values. May be any iterable.
"""
return (
data.dataframe.filter(
~col(data.column_name).isin(list(allowed_values))
)
.limit(1)
.count()
== 0
)
cond = col(data.column_name).isin(list(allowed_values))
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(~cond).limit(1).count() == 0

Check warning on line 282 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L278-L282

Added lines #L278 - L282 were not covered by tests


@register_builtin_check(
Expand All @@ -248,8 +292,9 @@
)
)
def notin(
data: PysparkDataframeColumnObject, forbidden_values: Iterable
) -> bool:
data: PysparkDataframeColumnObject,
forbidden_values: Iterable,
) -> Union[bool, ps.Column]:
"""Ensure some defined values don't occur within a series.

Remember it can be a compute intensive check on large dataset. So, use it with caution.
Expand All @@ -264,22 +309,22 @@
:param forbidden_values: The set of values which should not occur. May
be any iterable.
"""
return (
data.dataframe.filter(
col(data.column_name).isin(list(forbidden_values))
)
.limit(1)
.count()
== 0
)
cond = col(data.column_name).isin(list(forbidden_values))
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))
return data.dataframe.filter(cond).limit(1).count() == 0

Check warning on line 316 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L312-L316

Added lines #L312 - L316 were not covered by tests


@register_builtin_check(
strategy=st.str_contains_strategy,
error="str_contains('{pattern}')",
)
@register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE))
def str_contains(data: PysparkDataframeColumnObject, pattern: str) -> bool:
def str_contains(

Check warning on line 324 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L324

Added line #L324 was not covered by tests
data: PysparkDataframeColumnObject,
pattern: re.Pattern,
) -> Union[bool, ps.Column]:
"""Ensure that a pattern can be found within each row.

Remember it can be a compute intensive check on large dataset. So, use it with caution.
Expand All @@ -288,6 +333,10 @@
to access the dataframe is "dataframe" and column name using "column_name".
:param pattern: Regular expression pattern to use for searching
"""
cond = col(data.column_name).rlike(pattern.pattern)
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))

Check warning on line 339 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L336-L339

Added lines #L336 - L339 were not covered by tests
return (
data.dataframe.filter(~col(data.column_name).rlike(pattern))
.limit(1)
Expand All @@ -300,7 +349,10 @@
error="str_startswith('{string}')",
)
@register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE))
def str_startswith(data: PysparkDataframeColumnObject, string: str) -> bool:
def str_startswith(

Check warning on line 352 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L352

Added line #L352 was not covered by tests
data: PysparkDataframeColumnObject,
string: str,
) -> bool:
"""Ensure that all values start with a certain string.

Remember it can be a compute intensive check on large dataset. So, use it with caution.
Expand All @@ -310,14 +362,20 @@
:param string: String all values should start with
"""
cond = col(data.column_name).startswith(string)
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))

Check warning on line 367 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L365-L367

Added lines #L365 - L367 were not covered by tests
return data.dataframe.filter(~cond).limit(1).count() == 0


@register_builtin_check(
strategy=st.str_endswith_strategy, error="str_endswith('{string}')"
)
@register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE))
def str_endswith(data: PysparkDataframeColumnObject, string: str) -> bool:
def str_endswith(

Check warning on line 375 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L375

Added line #L375 was not covered by tests
data: PysparkDataframeColumnObject,
string: str,
) -> bool:
"""Ensure that all values end with a certain string.

Remember it can be a compute intensive check on large dataset. So, use it with caution.
Expand All @@ -327,4 +385,7 @@
:param string: String all values should end with
"""
cond = col(data.column_name).endswith(string)
should_validate_full_table = get_full_table_validation()
if should_validate_full_table:
return data.dataframe.select(when(cond, True).otherwise(False))

Check warning on line 390 in pandera/backends/pyspark/builtin_checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/builtin_checks.py#L388-L390

Added lines #L388 - L390 were not covered by tests
return data.dataframe.filter(~cond).limit(1).count() == 0
22 changes: 22 additions & 0 deletions pandera/backends/pyspark/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""pyspark backend utilities."""

from pandera.config import get_config_context, get_config_global

Check warning on line 3 in pandera/backends/pyspark/utils.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/utils.py#L3

Added line #L3 was not covered by tests


def convert_to_list(*args):
"""Converts arguments to a list"""
Expand All @@ -11,3 +13,23 @@
converted_list.append(arg)

return converted_list


def get_full_table_validation():

Check warning on line 18 in pandera/backends/pyspark/utils.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/utils.py#L18

Added line #L18 was not covered by tests
"""
Get the full table validation configuration.
- By default, full table validation is disabled for pyspark dataframes for performance reasons.
"""
config_global = get_config_global()
config_ctx = get_config_context(full_table_validation_default=None)

Check warning on line 24 in pandera/backends/pyspark/utils.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/utils.py#L23-L24

Added lines #L23 - L24 were not covered by tests

if config_ctx.full_table_validation is not None:

Check warning on line 26 in pandera/backends/pyspark/utils.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/utils.py#L26

Added line #L26 was not covered by tests
# use context configuration if specified
return config_ctx.full_table_validation

Check warning on line 28 in pandera/backends/pyspark/utils.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/utils.py#L28

Added line #L28 was not covered by tests

if config_global.full_table_validation is not None:

Check warning on line 30 in pandera/backends/pyspark/utils.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/utils.py#L30

Added line #L30 was not covered by tests
# use global configuration if specified
return config_global.full_table_validation

Check warning on line 32 in pandera/backends/pyspark/utils.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/utils.py#L32

Added line #L32 was not covered by tests

# full table validation is disabled by default for pyspark dataframes
return False

Check warning on line 35 in pandera/backends/pyspark/utils.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/utils.py#L35

Added line #L35 was not covered by tests
Loading
Loading