diff --git a/torchft/_serialization.py b/torchft/_serialization.py new file mode 100644 index 0000000..44c13a9 --- /dev/null +++ b/torchft/_serialization.py @@ -0,0 +1,155 @@ +import pickle +from dataclasses import dataclass +from io import BufferedIOBase +from typing import Any, Dict, List, Tuple + +import torch +import torch._weights_only_unpickler as _weights_only_unpickler +from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION + + +__all__: List[str] = [] + + +@dataclass +class _Entry: + key: str + is_storage: bool + length: int + + +_weights_only_unpickler._add_safe_globals([_Entry]) + + +class _PseudoZipFile: + def __init__(self) -> None: + self.records: Dict[str, Tuple[object, int]] = {} + + def write_record(self, key: str, data: object, length: int) -> None: + self.records[key] = (data, length) + + def write_to(self, f: BufferedIOBase) -> None: + entries = [] + for key, (data, length) in self.records.items(): + entries.append( + _Entry( + key=key, + is_storage=isinstance(data, torch.UntypedStorage), + length=length, + ) + ) + + pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL) + + for key, (data, length) in self.records.items(): + if isinstance(data, bytes): + f.write(data) + elif isinstance(data, str): + f.write(data.encode("utf-8")) + elif isinstance(data, torch.UntypedStorage): + data._write_file(f, False, False, 1) + else: + raise TypeError(f"unknown type: {type(data)}") + + def read_from(self, f: BufferedIOBase) -> None: + entries = _weights_only_unpickler.load(f) + + for entry in entries: + data = f.read(entry.length) + if entry.is_storage: + storage = torch.frombuffer( + data, + dtype=torch.uint8, + ).untyped_storage() + + self.records[entry.key] = ( + storage, + entry.length, + ) + else: + self.records[entry.key] = (data, entry.length) + + def has_record(self, key: str) -> bool: + return key in self.records + + def get_record(self, key: str) -> object: + return self.records[key][0] + + def get_storage_from_record( + self, key: str, _length: int, _type: int + ) -> torch.Tensor: + return torch.tensor(self.records[key][0], dtype=torch.uint8) + + def serialization_id(self) -> str: + return "torchft" + + +def _streaming_save( + obj: object, + f: BufferedIOBase, + pickle_module: Any = pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, +) -> None: + """ + Save the object to a file-like object in a streaming fashion compatible with + network sockets. + + This behaves similarly to :func:`torch.save` with a few notable differences: + + * A non-seekable file like object can be used when loading. + * No forwards/backwards compatiblity is provided for the serialization + format. This is only intended to be used with a single version of PyTorch + with transient storage (i.e. sockets or temp files). + * mmap is not supported + + See :func:`torch.save` for more details on specific arguments. + """ + + zip_file = _PseudoZipFile() + _save( + obj, + zip_file=zip_file, + pickle_module=pickle_module, + pickle_protocol=pickle_protocol, + _disable_byteorder_record=False, + ) + zip_file.write_to(f) + + +def _streaming_load( + f: BufferedIOBase, + map_location: MAP_LOCATION = None, + pickle_module: Any = None, + *, + weights_only: bool = True, + **pickle_load_args: Any, +) -> object: + """ + Load the object from a file-like object in a streaming fashion compatible with + network sockets. + + See :func:`_streaming_save` for more details about the streaming behavior. + + See :func:`torch.load` for more details on specific arguments. + """ + if weights_only: + if pickle_module is not None: + raise RuntimeError( + "Can not safely load weights when explicit pickle_module is specified" + ) + pickle_module = _weights_only_unpickler + else: + if pickle_module is None: + pickle_module = pickle + + if "encoding" not in pickle_load_args.keys(): + pickle_load_args["encoding"] = "utf-8" + + zip_file = _PseudoZipFile() + zip_file.read_from(f) + return _load( + zip_file=zip_file, + map_location=map_location, + pickle_module=pickle_module, + **pickle_load_args, + ) diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index c3168b2..191ac01 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -13,24 +13,56 @@ import io import logging +import pickle import socket import threading +import time import urllib.request from abc import ABC, abstractmethod -from contextlib import contextmanager +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Generator, Generic, List, Optional, TypeVar +from typing import Callable, Generator, Generic, List, Optional, Tuple, TypeVar import torch +from torch.distributed.tensor import DTensor +from torch.utils._pytree import tree_flatten, tree_unflatten from torchft.http import _IPv6HTTPServer +from torchft.process_group import ProcessGroup +from torchft.rwlock import RWLock logger: logging.Logger = logging.getLogger(__name__) T = TypeVar("T") +""" +def _save(obj: object, f: io.BufferedIOBase) -> None: + torch.save(obj, f) + + +def _load(f: io.BufferedIOBase) -> object: + data = f.read() + reader = io.BytesIO(data) + return torch.load(reader, weights_only=False) +""" + + +try: + from torch.distributed._serialization import _streaming_load, _streaming_save +except ImportError: + from torchft._serialization import _streaming_load, _streaming_save + +_save = _streaming_save + + +def _load(f: io.BufferedIOBase) -> object: + return _streaming_load(f, weights_only=False, map_location="cpu") + + class CheckpointTransport(Generic[T], ABC): @abstractmethod def metadata(self) -> str: @@ -88,23 +120,209 @@ def shutdown(self, wait: bool = True) -> None: """ +@dataclass +class _TensorMeta: + shape: torch.Size + dtype: torch.dtype + storage_offset: int + stride: int + nbytes: int + + +@dataclass +class _DTensorMeta: + local: _TensorMeta + spec: object + + +@dataclass +class _StateDictMeta: + step: int + spec: object + non_tensors: List[object] + tensor_metas: List[_TensorMeta] + + @contextmanager -def _timed_acquire( - lock: threading.Lock, timeout: timedelta -) -> Generator[None, None, None]: +def _timeit(name: str) -> Generator[None, None, None]: + start = time.perf_counter() + yield + dur = time.perf_counter() - start + logger.info(f"{name} took {dur}s") + + +def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]: + return ( + _cast_tensor(tensor, torch.uint8), + _TensorMeta( + shape=tensor.shape, + dtype=tensor.dtype, + storage_offset=tensor.storage_offset(), + stride=tensor.stride(), + nbytes=tensor.untyped_storage().nbytes(), + ), + ) + + +def _prepare_state_dict( + state_dict: object, + step: int, + device: str, +) -> Tuple[_StateDictMeta, List[torch.Tensor]]: + start = time.perf_counter() + values, spec = tree_flatten(state_dict) + + non_tensors = [] + tensors = [] + tensor_metas = [] + for v in values: + if isinstance(v, DTensor): + tensor, tensor_meta = _prepare_tensor(v._local_tensor) + + tensor_metas.append(tensor_meta) + tensors.append(tensor) + + non_tensors.append( + _DTensorMeta( + local=tensor_meta, + spec=v._spec, + ) + ) + elif isinstance(v, torch.Tensor): + tensor, tensor_meta = _prepare_tensor(v) + tensors.append(tensor) + non_tensors.append(tensor_meta) + tensor_metas.append(tensor_meta) + else: + non_tensors.append(v) + + total_size = sum(t.nbytes for t in tensors) + + dur = time.perf_counter() - start + logger.info( + f"prepared state_dict {total_size=} {len(tensors)=} {len(non_tensors)=} in {dur}s" + ) + + return ( + _StateDictMeta( + step=step, + spec=spec, + non_tensors=non_tensors, + tensor_metas=tensor_metas, + ), + tensors, + ) + + +def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + storage = tensor.untyped_storage() + ret = torch.tensor(storage, dtype=dtype, device=tensor.device) + assert ret.untyped_storage() is storage, "storage should be the same" + return ret + + +class PGTransport(CheckpointTransport[T]): """ - Acquire a lock with a timeout. + This is a checkpoint transport that uses the process group to transfer checkpoints. + + This allows for fast recovery of workers by fetching the current weights + from an existing worker. Args: - lock: the lock to acquire - timeout: the timeout to acquire the lock + state_dict: a callable that returns the state dict to be transferred """ - if not lock.acquire(timeout=timeout.total_seconds()): - raise TimeoutError(f"timed out acquiring lock after {timeout}") - try: - yield - finally: - lock.release() + + def __init__( + self, pg: ProcessGroup, timeout: timedelta, device: torch.device + ) -> None: + self._work = [] + self._pg = pg + self._timeout = timeout + self._device = device + + def metadata(self) -> str: + return "" + + def disallow_checkpoint(self) -> None: + pass + + def send_checkpoint( + self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta + ) -> None: + meta, tensors = _prepare_state_dict(state_dict, step, device=self._device) + + work = [] + + with _timeit("send pickle"): + buf = pickle.dumps(meta) + len_t = torch.tensor([len(buf)], dtype=torch.int64, device=self._device) + buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device) + for dst_rank in dst_ranks: + work.append(self._pg.send([len_t], dst_rank, tag=1)) + work.append(self._pg.send([buf_t], dst_rank, tag=2)) + + with _timeit("send tensors"): + for i, t in enumerate(tensors): + t = t.to(self._device) + for dst_rank in dst_ranks: + work.append(self._pg.send([t], dst_rank, tag=3 + i)) + + # allow 3 concurrent transfers at a time + while len(work) > (3 * len(dst_ranks)): + work.pop(0).wait() + + for w in work: + w.wait() + + def recv_checkpoint( + self, src_rank: int, metadata: str, step: int, timeout: timedelta + ) -> T: + len_t = torch.zeros(1, dtype=torch.int64, device=self._device) + self._pg.recv([len_t], src_rank, tag=1).wait() + length = len_t.item() + + assert length > 0, f"invalid metadata length {length=}" + + buf = torch.empty(length, dtype=torch.uint8, device=self._device) + self._pg.recv([buf], src_rank, tag=2).wait() + + meta = pickle.loads(buf.cpu().numpy().tobytes()) + assert meta.step == step + + i = 0 + + values = [] + for v in meta.non_tensors: + if isinstance(v, _TensorMeta): + t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device) + self._pg.recv([t], src_rank, tag=3 + i).wait() + i += 1 + t = t.cpu() + + tensor = torch.as_strided( + _cast_tensor(t, v.dtype), + size=v.shape, + stride=v.stride, + storage_offset=v.storage_offset, + ) + values.append(tensor) + elif isinstance(v, _DTensorMeta): + t = torch.empty(v.local.nbytes, dtype=torch.uint8, device=self._device) + self._pg.recv([t], src_rank, tag=3 + i).wait() + i += 1 + t = t.cpu() + + tensor = torch.as_strided( + _cast_tensor(t, v.local.dtype), + size=v.local.shape, + stride=v.local.stride, + storage_offset=v.local.storage_offset, + ) + values.append(DTensor(tensor, v.spec, requires_grad=False)) + else: + values.append(v) + + return tree_unflatten(values, meta.spec) class CheckpointServer(CheckpointTransport[T]): @@ -119,12 +337,18 @@ class CheckpointServer(CheckpointTransport[T]): state_dict: a callable that returns the state dict to be transferred """ - def __init__(self, timeout: timedelta) -> None: - self._checkpoint_lock = threading.Lock() + def __init__(self, timeout: timedelta, num_chunks: int = 10) -> None: + self._checkpoint_lock = RWLock(timeout=timeout.total_seconds()) self._disallowed = False self._step = -1 self._timeout = timeout self._state_dict: Optional[T] = None + self._num_chunks = num_chunks + self._spec: Optional[object] = None + self._chunks: Optiona[List[object]] = None + self._stream: Optional[torch.cuda.Stream] = ( + torch.cuda.Stream() if torch.cuda.is_available() else None + ) # We don't allow checkpoints until the first send_checkpoint to avoid # serving the default step=-1 invalid checkpoint. @@ -141,12 +365,17 @@ def do_GET(self): # validate socket timeout is actually set assert self.connection.gettimeout() == self.timeout - with _timed_acquire( - ckpt_server._checkpoint_lock, ckpt_server._timeout - ): + sock = self.wfile._sock + sock.setsockopt( + socket.SOL_SOCKET, socket.SO_SNDBUF, 2097152 + ) # set send buffer size to 2MB + + with ckpt_server._checkpoint_lock.r_lock(): step = ckpt_server._step - if self.path != f"/checkpoint/{step}": + parts = self.path.split("/") + assert len(parts) == 4 + if parts[1] != "checkpoint": self.send_response(400) self.send_header("Content-type", "text/plain") self.end_headers() @@ -155,13 +384,33 @@ def do_GET(self): ) return - self.send_response(200) - self.send_header("Content-type", "application/octet-stream") - self.end_headers() + step = int(parts[2]) + + key = parts[3] + if key == "full": + self.send_response(200) + self.send_header("Content-type", "application/octet-stream") + self.end_headers() + + state_dict = ckpt_server._state_dict + + _save(state_dict, self.wfile) + return + + if key == "metadata": + self.send_response(200) + self.send_header("Content-type", "application/octet-stream") + self.end_headers() + + _save(ckpt_server._spec, self.wfile) + else: + self.send_response(200) + self.send_header("Content-type", "application/octet-stream") + self.end_headers() - state_dict = ckpt_server._state_dict + chunk = ckpt_server._chunks[int(key)] + _save(chunk, self.wfile) - torch.save(state_dict, self.wfile) except Exception as e: logger.exception( f"Exception in checkpoint server when handling {self.path=}: {e}", @@ -194,13 +443,20 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T: """ logger.info(f"fetching checkpoint from {address}") + start = time.perf_counter() + with urllib.request.urlopen(address, timeout=timeout.total_seconds()) as f: - data = f.read() + sock = f.fp.raw._sock + sock.setsockopt( + socket.SOL_SOCKET, socket.SO_RCVBUF, 2097152 + ) # set receive buffer size to 2MB + data = _load(f) + + dur = time.perf_counter() - start - reader = io.BytesIO(data) - # We have to set weights_only to False as there are some non-tensor - # states like lr_scheduler. - return torch.load(reader, weights_only=False) + logger.info(f"done fetching checkpoint from {address} in {dur}s") + + return data def address(self) -> str: """ @@ -228,7 +484,7 @@ def disallow_checkpoint(self) -> None: """ if not self._disallowed: self._disallowed = True - self._checkpoint_lock.acquire() + self._checkpoint_lock.w_acquire() def allow_checkpoint(self, step: int) -> None: """ @@ -241,7 +497,7 @@ def allow_checkpoint(self, step: int) -> None: if self._disallowed: self._disallowed = False - self._checkpoint_lock.release() + self._checkpoint_lock.w_release() def shutdown(self, wait: bool = True) -> None: """ @@ -262,9 +518,113 @@ def send_checkpoint( self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta ) -> None: self._state_dict = state_dict + + from torch.utils._pytree import tree_flatten + + values, spec = tree_flatten(state_dict) + + with ( + torch.cuda.stream(self._stream) + if self._stream is not None + else nullcontext() + ): + logger.info("transferring to CPU") + start = time.perf_counter() + values = _to_cpu(values, pin_memory=False) + if self._stream is not None: + self._stream.synchronize() + logger.info(f"done transferring to CPU in {time.perf_counter() - start}s") + + self._spec = spec + self._chunks = _split_chunks(values, self._num_chunks) + self.allow_checkpoint(step) def recv_checkpoint( self, src_rank: int, metadata: str, step: int, timeout: timedelta ) -> T: - return self.load_from_address(f"{metadata}{step}", timeout) + base_url = f"{metadata}{step}" + if self._num_chunks == 0: + return self.load_from_address(f"{base_url}/full", timeout) + else: + urls = [f"{base_url}/metadata"] + [ + f"{base_url}/{i}" for i in range(self._num_chunks) + ] + + with ThreadPoolExecutor(max_workers=len(urls)) as executor: + futures = [ + executor.submit(self.load_from_address, url, timeout) + for url in urls + ] + + spec, *chunks = [future.result() for future in futures] + + values = _merge_chunks(chunks, self._num_chunks) + + from torch.utils._pytree import tree_flatten, tree_unflatten + + return tree_unflatten(values, spec) + + +def _to_cpu(values: List[object], pin_memory: bool) -> List[object]: + out = [] + for v in values: + if isinstance(v, torch.Tensor): + if v.device.type == "cuda": + if pin_memory: + cpu = torch.empty(*tuple(v.size()), dtype=v.dtype, pin_memory=True) + cpu.copy_(v, non_blocking=True) + out.append(cpu) + else: + out.append(v.cpu()) + else: + out.append(v) + else: + out.append(v) + return out + + +def _split_chunks(values: List[object], num_chunks: int) -> List[object]: + return [values[i::num_chunks] for i in range(num_chunks)] + + +def _merge_chunks(chunks: List[List[object]], num_chunks: int) -> List[object]: + max_len = max(len(lst) for lst in chunks) + output_list = [] + for i in range(max_len): + for lst in chunks: + if i < len(lst): + output_list.append(lst[i]) + return output_list + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + transport = CheckpointServer(timedelta(seconds=60), num_chunks=0) + metadata = transport.metadata() + print(f"fetching from {metadata}") + + device = torch.device("cpu") + + state_dict = {} + CHUNK_SIZE = 64_000_000 # 64MB + TOTAL_SIZE = 5_000_000_000 # 1GB + for i in range(0, TOTAL_SIZE, CHUNK_SIZE): + state_dict[f"chunk/{i}"] = torch.zeros( + CHUNK_SIZE // 4, dtype=torch.float32, device=device + ) + + transport.send_checkpoint( + dst_ranks=[0], step=1, state_dict=state_dict, timeout=timedelta(seconds=60) + ) + + import time + + print("starting") + start = time.perf_counter() + transport.recv_checkpoint( + src_rank=1, metadata=metadata, step=1, timeout=timedelta(seconds=60) + ) + end = time.perf_counter() + print(f"took {end - start} seconds") diff --git a/torchft/checkpointing_test.py b/torchft/checkpointing_test.py index 31658b4..2a64487 100644 --- a/torchft/checkpointing_test.py +++ b/torchft/checkpointing_test.py @@ -6,11 +6,18 @@ import threading import urllib.error +from concurrent.futures import ThreadPoolExecutor from datetime import timedelta -from unittest import TestCase +from unittest import skipUnless, TestCase from unittest.mock import MagicMock -from torchft.checkpointing import CheckpointServer, _timed_acquire +import torch +import torch.distributed as dist +from torch.distributed import TCPStore +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor + +from torchft.checkpointing import CheckpointServer, PGTransport +from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo class TestCheckpointing(TestCase): @@ -103,3 +110,100 @@ def test_timed_acquire(self) -> None: pass self.assertTrue(lock.locked()) + + def _test_pg_transport(self, backend: str, device: str) -> None: + dist.init_process_group( + backend=backend, rank=0, world_size=1, store=dist.HashStore() + ) + device_mesh = DeviceMesh("cpu", 1) + + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + store_addr = f"localhost:{store.port}/prefix" + + timeout = timedelta(seconds=5) + + def sender(device: str) -> object: + if backend == "gloo": + a = ProcessGroupGloo(timeout=timeout) + elif backend == "nccl": + device = f"{device}:0" + a = ProcessGroupBabyNCCL(timeout=timeout) + else: + raise ValueError(f"unknown backend: {backend}") + + a.configure(store_addr, 0, 2) + + print("send configured") + + tensor = torch.randn(4, 4) + dtensor = distribute_tensor(tensor, device_mesh, []) + + state_dict = { + "tensors": { + "float32": torch.tensor([1, 2, 3], dtype=torch.float32), + "strided": torch.rand(10, dtype=torch.float32)[1::2], + "uint16": torch.tensor([1, 2, 3], dtype=torch.uint16), + "dtensor": dtensor, + }, + "non_tensors": "blah", + } + + transport = PGTransport(a, timeout=timeout, device=device) + transport.send_checkpoint( + dst_ranks=[1], + step=123, + state_dict=state_dict, + timeout=timeout, + ) + transport.disallow_checkpoint() + + return state_dict + + def receiver(device: str) -> object: + if backend == "gloo": + a = ProcessGroupGloo(timeout=timeout) + elif backend == "nccl": + # torch.cuda.set_device(1) + device = f"{device}:1" + a = ProcessGroupBabyNCCL(timeout=timeout) + else: + raise ValueError(f"unknown backend: {backend}") + + a.configure(store_addr, 1, 2) + + print("recv configured") + + transport = PGTransport(a, timeout=timeout, device=device) + state_dict = transport.recv_checkpoint( + src_rank=0, metadata="blah", step=123, timeout=timeout + ) + + return state_dict + + with ThreadPoolExecutor(max_workers=2) as executor: + send_fut = executor.submit(sender, device) + recv_fut = executor.submit(receiver, device) + + send_state_dict = send_fut.result() + recv_state_dict = recv_fut.result() + + for k, a in send_state_dict["tensors"].items(): + b = recv_state_dict["tensors"][k] + + if isinstance(a, DTensor): + torch.testing.assert_close(b._local_tensor.cpu(), a._local_tensor.cpu()) + self.assertEqual(b._spec, a._spec) + else: + torch.testing.assert_close(b.cpu(), a.cpu()) + self.assertEqual(recv_state_dict["non_tensors"], send_state_dict["non_tensors"]) + + dist.destroy_process_group() + + def test_pg_transport_gloo(self) -> None: + self._test_pg_transport("gloo", "cpu") + + @skipUnless(torch.cuda.device_count() >= 2, "need two CUDA devices") + def test_pg_transport_baby_nccl(self) -> None: + self._test_pg_transport("nccl", "cuda") diff --git a/torchft/http.py b/torchft/http.py index a93a84b..e73c054 100644 --- a/torchft/http.py +++ b/torchft/http.py @@ -1,5 +1,6 @@ import socket from http.server import ThreadingHTTPServer +from urllib.request import build_opener, HTTPHandler class _IPv6HTTPServer(ThreadingHTTPServer): diff --git a/torchft/process_group.py b/torchft/process_group.py index 540633b..979d9d3 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -22,17 +22,17 @@ from dataclasses import dataclass from datetime import timedelta from typing import ( - TYPE_CHECKING, Any, Callable, + cast, Dict, Generator, List, Optional, Tuple, + TYPE_CHECKING, TypeVar, Union, - cast, ) import torch @@ -43,14 +43,14 @@ # pyre-fixme[21]: no attribute ProcessGroupGloo from torch.distributed import ( DeviceMesh, + get_rank, + init_device_mesh, PrefixStore, ProcessGroup as BaseProcessGroup, ProcessGroupGloo as BaseProcessGroupGloo, ProcessGroupNCCL as BaseProcessGroupNCCL, Store, TCPStore, - get_rank, - init_device_mesh, ) from torch.distributed.distributed_c10d import ( AllgatherOptions, @@ -143,6 +143,34 @@ def allgather( """ raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override + def send( + self, + tensors: List[torch.Tensor], + dst_rank: int, + tag: int = 0, + ) -> Work: + """ + Sends the tensor to the given rank. + + See torch.distributed.send for more details. + """ + raise NotImplementedError("not implemented") + + # pyre-fixme[14]: inconsistent override + def recv( + self, + tensors: List[torch.Tensor], + src_rank: int, + tag: int = 0, + ) -> Work: + """ + Receives the tensor from the given rank. + + See torch.distributed.recv for more details. + """ + raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override def broadcast( self, tensor_list: List[torch.Tensor], opts: BroadcastOptions @@ -267,6 +295,22 @@ def allgather( def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: return self.parent.broadcast(tensor_list, opts) + def send( + self, + tensors: List[torch.Tensor], + dst_rank: int, + tag: int = 0, + ) -> Work: + return self.parent.send(tensors, dst_rank, tag) + + def recv( + self, + tensors: List[torch.Tensor], + src_rank: int, + tag: int = 0, + ) -> Work: + return self.parent.recv(tensors, src_rank, tag) + def size(self) -> int: return self.parent.size() @@ -377,6 +421,26 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: self._work.append(res) return res + def send( + self, + tensors: List[torch.Tensor], + dst_rank: int, + tag: int = 0, + ) -> Work: + res = _DummyWork(tensors) + self._work.append(res) + return res + + def recv( + self, + tensors: List[torch.Tensor], + src_rank: int, + tag: int = 0, + ) -> Work: + res = _DummyWork(tensors) + self._work.append(res) + return res + def size(self) -> int: return self._world @@ -764,8 +828,9 @@ def _worker( args = _PickleSafeOptions.unsafe_args(args) fn = getattr(pg, func_name) + op_work = fn(*args, **kwargs) work[next_op_id] = _OpMetadata( - work=fn(*args, **kwargs), + work=op_work, stream=stream, ) tx.put(next_op_id) @@ -778,7 +843,7 @@ def _worker( with metadata.set_stream(): # With WorkNCCL this makes the stream wait not the CPU when # no timeout is passed. - metadata.work.wait() + metadata.work.wait(timedelta(seconds=60.0)) # Register event on the stream that we can pass to the main # process. @@ -970,6 +1035,30 @@ def broadcast( return self._run_func("broadcast", tensor_list, opts) + def send( + self, + tensors: List[torch.Tensor], + dst_rank: int, + tag: int = 0, + ) -> Work: + for tensor in tensors: + if not tensor.is_shared(): + tensor.share_memory_() + + return self._run_func("send", tensors, dst_rank, tag) + + def recv( + self, + tensors: List[torch.Tensor], + src_rank: int, + tag: int = 0, + ) -> Work: + for tensor in tensors: + if not tensor.is_shared(): + tensor.share_memory_() + + return self._run_func("recv", tensors, src_rank, tag) + def size(self) -> int: return self._world_size diff --git a/torchft/rwlock.py b/torchft/rwlock.py new file mode 100644 index 0000000..8104f70 --- /dev/null +++ b/torchft/rwlock.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +""" rwlock.py + + Adapted from: https://github.com/tylerneylon/rwlock/blob/main/rwlock.py + + A class to implement read-write locks on top of the standard threading + library. + + This is implemented with two mutexes (threading.Lock instances) as per this + wikipedia pseudocode: + + https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Using_two_mutexes + + __________________________ + License info (MIT): + + ******* + + Copyright 2023 Tyler Neylon and contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit + persons to whom the Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + ******* +""" + + +# _______________________________________________________________________ +# Imports + +from contextlib import contextmanager +from threading import Lock + + +# _______________________________________________________________________ +# Class + + +class RWLock(object): + """RWLock class; this is meant to allow an object to be read from by + multiple threads, but only written to by a single thread at a time. See: + https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock + + All operations are timed and will throw TimeoutError if the timeout is + exceeded. + + Usage: + + from rwlock import RWLock + + my_obj_rwlock = RWLock(timeout=60.0) + + # When reading from my_obj: + with my_obj_rwlock.r_lock(): + do_read_only_things_with(my_obj) + + # When writing to my_obj: + with my_obj_rwlock.w_lock(): + mutate(my_obj) + """ + + def __init__(self, timeout: float = -1) -> None: + self.timeout = timeout + + self._w_lock = Lock() + self._num_r_lock = Lock() + self._num_r = 0 + + # ___________________________________________________________________ + # Reading methods. + + def r_acquire(self) -> None: + if not self._num_r_lock.acquire(timeout=self.timeout): + raise TimeoutError( + f"Timed out waiting for rlock after {self.timeout} seconds" + ) + + self._num_r += 1 + if self._num_r == 1: + if not self._w_lock.acquire(timeout=self.timeout): + self._num_r -= 1 + self._num_r_lock.release() + raise TimeoutError( + f"Timed out waiting for wlock after {self.timeout} seconds" + ) + + self._num_r_lock.release() + + def r_release(self) -> None: + assert self._num_r > 0 + self._num_r_lock.acquire() + self._num_r -= 1 + if self._num_r == 0: + self._w_lock.release() + self._num_r_lock.release() + + @contextmanager + def r_lock(self): + """This method is designed to be used via the `with` statement.""" + self.r_acquire() + try: + yield + finally: + self.r_release() + + # ___________________________________________________________________ + # Writing methods. + + def w_acquire(self) -> None: + if not self._w_lock.acquire(timeout=self.timeout): + raise TimeoutError( + f"Timed out waiting for wlock after {self.timeout} seconds" + ) + + def w_release(self) -> None: + self._w_lock.release() + + @contextmanager + def w_lock(self): + """This method is designed to be used via the `with` statement.""" + self.w_acquire() + try: + yield + finally: + self.w_release() + + def w_locked(self) -> bool: + """Returns True if the lock is currently locked for reading.""" + return self._w_lock.locked()