diff --git a/metasyn/testutils.py b/metasyn/testutils.py index e07fe779..4a7e2dad 100644 --- a/metasyn/testutils.py +++ b/metasyn/testutils.py @@ -46,7 +46,7 @@ def check_distribution_provider(provider_name: str): def check_distribution(distribution: type[BaseDistribution], privacy: BasePrivacy, - provenance: str): + provenance: str, test_empty: bool=True): """Check whether the distributions in the package can be validated positively. Arguments @@ -57,6 +57,9 @@ def check_distribution(distribution: type[BaseDistribution], privacy: BasePrivac Level/type of privacy the distribution adheres to. provenance: Which provider/plugin/package provides the distribution. + test_empty: + If this is set to true, this will also check empty series and if the distribution + can fit them. Otherwise, ignore testing the distribution on empty series. """ # Check the schema of the distribution. schema = distribution.schema() @@ -86,9 +89,10 @@ def check_distribution(distribution: type[BaseDistribution], privacy: BasePrivac assert isinstance(new_dist, distribution) assert set(list(new_dist.to_dict())) >= set( ("implements", "provenance", "class_name", "parameters")) - empty_series = pl.Series([], dtype=series.dtype) - new_dist = distribution.fit(empty_series, **privacy.fit_kwargs) - assert isinstance(new_dist, distribution) + if test_empty: + empty_series = pl.Series([], dtype=series.dtype) + new_dist = distribution.fit(empty_series, **privacy.fit_kwargs) + assert isinstance(new_dist, distribution)