diff --git a/CHANGELOG.md b/CHANGELOG.md index 6be22892..b70a8afe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Apply func non round-trippable seq ([#250](https://github.com/Lightning-AI/utilities/pull/250)) ### Changed diff --git a/src/lightning_utilities/core/apply_func.py b/src/lightning_utilities/core/apply_func.py index 5336e29f..ca8a08fa 100644 --- a/src/lightning_utilities/core/apply_func.py +++ b/src/lightning_utilities/core/apply_func.py @@ -20,6 +20,14 @@ def is_dataclass_instance(obj: object) -> bool: return dataclasses.is_dataclass(obj) and not isinstance(obj, type) +def can_roundtrip_sequence(obj: Sequence) -> bool: + """Check if sequence can be roundtripped.""" + try: + return obj == type(obj)(list(obj)) # type: ignore[call-arg] + except (TypeError, ValueError): + return False + + def apply_to_collection( data: Any, dtype: Union[type, Any, Tuple[Union[type, Any]]], @@ -118,7 +126,7 @@ def _apply_to_collection_slow( return elem_type(OrderedDict(out)) is_namedtuple_ = is_namedtuple(data) - is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + is_sequence = isinstance(data, Sequence) and not isinstance(data, str) and can_roundtrip_sequence(data) if is_namedtuple_ or is_sequence: out = [] for d in data: diff --git a/tests/unittests/core/test_apply_func.py b/tests/unittests/core/test_apply_func.py index 8c7de355..a4d7e016 100644 --- a/tests/unittests/core/test_apply_func.py +++ b/tests/unittests/core/test_apply_func.py @@ -359,3 +359,13 @@ class Foo: foo = Foo(0) result = apply_to_collection(foo, int, lambda x: x + 1, allow_frozen=True) assert foo == result + + +def test_apply_to_collection_non_roundtrippable_sequence(): + class NonRoundtrippableSequence(list): + def __init__(self, x: int): + super().__init__(range(int(x))) + + val = NonRoundtrippableSequence(3) + result = apply_to_collection(val, int, lambda x: x + 1) + assert val == result