Skip to content

Commit

Permalink
Merge pull request #38 from OpenBMB/UPD_0708
Browse files Browse the repository at this point in the history
Upd 0708
  • Loading branch information
a710128 authored Jul 8, 2022
2 parents 3ed3e3c + 05a4804 commit ed5e1b7
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 20 deletions.
8 changes: 4 additions & 4 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, Iterator, Tuple, Union
from typing import Dict, Iterable, Iterator, Union


from .global_var import config
Expand All @@ -8,7 +8,6 @@
from .parameter import DistributedParameter, OpAllGather
from .checkpointing import ScopedTensorInspectorContext
from . import debug
from torch.nn.modules.module import _addindent
import copy

def round_up(x, d):
Expand Down Expand Up @@ -331,7 +330,8 @@ def __init__(self, inner_module : torch.nn.Module):

# calc total number of parameters
for name, param in ordered_parameters:
assert isinstance(param, DistributedParameter), "All parameters in checkpoint block must be DistributedParameter."
if not isinstance(param, DistributedParameter):
raise ValueError("All parameters in checkpoint block must be DistributedParameter.")

storage_type = storage_type_cuda(param.storage_type())
kw_name = _get_param_kw(param)
Expand Down Expand Up @@ -464,7 +464,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
# gather here
with torch.no_grad():
with CheckpointBlockContext(self):
return self._module.state_dict(destination, prefix, keep_vars)
return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
Expand Down
8 changes: 7 additions & 1 deletion bmtrain/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ def all_gather(x : torch.Tensor):
Returns:
torch.Tensor: The gathered tensor of shape (world_size, ...).
"""
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

assert x.is_cuda
return OpAllGather.apply(x)

class OpAllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, input : torch.Tensor, op : str):
if not input.contiguous():
if not input.is_contiguous():
input = input.contiguous()
if input.storage_offset() != 0 or input.storage().size() != input.numel():
input = input.clone()
Expand Down Expand Up @@ -82,6 +85,9 @@ def all_reduce(x : torch.Tensor, op : str = "sum"):
torch.Tensor: The reduced tensor of shape (...).
"""
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

assert x.is_cuda
return OpAllReduce.apply(x, op)

Expand Down
3 changes: 2 additions & 1 deletion bmtrain/global_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ class ConfigMap(TypedDict):
loss_scale_steps : int

gradient_inspect : bool
initialized : bool

comm : 'NCCLCommunicator'

config = ConfigMap()
config = ConfigMap(rank=0, local_rank=0, world_size=1, initialized=False)

def rank():
"""
Expand Down
5 changes: 4 additions & 1 deletion bmtrain/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .utils import print_dict
from .global_var import config
from . import nccl
import time
from .synchronize import synchronize
def init_distributed(
init_method : str = "env://",
Expand Down Expand Up @@ -57,6 +56,7 @@ def init_distributed(
store = dist.PrefixStore("bmtrain", store)
torch.cuda.set_device(local_rank)

config["initialized"] = True
config["local_rank"] = local_rank
config["local_size"] = local_size
config["rank"] = rank
Expand Down Expand Up @@ -110,3 +110,6 @@ def init_distributed(
"cpus": cpus_this_worker
})
synchronize()

def is_initialized() -> bool:
return config["initialized"]
3 changes: 3 additions & 0 deletions bmtrain/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def __new__(cls,
init_method : Optional[Callable[['DistributedParameter'], None]] = None,
group : Optional[str] = None
):
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

num_of_elements = data.numel()

cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda")
Expand Down
11 changes: 8 additions & 3 deletions bmtrain/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from .block_layer import CheckpointBlock
from . import nccl
import io, pickle
from typing import Mapping

def _save_to_state_dict(model : torch.nn.Module, destination, prefix):
if isinstance(model, CheckpointBlock):
if config['rank'] != 0:
destination = OrderedDict() # creates an temporary ordered dict
destination._metadata = OrderedDict()
model.state_dict(destination, prefix, False)
model.state_dict(destination=destination, prefix=prefix, keep_vars=False)
else:
if config['rank'] != 0:
destination = OrderedDict() # creates an temporary ordered dict
Expand Down Expand Up @@ -109,8 +110,8 @@ def broadcast_object(obj):
obj = _unpickler(io.BytesIO(buf)).load()
return obj


class DistributedStateDictWrapper:
# Must be a Mapping after pytorch 1.12.0
class DistributedStateDictWrapper(Mapping):
def __init__(self, state_dict : Dict) -> None:
self._state_dict = state_dict
self._metadata = broadcast_object(getattr(state_dict, "_metadata", None))
Expand Down Expand Up @@ -176,6 +177,10 @@ def __contains__(self, key : str):
def keys(self):
return broadcast_object(list(self._state_dict.keys()))

def __iter__(self):
# pytorch 1.12.0 updated the load_state_dict method, which needs the state_dict to be a `Mapping`.
return iter(self.keys())

def load(model : torch.nn.Module, file_name : str, strict : bool = True):
"""Loads the model from the file.
Expand Down
28 changes: 18 additions & 10 deletions bmtrain/synchronize.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import torch
from . import nccl
from . import distributed, nccl
from .global_var import config
import warnings

def synchronize():
"""
Synchronize all the workers across all nodes. (both CPU and GPU are synchronized)
"""
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

with torch.cuda.stream(config['barrier_stream']):
barrier = torch.cuda.FloatTensor([1])
nccl.allReduce(barrier.storage(), barrier.storage(), 'sum', config['comm'])
config['barrier_stream'].synchronize()

def wait_loader():
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

# wait lastest loader event, and set a new one
config['load_event'].synchronize()
config['calc_stream'].record_event(config['load_event'])
Expand All @@ -23,22 +30,23 @@ def sum_loss(loss : torch.Tensor):
This is a helper function to reduce the loss across all workers.
"""
ret = torch.empty_like(loss)
nccl.allReduce(
loss.storage(),
ret.storage(),
'avg',
config['comm']
)
return ret
warnings.warn("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", DeprecationWarning)
return distributed.all_reduce(loss, "avg")

def gather_result(result: torch.Tensor):
warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning)

output_cuda = True
if not result.is_cuda:
result = result.cuda()
output_cuda = False
ret = torch.empty((result.shape[0]*config['world_size'], *list(result.shape[1:])), device=result.device, dtype=result.dtype)
nccl.allGather(
result.storage(),
ret.storage(),
config['comm']
)
return ret
if output_cuda:
return ret
else:
return ret.cpu()

0 comments on commit ed5e1b7

Please sign in to comment.