-
Notifications
You must be signed in to change notification settings - Fork 39
make checkpointing thread safe #245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
5d2f3e4
to
dd0f5fc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there some other way we can handle this?
Two main concerns:
- cloning has significant memory impact
- send_checkpoint may race with the first inner optimizer step so cloining in _to_cpu may not even be safe
I thought with the DiLoCo implementation we had a second copy of the weights to compute the pseudo gradient. Can we not reuse those for the state_dict transfer?
out.append(v) | ||
if isinstance(v, DTensor): | ||
clone = distribute_tensor( | ||
v.to_local().clone(), v.device_mesh, v.placements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is v.clone()
not sufficient? What does that do? Do we also need special logic for cuda DTensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is v.clone() not sufficient?
tried it and didn't work. it made an empty dtesnor. took a long time to figure how to clone dtensors
@@ -278,7 +279,13 @@ def _to_cpu(values: List[T], pin_memory: bool) -> List[T]: | |||
else: | |||
out.append(v.cpu()) | |||
else: | |||
out.append(v) | |||
if isinstance(v, DTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we rename _to_cpu
to _clone_cpu
?
the model transfer isn't even controlled by the wrapper in general
thought about blocking everything until checkpoint transfer is complete but that's probably also complicated since the node may never want to transfer checkpoint or use locks, the inner step locks the model and we lock while actually transferring the state dict |
0b91490
to
49a161f
Compare
@d4l3k for the gpu case, we're clone the tensor into cpu memory anyway. thing is we don't control at what inner step the checkpoint will be sent so the regression test ends up being non deterministic. it changes what model parameters are used for syncing, even if we make it thread safe also on thread safety, the cloning seems to be a blocking call in post step hook, so it shouldn't race with the inner optimizer step? i added some locking anyway |
636a86a
to
2b7defd
Compare
cc0a37a
to
76f2f60
Compare
2388879
to
fc1fa08
Compare
f99b8dd
to
acd9ede
Compare
"user": {key: value() for key, value in self._user_state_dicts.items()}, | ||
"torchft": self.state_dict(), | ||
} | ||
with self._state_dict_lock.r_lock(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already have allow_checkpoint
and disallow_checkpoint
in HTTPTransport -- can we reuse those instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- that also requires copying over state dict to the http transport
- or keeping some tracking on which step was transferred last to the http transport
with a separate lock, we can decouple checkpoint specific logic with training logic
0283856
to
1f5854d
Compare
800c48f
to
577bacd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable to me -- just want to add a unit test specifically on the lock
@@ -324,6 +328,21 @@ def __init__( | |||
# first step is 1 | |||
self._participating_replica_rank: Optional[int] = None | |||
self._participating_replica_world_size: int = 0 | |||
self._is_state_dict_read_allowed = True | |||
|
|||
def allow_state_dict_read(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add specific tests for these methods? Lock logic is pretty risk prone so would be nice to have unit test coverage for these
495ab9a
to
595e7e9
Compare
Summary: - the checkpointing wasn't thread safe for http transport so use lock the model in the pre step hook and when we want to transfer the checkpoint
Summary: