Skip to content

Commit

Permalink
Fix strings tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pprados authored and cosmicBboy committed Jan 27, 2023
1 parent ae2853b commit 2873068
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 75 deletions.
12 changes: 12 additions & 0 deletions pandera/core/pandas/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ def str_matches(
:param pattern: Regular expression pattern to use for matching
:param kwargs: key-word arguments passed into the `Check` initializer.
"""
if data.__module__.startswith("cudf"):
# This should be in its own backend implementation
return data.str.match(cast(str, pattern))
return data.str.match(cast(str, pattern), na=False)


Expand All @@ -317,6 +320,9 @@ def str_contains(
:param pattern: Regular expression pattern to use for searching
:param kwargs: key-word arguments passed into the `Check` initializer.
"""
if data.__module__.startswith("cudf"):
# This should be in its own backend implementation
return data.str.contains(cast(str, pattern))
return data.str.contains(cast(str, pattern), na=False)


Expand All @@ -330,6 +336,9 @@ def str_startswith(data: PandasData, string: str) -> PandasData:
:param string: String all values should start with
:param kwargs: key-word arguments passed into the `Check` initializer.
"""
if data.__module__.startswith("cudf"):
# This should be in its own backend implementation
return data.str.startswith(string)
return data.str.startswith(string, na=False)


Expand All @@ -342,6 +351,9 @@ def str_endswith(data: PandasData, string: str) -> PandasData:
:param string: String all values should end with
:param kwargs: key-word arguments passed into the `Check` initializer.
"""
if data.__module__.startswith("cudf"):
# This should be in its own backend implementation
return data.str.endswith(string, na=False)
return data.str.endswith(string, na=False)


Expand Down
3 changes: 1 addition & 2 deletions pandera/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,7 @@ def _parse_schema_errors(schema_errors: List[Dict[str, Any]]):
]

elif any(
type(x).__module__.startswith("cudf")
for x in check_failure_cases
type(x).__module__.startswith("cudf") for x in check_failure_cases
):
# pylint: disable=import-outside-toplevel
# The current version of cudf is not compatible with sort_values() of strings.
Expand Down
11 changes: 5 additions & 6 deletions pandera/typing/cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,13 @@
except ImportError:
ModelField = Any # type: ignore


# pylint:disable=too-few-public-methods
class Index(IndexBase, cudf.Index, Generic[GenericDtype]):
"""Representation of pandas.Index, only used for type annotation.
*new in 0.5.0*
"""


# pylint:disable=too-few-public-methods
class Series(SeriesBase, cudf.Series, Generic[GenericDtype]): # type: ignore
"""Representation of pandas.Series, only used for type annotation.
Expand All @@ -40,21 +38,20 @@ class Series(SeriesBase, cudf.Series, Generic[GenericDtype]): # type: ignore
"""

if hasattr(pd.Series, "__class_getitem__") and _GenericAlias:

def __class_getitem__(cls, item):
"""Define this to override the patch that pyspark.pandas performs on pandas.
https://github.com/apache/spark/blob/master/python/pyspark/pandas/__init__.py#L124-L144
"""
_type_check(item, "Parameters to generic types must be types.")
return _GenericAlias(cls, item)


# pylint:disable=invalid-name
if TYPE_CHECKING:
T = TypeVar("T") # pragma: no cover
else:
T = Schema


# pylint:disable=too-few-public-methods
class DataFrame(DataFrameBase, cudf.DataFrame, Generic[T]):
"""
Expand All @@ -64,6 +61,7 @@ class DataFrame(DataFrameBase, cudf.DataFrame, Generic[T]):
"""

if hasattr(pd.DataFrame, "__class_getitem__") and _GenericAlias:

def __class_getitem__(cls, item):
"""Define this to override the patch that pyspark.pandas performs on pandas.
https://github.com/apache/spark/blob/master/python/pyspark/pandas/__init__.py#L124-L144
Expand Down Expand Up @@ -162,7 +160,9 @@ def _get_schema(cls, field: ModelField):
return schema_model, schema

@classmethod
def pydantic_validate(cls, obj: Any, field: ModelField) -> pd.DataFrame:
def pydantic_validate(
cls, obj: Any, field: ModelField
) -> pd.DataFrame:
"""
Verify that the input can be converted into a pandas dataframe that
meets all schema requirements.
Expand All @@ -177,7 +177,6 @@ def pydantic_validate(cls, obj: Any, field: ModelField) -> pd.DataFrame:

return cls.to_format(valid_data, schema_model.__config__)


CUDF_INSTALLED = True
except ImportError:
CUDF_INSTALLED = False
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import os

# pylint: disable=unused-import
from tests.core.checks_fixtures import custom_check_teardown

try:
# pylint: disable=unused-import
import hypothesis # noqa F401
Expand Down
9 changes: 0 additions & 9 deletions tests/cudf/conftest.py

