Skip to content

make checkpointing thread safe #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def __init__(
self._hooks: List[RemovableHandle] = []

def __enter__(self) -> "LocalSGD":
self._hooks.append(
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
)
# Add optimizer hook which increments the local step counter and syncs if necessary
self._hooks.append(
self._local_optimizer.register_step_post_hook(self._step_post_hook)
Expand All @@ -106,12 +109,20 @@ def __exit__(

return False # Propagate exceptions

def _step_pre_hook(
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
self._manager.disallow_state_dict_read()

def _step_post_hook(
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
"""
self._manager.allow_state_dict_read()

self._local_step += 1
if self._local_step >= self._sync_every:
self.sync()
Expand Down Expand Up @@ -682,12 +693,21 @@ def _restore_parameters(self) -> None:
fragment.restore_parameters()

def __enter__(self) -> "DiLoCo":
self._hooks.append(
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
)
# Add optimizer hook which increments the local step counter and syncs if necessary
self._hooks.append(
self._local_optimizer.register_step_post_hook(self._step_post_hook)
)
return self

def _step_pre_hook(
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
self._manager.disallow_state_dict_read()

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
Expand Down Expand Up @@ -722,6 +742,8 @@ def _step_post_hook(
"""
This hook is registered on the optimizer and is called after the optimizer step.
"""
self._manager.allow_state_dict_read()

# We need to make sure all nodes send the same fragments in order.
# This is to avoid deadlocking e.g.
#
Expand Down
8 changes: 8 additions & 0 deletions torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)

logger: logging.Logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def local_sgd_train_loop(
Expand Down Expand Up @@ -143,6 +144,7 @@ def assert_equal_global_state(
rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"],
rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"],
check_device=False,
msg=f"{step=} {i=}",
)
# Check all outer optimizers
for i in range(
Expand Down Expand Up @@ -574,3 +576,9 @@ def test_streaming_diloco_commit_failure(
self.assertEqual(
event_injector.count[EventInjectorEvent.AllreduceFailure], 1
)


if __name__ == "__main__":
import unittest

unittest.main()
32 changes: 27 additions & 5 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

from torchft._torchft import ManagerClient, ManagerServer
from torchft.checkpointing import CheckpointTransport, HTTPTransport
from torchft.checkpointing._rwlock import RWLock
from torchft.futures import future_timeout
from torchft.work import _DummyWork

Expand Down Expand Up @@ -216,6 +217,9 @@ def __init__(
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
self._user_state_dicts: Dict[str, Callable[[], object]] = {}

# Protects state dict
self._state_dict_lock = RWLock(timeout=timeout.total_seconds())

if load_state_dict and state_dict:
self.register_state_dict_fn("default", load_state_dict, state_dict)

Expand Down Expand Up @@ -324,6 +328,21 @@ def __init__(
# first step is 1
self._participating_replica_rank: Optional[int] = None
self._participating_replica_world_size: int = 0
self._is_state_dict_read_allowed = True

def allow_state_dict_read(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add specific tests for these methods? Lock logic is pretty risk prone so would be nice to have unit test coverage for these

if self._is_state_dict_read_allowed:
return

self._is_state_dict_read_allowed = True
self._state_dict_lock.w_release()

def disallow_state_dict_read(self) -> None:
if not self._is_state_dict_read_allowed:
return

self._is_state_dict_read_allowed = False
self._state_dict_lock.w_acquire()

def register_state_dict_fn(
self,
Expand Down Expand Up @@ -806,11 +825,14 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
self._batches_committed = state_dict["batches_committed"]

def _manager_state_dict(self) -> Dict[str, object]:
assert len(self._user_state_dicts) > 0, "user state_dict is not initialized."
return {
"user": {key: value() for key, value in self._user_state_dicts.items()},
"torchft": self.state_dict(),
}
with self._state_dict_lock.r_lock():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have allow_checkpoint and disallow_checkpoint in HTTPTransport -- can we reuse those instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • that also requires copying over state dict to the http transport
  • or keeping some tracking on which step was transferred last to the http transport

with a separate lock, we can decouple checkpoint specific logic with training logic

assert (
len(self._user_state_dicts) > 0
), "user state_dict is not initialized."
return {
"user": {key: value() for key, value in self._user_state_dicts.items()},
"torchft": self.state_dict(),
}

def state_dict(self) -> Dict[str, int]:
"""
Expand Down
115 changes: 115 additions & 0 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.

import concurrent
import threading
import time
from datetime import timedelta
from typing import Optional
from unittest import TestCase
Expand All @@ -14,6 +16,7 @@
from torch.distributed import TCPStore

from torchft._torchft import QuorumResult
from torchft.checkpointing._rwlock import RWLock
from torchft.checkpointing.transport import CheckpointTransport
from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode
from torchft.process_group import ProcessGroup
Expand Down Expand Up @@ -778,3 +781,115 @@ def test_max_retries(self, client_mock: MagicMock) -> None:
# This should succeed and reset the counter
self.assertTrue(manager.should_commit())
self.assertEqual(manager._commit_failures, 0)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_state_dict_lock_allow_disallow(self, client_mock: MagicMock) -> None:
"""Test that allow_state_dict_read and disallow_state_dict_read methods work correctly."""
manager = self._create_manager()

# Initially, state dict read should be allowed
self.assertTrue(manager._is_state_dict_read_allowed)

# Test disallow_state_dict_read
manager.disallow_state_dict_read()
self.assertFalse(manager._is_state_dict_read_allowed)
self.assertTrue(manager._state_dict_lock.w_locked())

# Calling disallow_state_dict_read again should be a no-op
manager.disallow_state_dict_read()
self.assertFalse(manager._is_state_dict_read_allowed)
self.assertTrue(manager._state_dict_lock.w_locked())

# Test allow_state_dict_read
manager.allow_state_dict_read()
self.assertTrue(manager._is_state_dict_read_allowed)
self.assertFalse(manager._state_dict_lock.w_locked())

# Calling allow_state_dict_read again should be a no-op
manager.allow_state_dict_read()
self.assertTrue(manager._is_state_dict_read_allowed)
self.assertFalse(manager._state_dict_lock.w_locked())

@patch("torchft.manager.ManagerClient", autospec=True)
def test_state_dict_lock_concurrent_access(self, client_mock: MagicMock) -> None:
"""Test that _state_dict_lock properly protects concurrent access to the state dictionary."""
manager: Manager = self._create_manager()

# Create flags for thread synchronization
access_attempted: threading.Event = threading.Event()
can_proceed: threading.Event = threading.Event()
access_result: dict[str, bool] = {"succeeded": False}

def try_access_state_dict() -> None:
# Wait until the main thread signals it's ready
nonlocal access_attempted, can_proceed, access_result, manager
access_attempted.set()
can_proceed.wait(timeout=1.0)

# Try to access the state dict
if manager._is_state_dict_read_allowed:
access_result["succeeded"] = True

# Start a thread that will try to access the state dict
thread = threading.Thread(target=try_access_state_dict)
thread.daemon = True
thread.start()

# Disallow state dict read
manager.disallow_state_dict_read()
self.assertFalse(manager._is_state_dict_read_allowed)

# Wait for the thread to be ready
access_attempted.wait(timeout=1.0)

# Signal the thread to proceed while state dict read is disallowed
can_proceed.set()
thread.join(timeout=1.0)

# The thread should not have been able to access the state dict
self.assertFalse(access_result["succeeded"])

# Reset for the second part of the test
access_attempted.clear()
can_proceed.clear()

# Start another thread
thread = threading.Thread(target=try_access_state_dict)
thread.daemon = True
thread.start()

# Allow state dict read
manager.allow_state_dict_read()
self.assertTrue(manager._is_state_dict_read_allowed)

# Wait for the thread to be ready
access_attempted.wait(timeout=1.0)

# Signal the thread to proceed while state dict read is allowed
can_proceed.set()
thread.join(timeout=1.0)

# The thread should now have been able to access the state dict
self.assertTrue(access_result["succeeded"])

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_state_dict_with_lock(self, client_mock: MagicMock) -> None:
"""Test that _manager_state_dict properly uses the read lock."""
manager = self._create_manager()

# Replace the real RWLock with a mock to track lock acquisition
original_lock = manager._state_dict_lock
mock_lock = create_autospec(RWLock)
mock_context = MagicMock()
mock_lock.r_lock.return_value.__enter__ = lambda _: mock_context
mock_lock.r_lock.return_value.__exit__ = lambda *args: None
manager._state_dict_lock = mock_lock

# Call _manager_state_dict
result = manager._manager_state_dict()

# Verify that r_lock was called
mock_lock.r_lock.assert_called_once()

# Restore the original lock
manager._state_dict_lock = original_lock
Loading