diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 69f7130..e92d4bd 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -86,6 +86,9 @@ def __init__( self._hooks: List[RemovableHandle] = [] def __enter__(self) -> "LocalSGD": + self._hooks.append( + self._local_optimizer.register_step_pre_hook(self._step_pre_hook) + ) # Add optimizer hook which increments the local step counter and syncs if necessary self._hooks.append( self._local_optimizer.register_step_post_hook(self._step_post_hook) @@ -106,12 +109,20 @@ def __exit__( return False # Propagate exceptions + def _step_pre_hook( + self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any] + ) -> None: + # The checkpoint may transfer model parameters, so we need to make access to it thread safe + self._manager.disallow_state_dict_read() + def _step_post_hook( self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any] ) -> None: """ This hook is registered on the optimizer and is called after the optimizer step. """ + self._manager.allow_state_dict_read() + self._local_step += 1 if self._local_step >= self._sync_every: self.sync() @@ -682,12 +693,21 @@ def _restore_parameters(self) -> None: fragment.restore_parameters() def __enter__(self) -> "DiLoCo": + self._hooks.append( + self._local_optimizer.register_step_pre_hook(self._step_pre_hook) + ) # Add optimizer hook which increments the local step counter and syncs if necessary self._hooks.append( self._local_optimizer.register_step_post_hook(self._step_post_hook) ) return self + def _step_pre_hook( + self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any] + ) -> None: + # The checkpoint may transfer model parameters, so we need to make access to it thread safe + self._manager.disallow_state_dict_read() + def __exit__( self, exc_type: Optional[Type[BaseException]], @@ -722,6 +742,8 @@ def _step_post_hook( """ This hook is registered on the optimizer and is called after the optimizer step. """ + self._manager.allow_state_dict_read() + # We need to make sure all nodes send the same fragments in order. # This is to avoid deadlocking e.g. # diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 3f3bbff..88bd88c 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -36,6 +36,7 @@ ) logger: logging.Logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) def local_sgd_train_loop( @@ -143,6 +144,7 @@ def assert_equal_global_state( rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"], rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"], check_device=False, + msg=f"{step=} {i=}", ) # Check all outer optimizers for i in range( @@ -574,3 +576,9 @@ def test_streaming_diloco_commit_failure( self.assertEqual( event_injector.count[EventInjectorEvent.AllreduceFailure], 1 ) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/torchft/manager.py b/torchft/manager.py index c49f839..ad9a056 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -55,6 +55,7 @@ from torchft._torchft import ManagerClient, ManagerServer from torchft.checkpointing import CheckpointTransport, HTTPTransport +from torchft.checkpointing._rwlock import RWLock from torchft.futures import future_timeout from torchft.work import _DummyWork @@ -216,6 +217,9 @@ def __init__( self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {} self._user_state_dicts: Dict[str, Callable[[], object]] = {} + # Protects state dict + self._state_dict_lock = RWLock(timeout=timeout.total_seconds()) + if load_state_dict and state_dict: self.register_state_dict_fn("default", load_state_dict, state_dict) @@ -324,6 +328,21 @@ def __init__( # first step is 1 self._participating_replica_rank: Optional[int] = None self._participating_replica_world_size: int = 0 + self._is_state_dict_read_allowed = True + + def allow_state_dict_read(self) -> None: + if self._is_state_dict_read_allowed: + return + + self._is_state_dict_read_allowed = True + self._state_dict_lock.w_release() + + def disallow_state_dict_read(self) -> None: + if not self._is_state_dict_read_allowed: + return + + self._is_state_dict_read_allowed = False + self._state_dict_lock.w_acquire() def register_state_dict_fn( self, @@ -806,11 +825,14 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: self._batches_committed = state_dict["batches_committed"] def _manager_state_dict(self) -> Dict[str, object]: - assert len(self._user_state_dicts) > 0, "user state_dict is not initialized." - return { - "user": {key: value() for key, value in self._user_state_dicts.items()}, - "torchft": self.state_dict(), - } + with self._state_dict_lock.r_lock(): + assert ( + len(self._user_state_dicts) > 0 + ), "user state_dict is not initialized." + return { + "user": {key: value() for key, value in self._user_state_dicts.items()}, + "torchft": self.state_dict(), + } def state_dict(self) -> Dict[str, int]: """ diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 2a6ec29..6960abc 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. import concurrent +import threading +import time from datetime import timedelta from typing import Optional from unittest import TestCase @@ -14,6 +16,7 @@ from torch.distributed import TCPStore from torchft._torchft import QuorumResult +from torchft.checkpointing._rwlock import RWLock from torchft.checkpointing.transport import CheckpointTransport from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode from torchft.process_group import ProcessGroup @@ -778,3 +781,115 @@ def test_max_retries(self, client_mock: MagicMock) -> None: # This should succeed and reset the counter self.assertTrue(manager.should_commit()) self.assertEqual(manager._commit_failures, 0) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_state_dict_lock_allow_disallow(self, client_mock: MagicMock) -> None: + """Test that allow_state_dict_read and disallow_state_dict_read methods work correctly.""" + manager = self._create_manager() + + # Initially, state dict read should be allowed + self.assertTrue(manager._is_state_dict_read_allowed) + + # Test disallow_state_dict_read + manager.disallow_state_dict_read() + self.assertFalse(manager._is_state_dict_read_allowed) + self.assertTrue(manager._state_dict_lock.w_locked()) + + # Calling disallow_state_dict_read again should be a no-op + manager.disallow_state_dict_read() + self.assertFalse(manager._is_state_dict_read_allowed) + self.assertTrue(manager._state_dict_lock.w_locked()) + + # Test allow_state_dict_read + manager.allow_state_dict_read() + self.assertTrue(manager._is_state_dict_read_allowed) + self.assertFalse(manager._state_dict_lock.w_locked()) + + # Calling allow_state_dict_read again should be a no-op + manager.allow_state_dict_read() + self.assertTrue(manager._is_state_dict_read_allowed) + self.assertFalse(manager._state_dict_lock.w_locked()) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_state_dict_lock_concurrent_access(self, client_mock: MagicMock) -> None: + """Test that _state_dict_lock properly protects concurrent access to the state dictionary.""" + manager: Manager = self._create_manager() + + # Create flags for thread synchronization + access_attempted: threading.Event = threading.Event() + can_proceed: threading.Event = threading.Event() + access_result: dict[str, bool] = {"succeeded": False} + + def try_access_state_dict() -> None: + # Wait until the main thread signals it's ready + nonlocal access_attempted, can_proceed, access_result, manager + access_attempted.set() + can_proceed.wait(timeout=1.0) + + # Try to access the state dict + if manager._is_state_dict_read_allowed: + access_result["succeeded"] = True + + # Start a thread that will try to access the state dict + thread = threading.Thread(target=try_access_state_dict) + thread.daemon = True + thread.start() + + # Disallow state dict read + manager.disallow_state_dict_read() + self.assertFalse(manager._is_state_dict_read_allowed) + + # Wait for the thread to be ready + access_attempted.wait(timeout=1.0) + + # Signal the thread to proceed while state dict read is disallowed + can_proceed.set() + thread.join(timeout=1.0) + + # The thread should not have been able to access the state dict + self.assertFalse(access_result["succeeded"]) + + # Reset for the second part of the test + access_attempted.clear() + can_proceed.clear() + + # Start another thread + thread = threading.Thread(target=try_access_state_dict) + thread.daemon = True + thread.start() + + # Allow state dict read + manager.allow_state_dict_read() + self.assertTrue(manager._is_state_dict_read_allowed) + + # Wait for the thread to be ready + access_attempted.wait(timeout=1.0) + + # Signal the thread to proceed while state dict read is allowed + can_proceed.set() + thread.join(timeout=1.0) + + # The thread should now have been able to access the state dict + self.assertTrue(access_result["succeeded"]) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_manager_state_dict_with_lock(self, client_mock: MagicMock) -> None: + """Test that _manager_state_dict properly uses the read lock.""" + manager = self._create_manager() + + # Replace the real RWLock with a mock to track lock acquisition + original_lock = manager._state_dict_lock + mock_lock = create_autospec(RWLock) + mock_context = MagicMock() + mock_lock.r_lock.return_value.__enter__ = lambda _: mock_context + mock_lock.r_lock.return_value.__exit__ = lambda *args: None + manager._state_dict_lock = mock_lock + + # Call _manager_state_dict + result = manager._manager_state_dict() + + # Verify that r_lock was called + mock_lock.r_lock.assert_called_once() + + # Restore the original lock + manager._state_dict_lock = original_lock