diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ce85303..a7569b0 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,7 +1,8 @@ UNRELEASED ---------- -* `#132 `__: Add documentation for specifying custom data directories. +* `#132 `__: Added documentation for specifying custom data directories. +* `#177 `__: Added new ``round_digits`` to ``data_regression.check``, which when given will round all float values to the given number of dicts (recursively) before saving the data to disk. 2.5.0 (2023-08-31) ------------------ diff --git a/src/pytest_regressions/common.py b/src/pytest_regressions/common.py index 30f118d..4992f11 100644 --- a/src/pytest_regressions/common.py +++ b/src/pytest_regressions/common.py @@ -3,7 +3,11 @@ from pathlib import Path from typing import Callable from typing import List +from typing import MutableMapping +from typing import MutableSequence from typing import Optional +from typing import TypeVar +from typing import Union import pytest @@ -188,3 +192,33 @@ def make_location_message(banner: str, filename: Path, aux_files: List[str]) -> else: dump_aux_fn(Path(obtained_filename)) raise + + +T = TypeVar("T", bound=Union[MutableSequence, MutableMapping]) + + +def round_digits_in_data(data: T, digits: int) -> T: + """ + Recursively round the values of any float value in a collection to the given number of digits. The rounding is done in-place. + + :param data: + The collection to round. + + :param digits: + The number of digits to round to. + + :return: + The collection with all float values rounded to the given precision. + Note that the rounding is done in-place, so this return value only exists + because we use the function recursively. + """ + # Change the generator depending on the collection type. + generator = enumerate(data) if isinstance(data, MutableSequence) else data.items() + for k, v in generator: + if isinstance(v, (MutableSequence, MutableMapping)): + data[k] = round_digits_in_data(v, digits) + elif isinstance(v, float): + data[k] = round(v, digits) + else: + data[k] = v + return data diff --git a/src/pytest_regressions/data_regression.py b/src/pytest_regressions/data_regression.py index 70a1d09..b86298e 100644 --- a/src/pytest_regressions/data_regression.py +++ b/src/pytest_regressions/data_regression.py @@ -11,6 +11,7 @@ from .common import check_text_files from .common import perform_regression_check +from .common import round_digits_in_data class DataRegressionFixture: @@ -32,6 +33,7 @@ def check( data_dict: Dict[str, Any], basename: Optional[str] = None, fullpath: Optional["os.PathLike[str]"] = None, + round_digits: Optional[int] = None, ) -> None: """ Checks the given dict against a previously recorded version, or generate a new file. @@ -46,13 +48,18 @@ def check( will ignore ``datadir`` fixture when reading *expected* files but will still use it to write *obtained* files. Useful if a reference file is located in the session data dir for example. + :param round_digits: + If given, round all floats in the dict to the given number of digits. + ``basename`` and ``fullpath`` are exclusive. """ __tracebackhide__ = True + if round_digits is not None: + round_digits_in_data(data_dict, round_digits) + def dump(filename: Path) -> None: """Dump dict contents to the given filename""" - import yaml dumped_str = yaml.dump_all( [data_dict], diff --git a/tests/test_data_regression.py b/tests/test_data_regression.py index c36e003..1fc5dc1 100644 --- a/tests/test_data_regression.py +++ b/tests/test_data_regression.py @@ -1,6 +1,7 @@ import sys from textwrap import dedent +import pytest import yaml from pytest_regressions.testing import check_regression_fixture_workflow @@ -38,6 +39,24 @@ def dump_scalar(dumper, scalar): data_regression.check(contents) +def test_round_digits(data_regression): + """Example including float numbers and check rounding capabilities.""" + contents = { + "content": {"value1": "toto", "value": 1.123456789}, + "values": [1.12345, 2.34567], + "value": 1.23456789, + } + data_regression.check(contents, round_digits=2) + + with pytest.raises(AssertionError): + contents = { + "content": {"value1": "toto", "value": 1.2345678}, + "values": [1.13456, 2.45678], + "value": 1.23456789, + } + data_regression.check(contents, round_digits=2) + + def test_usage_workflow(pytester, monkeypatch): monkeypatch.setattr( sys, "testing_get_data", lambda: {"contents": "Foo", "value": 10}, raising=False diff --git a/tests/test_data_regression/test_round_digits.yml b/tests/test_data_regression/test_round_digits.yml new file mode 100644 index 0000000..54cc585 --- /dev/null +++ b/tests/test_data_regression/test_round_digits.yml @@ -0,0 +1,7 @@ +content: + value: 1.12 + value1: toto +value: 1.23 +values: +- 1.12 +- 2.35