diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5fe62ca0..969b5db8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,14 +6,15 @@ on: branches: - 'dev' - 'main' + push: + branches: + - 'dev' jobs: build-archive-wheel: uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main - secrets: - DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} - DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} + secrets: inherit publish: needs: build-archive-wheel diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index f4ac3642..fb114a84 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -1,4 +1,4 @@ -from .utils import print_block, print_dict, print_rank, see_memory, load_nccl_pypi +from .utils import print_block, print_dict, print_rank, print_rank_pp, see_memory, load_nccl_pypi try: from . import nccl except: @@ -10,11 +10,11 @@ from .layer import DistributedModule from .param_init import init_parameters, grouped_parameters from .synchronize import synchronize, sum_loss, wait_loader, gather_result -from .block_layer import Block, TransformerBlockList +from .block_layer import Block, TransformerBlockList, PipeDreamBlockList from .wrapper import BMTrainModelWrapper from .pipe_layer import PipelineTransformerBlockList from . import debug -from .store import save, load +from .store import save, load, clean from . import loss from . import distributed diff --git a/bmtrain/benchmark/all_reduce.py b/bmtrain/benchmark/all_reduce.py new file mode 100644 index 00000000..5a32db24 --- /dev/null +++ b/bmtrain/benchmark/all_reduce.py @@ -0,0 +1,27 @@ +from .. import nccl +from .shape import SHAPES +from ..global_var import config +from ..utils import round_up, print_rank +from .utils import format_size +import torch + +def all_reduce(): + current_stream = torch.cuda.current_stream() + for shape in SHAPES: + global_size = round_up(shape, config['world_size'] * 2) + + partition_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" ) + global_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" ) + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + current_stream.record_event(start_evt) + nccl.allReduce(partition_tensor.storage(), global_tensor.storage(),"sum", config['comm']) + current_stream.record_event(end_evt) + current_stream.synchronize() + time_usage = start_evt.elapsed_time(end_evt) + + bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage * 2 + print_rank("All reduce:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw)) + diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 98200465..afbc6ef3 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -9,6 +9,7 @@ from . import hook_func import inspect from torch.utils.checkpoint import checkpoint +from .distributed.ops import send_tensor_inplace, recv_tensor_inplace def storage_type_cuda(storage_type): STORAGE_MAP = { @@ -61,9 +62,10 @@ class Block(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3, initialized=False, mode="BLOCK"): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3, initialize_param=True, mode="BLOCK"): super().__init__() self._module = inner_module + self._module._in_block = True self._inputs = None self._layer_dict = {} self._forward_block_ctx = None @@ -84,7 +86,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev self.all_input_no_grad = False self.all_param_no_grad = False self._zero_level = zero_level - if not initialized: + if initialize_param: self.init_param_storage() def reference(self, block): @@ -106,13 +108,12 @@ def init_param_storage(self): storage_type = storage_type_cuda(param.storage_type()) kw_name = _get_param_kw(param) - if kw_name not in self._storage_info: - if self._mode == "PIPE" and param._tp_mode: + if (self._mode == "PIPE" or self._mode == "1F1B") and param._tp_mode: zero_comm = config["pp_tp_zero_comm"] - elif self._mode != "PIPE" and param._tp_mode: + elif (self._mode != "PIPE" and self._mode != "1F1B") and param._tp_mode: zero_comm = config["tp_zero_comm"] - elif self._mode == "PIPE" and not param._tp_mode: + elif (self._mode == "PIPE" or self._mode == "1F1B") and not param._tp_mode: zero_comm = config["pp_zero_comm"] else: zero_comm = config["zero_comm"] @@ -188,11 +189,10 @@ def init_param_storage(self): # copy values to buffer for normal parameter storage_st = self._storage_info[kw_name]["begin"] storage_end = self._storage_info[kw_name]["end"] - + comm = self._storage_info[kw_name]["zero_comm"] # make parameter contiguous in storage with torch.no_grad(): contiguous_param = OpAllGather.apply(param) - if not (param_st >= storage_end or param_end <= storage_st): # copy offset in parameter storage offset_st = max(storage_st - param_st, 0) @@ -207,13 +207,18 @@ def init_param_storage(self): # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype d_device = self._storage_params[kw_name].device - param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) self._param_info[-1]["begin"] = to_offset_st self._param_info[-1]["end"] = (to_offset_end - to_offset_st,) setattr(param, "_start_partition", offset_st) setattr(param, "_end_partition", offset_end) - param.data[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] + if nccl.commCount(comm) != 1: + param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) + param.data[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] + else: + param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, param_shape) + param.data[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, param_shape)[:] del contiguous_param else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) @@ -221,7 +226,6 @@ def init_param_storage(self): setattr(param, "_end_partition", 0) # clear parameter data, but keep the dtype and device setattr(param, "_in_block", True) - for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] @@ -279,8 +283,8 @@ def post_hook(self, out): return post_out def forward(self, *args): + arg_list = self.pre_hook(*args) - if self.all_input_no_grad and not self.all_param_no_grad: placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) @@ -314,8 +318,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def state_dict(self, destination=None, prefix='', keep_vars=False): # gather here with torch.no_grad(): - with ZeroContext(self): + if config['save_param_gather']: + with ZeroContext(self): + return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + else: return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -330,8 +338,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, tp_mode = param._tp_mode if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - - verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape) + if config['load_param_gather']: + verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape) + else: + verify_shape = param.shape if input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' @@ -353,24 +363,28 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, # copy to buffer verify_size = verify_shape.numel() assert input_param.numel() == verify_size - contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() + d_dtype = self._storage_params[kw_name].dtype + d_device = self._storage_params[kw_name].device + offset_st = max(storage_st - param_st, 0) + to_offset_st = offset_st + param_st - storage_st + if not config['load_param_gather']: + partition_numel= contiguous_param.numel() + torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (partition_numel,))[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), 0, (partition_numel,))[:] + continue tp_split_dim = param._tp_split_dim if tp_mode and tp_split_dim >= 0: contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim) - offset_st = max(storage_st - param_st, 0) offset_end = min(storage_end - param_st, contiguous_param.numel()) + to_offset_end = offset_end + param_st - storage_st assert offset_st < offset_end - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ - d_dtype = self._storage_params[kw_name].dtype - d_device = self._storage_params[kw_name].device torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] del contiguous_param @@ -460,8 +474,12 @@ def init_parameters(self): # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype d_device = self._storage_params[kw_name].device - param.data[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] + if nccl.commCount(self._storage_info[kw_name]["zero_comm"]) == 1: + param.data[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, it["shape"])[:] + else: + param.data[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] del tmp_tensor @@ -530,17 +548,18 @@ def eval(self): def __repr__(self): return self._module.__repr__() -def _block_wrapper(module, module_dict:dict, mode="BLOCK"): +def _block_wrapper(module, module_dict:dict, mode="BLOCK", **kwargs): if not isinstance(module, Block): - in_block = id(module) in module_dict - new_module = Block(module, initialized=in_block, mode=mode) - if in_block: - new_module.reference(module_dict[id(module)]) + if mode == "BLOCK": + in_block = id(module) in module_dict + new_module = Block(module, initialize_param=not in_block, mode=mode, **kwargs) + if in_block: + new_module.reference(module_dict[id(module)]) + elif mode == "PIPE" or mode == "1F1B": + new_module = Block(module, initialize_param=False, mode=mode, **kwargs) else: module_dict[id(module)] = new_module else: - if mode == "PIPE" and module._mode != "PIPE": - assert False, "You must be set mode=\"PIPE\" in bmt.Block when use PipelineTransformerBlockList!" if id(module._module) in module_dict: assert False, "Duplicate bmt.Block not supported in same block list!" else: @@ -569,25 +588,23 @@ class TransformerBlockList(torch.nn.Module): """ _modules: Dict[str, Block] - def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: + def __init__(self, modules: Iterable[Block], num_hidden=1, mode="BLOCK") -> None: super().__init__() self._modules = {} pre_module = None module_dict = {} - module_dict = {} for i, module in enumerate(modules): - module = _block_wrapper(module, module_dict) + module = _block_wrapper(module, module_dict, mode=mode) module.set_pre_module(pre_module) pre_module = module module._is_first_layer = False module._is_last_layer = False self._modules[str(i)] = module self.add_module(str(i), module) - self._modules[str(0)]._is_first_layer = True self._modules[str(len(modules)-1)]._is_last_layer = True - + self.module_dict = module_dict self.num_hidden = num_hidden def __len__(self) -> int: @@ -621,3 +638,146 @@ def forward(self, *args, return_hidden_states = False): return outputs + tuple(hidden_states) else: return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] + +def DummyForward(*args, **kwargs): + """ + Only useful for embedding and layernorm layer + """ + return args[0] + +class PipeDreamBlockList(TransformerBlockList): + + def __init__(self, modules: Iterable[Block], num_hidden=1, use_checkpoint=False) -> None: + module_dict = {} + mode = "1F1B" + if isinstance(use_checkpoint, bool): + use_checkpoint = [use_checkpoint for _ in range(len(modules))] + assert isinstance(use_checkpoint,Iterable) and len(use_checkpoint) == len(modules), "use_checkpoint should be a list of bool variable or a bool variable" + for idx in range(len(modules)): + modules[idx] = _block_wrapper(modules[idx], module_dict, mode=mode, zero_level=2, use_checkpoint=use_checkpoint[idx]) + s,e = self.partition(modules) + self.head_idx = s + self.tail_idx = e + partition_modules = [] + for idx,m in enumerate(modules): + if idx>=s and idx 1 else outputs[0] + + + def partition(self, modules): + pipe_size = config["topology"].pipe_size + pipe_rank = config["topology"].pipe_rank + part_lens = [0]+[len(modules) // pipe_size + (i < (len(modules) % pipe_size)) for i in range(pipe_rank+1)] + start = sum(part_lens[:pipe_rank+1]) + end = start + part_lens[pipe_rank+1] + return start,end + + def _add_head(self, module): + self.fisrt_module[0]._is_first_layer = False + module._is_first_layer = True + self.fisrt_module[0].set_pre_module(module) + self.fisrt_module = (module,) + + def add_head(self, module, use_checkpoint=False): + module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=use_checkpoint) + if config['topology'].pipe_rank != 0: + for name, param in module.named_parameters(): + c = OpAllGather.apply(param) + del param + return DummyForward + else: + module.init_param_storage() + self._add_head(module) + return module + + def get_first_layer(self): + return self._modules['0'] + + def get_last_layer(self): + return self._modules[str(len(self)-1)] + + def add_head_tail(self, module, use_checkpoint=False): + module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=use_checkpoint) + module.init_param_storage() + if config['topology'].pipe_rank != 0 and not config['topology'].is_last_rank(): + return DummyForward + else: + if config['topology'].pipe_rank == 0: + module._tied = "head" + self._add_head(module) + elif config['topology'].is_last_rank(): + module._tied = "tail" + self._add_tail(module) + self.tied_modules.append(module) + return module + + + def reduce_tied_module(self): + if config['topology'].pipe_rank != 0 and not config['topology'].is_last_rank(): + return + else: + for tied_m in self.tied_modules: + for name, param in tied_m.named_parameters(): + if config['topology'].pipe_rank == 0 and param.grad is not None: + with torch.no_grad(): + grad = torch.empty_like(param) + param.grad += recv_tensor_inplace(grad, 1, config["pipe_tied_comm"]) + send_tensor_inplace(param.grad, 1, config["pipe_tied_comm"]) + elif config['topology'].pipe_rank == 0 and param.grad is None: + grad = torch.empty_like(param) + param.grad = recv_tensor_inplace(grad, 1, config["pipe_tied_comm"]) + elif config['topology'].is_last_rank() and param.grad is not None: + send_tensor_inplace(param.grad, 0, config["pipe_tied_comm"]) + param.grad = recv_tensor_inplace(param.grad, 0, config["pipe_tied_comm"]) + + def _add_tail(self, module): + self.last_module[0]._is_last_layer = False + module._is_last_layer = True + module.set_pre_module(self.last_module[0]) + self.last_module = (module,) + + def add_tail(self, module, use_checkpoint=False): + module = _block_wrapper(module, self.module_dict, mode="1F1B", zero_level=2, use_checkpoint=use_checkpoint) + if config['topology'].pipe_rank != config['topology'].pipe_size - 1: + for name, param in module.named_parameters(): + c = OpAllGather.apply(param) + del param + return DummyForward + else: + module.init_param_storage() + self._add_tail(module) + return module diff --git a/bmtrain/distributed/__init__.py b/bmtrain/distributed/__init__.py index 84a4adf8..8049b351 100644 --- a/bmtrain/distributed/__init__.py +++ b/bmtrain/distributed/__init__.py @@ -1 +1 @@ -from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, reduce_scatter +from .ops import all_gather, all_reduce, broadcast, recv_tensor, send_tensor, groupcall, send_object, recv_object, reduce_scatter diff --git a/bmtrain/distributed/dtype.py b/bmtrain/distributed/dtype.py new file mode 100644 index 00000000..a2b18db1 --- /dev/null +++ b/bmtrain/distributed/dtype.py @@ -0,0 +1,12 @@ +import torch +DTYPE_LIST = [ + torch.float64, + torch.float32, + torch.float16, + torch.int64, + torch.int32, + torch.int16, + torch.int8, + torch.bfloat16, + torch.bool +] \ No newline at end of file diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index d1b489e2..3690cc0e 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -1,48 +1,13 @@ import torch -from ..global_var import config -from ..nccl import allGather as ncclAllGather, recv +import bmtrain as bmt +from ..global_var import config, rank +from ..nccl import allGather as ncclAllGather from ..nccl import allReduce as ncclAllReduce from ..nccl import broadcast as ncclBroadcast from ..nccl import reduceScatter as ncclReduceScatter -from ..nccl import send as ncclSend -from ..nccl import recv as ncclRecv -from ..nccl import commCount,commRank,NCCLCommunicator -DTYPE_LIST = [ - torch.float64, - torch.float32, - torch.float16, - torch.int64, - torch.int32, - torch.int16, - torch.int8, - torch.bfloat16, - torch.bool -] -def send_activations(hidden_state, next_rank, comm): - send_meta(hidden_state, next_rank, comm) - ncclSend(hidden_state.storage(), next_rank, comm) - -def recv_activations(prev_rank, comm): - dtype, shape = recv_meta(prev_rank, comm) - hidden_state = torch.empty(shape, dtype=dtype, device="cuda") - ncclRecv(hidden_state.storage(), prev_rank, comm) - return hidden_state - -def send_meta(x, next_rank, comm): - meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) - meta_data[0] = len(x.size()) - meta_data[1] = DTYPE_LIST.index(x.dtype) - meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int) - meta_data = meta_data.contiguous() - ncclSend(meta_data.storage(), next_rank, comm) - -def recv_meta(prev_rank, comm): - meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) - ncclRecv(meta_data.storage(), prev_rank, comm) - n_dims = meta_data[0].item() - dtype = DTYPE_LIST[meta_data[1].item()] - shape = meta_data[2:n_dims+2].tolist() - return dtype,shape +from ..nccl import commCount, commRank, NCCLCommunicator, groupStart, groupEnd +from .p2p_ops import * + class OpBroadcast(torch.autograd.Function): diff --git a/bmtrain/distributed/p2p_ops.py b/bmtrain/distributed/p2p_ops.py new file mode 100644 index 00000000..35f9c6ff --- /dev/null +++ b/bmtrain/distributed/p2p_ops.py @@ -0,0 +1,159 @@ +import torch +from bmtrain import config +from ..nccl import reduceScatter as ncclReduceScatter +from ..nccl import send as ncclSend +from ..nccl import recv as ncclRecv +from ..nccl import groupStart,groupEnd +from .dtype import DTYPE_LIST +import pickle +import contextlib + +_p2p_stream = {} +_p2p_events = {} + +@contextlib.contextmanager +def groupcall(): + groupStart() + yield + groupEnd() +class handler: + def __init__(self, event): + self.event= event + + def wait(self): + torch.cuda.current_stream().wait_event(self.event) + +def send_object(obj, peer_rank, comm): + data_bytes: bytes = pickle.dumps(obj) + data_length: int = len(data_bytes) + + gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long) + ncclSend(gpu_data_length.storage(), peer_rank, comm) + byte_storage = torch.ByteStorage.from_buffer(data_bytes).cuda() + ncclSend(byte_storage, peer_rank, comm) + +def recv_object(peer_rank, comm): + data_length = torch.tensor([0], device="cuda", dtype=torch.long) + ncclRecv(data_length.storage(), peer_rank, comm) + data_bytes_stor = torch.cuda.ByteStorage(data_length.item()) + ncclRecv(data_bytes_stor, peer_rank, comm) + tensor = torch.ByteTensor(data_bytes_stor.cpu()) + data = pickle.loads(tensor.numpy().tobytes()) + return data + +def record_stream_helper(tensor_list, stream): + for t in tensor_list: + t.record_stream(stream) + +def send_tensors(tensor_list, peer_rank, comm): + handler = _send_tensors(tensor_list, peer_rank, comm) + handler.wait() + +def isend_tensor(tensor_list, peer_rank, comm): + return _send_tensors(tensor_list, peer_rank, comm) + +def _send_tensors(tensor_list, peer_rank, comm): + p2p_key = f"send {peer_rank}" + if p2p_key not in _p2p_stream: + _p2p_stream[p2p_key] = torch.cuda.Stream() + if p2p_key not in _p2p_events: + _p2p_events[p2p_key] = torch.cuda.Event() + stream = _p2p_stream[p2p_key] + event = _p2p_events[p2p_key] + event.record(torch.cuda.current_stream()) + stream.wait_event(event) + with torch.cuda.stream(stream): + length = torch.tensor(data=[len([h for h in tensor_list ])], device="cuda", dtype=torch.int) + flags = torch.tensor(data=[0 for _ in range(len(tensor_list))], device="cuda",dtype=torch.int) + for i in range(len(tensor_list)): + if tensor_list[i] is None: + flag = -1 + elif torch.is_tensor(tensor_list[i]): + flag = 0 + else: + flag = 1 + flags[i] = flag + ncclSend(length.storage(), peer_rank, comm) + ncclSend(flags.contiguous().storage(), peer_rank, comm) + for i in range(len(tensor_list)): + if flags[i] == 0: + tensor_list[i].record_stream(stream) + send_tensor(tensor_list[i], peer_rank, comm) + elif flags[i] == 1: + send_object(tensor_list[i], peer_rank, comm) + event.record(stream) + return handler(event) + +def recv_tensors(peer_rank, comm): + tensors, handle = _recv_tensors(peer_rank, comm) + handle.wait() + return tensors + +def irecv_tensors(peer_rank, comm): + tensors, handle = _recv_tensors(peer_rank, comm) + return tensors, handle + +def _recv_tensors(peer_rank, comm): + p2p_key = f"recv {peer_rank}" + if p2p_key not in _p2p_stream: + _p2p_stream[p2p_key] = torch.cuda.Stream() + if p2p_key not in _p2p_events: + _p2p_events[p2p_key] = torch.cuda.Event() + stream = _p2p_stream[p2p_key] + event = _p2p_events[p2p_key] + with torch.cuda.stream(stream): + length = torch.tensor(data=[0], device="cuda", dtype=torch.int) + tensor_list = [] + ncclRecv(length.storage(), peer_rank, comm) + flags = torch.tensor(data=[0 for _ in range(length)], device="cuda",dtype=torch.int) + ncclRecv(flags.storage(), peer_rank, comm) + for i in range(length[0].item()): + flag = flags[i].item() + if flag == -1: + tensor_list.append(None) + elif flag == 0: + recv = recv_tensor(peer_rank, comm) + tensor_list.append(recv) + elif flag == 1: + recv = recv_object(peer_rank, comm) + tensor_list.append(recv) + event.record(stream) + record_stream_helper([tensor_list[i] for i in range(length[0].item()) if flags[i].item() != -1], torch.cuda.current_stream()) + return tensor_list, handler(event) + +def send_tensor(hidden_state, peer_rank, comm): + hidden_state = hidden_state.contiguous() + send_meta(hidden_state, peer_rank, comm) + ncclSend(hidden_state.storage(), peer_rank, comm) + +def send_tensor_inplace(hidden_state, peer_rank, comm): + hidden_state = hidden_state.contiguous() + ncclSend(hidden_state.storage(), peer_rank, comm) + +def recv_tensor_inplace(hidden_state, peer_rank, comm): + hidden_state = hidden_state.contiguous() + ncclRecv(hidden_state.storage(), peer_rank, comm) + return hidden_state + +def recv_tensor(peer_rank, comm): + dtype, shape = recv_meta(peer_rank, comm) + hidden_state = torch.empty(shape, dtype=dtype, device="cuda") + ncclRecv(hidden_state.storage(), peer_rank, comm) + return hidden_state + +def send_meta(x, peer_rank, comm): + meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) + meta_data[0] = len(x.size()) + meta_data[1] = DTYPE_LIST.index(x.dtype) + meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int) + meta_data = meta_data.contiguous() + ncclSend(meta_data.storage(), peer_rank, comm) + +def recv_meta(peer_rank, comm): + meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) + ncclRecv(meta_data.storage(), peer_rank, comm) + n_dims = meta_data[0].item() + dtype = DTYPE_LIST[meta_data[1].item()] + shape = meta_data[2:n_dims+2].tolist() + + return dtype,shape \ No newline at end of file diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 2c6108b0..a69aaa6f 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -4,10 +4,13 @@ def zero_pre_forward(module, inputs): enter = True - pipe = False - if module._mode == "PIPE": - enter = module._micro_idx == 0 - pipe = True + if module._mode == "PIPE" or module._mode == "1F1B": + if not hasattr(module, "_micro_forward_idx") or module._micro_forward_idx == -1: + module._micro_forward_idx = 0 + enter = True + else: + enter = False + module._micro_forward_idx += 1 if enter: zero_level = module._zero_level forward_flag = 1 if zero_level == 2 else 0 @@ -15,40 +18,61 @@ def zero_pre_forward(module, inputs): forward_flag = 2 # repeating forward in same layer if module.all_param_no_grad: #only forward forward_flag = 0 - module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe) - module._forward_block_ctx.enter(forward_flag) + if module._mode == "1F1B": + module._block_ctx = ZeroContext(module, module._layer_dict) + module._block_ctx.enter(0, requires_grad=True) + else: + module._forward_block_ctx = ZeroContext(module, module._layer_dict) + module._forward_block_ctx.enter(forward_flag) def zero_post_forward(module, inputs, outputs): forward_flag = 1 if module._zero_level == 2 else 0 if module.all_param_no_grad: forward_flag = 0 exit = True - if module._mode == "PIPE": - exit = module._micro_idx == config['micros'] - 1 + if module._mode == "PIPE" or module._mode == "1F1B": + if module._micro_forward_idx == config["micros"] - 1: + module._micro_forward_idx = -1 + if module._mode == "1F1B": + exit = False + else: + exit = True + else: + exit = False if exit: module._forward_block_ctx.exit(forward_flag) def zero_pre_backward(module, grad_outputs): backward_flag = 2 if module._zero_level == 2 else 0 - if module._mode != "PIPE": + if module._mode != "PIPE" and module._mode != "1F1B": module._backward_block_ctx = ZeroContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) module.release_next_module(backward_flag) else: - if module._micro_idx == config['micros'] - 1: - module._backward_block_ctx = ZeroContext(module, module._layer_dict, pipe=True) - module._backward_block_ctx.enter(backward_flag, True) + if not hasattr(module, "_micro_backward_idx") or module._micro_backward_idx == -1: + if module._mode == "1F1B": + module._micro_backward_idx = 0 + else: + module._micro_backward_idx = 0 + module._backward_block_ctx = ZeroContext(module, module._layer_dict) + module._backward_block_ctx.enter(backward_flag,requires_grad=True) + else: + module._micro_backward_idx += 1 def zero_post_backward(module, grad_inputs, grad_outputs): backward_flag = 2 if module._zero_level == 2 else 0 - if module._mode != "PIPE": + if module._mode != "PIPE" and module._mode != "1F1B": if module._is_first_layer: module.release(backward_flag) else: - if module._micro_idx == 0: - module.release(backward_flag) - module._micro_idx -= 1 + if module._micro_backward_idx == config["micros"] - 1: + if module._mode == "1F1B": + module._block_ctx.exit(0, backward=True) + config['load_stream'].record_event(config['load_event']) + else: + module.release(backward_flag) + module._micro_backward_idx = -1 class OneStepNoGradFunc(torch.autograd.Function): """ diff --git a/bmtrain/init.py b/bmtrain/init.py index 69273c09..605eeff5 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -9,11 +9,10 @@ from . import nccl from .synchronize import synchronize - def init_distributed( init_method : str = "env://", seed : int = 0, - pipe_size: int = -1, + pipe_size: int = 1, num_micro_batches: int = None, tp_size : int = 1, ): @@ -66,7 +65,7 @@ def init_distributed( torch.cuda.set_device(local_rank) config["initialized"] = True config["pipe_size"] = pipe_size if pipe_size > 0 else 1 - config["pipe_enabled"] = pipe_size > 0 + config["pipe_enabled"] = pipe_size > 1 config["local_rank"] = local_rank config["local_size"] = local_size config["rank"] = rank @@ -79,10 +78,13 @@ def init_distributed( config["load_event"] = torch.cuda.Event() config["tp_size"] = tp_size if tp_size > 0 else 1 config["topology"] = topology(config) + config["pipe_rank"] = config['topology'].get_group_rank("pipe") config["zero_rank"] = config['topology'].get_group_rank("zero") config["tp_rank"] = config['topology'].get_group_rank("tp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") config["save_param_to_cpu"] = True + config["save_param_gather"] = True + config["load_param_gather"] = True cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) @@ -113,40 +115,45 @@ def init_distributed( config['comm'] = nccl.commInitRank(unique_id, world_size, rank) topo = config['topology'] + config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] + if topo.pipe_rank == 0: + unique_id = nccl.getUniqueId() + store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) + config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.pipe_rank) if config['pipe_enabled']: - config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] - if topo.stage_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) - unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) - config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) - - if topo.pp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode()) - config['pp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['pipe_size'], topo.pp_zero_id) - - if config['tp_size'] > 1: - if topo.tp_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) - unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) - config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) - - if topo.tp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) - config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['tp_size'], topo.tp_zero_id) - - - if config['pipe_size'] > 1 and config['tp_size'] > 1: - if topo.pp_tp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()) - config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) + if topo.pipe_rank == topo.pipe_size - 1 or topo.pipe_rank == 0: + if topo.pipe_rank == 0: + unique_tied_id = nccl.getUniqueId() + store.set(f"PIPE_TIED_UNIQUE_ID{topo.pipe_idx}", unique_tied_id.hex()) + unique_tied_id = bytes.fromhex(store.get(f"PIPE_TIED_UNIQUE_ID{topo.pipe_idx}").decode()) + rank = 0 if topo.pipe_rank == 0 else 1 + config['pipe_tied_comm'] = nccl.commInitRank(unique_tied_id, 2, rank) + + if topo.pp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode()) + config['pp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['pipe_size'], topo.pp_zero_id) + + if topo.tp_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) + config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['tp_size'], topo.tp_zero_id) + + + if topo.pp_tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()) + config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) config ['zero_comm'] = config['comm'] @@ -175,28 +182,28 @@ def __init__(self,config): dp_size = world_size // (pp_size * tp_size) config['tp_zero_size'] = dp_size config['zero_size'] = world_size // pp_size - self.stages = config['pipe_size'] - + self.pipe_size = config['pipe_size'] + self.dp_size = dp_size + self.tp_size = tp_size stage_size = world_size // pp_size for i in range(world_size): self.pipe_idx = self.rank % stage_size - self.stage_id = self.rank // stage_size + self.pipe_rank = self.rank // stage_size self.tp_id = self.rank % tp_size self.tp_idx = self.rank // tp_size #pp->zero - self.pp_zero_idx = self.stage_id + self.pp_zero_idx = self.pipe_rank self.pp_zero_id = self.pipe_idx #tp->zero self.tp_zero_idx = self.tp_id self.tp_zero_id = self.tp_idx #pp->tp->zero - self.pp_tp_zero_idx = self.stage_id * tp_size + self.tp_id + self.pp_tp_zero_idx = self.pipe_rank * tp_size + self.tp_id self.pp_tp_zero_id = self.pipe_idx // tp_size #only zero self.zero_idx = 0 self.zero_id = self.rank - def get_group_id(self,group_name): if group_name == "pipe": return self.pipe_idx @@ -209,7 +216,7 @@ def get_group_id(self,group_name): def get_group_rank(self,group_name): if group_name == "pipe": - return self.stage_id + return self.pipe_rank elif group_name == "zero": return self.zero_id elif group_name == "tp_zero": @@ -217,6 +224,22 @@ def get_group_rank(self,group_name): elif group_name == "tp": return self.tp_id + def is_first_rank(self, group_name="pipe"): + if group_name == "pipe": + return self.pipe_rank == 0 + elif group_name == "zero": + return self.zero_id == 0 + elif group_name == "tp": + return self.tp_id == 0 + + def is_last_rank(self, group_name="pipe"): + if group_name == "pipe": + return self.pipe_rank == self.pipe_size - 1 + elif group_name == "zero": + return self.zero_id == self.dp_size - 1 + elif group_name == "tp": + return self.tp_id == self.tp_size - 1 + def is_initialized() -> bool: return config["initialized"] diff --git a/bmtrain/inspect/tensor.py b/bmtrain/inspect/tensor.py index 2c45fdac..0b5a9ae4 100644 --- a/bmtrain/inspect/tensor.py +++ b/bmtrain/inspect/tensor.py @@ -39,8 +39,8 @@ def _set_summary(self, summary): kw = f'{item["prefix"]}{item["name"]}' assert item["inside_pipe"] is not None - stage_id = item["inside_pipe"]["stage_id"] - stages = item["inside_pipe"]["stages"] + pipe_rank = item["inside_pipe"]["pipe_rank"] + pipe_size = item["inside_pipe"]["pipe_size"] st = item["inside_pipe"]["st"] ed = item["inside_pipe"]["ed"] @@ -52,8 +52,8 @@ def _set_summary(self, summary): if ed: break - for stage in range(stages): - if stage_id == stage: + for stage in range(pipe_size): + if pipe_rank == stage: broadcast_object(pipe_cnt, config["pipe_comm"], src = stage) for k in range(i, j): item = summary[k] @@ -76,7 +76,7 @@ def _set_summary(self, summary): "tensor": tensor, "grad": grad, "requires_grad": item["requires_grad"], - "inside_pipe": {"stage_id": stage}, + "inside_pipe": {"pipe_rank": stage}, }) kw_cnt[kw] += 1 else: @@ -99,7 +99,7 @@ def _set_summary(self, summary): "tensor": None, "grad": None, "requires_grad": None, - "inside_pipe": {"stage_id": stage}, + "inside_pipe": {"pipe_rank": stage}, }) kw_cnt[kw] += 1 @@ -114,23 +114,23 @@ def _set_summary(self, summary): "requires_grad": it["requires_grad"], "has_grad": has_grad, } - broadcast_object(info, config["pipe_comm"], src = it["inside_pipe"]["stage_id"]) + broadcast_object(info, config["pipe_comm"], src = it["inside_pipe"]["pipe_rank"]) tensor = it["tensor"] - tensor = broadcast(tensor, it["inside_pipe"]["stage_id"], config["pipe_comm"]) + tensor = broadcast(tensor, it["inside_pipe"]["pipe_rank"], config["pipe_comm"]) grad = it["grad"] else: - info = broadcast_object({}, config["pipe_comm"], src = it["inside_pipe"]["stage_id"]) + info = broadcast_object({}, config["pipe_comm"], src = it["inside_pipe"]["pipe_rank"]) has_grad = info.pop("has_grad") it.update(info) tensor = torch.empty(it["shape"]).cuda().requires_grad_() - tensor = broadcast(tensor, it["inside_pipe"]["stage_id"], config["pipe_comm"]) + tensor = broadcast(tensor, it["inside_pipe"]["pipe_rank"], config["pipe_comm"]) if has_grad: grad = torch.empty(it["shape"]).cuda() - tensor = tensor.chunk(stages, dim=0)[stage_id].clone() + tensor = tensor.chunk(pipe_size, dim=0)[pipe_rank].clone() it["tensor"] = tensor if has_grad: - grad = broadcast(grad, it["inside_pipe"]["stage_id"], config["pipe_comm"]) - grad = grad.chunk(stages, dim=0)[stage_id].clone() + grad = broadcast(grad, it["inside_pipe"]["pipe_rank"], config["pipe_comm"]) + grad = grad.chunk(pipe_size, dim=0)[pipe_rank].clone() tensor.grad = grad it["shape"] = (it["shape"][0]//config["pipe_size"],) + it["shape"][1:] diff --git a/bmtrain/layer.py b/bmtrain/layer.py index e071e01b..8dee7167 100644 --- a/bmtrain/layer.py +++ b/bmtrain/layer.py @@ -33,10 +33,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): for name, param in self._parameters.items(): if param is not None: if isinstance(param, DistributedParameter):#and not param._in_block: - if param._in_block: - destination[prefix + name] = param.tp_gather().detach() # sync operation + if config["save_param_gather"]: + if param._in_block: + destination[prefix + name] = param.tp_gather().detach() # sync operation + else: + destination[prefix + name] = param.gather_all().detach() # sync operation else: - destination[prefix + name] = param.gather_all().detach() # sync operation + destination[prefix + name] = param.clone().detach() # sync operation if config['save_param_to_cpu']: destination[prefix + name] = destination[prefix + name].cpu() else: @@ -110,14 +113,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 'the shape in current model is {}.' .format(key, input_param.shape, param.shape)) continue - verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape) + if config['load_param_gather']: + verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape) + else: + verify_shape = param.shape if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' .format(key, input_param.shape, verify_shape)) try: with torch.no_grad(): - if isinstance(param, DistributedParameter): + if isinstance(param, DistributedParameter) and config['load_param_gather']: tp_split_dim = param._tp_split_dim if tp_mode and tp_split_dim >= 0: input_param = tp_split_tensor(input_param, tp_split_dim) diff --git a/bmtrain/lr_scheduler/warmup.py b/bmtrain/lr_scheduler/warmup.py index 0f08a600..c690a154 100644 --- a/bmtrain/lr_scheduler/warmup.py +++ b/bmtrain/lr_scheduler/warmup.py @@ -16,7 +16,7 @@ def __init__(self, optimizer : torch.optim.Optimizer, start_lr, warmup_iter, end self.warmup_iter = warmup_iter self.end_iter = end_iter self.optimizer = optimizer - self.num_iter = num_iter + self.num_iter = num_iter self._current_lr = None self.step(self.num_iter) diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 7aa1bb81..e45daa23 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -52,6 +52,7 @@ def __init__(self, loss_scale_steps : int = 1024, min_loss_scale = 1, max_loss_scale = float("inf"), + grad_scale : Optional[int] = None, ): if loss_scale is not None: self.loss_scale = loss_scale @@ -64,6 +65,9 @@ def __init__(self, self.loss_scale_steps = loss_scale_steps self.min_loss_scale = min_loss_scale self.max_loss_scale = max_loss_scale + if grad_scale is None: + grad_scale = config['zero_size'] + self.grad_scale = grad_scale self.optimizers = [] self.lr_schedulers = [] @@ -85,7 +89,7 @@ def add_optimizer( def scale_loss(self, loss : torch.Tensor) -> torch.Tensor: - return loss * (self.loss_scale / (config['world_size']//(config['tp_size']*config['pipe_size']))) # loss scale + return loss * ( self.loss_scale / self.grad_scale ) # loss scale def backward(self, loss : torch.Tensor): """ @@ -132,6 +136,12 @@ def step(self): self.zero_grad() return for optimizer, lr_scheduler in zip(self.optimizers, self.lr_schedulers): + try: + check_overflow(optimizer.param_groups) + except OverflowError: + has_overflow = True + print_rank("Gradient overflow, change scale from %lf to %lf" % (self.loss_scale, self.loss_scale / self.loss_scale_factor)) + break if hasattr(optimizer, "_bmtrain_optimizer") and optimizer._bmtrain_optimizer: optimizer.step(scale=self.loss_scale) else: diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index a46c7845..7d188a28 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -1,4 +1,4 @@ -from typing import Generator, Iterable, List, Tuple +from typing import Generator, Iterable, List, Tuple, Union import torch from .block_layer import Block from .parameter import DistributedParameter @@ -42,20 +42,22 @@ def iterate_parameters(model : torch.nn.Module): return [] yield val -def init_parameters(model : torch.nn.Module): +def init_parameters(models : Union[List[torch.nn.Module], torch.nn.Module]): """ Initialize the parameters of the model by calling the init_method of the distributed parameters. """ - - modules = model.named_modules() - for module_prefix, module in modules: - if isinstance(module, Block): - module.init_parameters() - else: - init_distributed_parameter( iterate_parameters(module) ) - - current_stream = torch.cuda.current_stream() - config['load_stream'].wait_stream(current_stream) + if not isinstance(models, list): + models = [models] + for model in models: + modules = model.named_modules() + for module_prefix, module in modules: + if isinstance(module, Block): + module.init_parameters() + else: + init_distributed_parameter( iterate_parameters(module) ) + + current_stream = torch.cuda.current_stream() + config['load_stream'].wait_stream(current_stream) def grouped_parameters(model : torch.nn.Module) -> Generator[Tuple[str, List[torch.nn.Parameter]], None, None]: """ diff --git a/bmtrain/pipe/__init__.py b/bmtrain/pipe/__init__.py new file mode 100644 index 00000000..410e3437 --- /dev/null +++ b/bmtrain/pipe/__init__.py @@ -0,0 +1,2 @@ +from .schedule import pipeline_forward_backward +from .store import load_model_pipe, save_model_pipe \ No newline at end of file diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py new file mode 100644 index 00000000..d65f0f1d --- /dev/null +++ b/bmtrain/pipe/comm.py @@ -0,0 +1,114 @@ +import torch +from bmtrain.distributed.ops import groupcall,all_reduce +from bmtrain.distributed.p2p_ops import send_tensors, recv_tensors +from bmtrain.global_var import config +from collections.abc import Iterable +from bmtrain.synchronize import synchronize +class PipeCommander: + def __init__(self, topo, model, data_iter, num_micros, num_warmup, forward_only, interleaving_size) -> None: + self.topo = topo + self.comm = config['pipe_comm'] + self.input_generator = data_iter + self.num_micros = num_micros + self.num_warmup = num_warmup + self.forward_only = forward_only + self.interleaving_size = interleaving_size + self.model = model + + def is_first_stage(self): + if self.interleaving_size == 1: + return self.topo.is_first_rank("pipe") + else: + raise ValueError("Now only supoort interleaving_size == 1") + + def is_last_stage(self): + if self.interleaving_size == 1: + return self.topo.is_last_rank("pipe") + else: + raise ValueError("Now only supoort interleaving_size == 1") + + + def param_reduce(self, module): + for name, param in module.named_parameters(): + p = all_reduce(param, "sum", config["pipe_tied_comm"]) + param.data = p + + def get_data(self): + micro_batch = next(self.input_generator) + assert isinstance(micro_batch, Iterable) + return micro_batch + + def send_next(self, tensors): + if not self.is_last_stage(): + if not isinstance(tensors, Iterable): + tensors = [tensors] + elif not isinstance(tensors, list): + tensors = list(tensors) + send_tensors(tensors, self.topo.pipe_rank + 1, self.comm) + + def send_prev(self, tensors): + if not self.is_first_stage(): + if not isinstance(tensors, Iterable): + tensors = [tensors] + elif not isinstance(tensors, list): + tensors = list(tensors) + send_tensors(tensors, self.topo.pipe_rank - 1, self.comm) + + def wait(self): + torch.cuda.current_stream().wait_stream(config["pp_comm_stream"]) + + def recv_prev(self, need_data=False): + if not self.is_first_stage(): + res = recv_tensors(self.topo.pipe_rank - 1, self.comm) + for idx,tensor in enumerate(res): + if idx == 0: + tensor.requires_grad_() + data = self.get_data() + # return hidden state and data + return res, data + else: + if need_data: + # for first stage , only data + return [None], self.get_data() + else: + # empty load for first stage + return [None], [None] + + def recv_next(self): + if not self.is_last_stage(): + res = recv_tensors(self.topo.pipe_rank + 1, self.comm) + return res + else: + return [None] + + def allocate_tensor(self, shape, dtype): + return torch.empty(shape, dtype=dtype, device="cuda") + + def is_even_rank(self): + return self.topo.pipe_rank % 2 == 0 + + def send_forward_recv_backward(self, forward_state): + if not self.is_last_stage(): + if forward_state[0] is not None: + self.send_next(forward_state) + backward_grad = self.recv_next() + else: + backward_grad = [None] + return backward_grad + + def send_backward_recv_forward(self, backward_grad, need_data=False): + if not self.is_first_stage(): + forward_state, data = self.recv_prev() + if backward_grad[0] is not None: + self.send_prev(backward_grad) + else: + if need_data: + forward_state = [None] + data = self.get_data() + else: + forward_state = [None] + data = [None] + return forward_state, data + + + diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py new file mode 100644 index 00000000..29299fb9 --- /dev/null +++ b/bmtrain/pipe/schedule.py @@ -0,0 +1,170 @@ +import sys +from bmtrain.global_var import config +import bmtrain as bmt +from .comm import PipeCommander +import torch +import logging +from typing import Iterable + + +def backward_func(inp, backward_step, output, grad_output, optim_manager=None): + """Backward step through passed-in output tensor. + + If last stage, output_tensor_grad is None, otherwise gradient of loss + with respect to stage's output tensor. + + Returns gradient of loss with respect to input tensor (None if first + stage).""" + + if not isinstance(inp, list) : + inp = [inp] + for x in inp: + if x is not None and (torch.is_tensor(x) and x.requires_grad): + x.retain_grad() + if not isinstance(output, Iterable): + output = [output] + if not isinstance(grad_output, Iterable): + grad_output = [grad_output] + backward_step(output[0], grad_output[0]) + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(bmt.config['load_stream']) + input_grad = [None] + if inp is not None: + input_grad = [] + for x in inp: + if x is None or (not torch.is_tensor(x)) or (not x.requires_grad): + input_grad.append(None) + else: + input_grad.append(x.grad) + + return input_grad + +def forward_func(model, forward_step, inp, data, micro_idx, is_last_micro=False): + output = forward_step(model, inp[0], data) + if not isinstance(output, list) and not isinstance(output, tuple): + output = [output] + return output + +def get_logger(rank, level, print_to_screen=False): + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + logger = logging.getLogger('pipeline') + logger.setLevel(level) + if print_to_screen: + if rank == 0: + ch = logging.StreamHandler() + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) + fh = logging.FileHandler(f'pipe_{rank}.log',mode="w") + fh.setLevel(level) + fh.setFormatter(formatter) + logger.addHandler(fh) + return logger + +def pipeline_forward_backward(model, data_iterator, forward_step, backward_step, micro_batch_size, num_micros, debug_log=False): + """Forward and backward the pipeline model. + + Args: + models (TransformerBlocklist): The list of models. + data_iterator (iterator): The iterator of the dataset. + forward_step(function): Describe how to forward the model and how to get loss + micro_batch_size (int): The micro batch size. + + Returns: + torch.Tensor: The loss of the model. + """ + + # forwrad unpack + loss = None + if 'logger' not in config: + if debug_log: + config['logger'] = get_logger(bmt.config['pipe_rank'], level="INFO", print_to_screen=True) + else: + config['logger'] = logging.getLogger("dummy") + config['logger'].addHandler(logging.NullHandler()) + + micro_batch_size = micro_batch_size + num_micro_batches = num_micros + global_batch_size = micro_batch_size * num_micro_batches + assert (num_micro_batches) % config["pipe_size"] == 0, "The number of micro batches must be divisible by the pipeline size" + config["micros"] = num_micro_batches + topo = config["topology"] + logger = config['logger'] + logger.info("topo: {}".format(topo)) + logger.info("num_micro_batches: {}".format(num_micro_batches)) + logger.info("micro_batch_size: {}".format(micro_batch_size)) + logger.info("global_batch_size: {}".format(global_batch_size)) + # construct Pipe Commander + forward_only = False + logger.info("forward_only: {}".format(forward_only)) + if forward_only: + num_warmup = num_micro_batches + else: + num_warmup = topo.pipe_size - topo.pipe_rank - 1 + interleaving_size = 1 + commander = PipeCommander(topo,model=model, data_iter=data_iterator, num_micros=num_micro_batches,\ + num_warmup=num_warmup, forward_only=False, \ + interleaving_size=interleaving_size \ + ) + inps = [] + outputs = [] + logger.info("num_warmup: {}".format(num_warmup)) + for micro in range(num_warmup): + inp, data = commander.recv_prev(need_data=True) + logger.info("{} recv micro {}th from prev neighbour".format(bmt.config["topology"].pipe_rank, micro)) + output = forward_func(model, forward_step, inp, data, micro) + logger.info("{} micro forward".format(micro)) + # send activations + commander.send_next(output) + logger.info("{} send micro {}th to next neighbour".format(bmt.config["topology"].pipe_rank, micro)) + if not forward_only: + inps.append(inp) + outputs.append(output) + remain_batch = num_micro_batches - num_warmup + logger.info("remain_batch: {}".format(remain_batch)) + if remain_batch > 0: + inp, data = commander.recv_prev(need_data=True) + logger.info("recv micro from prev neighbour") + for micro in range(num_micro_batches - num_warmup): + is_last_micro = micro == num_micro_batches - num_warmup - 1 + output = forward_func(model, forward_step, inp, data, micro + num_warmup, is_last_micro) + if commander.is_last_stage(): + loss = output[0] + logger.info("{} micro forward".format(micro+num_warmup)) + grad_output = commander.send_forward_recv_backward(output) + + inps.append(inp) + outputs.append(output) + + logger.info("{} send micro hidden state {}th to next neighbour and recv micro grad {} from next neighbour".format(bmt.config["topology"].pipe_rank, micro + num_warmup, micro)) + + inp = inps.pop(0) + output = outputs.pop(0) + + inp_grad = backward_func(inp, backward_step, output, grad_output) + logger.info("{} micro backward".format(micro+num_warmup)) + if micro == remain_batch - 1: + inp = None + commander.send_prev(inp_grad) + logger.info("{} send micro grad {}th to prev neighbour".format(bmt.config["topology"].pipe_rank, micro + num_warmup)) + else: + logger.info("send backward and recv forward") + inp, data = commander.send_backward_recv_forward(inp_grad, need_data=True) + if not forward_only: + logger.info("cooling stage") + for i in range(num_warmup): + logger.info("{} recv micro grad {}th from next neighbour".format(bmt.config["topology"].pipe_rank, num_micro_batches - num_warmup + i)) + inp = inps.pop(0) + output = outputs.pop(0) + grad_output = commander.recv_next() + logger.info("{} micro backward".format(num_micro_batches - num_warmup + i)) + input_grad = backward_func( + inp, backward_step, output , grad_output, + ) + logger.info("{} send micro grad {}th to prev neighbour".format(bmt.config["topology"].pipe_rank, i)) + + commander.send_prev(input_grad) + blocklist = model.get_blocklist() + # blocklist.reduce_tied_module() + + diff --git a/bmtrain/pipe/store.py b/bmtrain/pipe/store.py new file mode 100644 index 00000000..bbfc6ab9 --- /dev/null +++ b/bmtrain/pipe/store.py @@ -0,0 +1,67 @@ +import bmtrain as bmt +import torch +import re +from collections import OrderedDict + +def partition(pipe_rank, pipe_size, len_modules): + part_lens = [0]+[(len_modules // pipe_size + (i < (len_modules % pipe_size))) for i in range(pipe_rank+1)] + start = sum(part_lens[:pipe_rank+1]) + end = start + part_lens[pipe_rank+1] + return start,end + +def key_process(key, pipe_size , rank, start, end): + res = re.search("\.(\d+)\.", key) + if res is not None: + layer_idx = int(res.group(1)) + else: + layer_idx = None + if layer_idx is None or (layer_idx >= start and layer_idx < end): + if layer_idx is not None: + return re.sub("\.(\d+)\.", "."+str(layer_idx - start)+".", key) + else: + return key + +def get_len_modules(state): + max_len = 0 + for key in state: + s = re.search("\.(\d+)\.", key) + if s is not None: + res = int(s.group(1)) + if res>max_len: + max_len = res + return max_len+1 + +def get_state_dict_pipe(path): + pipe_size = bmt.config["pipe_size"] + pipe_rank = bmt.config["pipe_rank"] + + if bmt.rank() == 0: + ds_state_dict = bmt.store.DistributedStateDictWrapper(torch.load(path)) + else: + ds_state_dict = bmt.store.DistributedStateDictWrapper({}) + + len_modules = get_len_modules(ds_state_dict) + s,e = partition(pipe_rank, pipe_size, len_modules) + state_dict = OrderedDict() + + for key in ds_state_dict: + param = ds_state_dict[key].broadcast() + k_p = key_process(key, pipe_size, pipe_rank, s, e) + if k_p is not None: + state_dict[k_p] = param + else: + del param + return state_dict + +def load_model_pipe(model, path, load_whole=False): + """ + load_whole: Boolean, if True, load from the whole model file, else load model from the pipeline/tensor parallel model file + """ + if load_whole: + state_dict = get_state_dict_pipe(path) + model.load_state_dict(state_dict, strict=False) + else: + bmt.load(model, path, load_gather=False) + +def save_model_pipe(model, path): + bmt.save(model, path, save_gather=False) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 4d3b17ad..5954f9b4 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -5,7 +5,7 @@ from typing import Dict, Iterable, Iterator, Tuple, Union, List import torch -from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations +from .distributed import all_gather, broadcast, all_reduce, send_tensor, recv_tensor from .global_var import config from . import nccl from .zero_context import ( @@ -64,16 +64,16 @@ def backward(ctx, grads, arg_grads): if requires_grad: grad = torch.cat([ctx.args_list[m][idx].grad for m in range(num_micros)], dim=0) grad = all_reduce(grad, "sum", config["pipe_comm"]) - split_size = topo.stages if ctx.batch_related[idx] else num_micros + split_size = topo.pipe_size if ctx.batch_related[idx] else num_micros grad = grad.chunk(split_size) if ctx.batch_related[idx]: - arg_grads.append(grad[topo.stage_id]) + arg_grads.append(grad[topo.pipe_rank]) else: arg_grads.append(grad[0]) else: arg_grads.append(None) arg_grads.append(None) #for append(batch_related) - return grads.chunk(topo.stages, dim=0)[topo.stage_id], *arg_grads + return grads.chunk(topo.pipe_size, dim=0)[topo.pipe_rank], *arg_grads class PipePostFunction(torch.autograd.Function): @staticmethod @@ -81,24 +81,24 @@ def forward(ctx, last_hidden, hidden_states=None, forward_stage_ranges=None, bac topo = config['topology'] ctx.return_hidden_states = return_hidden_states last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) - last_hidden = last_hidden.chunk(topo.stages, dim=0) - output = last_hidden[topo.stage_id] + last_hidden = last_hidden.chunk(topo.pipe_size, dim=0) + output = last_hidden[topo.pipe_rank] output.requires_grad_() if return_hidden_states: - ctx.stage_id = topo.stage_id - ctx.stages = topo.stages + ctx.pipe_rank = topo.pipe_rank + ctx.pipe_size = topo.pipe_size ctx.backward_stage_ranges = backward_stage_ranges middle_hiddens = [] - for stage_id in range(ctx.stages): - if ctx.stage_id == stage_id: + for pipe_rank in range(ctx.pipe_size): + if ctx.pipe_rank == pipe_rank: middle_hidden = hidden_states else: - middle_shape = (forward_stage_ranges[stage_id],) + last_hidden_shape + middle_shape = (forward_stage_ranges[pipe_rank],) + last_hidden_shape middle_hidden = torch.zeros(middle_shape, device=hidden_states.device, dtype=hidden_states.dtype) - middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"]) - middle_hidden = middle_hidden.chunk(ctx.stages, dim=1) - middle_hidden = middle_hidden[ctx.stage_id].clone() + middle_hidden = broadcast(middle_hidden, pipe_rank, config["pipe_comm"]) + middle_hidden = middle_hidden.chunk(ctx.pipe_size, dim=1) + middle_hidden = middle_hidden[ctx.pipe_rank].clone() middle_hiddens.append(middle_hidden) middle_hiddens = torch.cat(middle_hiddens, dim=0) middle_hiddens.requires_grad_() @@ -112,12 +112,12 @@ def backward(ctx, grads, grad_middle=None): grad_list = grad_list.flatten(start_dim=0, end_dim=1) if ctx.return_hidden_states: - for stage_id in range(ctx.stages): - layer_range = ctx.backward_stage_ranges[stage_id] + for pipe_rank in range(ctx.pipe_size): + layer_range = ctx.backward_stage_ranges[pipe_rank] grad_middle_state = grad_middle[layer_range] grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"]) grad_middle_state = grad_middle_state.flatten(start_dim=0, end_dim=1).transpose(0, 1) - if ctx.stage_id == stage_id: + if ctx.pipe_rank == pipe_rank: grad_hidden_state_list = grad_middle_state return grad_list, grad_hidden_state_list, None, None, None, None else: @@ -125,12 +125,12 @@ def backward(ctx, grads, grad_middle=None): class StagePreFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, stage_id): - ctx.stage_id = stage_id - ctx.is_first_stage = stage_id == 0 - ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + def forward(ctx, input, pipe_rank): + ctx.pipe_rank = pipe_rank + ctx.is_first_stage = pipe_rank == 0 + ctx.is_last_stage = pipe_rank == config['pipe_size'] - 1 if not ctx.is_first_stage: - input = recv_activations(stage_id - 1, config['pipe_comm']) + input = recv_tensor(pipe_rank - 1, config['pipe_comm']) input.requires_grad_() return input return input @@ -143,28 +143,28 @@ def backward(ctx, grad_outputs): with torch.cuda.stream(config['pp_comm_stream']): config['pp_comm_stream'].wait_stream(current_stream) send_data.record_stream(config['pp_comm_stream']) - send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) + send_tensor(send_data, ctx.pipe_rank - 1, config['pipe_comm']) return grad_outputs, None class StagePostFunction(torch.autograd.Function): @staticmethod - def forward(ctx, outputs, stage_id): - ctx.stage_id = stage_id - ctx.is_first_stage = stage_id == 0 - ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + def forward(ctx, outputs, pipe_rank): + ctx.pipe_rank = pipe_rank + ctx.is_first_stage = pipe_rank == 0 + ctx.is_last_stage = pipe_rank == config['pipe_size'] - 1 if not ctx.is_last_stage: send_data = outputs[0] if isinstance(outputs, tuple) else outputs current_stream = torch.cuda.current_stream() with torch.cuda.stream(config['pp_comm_stream']): config['pp_comm_stream'].wait_stream(current_stream) send_data.record_stream(config['pp_comm_stream']) - send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) + send_tensor(send_data.detach(), pipe_rank + 1, config['pipe_comm']) return outputs @staticmethod def backward(ctx, grad_outputs): if not ctx.is_last_stage: - pre_grad_inputs = recv_activations(ctx.stage_id + 1, config['pipe_comm']) + pre_grad_inputs = recv_tensor(ctx.pipe_rank + 1, config['pipe_comm']) return pre_grad_inputs, None return grad_outputs, None @@ -196,8 +196,8 @@ def __init__(self, modules: Iterable[torch.nn.Module], num_hidden=1) -> None: self._modules = {} self.layer_ids = [] topo = config["topology"] - self.stages = topo.stages - self.stage_id = topo.stage_id + self.pipe_size = topo.pipe_size + self.pipe_rank = topo.pipe_rank self.pipe_idx = topo.pipe_idx module_dict = {} for idx, module in enumerate(modules): @@ -205,7 +205,7 @@ def __init__(self, modules: Iterable[torch.nn.Module], num_hidden=1) -> None: module._zero_level = 2 #currently, only support ZeRO-2 in pipeline mode self._modules[str(idx)] = module - self.layer_ids = self.get_range_by_stage_id(self.stage_id) + self.layer_ids = self.get_range_by_pipe_rank(self.pipe_rank) pre_module = None for i,layer_id in enumerate(self.layer_ids): @@ -242,14 +242,14 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): micro_hidden_states = [] - hidden_state = StagePreFunction.apply(hidden_state, self.stage_id) + hidden_state = StagePreFunction.apply(hidden_state, self.pipe_rank) for idx,layer_id in enumerate(self.layer_ids): - self._modules[str(layer_id)]._micro_idx = micro_idx + # self._modules[str(layer_id)]._micro_idx = micro_idx if return_hidden_states: micro_hidden_states.append(hidden_state) hidden_state = self._modules[str(layer_id)](hidden_state, *arg) - hidden_state = StagePostFunction.apply(hidden_state, self.stage_id) + hidden_state = StagePostFunction.apply(hidden_state, self.pipe_rank) outputs.append(hidden_state) if return_hidden_states: @@ -262,27 +262,27 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa hidden_states = torch.cat(hidden_states, dim=1) forward_stage_ranges = [] backward_stage_ranges = [] - for stage_id in range(self.stages): - forward_stage_ranges.append(self.get_part_len_by_stage_id(stage_id)) - backward_stage_ranges.append(self.get_range_by_stage_id(stage_id)) + for pipe_rank in range(self.pipe_size): + forward_stage_ranges.append(self.get_part_len_by_pipe_rank(pipe_rank)) + backward_stage_ranges.append(self.get_range_by_pipe_rank(pipe_rank)) outputs, hidden_states = PipePostFunction.apply(last_hidden, hidden_states, forward_stage_ranges, backward_stage_ranges, last_hidden_shape, return_hidden_states) return outputs, hidden_states else: outputs = PipePostFunction.apply(last_hidden) return outputs - def get_range_by_stage_id(self, stage_id : int) -> List[int]: - part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)] - start = sum(part_lens[:stage_id+1]) - end = start + part_lens[stage_id+1] + def get_range_by_pipe_rank(self, pipe_rank : int) -> List[int]: + part_lens = [0]+[self.get_part_len_by_pipe_rank(i) for i in range(pipe_rank+1)] + start = sum(part_lens[:pipe_rank+1]) + end = start + part_lens[pipe_rank+1] return range(start, end) - def get_part_len_by_stage_id(self, stage_id : int) -> int: - return len(self) // self.stages + (stage_id < (len(self) % self.stages)) + def get_part_len_by_pipe_rank(self, pipe_rank : int) -> int: + return len(self) // self.pipe_size + (pipe_rank < (len(self) % self.pipe_size)) def get_stage_by_layer_id(self, layer_id : int) -> int: - part_len = len(self) // self.stages - rest = len(self) % self.stages + part_len = len(self) // self.pipe_size + rest = len(self) % self.pipe_size if layer_id // (part_len + 1) < rest: return layer_id // (part_len + 1) else: @@ -307,8 +307,8 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): else: assert list(dst.keys()) == [name+n for n, parameter in module._module.named_parameters()] for key, tensor in dst.items(): - send_activations(tensor.cuda(), 0, config['pipe_comm']) + send_tensor(tensor.cuda(), 0, config['pipe_comm']) if config['rank'] == 0 and idx not in self.layer_ids: for n, parameter in module._module.named_parameters(): - destination[name+n] = recv_activations(self.get_stage_by_layer_id(idx), config['pipe_comm']).cpu() + destination[name+n] = recv_tensor(self.get_stage_by_layer_id(idx), config['pipe_comm']).cpu() diff --git a/bmtrain/store.py b/bmtrain/store.py index 2a3ee02c..488fb0b6 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -1,7 +1,6 @@ from collections import OrderedDict from typing import Dict import torch - from .pipe_layer import PipelineTransformerBlockList from .block_layer import TransformerBlockList from .global_var import config @@ -11,6 +10,7 @@ from typing import Mapping import threading import bmtrain as bmt +import os def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix): if isinstance(model, Block): @@ -24,6 +24,21 @@ def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix): destination._metadata = OrderedDict() model._save_to_state_dict(destination, prefix, False) +def _save_to_each_rank(model : torch.nn.Module, destination=None, prefix=''): + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) + _save_to_state_dict(model, 0, destination, prefix) + for name, module in model._modules.items(): + if module is not None: + _save_to_each_rank(module, destination, prefix + name + '.') + for hook in model._state_dict_hooks.values(): + hook_result = hook(model, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''): if destination is None: destination = OrderedDict() @@ -88,7 +103,7 @@ def async_save_to_file(state_dict, file_path): config['finish_save'] = True print("finish save state_dict to ", file_path) -def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): +def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False, save_gather : bool=True): """Saves the model to the file. Similar to torch.save, but it used for distributed modules. @@ -103,8 +118,16 @@ def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): >>> bmtrain.save(model, "model.pt") """ torch.cuda.synchronize() - state_dict = _save_to_rank0(model) - if config["rank"] == 0: + + if save_gather: + save_method = _save_to_rank0 + else: + save_method = _save_to_each_rank + file_name = f"{file_name}_rank_{bmt.rank()}" + tmp = bmt.config['save_param_gather'] + bmt.config['save_param_gather'] = save_gather + state_dict = save_method(model) + if config["rank"] == 0 or not save_gather: if non_blocking is False: torch.save(state_dict, file_name) else: @@ -118,6 +141,9 @@ def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name)) config['save_thread'].start() bmt.synchronize() + bmt.config['save_param_gather'] = tmp + + DTYPE_LIST = [ torch.float64, @@ -299,7 +325,7 @@ def __iter__(self): # pytorch 1.12.0 updated the load_state_dict method, which needs the state_dict to be a `Mapping`. return iter(self.keys()) -def load(model : torch.nn.Module, file_name : str, strict : bool = True): +def load(model : torch.nn.Module, file_name : str, strict : bool = True, load_gather : bool = True): """Loads the model from the file. Similar to torch.load, but it uses less memory when loading large models. @@ -312,14 +338,39 @@ def load(model : torch.nn.Module, file_name : str, strict : bool = True): Example: >>> bmtrain.load(model, "model.pt", strict=True) """ - if config['rank'] == 0: - state_dict = DistributedStateDictWrapper(torch.load(file_name)) + tmp = config['load_param_gather'] + config['load_param_gather'] = load_gather + if load_gather: + if config['rank'] == 0: + state_dict = DistributedStateDictWrapper(torch.load(file_name)) + else: + state_dict = DistributedStateDictWrapper({}) else: - state_dict = DistributedStateDictWrapper({}) + if "rank" not in file_name: + file_name = f"{file_name}_rank_{bmt.rank()}" + state_dict = torch.load(file_name) ret = model.load_state_dict( state_dict, strict = strict ) + config['load_param_gather'] = tmp torch.cuda.synchronize() return ret + +def clean(file_name : str): + """Cleans the file. + + Args: + file_name (str): The file name of the checkpoint. + + Example: + >>> bmtrain.clean("model.pt") + """ + if bmt.rank() == 0: + parent = os.path.dirname(os.path.abspath(file_name)) + for f in os.listdir(parent): + if f.startswith(file_name): + os.remove(os.path.join(parent, f)) + + diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index d562cc21..9be6ccad 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -2,17 +2,18 @@ from . import distributed, nccl from .global_var import config import warnings +from typing import Optional -def synchronize(): +def synchronize(comm=None): """ Synchronize all the workers across all nodes. (both CPU and GPU are synchronized) """ if not config["initialized"]: raise RuntimeError("BMTrain is not initialized") - + comm = config['comm'] if comm is None else comm with torch.cuda.stream(config['barrier_stream']): barrier = torch.cuda.FloatTensor([1]) - nccl.allReduce(barrier.storage(), barrier.storage(), 'sum', config['comm']) + nccl.allReduce(barrier.storage(), barrier.storage(), 'sum', comm) config['barrier_stream'].synchronize() def wait_loader(): @@ -24,14 +25,17 @@ def wait_loader(): config['calc_stream'].record_event(config['load_event']) -def sum_loss(loss : torch.Tensor): +def sum_loss(loss : torch.Tensor, comm: Optional[nccl.NCCLCommunicator] = None): """ Sum the loss across all workers. This is a helper function to reduce the loss across all workers. """ + if comm is None: + comm = config['comm'] warnings.warn("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", DeprecationWarning) - return distributed.all_reduce(loss, "sum") / config['world_size'] + + return distributed.all_reduce(loss, "avg", comm) def gather_result(result: torch.Tensor): warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning) diff --git a/bmtrain/utils.py b/bmtrain/utils.py index 8cb87808..bca462e2 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -66,7 +66,11 @@ def print_block(title : str, content : Optional[str] = None, file=sys.stdout): print("=" * left_title + " " + title + " " + "=" * right_title, file=file) if content is not None: print(content, file=file) - + +def print_rank_pp(*args, pipe_rank=0, **kwargs): + if config['topology'].pipe_rank == pipe_rank: + print(*args, **kwargs) + def print_rank(*args, rank=0, **kwargs): """ Prints the message only on the `rank` of the process. diff --git a/bmtrain/zero_context.py b/bmtrain/zero_context.py index 653f40fa..94c0112d 100644 --- a/bmtrain/zero_context.py +++ b/bmtrain/zero_context.py @@ -4,7 +4,7 @@ from .synchronize import wait_loader class ZeroContext: - def __init__(self, block : 'Block', ctx_dict : dict = None, pipe = False) -> None: + def __init__(self, block : 'Block', ctx_dict : dict = None) -> None: self.block = block self.ctx_dict = ctx_dict self._param_buffer = {} @@ -16,6 +16,10 @@ def __init__(self, block : 'Block', ctx_dict : dict = None, pipe = False) -> Non def enter(self, flag=0, requires_grad=False): """ gather parameters + flags = 0: normal mode + flags = 1: gather param and not release , then save in ctx_dict + flags = 2: not gather param and use the param in ctx_dict + """ if self.block._ready: return @@ -26,6 +30,8 @@ def enter(self, flag=0, requires_grad=False): with torch.cuda.stream(config["load_stream"]): for kw, val in self.block._storage_info.items(): assert self.block._storage_params[kw].is_cuda + if val["world_size"] == 1: + continue assert kw not in self._grad_buffer assert kw not in self._param_buffer local_param = self.block._storage_params[kw] @@ -41,6 +47,8 @@ def enter(self, flag=0, requires_grad=False): if flag != 2: nccl.groupStart() for kw, val in self.block._storage_info.items(): + if val["world_size"] == 1: + continue nccl.allGather( self.block._storage_params[kw].storage(), self._param_buffer[kw], @@ -53,6 +61,8 @@ def enter(self, flag=0, requires_grad=False): # set wait stream for each storage for kw in self.block._storage_info.keys(): + if self.block._storage_info[kw]['world_size'] == 1: + continue if flag != 2: self._param_tensor[kw].record_stream(current_stream) if requires_grad and kw in self._grad_tensor: @@ -64,6 +74,9 @@ def enter(self, flag=0, requires_grad=False): offset = param["offset"] shape = param["shape"] + if self.block._storage_info[kw_name]["world_size"] == 1: + continue + if flag != 2: dtype = self._param_buffer[kw_name].dtype device = self._param_buffer[kw_name].device @@ -90,8 +103,11 @@ def exit(self, flag=0, backward=False): self.block._ready = False if backward: for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] + if val['world_size'] == 1: + continue + + local_param = self.block._storage_params[kw] # accumulate previous gradient if local_param.requires_grad: if local_param.grad is None: @@ -106,8 +122,11 @@ def exit(self, flag=0, backward=False): with torch.cuda.stream(config["load_stream"]): nccl.groupStart() for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] + if val["world_size"] == 1: + continue + + local_param = self.block._storage_params[kw] # scatter gradient if local_param.requires_grad: nccl.reduceScatter( @@ -127,6 +146,8 @@ def exit(self, flag=0, backward=False): # Release all parameters from buffer to block_storge for param in self.block._param_info: kw_name = param["kw_name"] + if self.block._storage_info[kw_name]["world_size"] == 1: + continue dtype = self.block._storage_params[kw_name].dtype device = self.block._storage_params[kw_name].device if "begin" not in param: diff --git a/docs/UPDATE_0.2.0.md b/docs/UPDATE_0.2.0.md index 92819afd..5ee04639 100644 --- a/docs/UPDATE_0.2.0.md +++ b/docs/UPDATE_0.2.0.md @@ -70,7 +70,7 @@ layers = bmt.PipelineTransformerBlockList([ ``` Replacing TransformerBlockList with PipelineTransformerBlockList allows the parallel algorithm to switch from ZeRO to pipeline parallelism. -The number of stages in the pipeline can be set by passing the `pipe_size` parameter to bmtrain.init_distributed. +The number of pipe_size in the pipeline can be set by passing the `pipe_size` parameter to bmtrain.init_distributed. ### 3. Others diff --git a/example/init_test.py b/example/init_test.py new file mode 100644 index 00000000..67d62581 --- /dev/null +++ b/example/init_test.py @@ -0,0 +1,2 @@ +import bmtrain +bmtrain.init_distributed(sp_size=4,tp_size=2) diff --git a/example/layers/embedding.py b/example/layers/embedding.py index f62151c4..9cbc6715 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -77,13 +77,11 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: - out = F.embedding( + return F.embedding( input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) - return out + self.norm_type, self.scale_grad_by_freq, self.sparse) else: - out = F.linear(input, self.weight) - return out + return F.linear(input, self.weight) / math.sqrt(self.embedding_dim) def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}' @@ -99,4 +97,4 @@ def extra_repr(self) -> str: s += ', sparse=True' return s.format(**self.__dict__) - + \ No newline at end of file diff --git a/example/models/__init__.py b/example/models/__init__.py index e7d1dcc9..a17709b8 100644 --- a/example/models/__init__.py +++ b/example/models/__init__.py @@ -1 +1,2 @@ -from .gpt import GPT \ No newline at end of file +from .gpt import GPT +from .pipe_gpt import GPTPipe \ No newline at end of file diff --git a/example/models/gpt.py b/example/models/gpt.py index ed604382..feb2bd59 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -3,6 +3,7 @@ from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder from bmtrain.global_var import config + class GPT(bmt.DistributedModule): def __init__(self, num_layers : int, vocab_size : int, diff --git a/example/models/pipe_gpt.py b/example/models/pipe_gpt.py new file mode 100644 index 00000000..acf7aa20 --- /dev/null +++ b/example/models/pipe_gpt.py @@ -0,0 +1,64 @@ +import torch +import bmtrain as bmt +from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder +from bmtrain.global_var import config + +class GPTPipe(bmt.DistributedModule): + def __init__(self, + num_layers : int, vocab_size : int, + dim_model : int, dim_head : int, num_heads : int, dim_ff : int, + max_distance : int, + bias : bool = True, dtype = None + ) -> None: + super().__init__() + + self.max_distance = max_distance + + if config['tp_size'] > 1: + word_emb = bmt.nn.VPEmbedding(vocab_size, dim_model, dtype=dtype) + else: + word_emb = Embedding(vocab_size, dim_model, dtype=dtype) + pos_emb = Embedding(max_distance, dim_model, dtype=dtype) + blocklist = [] + blocklist += [ + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype + ) + for _ in range(num_layers)] + layernorm = Layernorm(dim_model, dtype=dtype) + self.transformers = bmt.PipeDreamBlockList( + blocklist, + ) + self.pos_emb = self.transformers.add_head(pos_emb) + self.layernorm = self.transformers.add_tail(layernorm) + self.word_emb = self.transformers.add_head_tail(word_emb) + + def get_blocklist(self): + return self.transformers + + def forward(self, + input : torch.LongTensor, # (batch, seq_len) + pos : torch.LongTensor, # (batch, seq_len) + mask : torch.BoolTensor, # (batch, seq_len) + ) -> torch.Tensor: + mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len) + mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) + + + # for layer in self.transformers: + out = self.transformers(input, mask_2d, None) + if bmt.config['topology'].is_last_rank(): + out = self.layernorm(out) + out = self.word_emb(out, True) + return out + + def preprocess_func(self, inp): + if config['topology'].pipe_rank == 0: + inp_id = inp[0] + pos = inp[1] + out = self.pos_emb(pos) + self.word_emb(inp_id) + return out + else: + return None + + diff --git a/example/pipe.sh b/example/pipe.sh new file mode 100644 index 00000000..ea07e82a --- /dev/null +++ b/example/pipe.sh @@ -0,0 +1 @@ +torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost pipe_train.py diff --git a/example/pipe_train.py b/example/pipe_train.py new file mode 100644 index 00000000..28c7ad1c --- /dev/null +++ b/example/pipe_train.py @@ -0,0 +1,131 @@ +import torch +import bmtrain as bmt +from models import GPTPipe +import time +from bmtrain import optim +from bmtrain.global_var import config +from bmtrain import inspect +from bmtrain.pipe import pipeline_forward_backward +from bmtrain.pipe import load_model_pipe, save_model_pipe +from typing import Iterable + +def main(): + bmt.init_distributed( + seed=0, + pipe_size=4, + tp_size=1, + ) + + model = GPTPipe( + num_layers=8, + vocab_size=10240, + dim_model=2560, + dim_head=80, + num_heads=32, + dim_ff=8192, + max_distance=1024, + bias=True, + dtype=torch.float16 + ) + bmt.init_parameters(model) + bmt.print_rank("Model memory") + bmt.print_rank(torch.cuda.memory_summary()) + bmt.synchronize() + # test save/load + save_model_pipe(model, "pipe.pt") + load_model_pipe(model, "pipe.pt") + + # data + # generate dummy data for each rank + torch.manual_seed(1234) + micro = 2 + num_micros = 16 + batch_size = micro * num_micros + seq_len = 512 + def data_loader(): + for i in range(1000): + micro = 2 + sent = torch.randint(0, 10240, (micro, seq_len + 1)) + enc_length = torch.randint(128, seq_len, (micro,)).long().cuda() + enc_input = sent[:, :-1].long().cuda() + targets = sent[:, 1:].long().cuda() + mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] + targets = torch.where( + mask, + targets, + torch.full_like(targets, -100, dtype=torch.long) + ) + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + yield enc_input, pos, mask, targets + + if config['tp_size'] > 1: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + else: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) + lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) + + optim_manager = optim.OptimManager(loss_scale=2**20) + optim_manager.add_optimizer(optimizer, lr_scheduler) + pipe_rank = bmt.config["topology"].pipe_rank + bmt.synchronize() + avg_time_recorder = bmt.utils.AverageRecorder() + avg_loss_recorder = bmt.utils.AverageRecorder() + + def forward_step(model, input, data): + enc_input, pos, mask, targets = data + input = model.preprocess_func((enc_input, pos)) if bmt.config["topology"].is_first_rank() else input + logits = model(input, pos, mask) + if bmt.config["topology"].is_last_rank(): + logits = logits.view(-1, logits.shape[-1]) + targets = targets.view(-1) + loss = loss_func(logits, targets) + nonlocal global_loss_items + global_loss = bmt.distributed.all_reduce(loss, comm=bmt.config["pp_tp_zero_comm"]).item() + global_loss_items.append(global_loss) + return loss, logits + else: + return logits + + def backward_step(output, grad_output): + if bmt.config['topology'].is_last_rank(): + output = optim_manager.scale_loss(output) + output = output / bmt.config['micros'] + torch.autograd.backward(output, grad_tensors=grad_output) + + + + for iteration in range(10): + # load data + global_loss_items = [] + st = time.time() + rank = bmt.config["topology"].pipe_rank + # global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) + optim_manager.zero_grad() + pipeline_forward_backward(model, data_loader(), forward_step, backward_step, micro , num_micros) + grad_norm = optim_manager.clip_grad_norm(optim_manager.optimizers[0].param_groups, 1.0, norm_type=2) + optim_manager.step() + bmt.synchronize() + # record time and loss + iteration_time = time.time() - st + + if bmt.config["topology"].is_last_rank(): + global_loss = sum(list(global_loss_items))/len(global_loss_items) + avg_time_recorder.record(iteration_time) + avg_loss_recorder.record(global_loss) + print( + "| Iter: {:6d} | loss: {:.10f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( + iteration, + global_loss, + avg_loss_recorder.value, + lr_scheduler.current_lr, + optim_manager.loss_scale, + avg_time_recorder.value + ) + ) + + + +if __name__ == '__main__': + main() diff --git a/example/run.sh b/example/run.sh index 542e5252..8a66db20 100644 --- a/example/run.sh +++ b/example/run.sh @@ -1,3 +1 @@ -export NCCL_P2P_DISABLE=1 -export CUDA_LAUNCH_BLOCKING=1 -torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost train.py +torchrun --nnodes=1 --nproc_per_node=1 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost train.py diff --git a/example/train.py b/example/train.py index d5906a06..6e9694a6 100644 --- a/example/train.py +++ b/example/train.py @@ -9,7 +9,7 @@ def main(): bmt.init_distributed( seed=0, - tp_size=2, + tp_size=1, ) model = GPT( @@ -23,37 +23,18 @@ def main(): bias=True, dtype=torch.half ) - bmt.init_parameters(model) - bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() - # data # generate dummy data for each rank torch.manual_seed(1234) - batch_size = 2 seq_len = 512 - world_size = bmt.config["world_size"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_size"] - r = bmt.config["rank"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_rank"] + batch = 2 + grad_accum = 1 - for i in range(world_size): - sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) - enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() - enc_input = sent[:, :-1].long().cuda() - targets = sent[:, 1:].long().cuda() - mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] - targets = torch.where( - mask, - targets, - torch.full_like(targets, -100, dtype=torch.long) - ) - - if i == r: - break - if config['tp_size'] > 1: loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) else: @@ -69,12 +50,23 @@ def main(): avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - - for iteration in range(1000): + for iteration in range(10): # load data st = time.time() + sum_loss = 0 + for micro in range(grad_accum): + for i in range(bmt.world_size()): + sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) + enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() + enc_input = sent[:, :-1].long().cuda() + targets = sent[:, 1:].long().cuda() + mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] + targets = torch.where( + mask, + targets, + torch.full_like(targets, -100, dtype=torch.long) + ) - with inspect.inspect_tensor() as inspector: pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) logits = model( enc_input, @@ -87,40 +79,29 @@ def main(): loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) else: loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) - - global_loss = bmt.sum_loss(loss).item() - - optim_manager.zero_grad() - + global_loss = loss.item() optim_manager.backward(loss) - + sum_loss += global_loss # print inspected tensors in the forward & backward pass # print parameters of the model if iteration % 100 == 0: - bmt.print_rank( - inspect.format_summary( - inspector.get_summary() - ) - ) bmt.print_rank( inspect.format_summary( inspect.inspect_model(model, "*") ) ) - optim_manager.step() - + optim_manager.zero_grad() # record time and loss iteration_time = time.time() - st avg_time_recorder.record(iteration_time) - avg_loss_recorder.record(global_loss) - + avg_loss_recorder.record(sum_loss / grad_accum) # print time and loss bmt.print_rank( - "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( + "| Iter: {:6d} | loss: {:.10f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( iteration, - global_loss, + sum_loss / grad_accum, avg_loss_recorder.value, lr_scheduler.current_lr, optim_manager.loss_scale, diff --git a/tests/test_load_ckpt.py b/tests/test_load_ckpt.py index 0eb4f95f..7de6b590 100644 --- a/tests/test_load_ckpt.py +++ b/tests/test_load_ckpt.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import bmtrain as bmt import os +from collections import OrderedDict class Linear_Normal(torch.nn.Module): def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: @@ -36,25 +37,30 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp def forward(self, input): return F.linear(input, self.weight, self.bias) +def test_save_load(m): + bmt.save(m, "test.pt", non_blocking=False) + bmt.load(m, "test.pt") + bmt.save(m, "test.pt", non_blocking=True) + bmt.load(m, "test.pt") + bmt.save(m, "test.pt", non_blocking=False, save_gather=True) + bmt.load(m, "test.pt", load_gather=True) + bmt.clean("test.pt") + def test_main(): - ckpt_path = "test_ckpt.pt" # Transformer BlockList m = Linear_Normal(256, 256).cuda() m2 = bmt.TransformerBlockList([bmt.Block(Linear_BMT(256, 256))]) - if bmt.rank() == 0: - torch.save(m.state_dict(), ckpt_path) - dic2 = m.state_dict() - dic2["0.weight"] = dic2.pop("weight") - dic2["0.bias"] = dic2.pop("bias") - m2.load_state_dict(dic2) + m2_state = m.state_dict().copy() + m2_state["0.weight"] = m2_state.pop("weight") + m2_state["0.bias"] = m2_state.pop("bias") + test_save_load(m2) + m2.load_state_dict(m2_state) for key in m.state_dict(): bmt_key = f"0.{key}" assert bmt_key in m2.state_dict(), "wrong key in bmtrain model" assert (m2.state_dict()[bmt_key].cuda() == m.state_dict()[key]).all() , "wrong param in bmtrain model" - if bmt.rank() == 0: - os.remove(ckpt_path) - print("Transformer Blocklist load_state_dict and state_dict test passed") + print("Transformer Blocklist load_state_dict ,state_dict, bmt.load/save test passed") # Block m3 = bmt.Block(Linear_BMT(256, 256)) @@ -62,7 +68,8 @@ def test_main(): for key in m.state_dict(): assert key in m3.state_dict(), "wrong key in bmtrain model" assert (m.state_dict()[key] == m3.state_dict()[key].cuda()).all(), "wrong param in bmtrain model" - print("Block load_state_dict and state_dict test passed") + test_save_load(m2) + print("Block load_state_dict ,state_dict, bmt.load/save test passed") # normal Distributed module m4 = Linear_BMT(256, 256) @@ -70,7 +77,8 @@ def test_main(): for key in m.state_dict(): assert key in m4.state_dict(), "wrong key in bmtrain model" assert (m.state_dict()[key] == m4.state_dict()[key].cuda()).all(), "wrong param in bmtrain model" - print("bmt.distributedmodule load_state_dict and state_dict test passed") + test_save_load(m2) + print("bmt.distributedmodule load_state_dict, state_dict, bmt.load/save test passed") if __name__ == "__main__": bmt.init_distributed() diff --git a/tests/test_send_recv.py b/tests/test_send_recv.py index f933b0c2..009be56e 100644 --- a/tests/test_send_recv.py +++ b/tests/test_send_recv.py @@ -5,14 +5,14 @@ from bmtrain.global_var import config def test_send_recv(): - if config["topology"].stage_id == 0: + if config["topology"].pipe_rank == 0: a = torch.ones((2,1)) * (config["topology"].pp_zero_id+1) a = a.cuda() print(f"send {a}") - bmt.distributed.send_activations(a, 1, config["pipe_comm"]) + bmt.distributed.send_tensor(a, 1, config["pipe_comm"]) else: ref = torch.ones((2,1)) * (config["topology"].pp_zero_id+1) - a = bmt.distributed.recv_activations(0, config["pipe_comm"]) + a = bmt.distributed.recv_tensor(0, config["pipe_comm"]) print(f"recv {a}") assert_all_eq(a, ref.cuda())