diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 1550b8f2db2..13cbc56d15e 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1097,6 +1097,12 @@ def _init_state_dict(self) -> dict: "type": self.__class__.__name__, } return self._state_dict + + def state_dict(self) -> dict: + return self.ex_iterable.state_dict() + + def load_state_dict(self, state: dict) -> None: + self.ex_iterable.load_state_dict(state) def __iter__(self): if self.formatting and self.formatting.is_table: