Skip to content

make checkpointing thread safe #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2025
Merged

Conversation

tushar00jain
Copy link
Contributor

@tushar00jain tushar00jain commented Jul 26, 2025

Summary:

  • the checkpointing wasn't thread safe for http transport so use lock the model in the pre step hook and when we want to transfer the checkpoint

Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some other way we can handle this?

Two main concerns:

  • cloning has significant memory impact
  • send_checkpoint may race with the first inner optimizer step so cloining in _to_cpu may not even be safe

I thought with the DiLoCo implementation we had a second copy of the weights to compute the pseudo gradient. Can we not reuse those for the state_dict transfer?

out.append(v)
if isinstance(v, DTensor):
clone = distribute_tensor(
v.to_local().clone(), v.device_mesh, v.placements
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is v.clone() not sufficient? What does that do? Do we also need special logic for cuda DTensors?

Copy link
Contributor Author

@tushar00jain tushar00jain Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is v.clone() not sufficient?

tried it and didn't work. it made an empty dtesnor. took a long time to figure how to clone dtensors

@@ -278,7 +279,13 @@ def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
else:
out.append(v.cpu())
else:
out.append(v)
if isinstance(v, DTensor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename _to_cpu to _clone_cpu?

@tushar00jain
Copy link
Contributor Author

tushar00jain commented Jul 29, 2025

@d4l3k

I thought with the DiLoCo implementation we had a second copy of the weights to compute the pseudo gradient. Can we not reuse those for the state_dict transfer?

the model transfer isn't even controlled by the wrapper in general

Is there some other way we can handle this?

thought about blocking everything until checkpoint transfer is complete but that's probably also complicated since the node may never want to transfer checkpoint

or use locks, the inner step locks the model and we lock while actually transferring the state dict

@tushar00jain tushar00jain force-pushed the pr245 branch 3 times, most recently from 0b91490 to 49a161f Compare July 29, 2025 03:19
@tushar00jain
Copy link
Contributor Author

tushar00jain commented Jul 29, 2025

@d4l3k for the gpu case, we're clone the tensor into cpu memory anyway. thing is we don't control at what inner step the checkpoint will be sent so the regression test ends up being non deterministic. it changes what model parameters are used for syncing, even if we make it thread safe

also on thread safety, the cloning seems to be a blocking call in post step hook, so it shouldn't race with the inner optimizer step? i added some locking anyway

@tushar00jain tushar00jain force-pushed the pr245 branch 3 times, most recently from 636a86a to 2b7defd Compare July 29, 2025 03:37
@tushar00jain tushar00jain changed the title deep copy state dict for checkpoint make checkpointing thread safe Jul 30, 2025
@tushar00jain tushar00jain force-pushed the pr245 branch 6 times, most recently from cc0a37a to 76f2f60 Compare July 30, 2025 22:05
@tushar00jain tushar00jain changed the title make checkpointing thread safe make checkpointing thread safe and deterministic Jul 30, 2025
@tushar00jain tushar00jain force-pushed the pr245 branch 2 times, most recently from 2388879 to fc1fa08 Compare July 30, 2025 22:52
@tushar00jain tushar00jain force-pushed the pr245 branch 2 times, most recently from f99b8dd to acd9ede Compare July 31, 2025 02:47
"user": {key: value() for key, value in self._user_state_dicts.items()},
"torchft": self.state_dict(),
}
with self._state_dict_lock.r_lock():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have allow_checkpoint and disallow_checkpoint in HTTPTransport -- can we reuse those instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • that also requires copying over state dict to the http transport
  • or keeping some tracking on which step was transferred last to the http transport

with a separate lock, we can decouple checkpoint specific logic with training logic

@tushar00jain tushar00jain force-pushed the pr245 branch 4 times, most recently from 0283856 to 1f5854d Compare August 1, 2025 19:17
@tushar00jain tushar00jain changed the title make checkpointing thread safe and deterministic make checkpointing thread safe Aug 1, 2025
@tushar00jain tushar00jain force-pushed the pr245 branch 2 times, most recently from 800c48f to 577bacd Compare August 1, 2025 20:02
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me -- just want to add a unit test specifically on the lock

@@ -324,6 +328,21 @@ def __init__(
# first step is 1
self._participating_replica_rank: Optional[int] = None
self._participating_replica_world_size: int = 0
self._is_state_dict_read_allowed = True

def allow_state_dict_read(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add specific tests for these methods? Lock logic is pretty risk prone so would be nice to have unit test coverage for these

@tushar00jain tushar00jain force-pushed the pr245 branch 3 times, most recently from 495ab9a to 595e7e9 Compare August 5, 2025 18:32
Summary:
- the checkpointing wasn't thread safe for http transport so use lock the model in the pre step hook and when we want to transfer the checkpoint
@tushar00jain tushar00jain merged commit ee2b322 into pytorch:main Aug 5, 2025
13 of 14 checks passed
@tushar00jain tushar00jain deleted the pr245 branch August 5, 2025 22:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants