From e08475bf0c0db37ab7fc5a1138a979a712694307 Mon Sep 17 00:00:00 2001 From: Arjun Dinesh Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Sun, 29 Jun 2025 13:18:58 +0530 Subject: [PATCH] fix(iterable): ensure MappedExamplesIterable supports state_dict for resume MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #7630 ### Problem When calling `.map()` on an `IterableDataset`, resuming from a checkpoint skips a large number of samples. This is because `MappedExamplesIterable` did not implement `state_dict()` or `load_state_dict()`, so checkpointing was not properly delegated to the underlying iterable. ### What This PR Does This patch adds: ```python def state_dict(self): return self.ex_iterable.state_dict() def load_state_dict(self, state): self.ex_iterable.load_state_dict(state) to MappedExamplesIterable, so the wrapped base iterable's state can be saved and restored as expected. Result Using .map() no longer causes sample skipping after checkpoint resume. Tested manually using the reproducer script in #7630. The resumed iterator now yields the correct samples. Let me know if a dedicated test case is required — happy to add one! --- src/datasets/iterable_dataset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 1550b8f2db2..13cbc56d15e 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1097,6 +1097,12 @@ def _init_state_dict(self) -> dict: "type": self.__class__.__name__, } return self._state_dict + + def state_dict(self) -> dict: + return self.ex_iterable.state_dict() + + def load_state_dict(self, state: dict) -> None: + self.ex_iterable.load_state_dict(state) def __iter__(self): if self.formatting and self.formatting.is_table: