-
Notifications
You must be signed in to change notification settings - Fork 39
make checkpointing thread safe #245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,7 @@ | |
|
||
from torchft._torchft import ManagerClient, ManagerServer | ||
from torchft.checkpointing import CheckpointTransport, HTTPTransport | ||
from torchft.checkpointing._rwlock import RWLock | ||
from torchft.futures import future_timeout | ||
from torchft.work import _DummyWork | ||
|
||
|
@@ -216,6 +217,9 @@ def __init__( | |
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {} | ||
self._user_state_dicts: Dict[str, Callable[[], object]] = {} | ||
|
||
# Protects state dict | ||
self._state_dict_lock = RWLock(timeout=timeout.total_seconds()) | ||
|
||
if load_state_dict and state_dict: | ||
self.register_state_dict_fn("default", load_state_dict, state_dict) | ||
|
||
|
@@ -324,6 +328,21 @@ def __init__( | |
# first step is 1 | ||
self._participating_replica_rank: Optional[int] = None | ||
self._participating_replica_world_size: int = 0 | ||
self._is_state_dict_read_allowed = True | ||
|
||
def allow_state_dict_read(self) -> None: | ||
if self._is_state_dict_read_allowed: | ||
return | ||
|
||
self._is_state_dict_read_allowed = True | ||
self._state_dict_lock.w_release() | ||
|
||
def disallow_state_dict_read(self) -> None: | ||
if not self._is_state_dict_read_allowed: | ||
return | ||
|
||
self._is_state_dict_read_allowed = False | ||
self._state_dict_lock.w_acquire() | ||
|
||
def register_state_dict_fn( | ||
self, | ||
|
@@ -806,11 +825,14 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: | |
self._batches_committed = state_dict["batches_committed"] | ||
|
||
def _manager_state_dict(self) -> Dict[str, object]: | ||
assert len(self._user_state_dicts) > 0, "user state_dict is not initialized." | ||
return { | ||
"user": {key: value() for key, value in self._user_state_dicts.items()}, | ||
"torchft": self.state_dict(), | ||
} | ||
with self._state_dict_lock.r_lock(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we already have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
with a separate lock, we can decouple checkpoint specific logic with training logic |
||
assert ( | ||
len(self._user_state_dicts) > 0 | ||
), "user state_dict is not initialized." | ||
return { | ||
"user": {key: value() for key, value in self._user_state_dicts.items()}, | ||
"torchft": self.state_dict(), | ||
} | ||
|
||
def state_dict(self) -> Dict[str, int]: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add specific tests for these methods? Lock logic is pretty risk prone so would be nice to have unit test coverage for these