Skip to content

Commit 76f2f60

Browse files
committed
make checkpointing thread safe and deterministic
Summary: - the regression tests fail (on future changes) because it expects no recovery to happen, or it happens at the first step - because we validate the parameters at each step, if recovery happens non deterministically, we can't really validate the parameters - to fix this, copy the state dict before transferring it - the checkpointing also wasn't thread safe for http transport so use lock the model in the pre step hook and when we want to transfer the checkpoint
1 parent fef4abc commit 76f2f60

File tree

4 files changed

+59
-6
lines changed

4 files changed

+59
-6
lines changed

torchft/checkpointing/http_transport.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Generator, List, Optional, TypeVar, cast
1717

1818
import torch
19+
from torch.distributed.tensor import DTensor, distribute_tensor
1920
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
2021

2122
from torchft.checkpointing._rwlock import RWLock
@@ -266,6 +267,15 @@ def recv_checkpoint(
266267
return tree_unflatten(values, spec)
267268

268269

270+
def _clone_cpu_tensor(tensor: torch.Tensor) -> torch.Tensor:
271+
if isinstance(tensor, DTensor):
272+
return distribute_tensor(
273+
tensor.to_local().clone(), tensor.device_mesh, tensor.placements
274+
)
275+
else:
276+
return tensor.clone()
277+
278+
269279
def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
270280
out = []
271281
for v in values:
@@ -278,7 +288,7 @@ def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
278288
else:
279289
out.append(v.cpu())
280290
else:
281-
out.append(v)
291+
out.append(_clone_cpu_tensor(v))
282292
else:
283293
out.append(v)
284294
return out

torchft/local_sgd.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def __init__(
8585
self._hooks: List[RemovableHandle] = []
8686

8787
def __enter__(self) -> "LocalSGD":
88+
self._hooks.append(
89+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
90+
)
8891
# Add optimizer hook which increments the local step counter and syncs if necessary
8992
self._hooks.append(
9093
self._local_optimizer.register_step_post_hook(self._step_post_hook)
@@ -105,12 +108,20 @@ def __exit__(
105108

106109
return False # Propagate exceptions
107110

111+
def _step_pre_hook(
112+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
113+
) -> None:
114+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
115+
self._manager.allow_state_dict_updates()
116+
108117
def _step_post_hook(
109118
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
110119
) -> None:
111120
"""
112121
This hook is registered on the optimizer and is called after the optimizer step.
113122
"""
123+
self._manager.allow_state_dict_updates()
124+
114125
self._local_step += 1
115126
if self._local_step >= self._sync_every:
116127
self.sync()
@@ -667,12 +678,21 @@ def _restore_parameters(self) -> None:
667678
fragment.restore_parameters()
668679

669680
def __enter__(self) -> "DiLoCo":
681+
self._hooks.append(
682+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
683+
)
670684
# Add optimizer hook which increments the local step counter and syncs if necessary
671685
self._hooks.append(
672686
self._local_optimizer.register_step_post_hook(self._step_post_hook)
673687
)
674688
return self
675689

690+
def _step_pre_hook(
691+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
692+
) -> None:
693+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
694+
self._manager.disallow_state_dict_updates()
695+
676696
def __exit__(
677697
self,
678698
exc_type: Optional[Type[BaseException]],
@@ -707,6 +727,8 @@ def _step_post_hook(
707727
"""
708728
This hook is registered on the optimizer and is called after the optimizer step.
709729
"""
730+
self._manager.allow_state_dict_updates()
731+
710732
# We need to make sure all nodes send the same fragments in order.
711733
# This is to avoid deadlocking e.g.
712734
#

torchft/local_sgd_integ_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737

3838
logger: logging.Logger = logging.getLogger(__name__)
39+
logging.basicConfig(level=logging.INFO)
3940

4041

4142
def local_sgd_train_loop(
@@ -143,6 +144,7 @@ def assert_equal_global_state(
143144
rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"],
144145
rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"],
145146
check_device=False,
147+
msg=f"{step=} {i=}",
146148
)
147149
# Check all outer optimizers
148150
for i in range(
@@ -574,3 +576,9 @@ def test_streaming_diloco_commit_failure(
574576
self.assertEqual(
575577
event_injector.count[EventInjectorEvent.AllreduceFailure], 1
576578
)
579+
580+
581+
if __name__ == "__main__":
582+
import unittest
583+
584+
unittest.main()

torchft/manager.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
from torchft._torchft import ManagerClient, ManagerServer
4545
from torchft.checkpointing import CheckpointTransport, HTTPTransport
46+
from torchft.checkpointing._rwlock import RWLock
4647
from torchft.futures import future_timeout
4748

4849
if TYPE_CHECKING:
@@ -203,6 +204,9 @@ def __init__(
203204
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
204205
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
205206

207+
# Protects state dict
208+
self._state_dict_lock = RWLock(timeout=timeout.total_seconds())
209+
206210
if load_state_dict and state_dict:
207211
self.register_state_dict_fn("default", load_state_dict, state_dict)
208212

@@ -312,6 +316,12 @@ def __init__(
312316
self._participating_replica_rank: Optional[int] = None
313317
self._participating_replica_world_size: int = 0
314318

319+
def allow_state_dict_updates(self) -> None:
320+
self._state_dict_lock.w_release()
321+
322+
def disallow_state_dict_updates(self) -> None:
323+
self._state_dict_lock.w_acquire()
324+
315325
def register_state_dict_fn(
316326
self,
317327
key: str,
@@ -820,11 +830,14 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
820830
self._batches_committed = state_dict["batches_committed"]
821831

822832
def _manager_state_dict(self) -> Dict[str, object]:
823-
assert len(self._user_state_dicts) > 0, "user state_dict is not initialized."
824-
return {
825-
"user": {key: value() for key, value in self._user_state_dicts.items()},
826-
"torchft": self.state_dict(),
827-
}
833+
with self._state_dict_lock.r_lock():
834+
assert (
835+
len(self._user_state_dicts) > 0
836+
), "user state_dict is not initialized."
837+
return {
838+
"user": {key: value() for key, value in self._user_state_dicts.items()},
839+
"torchft": self.state_dict(),
840+
}
828841

829842
def state_dict(self) -> Dict[str, int]:
830843
"""

0 commit comments

Comments
 (0)