From e9b8178d46e3915a8e21db0714281af09d38ae6e Mon Sep 17 00:00:00 2001 From: Andreas Poehlmann Date: Fri, 29 Mar 2024 17:52:03 +0100 Subject: [PATCH 1/4] tests: add test failing for non roundtrippable sequence --- tests/unittests/core/test_apply_func.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/unittests/core/test_apply_func.py b/tests/unittests/core/test_apply_func.py index 8c7de355..a4d7e016 100644 --- a/tests/unittests/core/test_apply_func.py +++ b/tests/unittests/core/test_apply_func.py @@ -359,3 +359,13 @@ class Foo: foo = Foo(0) result = apply_to_collection(foo, int, lambda x: x + 1, allow_frozen=True) assert foo == result + + +def test_apply_to_collection_non_roundtrippable_sequence(): + class NonRoundtrippableSequence(list): + def __init__(self, x: int): + super().__init__(range(int(x))) + + val = NonRoundtrippableSequence(3) + result = apply_to_collection(val, int, lambda x: x + 1) + assert val == result From 42b0ea2f8800912eaf68f25ddce4c60634d4b307 Mon Sep 17 00:00:00 2001 From: Andreas Poehlmann Date: Fri, 29 Mar 2024 17:55:02 +0100 Subject: [PATCH 2/4] lightning_utilities.core.apply_func: exclude sequences if not roundtrippable --- src/lightning_utilities/core/apply_func.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/lightning_utilities/core/apply_func.py b/src/lightning_utilities/core/apply_func.py index 5336e29f..8169d311 100644 --- a/src/lightning_utilities/core/apply_func.py +++ b/src/lightning_utilities/core/apply_func.py @@ -20,6 +20,14 @@ def is_dataclass_instance(obj: object) -> bool: return dataclasses.is_dataclass(obj) and not isinstance(obj, type) +def can_roundtrip_sequence(obj: Sequence) -> bool: + """Check if sequence can be roundtripped.""" + try: + return obj == type(obj)(list(obj)) + except (TypeError, ValueError): + return False + + def apply_to_collection( data: Any, dtype: Union[type, Any, Tuple[Union[type, Any]]], @@ -118,7 +126,7 @@ def _apply_to_collection_slow( return elem_type(OrderedDict(out)) is_namedtuple_ = is_namedtuple(data) - is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + is_sequence = isinstance(data, Sequence) and not isinstance(data, str) and can_roundtrip_sequence(data) if is_namedtuple_ or is_sequence: out = [] for d in data: From 4c14e623a057e78128139f1917074180d7aa1b6e Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 29 Mar 2024 23:20:49 +0100 Subject: [PATCH 3/4] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6be22892..b70a8afe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Apply func non round-trippable seq ([#250](https://github.com/Lightning-AI/utilities/pull/250)) ### Changed From 2d58e0e32782c3bb201736047059ef8b1d690323 Mon Sep 17 00:00:00 2001 From: Andreas Poehlmann Date: Fri, 29 Mar 2024 23:43:42 +0100 Subject: [PATCH 4/4] lightning_utilities.core.apply_func: fix typing issue --- src/lightning_utilities/core/apply_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_utilities/core/apply_func.py b/src/lightning_utilities/core/apply_func.py index 8169d311..ca8a08fa 100644 --- a/src/lightning_utilities/core/apply_func.py +++ b/src/lightning_utilities/core/apply_func.py @@ -23,7 +23,7 @@ def is_dataclass_instance(obj: object) -> bool: def can_roundtrip_sequence(obj: Sequence) -> bool: """Check if sequence can be roundtripped.""" try: - return obj == type(obj)(list(obj)) + return obj == type(obj)(list(obj)) # type: ignore[call-arg] except (TypeError, ValueError): return False