Skip to content

Commit

Permalink
update cache dataframe config args, fix tests (unionai-oss#1437)
Browse files Browse the repository at this point in the history
This PR renames the pandera config arguments introduced in this PR:
unionai-oss#1414 and makes the names
more generic. Fixes tests that were broken by the config changes.

Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy authored Dec 5, 2023
1 parent 81bab7d commit 7a20c7a
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 24 deletions.
4 changes: 2 additions & 2 deletions pandera/backends/pyspark/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _wrapper(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# Skip if not enabled
if CONFIG.pyspark_cache is not True:
if CONFIG.cache_dataframe is not True:
return func(self, *args, **kwargs)

check_obj: DataFrame = None
Expand Down Expand Up @@ -186,7 +186,7 @@ def cached_check_obj():

yield # Execute the decorated function

if not CONFIG.pyspark_keep_cache:
if not CONFIG.keep_cached_dataframe:
# If not cached, `.unpersist()` does nothing
logger.debug("Unpersisting dataframe...")
check_obj.unpersist()
Expand Down
8 changes: 4 additions & 4 deletions pandera/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class PanderaConfig(BaseModel):

validation_enabled: bool = True
validation_depth: ValidationDepth = ValidationDepth.SCHEMA_AND_DATA
pyspark_cache: bool = False
pyspark_keep_cache: bool = False
cache_dataframe: bool = False
keep_cached_dataframe: bool = False


# this config variable should be accessible globally
Expand All @@ -39,11 +39,11 @@ class PanderaConfig(BaseModel):
validation_depth=os.environ.get(
"PANDERA_VALIDATION_DEPTH", ValidationDepth.SCHEMA_AND_DATA
),
pyspark_cache=os.environ.get(
cache_dataframe=os.environ.get(
"PANDERA_CACHE_DATAFRAME",
False,
),
pyspark_keep_cache=os.environ.get(
keep_cached_dataframe=os.environ.get(
"PANDERA_KEEP_CACHED_DATAFRAME",
False,
),
Expand Down
4 changes: 4 additions & 0 deletions tests/core/test_pandas_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class TestSchema(DataFrameModel):
price_val: int = pa.Field()

expected = {
"cache_dataframe": False,
"keep_cached_dataframe": False,
"validation_enabled": False,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
}
Expand All @@ -61,6 +63,8 @@ class TestPandasSeriesConfig:
def test_disable_validation(self, disable_validation):
"""This function validates that a none object is loaded if validation is disabled"""
expected = {
"cache_dataframe": False,
"keep_cached_dataframe": False,
"validation_enabled": False,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
}
Expand Down
26 changes: 13 additions & 13 deletions tests/pyspark/test_pyspark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class TestSchema(DataFrameModel):
expected = {
"validation_enabled": False,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": False,
"pyspark_keep_cache": False,
"cache_dataframe": False,
"keep_cached_dataframe": False,
}

assert CONFIG.dict() == expected
Expand All @@ -66,8 +66,8 @@ def test_schema_only(self, spark, sample_spark_schema):
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_ONLY,
"pyspark_cache": False,
"pyspark_keep_cache": False,
"cache_dataframe": False,
"keep_cached_dataframe": False,
}
assert CONFIG.dict() == expected

Expand Down Expand Up @@ -146,8 +146,8 @@ def test_data_only(self, spark, sample_spark_schema):
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.DATA_ONLY,
"pyspark_cache": False,
"pyspark_keep_cache": False,
"cache_dataframe": False,
"keep_cached_dataframe": False,
}
assert CONFIG.dict() == expected

Expand Down Expand Up @@ -233,8 +233,8 @@ def test_schema_and_data(self, spark, sample_spark_schema):
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": False,
"pyspark_keep_cache": False,
"cache_dataframe": False,
"keep_cached_dataframe": False,
}
assert CONFIG.dict() == expected

Expand Down Expand Up @@ -339,21 +339,21 @@ class TestSchema(DataFrameModel):
@pytest.mark.parametrize("cache_enabled", [True, False])
@pytest.mark.parametrize("keep_cache_enabled", [True, False])
# pylint:disable=too-many-locals
def test_pyspark_cache_settings(
def test_cache_dataframe_settings(
self,
cache_enabled,
keep_cache_enabled,
):
"""This function validates setters and getters for cache/keep_cache options."""
# Set expected properties in Config object
CONFIG.pyspark_cache = cache_enabled
CONFIG.pyspark_keep_cache = keep_cache_enabled
CONFIG.cache_dataframe = cache_enabled
CONFIG.keep_cached_dataframe = keep_cache_enabled

# Evaluate expected Config
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": cache_enabled,
"pyspark_keep_cache": keep_cache_enabled,
"cache_dataframe": cache_enabled,
"keep_cached_dataframe": keep_cache_enabled,
}
assert CONFIG.dict() == expected
10 changes: 5 additions & 5 deletions tests/pyspark/test_pyspark_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class TestPanderaDecorators:

sample_data = [("Bread", 9), ("Cutter", 15)]

def test_pyspark_cache_requirements(self, spark, sample_spark_schema):
def test_cache_dataframe_requirements(self, spark, sample_spark_schema):
"""Validates if decorator can only be applied in a proper function."""
# Set expected properties in Config object
CONFIG.pyspark_cache = True
CONFIG.cache_dataframe = True
input_df = spark_df(spark, self.sample_data, sample_spark_schema)

class FakeDataFrameSchemaBackend:
Expand Down Expand Up @@ -74,7 +74,7 @@ def func_wo_check_obj(self, message: str):
)

# pylint:disable=too-many-locals
def test_pyspark_cache_settings(
def test_cache_dataframe_settings(
self,
spark,
sample_spark_schema,
Expand All @@ -86,8 +86,8 @@ def test_pyspark_cache_settings(
):
"""This function validates that caching/unpersisting works as expected."""
# Set expected properties in Config object
CONFIG.pyspark_cache = cache_enabled
CONFIG.pyspark_keep_cache = keep_cache_enabled
CONFIG.cache_dataframe = cache_enabled
CONFIG.keep_cached_dataframe = keep_cache_enabled

# Prepare test data
input_df = spark_df(spark, self.sample_data, sample_spark_schema)
Expand Down

0 comments on commit 7a20c7a

Please sign in to comment.