This file was deleted.

132 changes: 75 additions & 57 deletions tests/cudf/test_schemas_on_cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
import pandera as pa
from pandera import extensions
from pandera.engines import numpy_engine, pandas_engine
from pandera.typing.modin import DataFrame, Index, Series, modin_version
from pandera.typing.modin import DataFrame, Index, Series
from tests.strategies.test_strategies import NULLABLE_DTYPES
from tests.strategies.test_strategies import (
SUPPORTED_DTYPES as SUPPORTED_STRATEGY_DTYPES,
)
from tests.strategies.test_strategies import (
UNSUPPORTED_DTYPE_CLS as UNSUPPORTED_STRATEGY_DTYPE_CLS,
)
Expand All @@ -26,26 +23,10 @@
hypothesis = MagicMock()
st = MagicMock()


UNSUPPORTED_STRATEGY_DTYPE_CLS = set(UNSUPPORTED_STRATEGY_DTYPE_CLS)
UNSUPPORTED_STRATEGY_DTYPE_CLS.add(numpy_engine.Object)

TEST_DTYPES_ON_CUDF = []
# pylint: disable=redefined-outer-name
# for dtype_cls in pandas_engine.Engine.get_registered_dtypes():
# if (
# dtype_cls in UNSUPPORTED_STRATEGY_DTYPE_CLS
# or (
# pandas_engine.Engine.dtype(dtype_cls)
# not in SUPPORTED_STRATEGY_DTYPES
# )
# or not (
# pandas_engine.GEOPANDAS_INSTALLED
# and dtype_cls == pandas_engine.Geometry
# )
# ):
# continue
# TEST_DTYPES_ON_CUDF.append(pandas_engine.Engine.dtype(dtype_cls))
TEST_DTYPES_ON_CUDF: typing.List[str] = []


