From 9c4f70d92e9d1980fc9e6bf2b3c45dc72504116a Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 2 Jun 2022 13:00:15 +0800 Subject: [PATCH 1/7] zero2 --- bmtrain/block_layer.py | 107 ++++++++++++++++++++++++----------------- bmtrain/global_var.py | 4 +- bmtrain/init.py | 10 ++-- example/train.py | 3 +- 4 files changed, 74 insertions(+), 50 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 63806290..7e48013b 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -34,8 +34,11 @@ def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len ctx.nontensor_inputs = others ctx.len_args = len_args ctx.save_for_backward(*tensors) - - with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block): + if config['zero_level'] == 2: + flag = 1 + else: + flag = 0 + with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block,flag): inp_args = args[:len_args] inp_kwargs = {} for k, v in zip(args[len_args::2], args[len_args + 1::2]): @@ -71,7 +74,11 @@ def backward(ctx, *grad_outputs): with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=ctx.preserve_rng_state): if ctx.preserve_rng_state: torch.cuda.set_rng_state(ctx.cuda_rng_state) - with torch.enable_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(ctx.block): + if config['zero_level'] == 2: + flag = 2 + else: + flag = 0 + with torch.enable_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(ctx.block,flag): inp_args = all_inputs[:len_args] inp_kwargs = {} for k, v in zip(all_inputs[len_args::2], all_inputs[len_args + 1::2]): @@ -112,13 +119,13 @@ def backward(ctx, *grad_outputs): return (None, None, None, None) + tuple(grads) class CheckpointBlockContext: - def __init__(self, block : 'CheckpointBlock') -> None: + def __init__(self, block : 'CheckpointBlock',flag = 0) -> None: self.block = block self._param_buffer = {} self._grad_buffer = {} self._param_tensor = {} self._grad_tensor = {} - + self.flag = flag self._need_release = False def enter(self): @@ -132,7 +139,6 @@ def enter(self): wait_loader() requires_grad = torch.is_grad_enabled() - with torch.cuda.stream(config["load_stream"]): for kw, val in self.block._storage_info.items(): assert self.block._storage_params[kw].is_cuda @@ -141,47 +147,50 @@ def enter(self): local_param = self.block._storage_params[kw] storage_type = local_param.storage_type() - - self._param_buffer[kw] = storage_type(val["partition_size"] * config["world_size"]) - self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) + if self.flag != 2: + self._param_buffer[kw] = storage_type(val["partition_size"] * config["world_size"]) + self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) if requires_grad and local_param.requires_grad: self._grad_buffer[kw] = storage_type(val["partition_size"] * config["world_size"]) self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() - - nccl.groupStart() - for kw, val in self.block._storage_info.items(): - nccl.allGather( - self.block._storage_params[kw].storage(), - self._param_buffer[kw], - config["comm"] - ) - nccl.groupEnd() + if self.flag != 2: + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + nccl.allGather( + self.block._storage_params[kw].storage(), + self._param_buffer[kw], + config["comm"] + ) + nccl.groupEnd() current_stream = torch.cuda.current_stream() current_stream.wait_stream(config["load_stream"]) # set wait stream for each storage - for kw in self._param_tensor.keys(): - self._param_tensor[kw].record_stream(current_stream) + for kw in self.block._storage_info.keys(): + if self.flag != 2: + self._param_tensor[kw].record_stream(current_stream) if requires_grad and kw in self._grad_tensor: self._grad_tensor[kw].record_stream(current_stream) # update parameters in block - for param in self.block._param_info: - kw_name = param["kw_name"] - dtype = self._param_buffer[kw_name].dtype - device = self._param_buffer[kw_name].device - offset = param["offset"] - shape = param["shape"] - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) - if requires_grad and kw_name in self._grad_buffer: - param["parameter"].requires_grad_(True) - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) - else: - param["parameter"].requires_grad_(False) - - + for param in self.block._param_info: + kw_name = param["kw_name"] + offset = param["offset"] + shape = param["shape"] + if self.flag != 2: + dtype = self._param_buffer[kw_name].dtype + device = self._param_buffer[kw_name].device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) + else: + dtype = param["parameter"].data.dtype + device = param["parameter"].data.device + if requires_grad and kw_name in self._grad_buffer: + param["parameter"].requires_grad_(True) + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) + else: + param["parameter"].requires_grad_(False) def __enter__(self): self.enter() @@ -232,17 +241,17 @@ def exit(self): self._grad_tensor[kw].record_stream(config["load_stream"]) # Release all parameters in buffer - for param in self.block._param_info: - dtype = param["parameter"].dtype - device = param["parameter"].device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device) - param["parameter"].grad = None + if self.flag !=1: + for param in self.block._param_info: + dtype = param["parameter"].dtype + device = param["parameter"].device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device) + param["parameter"].grad = None self._grad_tensor = {} self._param_tensor = {} self._grad_buffer = {} self._param_buffer = {} - def __exit__(self, exc_type, exc_val, exc_tb): # reduce scatter gradients self.exit() @@ -720,7 +729,11 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s if save_list[i][0] == i: layer_inputs.append(hidden_state) cuda_rng_state.append( torch.cuda.get_rng_state() ) - block_ctx = CheckpointBlockContext(self._modules[str(i)]) + if config['zero_level']==2: + flag = 1 + else: + flag = 0 + block_ctx = CheckpointBlockContext(self._modules[str(i)],flag = flag) # gather parameter on load stream block_ctx.enter() # call inner module directly @@ -785,7 +798,11 @@ def exit_prev(prev_ctx, prev_grad): st = ctx.save_list[i][0] for j in range(st, i): torch.cuda.set_rng_state(ctx.cuda_rng_state[j]) - block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)]) + if config['zero_level'] == 2: + flag = 2 + else: + flag = 0 + block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)],flag) block_ctx.enter() exit_prev(prev_ctx, prev_grad) output = ctx.self._modules[str(j)]._module._call_impl(layer_inputs[ctx.save_list[j][1]], *all_inputs) @@ -795,7 +812,11 @@ def exit_prev(prev_ctx, prev_grad): ctx.save_list[j+1][0] = j+1 torch.cuda.set_rng_state(ctx.cuda_rng_state[i]) ipt = layer_inputs[ctx.save_list[i][1]].detach().requires_grad_() - block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)]) + if config['zero_level'] == 2: + flag = 2 + else: + flag = 0 + block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)],flag) block_ctx.enter() exit_prev(prev_ctx, prev_grad) prev_ctx = block_ctx diff --git a/bmtrain/global_var.py b/bmtrain/global_var.py index 786546b0..d69bfa53 100644 --- a/bmtrain/global_var.py +++ b/bmtrain/global_var.py @@ -5,12 +5,12 @@ class ConfigMap(TypedDict): local_rank : int world_size : int local_size : int - + zero_level : int calc_stream : torch.cuda.Stream load_stream : torch.cuda.Stream load_event : torch.cuda.Event barrier_stream : torch.cuda.Stream - + # rank_graph : ParallelGraph loss_scale_factor : float loss_scale_steps : int diff --git a/bmtrain/init.py b/bmtrain/init.py index 6baca12d..0fd1ea2f 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -8,16 +8,18 @@ from . import nccl import time from .synchronize import synchronize - def init_distributed( init_method : str = "env://", seed : int = 0, loss_scale_factor : float = 2, - loss_scale_steps : int = 1024 + loss_scale_steps : int = 1024, + data_parallel_size: int = 1, + pipe_parallel_size: int =1, + zero_level: int = 3, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. - It must be called before any other distributed functions. + It must be called before any other distributed functions.siz Args: seed (int): The random seed. @@ -64,7 +66,7 @@ def init_distributed( config["load_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() - + config["zero_level"] = zero_level config["loss_scale_factor"] = loss_scale_factor if loss_scale_factor > 1 else 1 / loss_scale_factor config["loss_scale_steps"] = loss_scale_steps diff --git a/example/train.py b/example/train.py index 6bdf3c4c..86a9f5bf 100644 --- a/example/train.py +++ b/example/train.py @@ -6,6 +6,7 @@ def main(): bmt.init_distributed( seed=0, + zero_level=2 ) model = GPT( @@ -121,4 +122,4 @@ def main(): bmt.save(model, "checkpoint.pt") if __name__ == '__main__': - main() \ No newline at end of file + main() From c7965758224c97ac0456ecc447153ff17e8acc15 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 10 Jun 2022 14:32:53 +0800 Subject: [PATCH 2/7] zero2 fixed and _parameters in checkpointblock can directly point to the parameter in _module --- bmtrain/block_layer.py | 238 +++++++++++++++-------------------------- bmtrain/param_init.py | 2 + 2 files changed, 87 insertions(+), 153 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 7e48013b..7f7a2279 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -1,5 +1,7 @@ from typing import Dict, Iterable, Iterator, Tuple, Union +from torch.nn import parameter +from collections import OrderedDict from .global_var import config import torch from . import nccl @@ -20,7 +22,6 @@ def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len ctx.preserve_rng_state = preserve_rng_state ctx.cuda_rng_state = torch.cuda.get_rng_state() if preserve_rng_state else None - tensors = [] others = [] for arg in args: @@ -34,11 +35,12 @@ def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len ctx.nontensor_inputs = others ctx.len_args = len_args ctx.save_for_backward(*tensors) + ctx.param_dict={} if config['zero_level'] == 2: flag = 1 else: flag = 0 - with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block,flag): + with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block,ctx.param_dict,flag): inp_args = args[:len_args] inp_kwargs = {} for k, v in zip(args[len_args::2], args[len_args + 1::2]): @@ -78,7 +80,7 @@ def backward(ctx, *grad_outputs): flag = 2 else: flag = 0 - with torch.enable_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(ctx.block,flag): + with torch.enable_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(ctx.block,ctx.param_dict,flag): inp_args = all_inputs[:len_args] inp_kwargs = {} for k, v in zip(all_inputs[len_args::2], all_inputs[len_args + 1::2]): @@ -119,8 +121,9 @@ def backward(ctx, *grad_outputs): return (None, None, None, None) + tuple(grads) class CheckpointBlockContext: - def __init__(self, block : 'CheckpointBlock',flag = 0) -> None: + def __init__(self ,block : 'CheckpointBlock',ctx_dict : dict = None, flag : int = 0) -> None: self.block = block + self.ctx_dict=ctx_dict self._param_buffer = {} self._grad_buffer = {} self._param_tensor = {} @@ -179,18 +182,22 @@ def enter(self): kw_name = param["kw_name"] offset = param["offset"] shape = param["shape"] + if self.flag != 2: dtype = self._param_buffer[kw_name].dtype device = self._param_buffer[kw_name].device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) else: dtype = param["parameter"].data.dtype device = param["parameter"].data.device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) + if requires_grad and kw_name in self._grad_buffer: param["parameter"].requires_grad_(True) param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) else: param["parameter"].requires_grad_(False) + def __enter__(self): self.enter() @@ -241,13 +248,21 @@ def exit(self): self._grad_tensor[kw].record_stream(config["load_stream"]) # Release all parameters in buffer - if self.flag !=1: - for param in self.block._param_info: - dtype = param["parameter"].dtype - device = param["parameter"].device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device) - param["parameter"].grad = None - + for param in self.block._param_info: + kw_name = param["kw_name"] + param["parameter"].grad = None + if "begin" not in param: + continue + begin = param["begin"] + end = param["end"] + dtype = self.block._storage_params[kw_name].dtype + device = self.block._storage_params[kw_name].device + param["parameter"].data=torch.tensor([],dtype=dtype,device=device).set_(self.block._storage_params[kw_name].storage(),begin,end) + if param["parameter"].requires_grad: + param["parameter"].grad=torch.tensor([],dtype=dtype,device=device).set_(self.block._storage_params[kw_name].grad.storage(),begin,end) + if self.flag==1: + for i in self._param_buffer: + self.ctx_dict[i] = self._param_buffer[i] self._grad_tensor = {} self._param_tensor = {} self._grad_buffer = {} @@ -302,7 +317,6 @@ class CheckpointBlock(torch.nn.Module): """ def __init__(self, inner_module : torch.nn.Module): super().__init__() - self._module = inner_module # build large parameter&grad here self._param_info = [] @@ -362,8 +376,8 @@ def __init__(self, inner_module : torch.nn.Module): storage_param.requires_grad_(False) # register parameter - self.register_parameter(kw, storage_param) - + # self.register_parameter(kw, storage_param) + self._storage_params[kw] = storage_param # initialize parameters in module @@ -385,44 +399,44 @@ def __init__(self, inner_module : torch.nn.Module): "kw_name": kw_name, }) - if isinstance(param, DistributedParameter) and param._init_method is not None: - # do not copy distributed parameters - pass - else: - # copy values to buffer for normal parameter - storage_st = self._storage_info[kw_name]["begin"] - storage_end = self._storage_info[kw_name]["end"] - - # make parameter contiguous in storage - with torch.no_grad(): - contiguous_param = OpAllGather.apply(param) - - if not (param_st >= storage_end or param_end <= storage_st): - # copy offset in parameter storage - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, contiguous_param.numel()) - assert offset_st < offset_end - - # copy to offset in buffer storage - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st - - # copy to buffer - # PyTorch 1.11 changed the API of storage.__getitem__ - d_dtype = self._storage_params[kw_name].dtype - d_device = self._storage_params[kw_name].device - torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] - # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end]) - del contiguous_param + # copy values to buffer for normal parameter + storage_st = self._storage_info[kw_name]["begin"] + storage_end = self._storage_info[kw_name]["end"] + # make parameter contiguous in storage + with torch.no_grad(): + contiguous_param = OpAllGather.apply(param) + + if not (param_st >= storage_end or param_end <= storage_st): + # copy offset in parameter storage + offset_st = max(storage_st - param_st, 0) + offset_end = min(storage_end - param_st, contiguous_param.numel()) + assert offset_st < offset_end + + # copy to offset in buffer storage + to_offset_st = offset_st + param_st - storage_st + to_offset_end = offset_end + param_st - storage_st + + # copy to buffer + # PyTorch 1.11 changed the API of storage.__getitem__ + d_dtype = self._storage_params[kw_name].dtype + d_device = self._storage_params[kw_name].device + param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) + self._param_info[-1]["begin"] = to_offset_st + self._param_info[-1]["end"] = (to_offset_end - to_offset_st,) + param.data[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] + # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end]) + del contiguous_param + else: + param.data = torch.tensor([], dtype=param.dtype, device=param.device) + # clear parameter data, but keep the dtype and device - param.data = torch.tensor([], dtype=param.dtype, device=param.device) setattr(param, "_in_checkpoint_block", True) for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] - + def __call__(self, *args, **kwargs): # gather here placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) @@ -431,13 +445,18 @@ def __call__(self, *args, **kwargs): all_inputs.append(kw) all_inputs.append(val) return OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) - + def __getattr__(self,name:str): + if name=="_module": + return self._module + return getattr(self._module, name) def __setattr__(self, name, value): object.__setattr__(self, name, value) - + def __getattribute__(self, name: str): + if name=="_parameters": + return self._module._parameters + return super().__getattribute__(name) def __delattr__(self, name): object.__delattr__(self, name) - def _save_to_state_dict(self, destination, prefix, keep_vars): raise RuntimeError("._save_to_state_dict() of CheckpointBlock should not be called") @@ -449,7 +468,6 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - all_keys = [] for it in self._param_info: key = prefix + it["name"] @@ -501,7 +519,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, for key in state_dict.keys(): if key.startswith(prefix) and key not in all_keys: unexpected_keys.append(key) - + def grouped_parameters(self): ret = {} for kw, val in self._storage_info.items(): @@ -522,7 +540,6 @@ def init_parameters(self): tmp_tensor = torch.empty(it["shape"], device=param.device, dtype=param.dtype) param._init_method(tmp_tensor) - param_st = it["offset"] param_end = it["offset"] + it["size"] kw_name = it["kw_name"] @@ -549,25 +566,16 @@ def init_parameters(self): # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype d_device = self._storage_params[kw_name].device - torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ + # param.data=torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) + param.data[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(tmp_tensor.storage()[offset_st: offset_end]) del tmp_tensor def _named_members(self, get_members_fn, prefix='', recurse=True): r"""Helper method for yielding various names + members of modules.""" - print("here in _named_members") - memo = set() - modules = torch.nn.Module.named_modules(self, prefix=prefix) if recurse else [(prefix, self)] - for module_prefix, module in modules: - members = get_members_fn(module) - for k, v in members: - if v is None or v in memo: - continue - memo.add(v) - name = module_prefix + ('.' if module_prefix else '') + k - yield name, v - + return self._module._named_members(get_members_fn, prefix, recurse) + def named_modules(self, memo = None, prefix: str = '', remove_duplicate: bool = True): r"""Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself. @@ -613,95 +621,17 @@ def named_modules(self, memo = None, prefix: str = '', remove_duplicate: bool = submodule_prefix = prefix + ('.' if prefix else '') + name for m in module.named_modules(memo, submodule_prefix, remove_duplicate): yield m - def named_children(self): - r"""Returns an iterator over immediate children modules, yielding both - the name of the module as well as the module itself. - - Yields: - (string, Module): Tuple containing a name and child module - - Example:: - - >>> for name, module in model.named_children(): - >>> if name in ['conv4', 'conv5']: - >>> print(module) - - """ - memo = set() - for name, module in self._module._modules.items(): - if module is not None and module not in memo: - memo.add(module) - yield name, module + return self._module.named_children() def train(self, mode: bool = True): - r"""Sets the module in training mode. - - This has any effect only on certain modules. See documentations of - particular modules for details of their behaviors in training/evaluation - mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, - etc. - - Args: - mode (bool): whether to set training mode (``True``) or evaluation - mode (``False``). Default: ``True``. - - Returns: - Module: self - """ - if not isinstance(mode, bool): - raise ValueError("training mode is expected to be boolean") - self.training = mode self._module.train(mode) - return self def eval(self): - r"""Sets the module in evaluation mode. - - This has any effect only on certain modules. See documentations of - particular modules for details of their behaviors in training/evaluation - mode, if they are affected, e.g. Dropout, BatchNorm, - etc. - - This is equivalent with `self.train(False)`. - - Returns: - Module: self - """ - return self.train(False) - - def __repr__(self): - # We treat the extra repr like the sub-module, one item per line - extra_lines = [] - extra_repr = self.extra_repr() - # empty string will be split into list [''] - if extra_repr: - extra_lines = extra_repr.split('\n') - child_lines = [] - for key, module in self._module._modules.items(): - mod_str = repr(module) - mod_str = _addindent(mod_str, 2) - child_lines.append('(' + key + '): ' + mod_str) - lines = extra_lines + child_lines - - main_str = self._get_name() + '(' - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += '\n ' + '\n '.join(lines) + '\n' - - main_str += ')' - return main_str + self._module.eval() - def __getattr__(self, attribute): - try: - return super().__getattr__(attribute) - except: - return getattr(self._module, attribute) - - + def __repr__(self): + return self._module.__repr__() class OpTransformerBlockList(torch.autograd.Function): @staticmethod @@ -720,7 +650,7 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s ctx.self = self ctx.save_list = copy.deepcopy(save_list) ctx.num_save_needed = save_list[-1][1]+1 - + ctx.layers_dict=[{} for _ in range(len(self))] layer_inputs = [] layer_inspector = [] cuda_rng_state = [] @@ -733,7 +663,7 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s flag = 1 else: flag = 0 - block_ctx = CheckpointBlockContext(self._modules[str(i)],flag = flag) + block_ctx = CheckpointBlockContext(self._modules[str(i)],ctx.layers_dict[i],flag = flag) # gather parameter on load stream block_ctx.enter() # call inner module directly @@ -770,7 +700,6 @@ def exit_prev(prev_ctx, prev_grad): "Checkpointing is not compatible with .grad() or when an `inputs` parameter" " is passed to .backward(). Please use .backward() and do not pass its `inputs`" " argument.") - all_inputs = [] input_requires_grad = [] @@ -802,7 +731,7 @@ def exit_prev(prev_ctx, prev_grad): flag = 2 else: flag = 0 - block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)],flag) + block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)],ctx.layers_dict[j],flag) block_ctx.enter() exit_prev(prev_ctx, prev_grad) output = ctx.self._modules[str(j)]._module._call_impl(layer_inputs[ctx.save_list[j][1]], *all_inputs) @@ -810,13 +739,14 @@ def exit_prev(prev_ctx, prev_grad): prev_grad = False layer_inputs[ctx.save_list[j+1][1]].copy_(output) ctx.save_list[j+1][0] = j+1 + torch.cuda.set_rng_state(ctx.cuda_rng_state[i]) ipt = layer_inputs[ctx.save_list[i][1]].detach().requires_grad_() if config['zero_level'] == 2: flag = 2 else: flag = 0 - block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)],flag) + block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)],ctx.layers_dict[i],flag) block_ctx.enter() exit_prev(prev_ctx, prev_grad) prev_ctx = block_ctx @@ -911,3 +841,5 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: def forward(self, hidden_state, *args): placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) return OpTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args) + # def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): + # return super().named_modules(memo, prefix, remove_duplicate) \ No newline at end of file diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 71305f24..49f00346 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -29,6 +29,8 @@ def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): def iterate_parameters(model : torch.nn.Module): for kw, val in model._parameters.items(): + if hasattr(val,"_in_checkpoint_block") and val._in_checkpoint_block: + return [] yield val def init_parameters(model : torch.nn.Module): From 6b9974b89a322ac977fb5d363eaf50c82b88b4f5 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 10 Jun 2022 15:54:39 +0800 Subject: [PATCH 3/7] delete wrong --- bmtrain/init.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 0fd1ea2f..805c1f4d 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -13,8 +13,6 @@ def init_distributed( seed : int = 0, loss_scale_factor : float = 2, loss_scale_steps : int = 1024, - data_parallel_size: int = 1, - pipe_parallel_size: int =1, zero_level: int = 3, ): """Initialize distributed training. @@ -48,7 +46,7 @@ def init_distributed( world_size = int(os.environ["WORLD_SIZE"]) local_size = int(os.environ["LOCAL_WORLD_SIZE"]) master = os.environ["MASTER_ADDR"] + ":" + os.environ["MASTER_PORT"] - timeout = datetime.timedelta(seconds=1800) + timeout = datetime.timedelta(seconds=1800) rendezvous_iterator = dist.rendezvous( init_method, rank, world_size, timeout=timeout ) From 72918610402c967a8845e6bfd970f52cff35494c Mon Sep 17 00:00:00 2001 From: Bojack <57244158+MayDomine@users.noreply.github.com> Date: Fri, 10 Jun 2022 16:07:33 +0800 Subject: [PATCH 4/7] Update block_layer.py --- bmtrain/block_layer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 7f7a2279..c55ddab6 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -1,7 +1,6 @@ from typing import Dict, Iterable, Iterator, Tuple, Union -from torch.nn import parameter -from collections import OrderedDict + from .global_var import config import torch from . import nccl @@ -247,7 +246,7 @@ def exit(self): # grads can not be freed until reduce ops finish self._grad_tensor[kw].record_stream(config["load_stream"]) - # Release all parameters in buffer + # Release all parameters from buffer to block_storge for param in self.block._param_info: kw_name = param["kw_name"] param["parameter"].grad = None @@ -375,8 +374,6 @@ def __init__(self, inner_module : torch.nn.Module): else: storage_param.requires_grad_(False) - # register parameter - # self.register_parameter(kw, storage_param) self._storage_params[kw] = storage_param @@ -842,4 +839,4 @@ def forward(self, hidden_state, *args): placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) return OpTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args) # def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): - # return super().named_modules(memo, prefix, remove_duplicate) \ No newline at end of file + # return super().named_modules(memo, prefix, remove_duplicate) From 80037db469eb197af91c568f05313abea2b58888 Mon Sep 17 00:00:00 2001 From: Bojack <57244158+MayDomine@users.noreply.github.com> Date: Fri, 10 Jun 2022 16:08:18 +0800 Subject: [PATCH 5/7] Update global_var.py --- bmtrain/global_var.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bmtrain/global_var.py b/bmtrain/global_var.py index d69bfa53..909711f5 100644 --- a/bmtrain/global_var.py +++ b/bmtrain/global_var.py @@ -10,7 +10,6 @@ class ConfigMap(TypedDict): load_stream : torch.cuda.Stream load_event : torch.cuda.Event barrier_stream : torch.cuda.Stream - # rank_graph : ParallelGraph loss_scale_factor : float loss_scale_steps : int @@ -30,4 +29,4 @@ def world_size(): """ Returns the total number of workers across all nodes. """ - return config['world_size'] \ No newline at end of file + return config['world_size'] From 567317cd0daa6c267399851060dfc4ef907169b6 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 10 Jun 2022 16:37:55 +0800 Subject: [PATCH 6/7] 1 --- bmtrain/block_layer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 7f7a2279..1eabe39d 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -40,7 +40,7 @@ def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len flag = 1 else: flag = 0 - with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block,ctx.param_dict,flag): + with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block, ctx.param_dict, flag): inp_args = args[:len_args] inp_kwargs = {} for k, v in zip(args[len_args::2], args[len_args + 1::2]): @@ -80,7 +80,7 @@ def backward(ctx, *grad_outputs): flag = 2 else: flag = 0 - with torch.enable_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(ctx.block,ctx.param_dict,flag): + with torch.enable_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(ctx.block, ctx.param_dict, flag): inp_args = all_inputs[:len_args] inp_kwargs = {} for k, v in zip(all_inputs[len_args::2], all_inputs[len_args + 1::2]): @@ -257,10 +257,10 @@ def exit(self): end = param["end"] dtype = self.block._storage_params[kw_name].dtype device = self.block._storage_params[kw_name].device - param["parameter"].data=torch.tensor([],dtype=dtype,device=device).set_(self.block._storage_params[kw_name].storage(),begin,end) + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) if param["parameter"].requires_grad: - param["parameter"].grad=torch.tensor([],dtype=dtype,device=device).set_(self.block._storage_params[kw_name].grad.storage(),begin,end) - if self.flag==1: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + if self.flag == 1: for i in self._param_buffer: self.ctx_dict[i] = self._param_buffer[i] self._grad_tensor = {} @@ -663,7 +663,7 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s flag = 1 else: flag = 0 - block_ctx = CheckpointBlockContext(self._modules[str(i)],ctx.layers_dict[i],flag = flag) + block_ctx = CheckpointBlockContext(self._modules[str(i)], ctx.layers_dict[i], flag) # gather parameter on load stream block_ctx.enter() # call inner module directly @@ -731,7 +731,7 @@ def exit_prev(prev_ctx, prev_grad): flag = 2 else: flag = 0 - block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)],ctx.layers_dict[j],flag) + block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)], ctx.layers_dict[j], flag) block_ctx.enter() exit_prev(prev_ctx, prev_grad) output = ctx.self._modules[str(j)]._module._call_impl(layer_inputs[ctx.save_list[j][1]], *all_inputs) @@ -746,7 +746,7 @@ def exit_prev(prev_ctx, prev_grad): flag = 2 else: flag = 0 - block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)],ctx.layers_dict[i],flag) + block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)], ctx.layers_dict[i], flag) block_ctx.enter() exit_prev(prev_ctx, prev_grad) prev_ctx = block_ctx From 6861ce1f054763f347375e3ddd3d655972a30543 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 10 Jun 2022 16:56:57 +0800 Subject: [PATCH 7/7] 1 --- bmtrain/block_layer.py | 7 +------ bmtrain/init.py | 4 ++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 9c0e8dde..f7bbd935 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -14,6 +14,7 @@ def round_up(x, d): return (x + d - 1) // d * d +# the flag is used to control the zero level , 0 means normal zero3 , 1 means forward without release parameter ,2 means backward without gather parameter class OpCheckpointBlock(torch.autograd.Function): @staticmethod def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len_args, *args): @@ -423,7 +424,6 @@ def __init__(self, inner_module : torch.nn.Module): self._param_info[-1]["end"] = (to_offset_end - to_offset_st,) param.data[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] - # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end]) del contiguous_param else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) @@ -506,7 +506,6 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, d_device = self._storage_params[kw_name].device torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] - # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end]) del contiguous_param elif strict: missing_keys.append(key) @@ -563,10 +562,8 @@ def init_parameters(self): # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype d_device = self._storage_params[kw_name].device - # param.data=torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) param.data[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] - # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(tmp_tensor.storage()[offset_st: offset_end]) del tmp_tensor def _named_members(self, get_members_fn, prefix='', recurse=True): @@ -838,5 +835,3 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: def forward(self, hidden_state, *args): placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) return OpTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args) - # def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): - # return super().named_modules(memo, prefix, remove_duplicate) diff --git a/bmtrain/init.py b/bmtrain/init.py index 805c1f4d..67db7404 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -17,7 +17,7 @@ def init_distributed( ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. - It must be called before any other distributed functions.siz + It must be called before any other distributed functions. Args: seed (int): The random seed. @@ -46,7 +46,7 @@ def init_distributed( world_size = int(os.environ["WORLD_SIZE"]) local_size = int(os.environ["LOCAL_WORLD_SIZE"]) master = os.environ["MASTER_ADDR"] + ":" + os.environ["MASTER_PORT"] - timeout = datetime.timedelta(seconds=1800) + timeout = datetime.timedelta(seconds=1800) rendezvous_iterator = dist.rendezvous( init_method, rank, world_size, timeout=timeout )