diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 63806290..f7bbd935 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -1,5 +1,6 @@ from typing import Dict, Iterable, Iterator, Tuple, Union + from .global_var import config import torch from . import nccl @@ -13,6 +14,7 @@ def round_up(x, d): return (x + d - 1) // d * d +# the flag is used to control the zero level , 0 means normal zero3 , 1 means forward without release parameter ,2 means backward without gather parameter class OpCheckpointBlock(torch.autograd.Function): @staticmethod def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len_args, *args): @@ -20,7 +22,6 @@ def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len ctx.preserve_rng_state = preserve_rng_state ctx.cuda_rng_state = torch.cuda.get_rng_state() if preserve_rng_state else None - tensors = [] others = [] for arg in args: @@ -34,8 +35,12 @@ def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len ctx.nontensor_inputs = others ctx.len_args = len_args ctx.save_for_backward(*tensors) - - with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block): + ctx.param_dict={} + if config['zero_level'] == 2: + flag = 1 + else: + flag = 0 + with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block, ctx.param_dict, flag): inp_args = args[:len_args] inp_kwargs = {} for k, v in zip(args[len_args::2], args[len_args + 1::2]): @@ -71,7 +76,11 @@ def backward(ctx, *grad_outputs): with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=ctx.preserve_rng_state): if ctx.preserve_rng_state: torch.cuda.set_rng_state(ctx.cuda_rng_state) - with torch.enable_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(ctx.block): + if config['zero_level'] == 2: + flag = 2 + else: + flag = 0 + with torch.enable_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(ctx.block, ctx.param_dict, flag): inp_args = all_inputs[:len_args] inp_kwargs = {} for k, v in zip(all_inputs[len_args::2], all_inputs[len_args + 1::2]): @@ -112,13 +121,14 @@ def backward(ctx, *grad_outputs): return (None, None, None, None) + tuple(grads) class CheckpointBlockContext: - def __init__(self, block : 'CheckpointBlock') -> None: + def __init__(self ,block : 'CheckpointBlock',ctx_dict : dict = None, flag : int = 0) -> None: self.block = block + self.ctx_dict=ctx_dict self._param_buffer = {} self._grad_buffer = {} self._param_tensor = {} self._grad_tensor = {} - + self.flag = flag self._need_release = False def enter(self): @@ -132,7 +142,6 @@ def enter(self): wait_loader() requires_grad = torch.is_grad_enabled() - with torch.cuda.stream(config["load_stream"]): for kw, val in self.block._storage_info.items(): assert self.block._storage_params[kw].is_cuda @@ -141,47 +150,54 @@ def enter(self): local_param = self.block._storage_params[kw] storage_type = local_param.storage_type() - - self._param_buffer[kw] = storage_type(val["partition_size"] * config["world_size"]) - self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) + if self.flag != 2: + self._param_buffer[kw] = storage_type(val["partition_size"] * config["world_size"]) + self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) if requires_grad and local_param.requires_grad: self._grad_buffer[kw] = storage_type(val["partition_size"] * config["world_size"]) self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() - - nccl.groupStart() - for kw, val in self.block._storage_info.items(): - nccl.allGather( - self.block._storage_params[kw].storage(), - self._param_buffer[kw], - config["comm"] - ) - nccl.groupEnd() + if self.flag != 2: + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + nccl.allGather( + self.block._storage_params[kw].storage(), + self._param_buffer[kw], + config["comm"] + ) + nccl.groupEnd() current_stream = torch.cuda.current_stream() current_stream.wait_stream(config["load_stream"]) # set wait stream for each storage - for kw in self._param_tensor.keys(): - self._param_tensor[kw].record_stream(current_stream) + for kw in self.block._storage_info.keys(): + if self.flag != 2: + self._param_tensor[kw].record_stream(current_stream) if requires_grad and kw in self._grad_tensor: self._grad_tensor[kw].record_stream(current_stream) # update parameters in block - for param in self.block._param_info: - kw_name = param["kw_name"] - dtype = self._param_buffer[kw_name].dtype - device = self._param_buffer[kw_name].device - offset = param["offset"] - shape = param["shape"] - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) - if requires_grad and kw_name in self._grad_buffer: - param["parameter"].requires_grad_(True) - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) - else: - param["parameter"].requires_grad_(False) - - + for param in self.block._param_info: + kw_name = param["kw_name"] + offset = param["offset"] + shape = param["shape"] + + if self.flag != 2: + dtype = self._param_buffer[kw_name].dtype + device = self._param_buffer[kw_name].device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) + else: + dtype = param["parameter"].data.dtype + device = param["parameter"].data.device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) + + if requires_grad and kw_name in self._grad_buffer: + param["parameter"].requires_grad_(True) + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) + else: + param["parameter"].requires_grad_(False) + def __enter__(self): self.enter() @@ -231,18 +247,26 @@ def exit(self): # grads can not be freed until reduce ops finish self._grad_tensor[kw].record_stream(config["load_stream"]) - # Release all parameters in buffer + # Release all parameters from buffer to block_storge for param in self.block._param_info: - dtype = param["parameter"].dtype - device = param["parameter"].device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device) + kw_name = param["kw_name"] param["parameter"].grad = None - + if "begin" not in param: + continue + begin = param["begin"] + end = param["end"] + dtype = self.block._storage_params[kw_name].dtype + device = self.block._storage_params[kw_name].device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) + if param["parameter"].requires_grad: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + if self.flag == 1: + for i in self._param_buffer: + self.ctx_dict[i] = self._param_buffer[i] self._grad_tensor = {} self._param_tensor = {} self._grad_buffer = {} self._param_buffer = {} - def __exit__(self, exc_type, exc_val, exc_tb): # reduce scatter gradients self.exit() @@ -293,7 +317,6 @@ class CheckpointBlock(torch.nn.Module): """ def __init__(self, inner_module : torch.nn.Module): super().__init__() - self._module = inner_module # build large parameter&grad here self._param_info = [] @@ -352,9 +375,7 @@ def __init__(self, inner_module : torch.nn.Module): else: storage_param.requires_grad_(False) - # register parameter - self.register_parameter(kw, storage_param) - + self._storage_params[kw] = storage_param # initialize parameters in module @@ -376,44 +397,43 @@ def __init__(self, inner_module : torch.nn.Module): "kw_name": kw_name, }) - if isinstance(param, DistributedParameter) and param._init_method is not None: - # do not copy distributed parameters - pass - else: - # copy values to buffer for normal parameter - storage_st = self._storage_info[kw_name]["begin"] - storage_end = self._storage_info[kw_name]["end"] - - # make parameter contiguous in storage - with torch.no_grad(): - contiguous_param = OpAllGather.apply(param) - - if not (param_st >= storage_end or param_end <= storage_st): - # copy offset in parameter storage - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, contiguous_param.numel()) - assert offset_st < offset_end - - # copy to offset in buffer storage - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st - - # copy to buffer - # PyTorch 1.11 changed the API of storage.__getitem__ - d_dtype = self._storage_params[kw_name].dtype - d_device = self._storage_params[kw_name].device - torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] - # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end]) - del contiguous_param + # copy values to buffer for normal parameter + storage_st = self._storage_info[kw_name]["begin"] + storage_end = self._storage_info[kw_name]["end"] + # make parameter contiguous in storage + with torch.no_grad(): + contiguous_param = OpAllGather.apply(param) + + if not (param_st >= storage_end or param_end <= storage_st): + # copy offset in parameter storage + offset_st = max(storage_st - param_st, 0) + offset_end = min(storage_end - param_st, contiguous_param.numel()) + assert offset_st < offset_end + + # copy to offset in buffer storage + to_offset_st = offset_st + param_st - storage_st + to_offset_end = offset_end + param_st - storage_st + + # copy to buffer + # PyTorch 1.11 changed the API of storage.__getitem__ + d_dtype = self._storage_params[kw_name].dtype + d_device = self._storage_params[kw_name].device + param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) + self._param_info[-1]["begin"] = to_offset_st + self._param_info[-1]["end"] = (to_offset_end - to_offset_st,) + param.data[:] = \ + torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] + del contiguous_param + else: + param.data = torch.tensor([], dtype=param.dtype, device=param.device) + # clear parameter data, but keep the dtype and device - param.data = torch.tensor([], dtype=param.dtype, device=param.device) setattr(param, "_in_checkpoint_block", True) for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] - + def __call__(self, *args, **kwargs): # gather here placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) @@ -422,13 +442,18 @@ def __call__(self, *args, **kwargs): all_inputs.append(kw) all_inputs.append(val) return OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) - + def __getattr__(self,name:str): + if name=="_module": + return self._module + return getattr(self._module, name) def __setattr__(self, name, value): object.__setattr__(self, name, value) - + def __getattribute__(self, name: str): + if name=="_parameters": + return self._module._parameters + return super().__getattribute__(name) def __delattr__(self, name): object.__delattr__(self, name) - def _save_to_state_dict(self, destination, prefix, keep_vars): raise RuntimeError("._save_to_state_dict() of CheckpointBlock should not be called") @@ -440,7 +465,6 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - all_keys = [] for it in self._param_info: key = prefix + it["name"] @@ -482,7 +506,6 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, d_device = self._storage_params[kw_name].device torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] - # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end]) del contiguous_param elif strict: missing_keys.append(key) @@ -492,7 +515,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, for key in state_dict.keys(): if key.startswith(prefix) and key not in all_keys: unexpected_keys.append(key) - + def grouped_parameters(self): ret = {} for kw, val in self._storage_info.items(): @@ -513,7 +536,6 @@ def init_parameters(self): tmp_tensor = torch.empty(it["shape"], device=param.device, dtype=param.dtype) param._init_method(tmp_tensor) - param_st = it["offset"] param_end = it["offset"] + it["size"] kw_name = it["kw_name"] @@ -540,25 +562,14 @@ def init_parameters(self): # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype d_device = self._storage_params[kw_name].device - torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \ + param.data[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] - # self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(tmp_tensor.storage()[offset_st: offset_end]) del tmp_tensor def _named_members(self, get_members_fn, prefix='', recurse=True): r"""Helper method for yielding various names + members of modules.""" - print("here in _named_members") - memo = set() - modules = torch.nn.Module.named_modules(self, prefix=prefix) if recurse else [(prefix, self)] - for module_prefix, module in modules: - members = get_members_fn(module) - for k, v in members: - if v is None or v in memo: - continue - memo.add(v) - name = module_prefix + ('.' if module_prefix else '') + k - yield name, v - + return self._module._named_members(get_members_fn, prefix, recurse) + def named_modules(self, memo = None, prefix: str = '', remove_duplicate: bool = True): r"""Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself. @@ -604,95 +615,17 @@ def named_modules(self, memo = None, prefix: str = '', remove_duplicate: bool = submodule_prefix = prefix + ('.' if prefix else '') + name for m in module.named_modules(memo, submodule_prefix, remove_duplicate): yield m - def named_children(self): - r"""Returns an iterator over immediate children modules, yielding both - the name of the module as well as the module itself. - - Yields: - (string, Module): Tuple containing a name and child module - - Example:: - - >>> for name, module in model.named_children(): - >>> if name in ['conv4', 'conv5']: - >>> print(module) - - """ - memo = set() - for name, module in self._module._modules.items(): - if module is not None and module not in memo: - memo.add(module) - yield name, module + return self._module.named_children() def train(self, mode: bool = True): - r"""Sets the module in training mode. - - This has any effect only on certain modules. See documentations of - particular modules for details of their behaviors in training/evaluation - mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, - etc. - - Args: - mode (bool): whether to set training mode (``True``) or evaluation - mode (``False``). Default: ``True``. - - Returns: - Module: self - """ - if not isinstance(mode, bool): - raise ValueError("training mode is expected to be boolean") - self.training = mode self._module.train(mode) - return self def eval(self): - r"""Sets the module in evaluation mode. - - This has any effect only on certain modules. See documentations of - particular modules for details of their behaviors in training/evaluation - mode, if they are affected, e.g. Dropout, BatchNorm, - etc. - - This is equivalent with `self.train(False)`. - - Returns: - Module: self - """ - return self.train(False) - - def __repr__(self): - # We treat the extra repr like the sub-module, one item per line - extra_lines = [] - extra_repr = self.extra_repr() - # empty string will be split into list [''] - if extra_repr: - extra_lines = extra_repr.split('\n') - child_lines = [] - for key, module in self._module._modules.items(): - mod_str = repr(module) - mod_str = _addindent(mod_str, 2) - child_lines.append('(' + key + '): ' + mod_str) - lines = extra_lines + child_lines - - main_str = self._get_name() + '(' - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += '\n ' + '\n '.join(lines) + '\n' - - main_str += ')' - return main_str + self._module.eval() - def __getattr__(self, attribute): - try: - return super().__getattr__(attribute) - except: - return getattr(self._module, attribute) - - + def __repr__(self): + return self._module.__repr__() class OpTransformerBlockList(torch.autograd.Function): @staticmethod @@ -711,7 +644,7 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s ctx.self = self ctx.save_list = copy.deepcopy(save_list) ctx.num_save_needed = save_list[-1][1]+1 - + ctx.layers_dict=[{} for _ in range(len(self))] layer_inputs = [] layer_inspector = [] cuda_rng_state = [] @@ -720,7 +653,11 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s if save_list[i][0] == i: layer_inputs.append(hidden_state) cuda_rng_state.append( torch.cuda.get_rng_state() ) - block_ctx = CheckpointBlockContext(self._modules[str(i)]) + if config['zero_level']==2: + flag = 1 + else: + flag = 0 + block_ctx = CheckpointBlockContext(self._modules[str(i)], ctx.layers_dict[i], flag) # gather parameter on load stream block_ctx.enter() # call inner module directly @@ -757,7 +694,6 @@ def exit_prev(prev_ctx, prev_grad): "Checkpointing is not compatible with .grad() or when an `inputs` parameter" " is passed to .backward(). Please use .backward() and do not pass its `inputs`" " argument.") - all_inputs = [] input_requires_grad = [] @@ -785,7 +721,11 @@ def exit_prev(prev_ctx, prev_grad): st = ctx.save_list[i][0] for j in range(st, i): torch.cuda.set_rng_state(ctx.cuda_rng_state[j]) - block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)]) + if config['zero_level'] == 2: + flag = 2 + else: + flag = 0 + block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)], ctx.layers_dict[j], flag) block_ctx.enter() exit_prev(prev_ctx, prev_grad) output = ctx.self._modules[str(j)]._module._call_impl(layer_inputs[ctx.save_list[j][1]], *all_inputs) @@ -793,9 +733,14 @@ def exit_prev(prev_ctx, prev_grad): prev_grad = False layer_inputs[ctx.save_list[j+1][1]].copy_(output) ctx.save_list[j+1][0] = j+1 + torch.cuda.set_rng_state(ctx.cuda_rng_state[i]) ipt = layer_inputs[ctx.save_list[i][1]].detach().requires_grad_() - block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)]) + if config['zero_level'] == 2: + flag = 2 + else: + flag = 0 + block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)], ctx.layers_dict[i], flag) block_ctx.enter() exit_prev(prev_ctx, prev_grad) prev_ctx = block_ctx diff --git a/bmtrain/global_var.py b/bmtrain/global_var.py index 786546b0..909711f5 100644 --- a/bmtrain/global_var.py +++ b/bmtrain/global_var.py @@ -5,12 +5,11 @@ class ConfigMap(TypedDict): local_rank : int world_size : int local_size : int - + zero_level : int calc_stream : torch.cuda.Stream load_stream : torch.cuda.Stream load_event : torch.cuda.Event barrier_stream : torch.cuda.Stream - loss_scale_factor : float loss_scale_steps : int @@ -30,4 +29,4 @@ def world_size(): """ Returns the total number of workers across all nodes. """ - return config['world_size'] \ No newline at end of file + return config['world_size'] diff --git a/bmtrain/init.py b/bmtrain/init.py index 6baca12d..67db7404 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -8,12 +8,12 @@ from . import nccl import time from .synchronize import synchronize - def init_distributed( init_method : str = "env://", seed : int = 0, loss_scale_factor : float = 2, - loss_scale_steps : int = 1024 + loss_scale_steps : int = 1024, + zero_level: int = 3, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -64,7 +64,7 @@ def init_distributed( config["load_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() - + config["zero_level"] = zero_level config["loss_scale_factor"] = loss_scale_factor if loss_scale_factor > 1 else 1 / loss_scale_factor config["loss_scale_steps"] = loss_scale_steps diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 71305f24..49f00346 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -29,6 +29,8 @@ def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): def iterate_parameters(model : torch.nn.Module): for kw, val in model._parameters.items(): + if hasattr(val,"_in_checkpoint_block") and val._in_checkpoint_block: + return [] yield val def init_parameters(model : torch.nn.Module): diff --git a/example/train.py b/example/train.py index 6bdf3c4c..86a9f5bf 100644 --- a/example/train.py +++ b/example/train.py @@ -6,6 +6,7 @@ def main(): bmt.init_distributed( seed=0, + zero_level=2 ) model = GPT( @@ -121,4 +122,4 @@ def main(): bmt.save(model, "checkpoint.pt") if __name__ == '__main__': - main() \ No newline at end of file + main()