@pytest.mark.parametrize("coerce", [True, False])
Expand All @@ -55,15 +36,16 @@ def test_dataframe_schema_case(coerce):
{
"int_column": pa.Column(int, pa.Check.ge(0)),
"float_column": pa.Column(float, pa.Check.le(0)),
# cudf not implemented "str_column": pa.Column(str, pa.Check.isin(list("abcde"))),
# not implemented in cudf 22.08.00
# "str_column": pa.Column(str, pa.Check.isin(list("abcde"))),
},
coerce=coerce,
)
cdf = cudf.DataFrame(
{
"int_column": range(10),
"float_column": [float(-x) for x in range(10)],
# cudf not implemented "str_column": list("aabbcceedd"),
# "str_column": list("aabbcceedd"), # not implemented in cudf 22.08.00
}
)
assert isinstance(schema.validate(cdf), cudf.DataFrame)
Expand Down Expand Up @@ -126,12 +108,14 @@ def test_field_schema_dtypes(
int,
float,
bool,
# str,
# pandas_engine.DateTime,
# str, # not implemented in cudf 22.08.00
# pandas_engine.DateTime, # not implemented in cudf 22.08.00
],
)
@pytest.mark.parametrize("coerce", [True, False])
@pytest.mark.parametrize("schema_cls", [pa.Index])
@pytest.mark.parametrize(
"schema_cls", [pa.Index]
) # Multiindex not implemented in cudf 22.08.00
@hypothesis.given(st.data())
def test_index_dtypes(
dtype: pandas_engine.DataType,
Expand Down Expand Up @@ -197,6 +181,39 @@ def test_nullable(
nonnullable_schema(ks_null_sample)


# def test_unique(): # cudf 22.08.00 not implemented `df.duplicated()`
# """Test uniqueness checks on modin dataframes."""
# schema = pa.DataFrameSchema({"field": pa.Column(int)}, unique=["field"])
# column_schema = pa.Column(int, unique=True, name="field")
# series_schema = pa.SeriesSchema(int, unique=True, name="field")
#
# data_unique = cudf.DataFrame({"field": [1, 2, 3]})
# data_non_unique = cudf.DataFrame({"field": [1, 1, 1]})
#
# assert isinstance(schema(data_unique), cudf.DataFrame)
# assert isinstance(column_schema(data_unique), cudf.DataFrame)
# assert isinstance(series_schema(data_unique["field"]), cudf.Series)
#
# with pytest.raises(pa.errors.SchemaError, match="columns .+ not unique"):
# schema(data_non_unique)
# with pytest.raises(
# pa.errors.SchemaError, match="series .+ contains duplicate values"
# ):
# column_schema(data_non_unique)
# with pytest.raises(
# pa.errors.SchemaError, match="series .+ contains duplicate values"
# ):
# series_schema(data_non_unique["field"])
#
# schema.unique = None
# column_schema.unique = False
# series_schema.unique = False
#
# assert isinstance(schema(data_non_unique), mpd.DataFrame)
# assert isinstance(column_schema(data_non_unique), mpd.DataFrame)
# assert isinstance(series_schema(data_non_unique["field"]), mpd.Series)


def test_required_column():
"""Test the required column raises error."""
required_schema = pa.DataFrameSchema(
Expand All @@ -214,7 +231,9 @@ def test_required_column():
schema(cudf.DataFrame({"another_field": [1, 2, 3]}))


@pytest.mark.parametrize("from_dtype", [bool, float, int])
@pytest.mark.parametrize(
"from_dtype", [bool, float, int]
) # str not implemented in cudf 22.08.00
@pytest.mark.parametrize("to_dtype", [float, int, str, bool])
@hypothesis.given(st.data())
def test_dtype_coercion(from_dtype, to_dtype, data):
Expand Down Expand Up @@ -265,26 +284,26 @@ def test_strict_schema():
def test_custom_checks(custom_check_teardown):
"""Test that custom checks can be executed."""

# @extensions.register_check_method(statistics=["value"])
# def cudf_eq(cudf_obj, *, value): # PPR
# return cudf_obj == value
#
# custom_schema = pa.DataFrameSchema(
# {"field": pa.Column(checks=pa.Check(lambda s: s == 0, name="custom"))}
# )
#
# custom_registered_schema = pa.DataFrameSchema(
# {"field": pa.Column(checks=pa.Check.cudf_eq(0))}
# )
#
# for schema in (custom_schema, custom_registered_schema):
# schema(cudf.DataFrame({"field": [0] * 100}))
#
# try:
# schema(cudf.DataFrame({"field": [-1] * 100}))
# except pa.errors.SchemaError as err:
# assert (err.failure_cases["failure_case"] == -1).all()
pass
@extensions.register_check_method(statistics=["value"])
def cudf_eq(cudf_obj, *, value):
return cudf_obj == value

custom_schema = pa.DataFrameSchema(
{"field": pa.Column(checks=pa.Check(lambda s: s == 0, name="custom"))}
)

custom_registered_schema = pa.DataFrameSchema(
{"field": pa.Column(checks=pa.Check.cudf_eq(0))}
)

for schema in (custom_schema, custom_registered_schema):
schema(cudf.DataFrame({"field": [0] * 100}))

try:
schema(cudf.DataFrame({"field": [-1] * 100}))
except pa.errors.SchemaError as err:
assert (err.failure_cases["failure_case"] == -1).all()


def test_schema_model():
# pylint: disable=missing-class-docstring
Expand All @@ -300,14 +319,14 @@ class Schema(pa.SchemaModel):
{
"int_field": [1, 2, 3],
"float_field": [-1.1, -2.1, -3.1],
# "in_field": [1, 2, 3],
# "str_field": ["a", "b", "c"], # not implemented in cudf 22.08.00
}
)
invalid_df = cudf.DataFrame(
{
"int_field": [-1],
"field_field": [1.0],
# "in_field": [4],
# "str_field": ["d"], # not implemented in cudf 22.08.00
}
)

Expand All @@ -332,14 +351,13 @@ class Schema(pa.SchemaModel):
[pa.Check.lt(0), -1, 0],
[pa.Check.le(0), 0, 1],
[pa.Check.in_range(0, 10), 5, -1],
# FIXME: a valider
# [pa.Check.isin(["a"]), "a", "b"],
# [pa.Check.notin(["a"]), "b", "a"],
# [pa.Check.str_matches("^a$"), "a", "b"],
# [pa.Check.str_contains("a"), "faa", "foo"],
# [pa.Check.str_startswith("a"), "ab", "ba"],
# [pa.Check.str_endswith("a"), "ba", "ab"],
# [pa.Check.str_length(1, 2), "a", ""],
# [pa.Check.isin(["a"]), "a", "b"], # Not impleted by cudf
# [pa.Check.notin(["a"]), "b", "a"], # Not impleted by cudf
[pa.Check.str_matches("^a$"), "a", "b"],
[pa.Check.str_contains("a"), "faa", "foo"],
[pa.Check.str_startswith("a"), "ab", "ba"],
[pa.Check.str_endswith("a"), "ba", "ab"],
[pa.Check.str_length(1, 2), "a", ""],
],
)
def test_check_comparison_operators(check, valid, invalid):
Expand Down
2 changes: 1 addition & 1 deletion tests/modin/test_schemas_on_modin.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_required_column():
schema(mpd.DataFrame({"another_field": [1, 2, 3]}))


@pytest.mark.parametrize("from_dtype", [str])
@pytest.mark.parametrize("from_dtype", [bool, float, int, str])
@pytest.mark.parametrize("to_dtype", [float, int, str, bool])
@hypothesis.given(st.data())
def test_dtype_coercion(from_dtype, to_dtype, data):
Expand Down

0 comments on commit 2873068

Please sign in to comment.