From bcea03505136d46b8b8e59ac08ec9d0dac57911a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 24 Jul 2023 13:50:17 +0800 Subject: [PATCH 001/122] using hooks to implement ZeRO and Checkpoint --- bmtrain/block_layer.py | 250 ++++++++++------------------------------- bmtrain/init.py | 2 + 2 files changed, 64 insertions(+), 188 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 8647b7ff..d2dd5cca 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -253,6 +253,7 @@ def exit(self): # grads can not be freed until reduce ops finish self._grad_tensor[kw].record_stream(config["load_stream"]) + # Release all parameters from buffer to block_storge for param in self.block._param_info: kw_name = param["kw_name"] @@ -274,6 +275,7 @@ def exit(self): self._param_tensor = {} self._grad_buffer = {} self._param_buffer = {} + def __exit__(self, exc_type, exc_val, exc_tb): # reduce scatter gradients self.exit() @@ -329,6 +331,8 @@ class CheckpointBlock(torch.nn.Module): def __init__(self, inner_module : torch.nn.Module): super().__init__() self._module = inner_module + self._inputs = None + self._layer_dict = {} # build large parameter&grad here self._param_info = [] self._storage_params : Dict[str, torch.nn.Parameter] = {} @@ -440,7 +444,6 @@ def __init__(self, inner_module : torch.nn.Module): del contiguous_param else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) - # clear parameter data, but keep the dtype and device setattr(param, "_in_checkpoint_block", True) @@ -458,6 +461,15 @@ def __call__(self, *args, **kwargs): len_output = outputs[0] return outputs[1:1+len_output] if len_output > 0 else outputs[1] + def forward(self, *args): + if config["use_checkpoint"]: + with torch.no_grad(): + out = self._module(*args) + out.requires_grad_() + return out + else: + return self._module(*args) + def __getattr__(self,name:str): if name=="_module": return self._module @@ -619,6 +631,7 @@ def init_parameters(self): param.data[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] del tmp_tensor + def _named_members(self, get_members_fn, prefix='', recurse=True, **kwargs): r"""Helper method for yielding various names + members of modules.""" @@ -685,192 +698,36 @@ def eval(self): def __repr__(self): return self._module.__repr__() -class OpTransformerBlockList(torch.autograd.Function): - @staticmethod - def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, num_hidden, *args): - tensors = [] - others = [] - for arg in args[num_hidden:]: - if torch.is_tensor(arg): - tensors.append(arg) - others.append(None) - else: - tensors.append(None) - others.append(arg) - hidden_states = args[:num_hidden] +def checkpoint_pre_forward(module, inputs): + module._inputs = inputs + +def checkpoint_pre_backward(module, grad_outputs): + with torch.enable_grad(): + out = module._module(*module._inputs) + torch.autograd.backward(out, *grad_outputs) + + if config["zero_level"] != 0: + module._backward_block_ctx.exit() + +def zero_pre_forward(module, inputs): + forward_flag = 1 if config['zero_level'] == 2 else 0 + module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, forward_flag) + module._forward_block_ctx.enter() + +def zero_post_forward(module, inputs, outputs): + module._forward_block_ctx.exit() + +def zero_pre_backward(module, grad_outputs): + backward_flag = 2 if config['zero_level'] == 2 else 0 + with torch.enable_grad(): + module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, backward_flag) + module._backward_block_ctx.enter() + +def zero_post_backward(module, grad_inputs, grad_outputs): + with torch.enable_grad(): + module._backward_block_ctx.exit() - ctx.nontensor_inputs = others - ctx.self = self - ctx.save_list = copy.deepcopy(save_list) - ctx.num_save_needed = save_list[-1][1]+1 - ctx.layers_dict = [{} for _ in range(len(self))] - layer_inputs = [] - layer_inspector = [] - cuda_rng_state = [] - for i in range(len(self)): - with torch.no_grad(): - if save_list[i][0] == i: - layer_inputs += [hidden_state.detach() for hidden_state in hidden_states] - cuda_rng_state.append( torch.cuda.get_rng_state() ) - if config['zero_level']==2: - flag = 1 - else: - flag = 0 - block_ctx = CheckpointBlockContext(self._modules[str(i)], ctx.layers_dict[i], flag) - # gather parameter on load stream - block_ctx.enter() - # call inner module directly - with ScopedTensorInspectorContext() as inspector: - hidden_states = self._modules[str(i)]._module._call_impl(*hidden_states, *args[num_hidden:]) - if not isinstance(hidden_states, tuple): - hidden_states = (hidden_states,) - block_ctx.exit() - for it in inspector.hidden_states: - debug.append("_inspect_hidden_states", it) - layer_inspector.append(inspector.hidden_states) - - ctx.layer_inspector = layer_inspector - ctx.cuda_rng_state = cuda_rng_state - ctx.num_hidden = num_hidden - - ctx.save_for_backward(*layer_inputs, *tensors) - - if self.return_hidden_states: - middle_hiddens = layer_inputs - for mid in middle_hiddens: - mid.requires_grad_() - middle_hiddens = [ - torch.stack(middle_hiddens[i::num_hidden], dim=0) - for i in range(num_hidden) - ] - else: - middle_hiddens = [None] * num_hidden - return tuple(list(hidden_states) + middle_hiddens + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens]) - - - @staticmethod - def backward(ctx, *grads): - grad_hidden_states = grads[:ctx.num_hidden] - grad_middles = grads[ctx.num_hidden:2*ctx.num_hidden] - grad_inspectors = grads[2*ctx.num_hidden:] - def exit_prev(prev_ctx, prev_grad): - if prev_ctx is not None: - if prev_grad: - with torch.enable_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - else: - with torch.no_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument.") - all_inputs = [] - input_requires_grad = [] - - layer_inputs = ctx.saved_tensors[:ctx.num_save_needed * ctx.num_hidden] - save_args = ctx.saved_tensors[ctx.num_save_needed * ctx.num_hidden:] - for tensor, other in zip(save_args, ctx.nontensor_inputs): - if tensor is None: - all_inputs.append(other) - input_requires_grad.append(False) - else: - # detach for tensor inputs - input_requires_grad.append( tensor.requires_grad ) - nw_tensor = tensor.detach() - nw_tensor.requires_grad = tensor.requires_grad - all_inputs.append(nw_tensor) - - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): - with torch.enable_grad(): - # overlap load and scatter here - prev_ctx = None - prev_grad = False - for i in reversed(range(len(ctx.self))): - if ctx.save_list[i][0] != i: - with torch.no_grad(): - st = ctx.save_list[i][0] - for j in range(st, i): - torch.cuda.set_rng_state(ctx.cuda_rng_state[j]) - if config['zero_level'] == 2: - flag = 2 - else: - flag = 0 - block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)], ctx.layers_dict[j], flag) - block_ctx.enter() - exit_prev(prev_ctx, prev_grad) - outputs = ctx.self._modules[str(j)]._module._call_impl( - layer_inputs[ctx.save_list[j][1]*ctx.num_hidden: ctx.save_list[j][1]*ctx.num_hidden+ctx.num_hidden], - *all_inputs - ) - if not isinstance(outputs, tuple): - outputs = (outputs,) - prev_ctx = block_ctx - prev_grad = False - for k, output in enumerate(outputs): - layer_inputs[ctx.save_list[j+1][1]*ctx.num_hidden + k].copy_(output) - ctx.save_list[j+1][0] = j+1 - - torch.cuda.set_rng_state(ctx.cuda_rng_state[i]) - ipts = [ - layer_inputs[ctx.save_list[i][1]*ctx.num_hidden + k].detach().requires_grad_() - for k in range(ctx.num_hidden) - ] - if config['zero_level'] == 2: - flag = 2 - else: - flag = 0 - block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)], ctx.layers_dict[i], flag) - block_ctx.enter() - exit_prev(prev_ctx, prev_grad) - prev_ctx = block_ctx - prev_grad = True - - with ScopedTensorInspectorContext() as inspector: - outputs = ctx.self._modules[str(i)]._module._call_impl(*ipts, *all_inputs) - if not isinstance(outputs, tuple): - outputs = (outputs,) - - assert len(ctx.layer_inspector[i]) == len(inspector.hidden_states), "Backward step changed" - for j, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.layer_inspector[i][j]["name"], "Backward step changed" - assert it["shape"] == ctx.layer_inspector[i][j]["shape"], "Backward step changed" - assert it["group"] == ctx.layer_inspector[i][j]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.layer_inspector[i][j]["tensor"] = it["tensor"] - ctx.layer_inspector[i][j]["requires_grad"] = it["requires_grad"] - if len(inspector.hidden_states) > 0: - torch.autograd.backward( - list(outputs) + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], - grad_hidden_states + grad_inspectors[-len(inspector.hidden_states):], - ) - grad_inspectors = grad_inspectors[:-len(inspector.hidden_states)] - else: - torch.autograd.backward( - outputs, - grad_hidden_states, - ) - grad_hidden_states = [ipt.grad for ipt in ipts] - for k in range(ctx.num_hidden): - if grad_middles[k] is not None: - grad_hidden_states[k] = grad_hidden_states[k] + grad_middles[k][i] - grad_hidden_states = tuple(grad_hidden_states) - - exit_prev(prev_ctx, prev_grad) - grads = [] - for inp, requires_grad in zip(all_inputs, input_requires_grad): - if requires_grad: - grads.append(inp.grad) - else: - grads.append(None) - return (None, None, None, None) + tuple(grad_hidden_states) + tuple(grads) - class TransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of CheckpointBlocks. @@ -899,6 +756,19 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) + + if config["zero_level"] > 0: + module.register_forward_pre_hook(zero_pre_forward) + module.register_forward_hook(zero_post_forward) + module.register_full_backward_pre_hook(zero_pre_backward) + + if config["use_checkpoint"]: + module.register_forward_pre_hook(checkpoint_pre_forward) + module.register_full_backward_pre_hook(checkpoint_pre_backward) + + if config["zero_level"] > 0 and not config["use_checkpoint"]: + module.register_full_backward_hook(zero_post_backward) + self._modules[str(i)] = module self.add_module(str(i), module) @@ -935,9 +805,13 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: def forward(self, *args, return_hidden_states = False): self.return_hidden_states = return_hidden_states - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, self.num_hidden, *args) + outputs = args[:self.num_hidden] + others = args[self.num_hidden:] + for i in range(len(self._modules)): + outputs = self._modules[str(i)](*outputs, *others) + outputs = (outputs,) + if return_hidden_states: return tuple(outputs[:2*self.num_hidden]) else: - return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] \ No newline at end of file + return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] diff --git a/bmtrain/init.py b/bmtrain/init.py index 5c3006d2..4cbdafd2 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -13,6 +13,7 @@ def init_distributed( zero_level: int = 3, pipe_size: int = -1, num_micro_batches: int = None, + use_checkpoint: bool = True, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -70,6 +71,7 @@ def init_distributed( config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() config["zero_level"] = zero_level + config["use_checkpoint"] = use_checkpoint config["topology"] = topology(config) config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] cpus_this_worker = None From 7b080e7169443d3ff9713c20a85aa28a03445da2 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 25 Jul 2023 16:40:21 +0800 Subject: [PATCH 002/122] async backward --- bmtrain/block_layer.py | 57 ++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index d2dd5cca..eb46574a 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -333,6 +333,7 @@ def __init__(self, inner_module : torch.nn.Module): self._module = inner_module self._inputs = None self._layer_dict = {} + self._backward_block_ctx = None # build large parameter&grad here self._param_info = [] self._storage_params : Dict[str, torch.nn.Parameter] = {} @@ -450,16 +451,16 @@ def __init__(self, inner_module : torch.nn.Module): for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] - def __call__(self, *args, **kwargs): - # gather here - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - all_inputs = list(args) - for kw, val in kwargs.items(): - all_inputs.append(kw) - all_inputs.append(val) - outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) - len_output = outputs[0] - return outputs[1:1+len_output] if len_output > 0 else outputs[1] +# def __call__(self, *args, **kwargs): +# # gather here +# placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) +# all_inputs = list(args) +# for kw, val in kwargs.items(): +# all_inputs.append(kw) +# all_inputs.append(val) +# outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) +# len_output = outputs[0] +# return outputs[1:1+len_output] if len_output > 0 else outputs[1] def forward(self, *args): if config["use_checkpoint"]: @@ -698,16 +699,6 @@ def eval(self): def __repr__(self): return self._module.__repr__() -def checkpoint_pre_forward(module, inputs): - module._inputs = inputs - -def checkpoint_pre_backward(module, grad_outputs): - with torch.enable_grad(): - out = module._module(*module._inputs) - torch.autograd.backward(out, *grad_outputs) - - if config["zero_level"] != 0: - module._backward_block_ctx.exit() def zero_pre_forward(module, inputs): forward_flag = 1 if config['zero_level'] == 2 else 0 @@ -720,12 +711,26 @@ def zero_post_forward(module, inputs, outputs): def zero_pre_backward(module, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 with torch.enable_grad(): - module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, backward_flag) - module._backward_block_ctx.enter() + module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag) + module._backward_block_ctxs[module._layer_id].enter() def zero_post_backward(module, grad_inputs, grad_outputs): with torch.enable_grad(): - module._backward_block_ctx.exit() + if not module._is_last_layer: + module._backward_block_ctxs[module._layer_id + 1].exit() + if module._layer_id == 0: + module._backward_block_ctxs[0].exit() + +def checkpoint_pre_forward(module, inputs): + module._inputs = inputs + +def checkpoint_pre_backward(module, grad_outputs): + with torch.enable_grad(): + out = module._module(*module._inputs) + torch.autograd.backward(out, *grad_outputs) + + if config["zero_level"] != 0: + zero_post_backward(module, None, None) class TransformerBlockList(torch.nn.Module): @@ -753,6 +758,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) super().__init__() self._modules = {} + self._backward_block_ctxs = [None for _ in range(len(modules))] for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) @@ -768,10 +774,13 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if config["zero_level"] > 0 and not config["use_checkpoint"]: module.register_full_backward_hook(zero_post_backward) + module._backward_block_ctxs = self._backward_block_ctxs + module._layer_id = i + module._is_last_layer = True if i == len(modules) -1 else False self._modules[str(i)] = module self.add_module(str(i), module) - + self.num_hidden = num_hidden if sqrt: From be5f9d72c898cd9bbf114d274b33c4abdae4930d Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 25 Jul 2023 17:05:39 +0800 Subject: [PATCH 003/122] async forward --- bmtrain/block_layer.py | 62 ++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index d2dd5cca..663e7cf1 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -208,7 +208,7 @@ def enter(self): def __enter__(self): self.enter() - def exit(self): + def exit(self, backward=False): """ Reduce scatter gradients """ @@ -217,8 +217,7 @@ def exit(self): return self._need_release = False self.block._ready = False - requires_grad = torch.is_grad_enabled() - if requires_grad: + if backward: for kw, val in self.block._storage_info.items(): local_param = self.block._storage_params[kw] @@ -333,6 +332,7 @@ def __init__(self, inner_module : torch.nn.Module): self._module = inner_module self._inputs = None self._layer_dict = {} + self._backward_block_ctx = None # build large parameter&grad here self._param_info = [] self._storage_params : Dict[str, torch.nn.Parameter] = {} @@ -450,16 +450,16 @@ def __init__(self, inner_module : torch.nn.Module): for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] - def __call__(self, *args, **kwargs): - # gather here - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - all_inputs = list(args) - for kw, val in kwargs.items(): - all_inputs.append(kw) - all_inputs.append(val) - outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) - len_output = outputs[0] - return outputs[1:1+len_output] if len_output > 0 else outputs[1] +# def __call__(self, *args, **kwargs): +# # gather here +# placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) +# all_inputs = list(args) +# for kw, val in kwargs.items(): +# all_inputs.append(kw) +# all_inputs.append(val) +# outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) +# len_output = outputs[0] +# return outputs[1:1+len_output] if len_output > 0 else outputs[1] def forward(self, *args): if config["use_checkpoint"]: @@ -698,16 +698,6 @@ def eval(self): def __repr__(self): return self._module.__repr__() -def checkpoint_pre_forward(module, inputs): - module._inputs = inputs - -def checkpoint_pre_backward(module, grad_outputs): - with torch.enable_grad(): - out = module._module(*module._inputs) - torch.autograd.backward(out, *grad_outputs) - - if config["zero_level"] != 0: - module._backward_block_ctx.exit() def zero_pre_forward(module, inputs): forward_flag = 1 if config['zero_level'] == 2 else 0 @@ -720,12 +710,26 @@ def zero_post_forward(module, inputs, outputs): def zero_pre_backward(module, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 with torch.enable_grad(): - module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, backward_flag) - module._backward_block_ctx.enter() + module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag) + module._backward_block_ctxs[module._layer_id].enter() def zero_post_backward(module, grad_inputs, grad_outputs): with torch.enable_grad(): - module._backward_block_ctx.exit() + if not module._is_last_layer: + module._backward_block_ctxs[module._layer_id + 1].exit(True) + if module._layer_id == 0: + module._backward_block_ctxs[0].exit(True) + +def checkpoint_pre_forward(module, inputs): + module._inputs = inputs + +def checkpoint_pre_backward(module, grad_outputs): + with torch.enable_grad(): + out = module._module(*module._inputs) + torch.autograd.backward(out, *grad_outputs) + + if config["zero_level"] != 0: + zero_post_backward(module, None, None) class TransformerBlockList(torch.nn.Module): @@ -753,6 +757,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) super().__init__() self._modules = {} + self._backward_block_ctxs = [None for _ in range(len(modules))] for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) @@ -768,10 +773,13 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if config["zero_level"] > 0 and not config["use_checkpoint"]: module.register_full_backward_hook(zero_post_backward) + module._backward_block_ctxs = self._backward_block_ctxs + module._layer_id = i + module._is_last_layer = True if i == len(modules) -1 else False self._modules[str(i)] = module self.add_module(str(i), module) - + self.num_hidden = num_hidden if sqrt: From 05bc5530ebf1411733115dbb4f401b2190f61af3 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 25 Jul 2023 17:17:31 +0800 Subject: [PATCH 004/122] fix --- bmtrain/pipe_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 69c299bc..032fcfa9 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -111,7 +111,7 @@ def exit_prev(prev_ctx, prev_grad): ipt = layer_inputs[ctx.save_list[idx][1]].requires_grad_() if ctx.micro_idx == 0: ctx.block_ctx_list[idx] = CheckpointBlockContext(ctx.self._modules[str(layer_id)], ctx.layers_dict[idx], 2, pipe=True) - ctx.block_ctx_list[idx].enter() + ctx.block_ctx_list[idx].enter(True) if ctx.micro_idx == config["micros"]-1: exit_prev(prev_ctx, prev_grad) prev_ctx = ctx.block_ctx_list[idx] @@ -514,4 +514,4 @@ def exit(self): def __enter__(self): return self.enter() def __exit__(self, exc_type, exc_val, exc_tb): - self.exit() \ No newline at end of file + self.exit() From bdf7087032363a58aa61be5d46081c7cac5c2a8a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 26 Jul 2023 17:03:58 +0800 Subject: [PATCH 005/122] save cuda_rng_state --- bmtrain/block_layer.py | 2 ++ bmtrain/pipe_layer.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 663e7cf1..02dd5fb2 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -722,9 +722,11 @@ def zero_post_backward(module, grad_inputs, grad_outputs): def checkpoint_pre_forward(module, inputs): module._inputs = inputs + module._cuda_rng_state = torch.cuda.get_rng_state() def checkpoint_pre_backward(module, grad_outputs): with torch.enable_grad(): + torch.cuda.set_rng_state(module._cuda_rng_state) out = module._module(*module._inputs) torch.autograd.backward(out, *grad_outputs) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 032fcfa9..26d42565 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -73,7 +73,7 @@ def exit_prev(prev_ctx, prev_grad): if prev_ctx is not None: if prev_grad: with torch.enable_grad(): - prev_ctx.exit() + prev_ctx.exit(True) config["load_stream"].record_event(config["load_event"]) else: with torch.no_grad(): @@ -111,7 +111,7 @@ def exit_prev(prev_ctx, prev_grad): ipt = layer_inputs[ctx.save_list[idx][1]].requires_grad_() if ctx.micro_idx == 0: ctx.block_ctx_list[idx] = CheckpointBlockContext(ctx.self._modules[str(layer_id)], ctx.layers_dict[idx], 2, pipe=True) - ctx.block_ctx_list[idx].enter(True) + ctx.block_ctx_list[idx].enter() if ctx.micro_idx == config["micros"]-1: exit_prev(prev_ctx, prev_grad) prev_ctx = ctx.block_ctx_list[idx] From 6a366e323bdac5a4983e2f23e5d7b629c4fea62c Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 27 Jul 2023 12:19:23 +0800 Subject: [PATCH 006/122] fix --- bmtrain/block_layer.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 02dd5fb2..3df5e48b 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -141,7 +141,7 @@ def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, flag : int self.comm = config["zero_comm"] else: self.comm = config["comm"] - def enter(self): + def enter(self, requires_grad=False): """ gather parameters """ @@ -151,7 +151,7 @@ def enter(self): self._need_release = True wait_loader() - requires_grad = torch.is_grad_enabled() +#requires_grad = torch.is_grad_enabled() with torch.cuda.stream(config["load_stream"]): for kw, val in self.block._storage_info.items(): assert self.block._storage_params[kw].is_cuda @@ -711,14 +711,16 @@ def zero_pre_backward(module, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 with torch.enable_grad(): module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag) - module._backward_block_ctxs[module._layer_id].enter() - -def zero_post_backward(module, grad_inputs, grad_outputs): - with torch.enable_grad(): + module._backward_block_ctxs[module._layer_id].enter(True) if not module._is_last_layer: module._backward_block_ctxs[module._layer_id + 1].exit(True) - if module._layer_id == 0: - module._backward_block_ctxs[0].exit(True) + module._backward_block_ctxs[module._layer_id + 1] = None + + +def zero_post_backward(module, grad_inputs, grad_outputs): + if module._layer_id == 0: + module._backward_block_ctxs[0].exit(True) + module._backward_block_ctxs[0] = None def checkpoint_pre_forward(module, inputs): module._inputs = inputs @@ -730,8 +732,6 @@ def checkpoint_pre_backward(module, grad_outputs): out = module._module(*module._inputs) torch.autograd.backward(out, *grad_outputs) - if config["zero_level"] != 0: - zero_post_backward(module, None, None) class TransformerBlockList(torch.nn.Module): From 25ef84fed39e1eaf559074a33580e5fb1f9b8edf Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 27 Jul 2023 15:17:28 +0800 Subject: [PATCH 007/122] fix --- bmtrain/block_layer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 3df5e48b..bfba3d65 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -727,10 +727,15 @@ def checkpoint_pre_forward(module, inputs): module._cuda_rng_state = torch.cuda.get_rng_state() def checkpoint_pre_backward(module, grad_outputs): - with torch.enable_grad(): - torch.cuda.set_rng_state(module._cuda_rng_state) - out = module._module(*module._inputs) - torch.autograd.backward(out, *grad_outputs) + with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): + with torch.enable_grad(): + torch.cuda.set_rng_state(module._cuda_rng_state) + out = module._module(*module._inputs) + torch.autograd.backward(out, *grad_outputs) + + if module._layer_id == 0: + module._backward_block_ctxs[0].exit(True) + module._backward_block_ctxs[0] = None @@ -775,6 +780,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if config["zero_level"] > 0 and not config["use_checkpoint"]: module.register_full_backward_hook(zero_post_backward) + module._backward_block_ctxs = self._backward_block_ctxs module._layer_id = i module._is_last_layer = True if i == len(modules) -1 else False From 768f2099bf58ef72f1698cd49e2d6ac942bed08a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 27 Jul 2023 16:42:32 +0800 Subject: [PATCH 008/122] fix --- bmtrain/block_layer.py | 20 ++++++++++---------- bmtrain/pipe_layer.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index bfba3d65..6dc67817 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -450,16 +450,16 @@ def __init__(self, inner_module : torch.nn.Module): for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] -# def __call__(self, *args, **kwargs): -# # gather here -# placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) -# all_inputs = list(args) -# for kw, val in kwargs.items(): -# all_inputs.append(kw) -# all_inputs.append(val) -# outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) -# len_output = outputs[0] -# return outputs[1:1+len_output] if len_output > 0 else outputs[1] + def __call__(self, *args, **kwargs): + # gather here + placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) + all_inputs = list(args) + for kw, val in kwargs.items(): + all_inputs.append(kw) + all_inputs.append(val) + outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) + len_output = outputs[0] + return outputs[1:1+len_output] if len_output > 0 else outputs[1] def forward(self, *args): if config["use_checkpoint"]: diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 26d42565..8ee8ccaa 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -111,7 +111,7 @@ def exit_prev(prev_ctx, prev_grad): ipt = layer_inputs[ctx.save_list[idx][1]].requires_grad_() if ctx.micro_idx == 0: ctx.block_ctx_list[idx] = CheckpointBlockContext(ctx.self._modules[str(layer_id)], ctx.layers_dict[idx], 2, pipe=True) - ctx.block_ctx_list[idx].enter() + ctx.block_ctx_list[idx].enter(True) if ctx.micro_idx == config["micros"]-1: exit_prev(prev_ctx, prev_grad) prev_ctx = ctx.block_ctx_list[idx] From 324e0ddcbfbd94ce776d9c22634882d769c4d934 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 31 Jul 2023 10:45:19 +0800 Subject: [PATCH 009/122] remove __call__ --- bmtrain/block_layer.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 6dc67817..bfba3d65 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -450,16 +450,16 @@ def __init__(self, inner_module : torch.nn.Module): for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] - def __call__(self, *args, **kwargs): - # gather here - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - all_inputs = list(args) - for kw, val in kwargs.items(): - all_inputs.append(kw) - all_inputs.append(val) - outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) - len_output = outputs[0] - return outputs[1:1+len_output] if len_output > 0 else outputs[1] +# def __call__(self, *args, **kwargs): +# # gather here +# placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) +# all_inputs = list(args) +# for kw, val in kwargs.items(): +# all_inputs.append(kw) +# all_inputs.append(val) +# outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) +# len_output = outputs[0] +# return outputs[1:1+len_output] if len_output > 0 else outputs[1] def forward(self, *args): if config["use_checkpoint"]: From 0f4ddb597f75e6a49d3c5ca5c2ff2613d5d2dc8b Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 31 Jul 2023 10:54:49 +0800 Subject: [PATCH 010/122] refactor code structure --- bmtrain/block_layer.py | 223 +++------------------------------------ bmtrain/checkpointing.py | 155 +++++++++++++++++++++++++++ bmtrain/hook_func.py | 43 ++++++++ bmtrain/pipe_layer.py | 187 +++++++++++++++++++++++++++++--- 4 files changed, 387 insertions(+), 221 deletions(-) create mode 100644 bmtrain/hook_func.py diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index bfba3d65..fd3d9f9a 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -6,8 +6,15 @@ from . import nccl from .synchronize import wait_loader from .parameter import DistributedParameter, OpAllGather -from .checkpointing import ScopedTensorInspectorContext +from .checkpointing import ( + ScopedTensorInspectorContext, + CheckpointBlockContext +) + from . import debug + +from . import hook_func + import copy import inspect @@ -127,157 +134,6 @@ def backward(ctx, _, *grads): grads.append(None) return (None, None, None, None) + tuple(grads) -class CheckpointBlockContext: - def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, flag : int = 0, pipe = False) -> None: - self.block = block - self.ctx_dict = ctx_dict - self._param_buffer = {} - self._grad_buffer = {} - self._param_tensor = {} - self._grad_tensor = {} - self.flag = flag - self._need_release = False - if pipe: - self.comm = config["zero_comm"] - else: - self.comm = config["comm"] - def enter(self, requires_grad=False): - """ - gather parameters - """ - if self.block._ready: - return - self.block._ready = True - self._need_release = True - - wait_loader() -#requires_grad = torch.is_grad_enabled() - with torch.cuda.stream(config["load_stream"]): - for kw, val in self.block._storage_info.items(): - assert self.block._storage_params[kw].is_cuda - assert kw not in self._grad_buffer - assert kw not in self._param_buffer - local_param = self.block._storage_params[kw] - - storage_type = local_param.storage_type() - if self.flag != 2: - self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) - self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) - - if requires_grad and local_param.requires_grad: - self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) - self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() - if self.flag != 2: - nccl.groupStart() - for kw, val in self.block._storage_info.items(): - nccl.allGather( - self.block._storage_params[kw].storage(), - self._param_buffer[kw], - self.comm - ) - nccl.groupEnd() - - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config["load_stream"]) - - # set wait stream for each storage - for kw in self.block._storage_info.keys(): - if self.flag != 2: - self._param_tensor[kw].record_stream(current_stream) - if requires_grad and kw in self._grad_tensor: - self._grad_tensor[kw].record_stream(current_stream) - - # update parameters in block - for param in self.block._param_info: - kw_name = param["kw_name"] - offset = param["offset"] - shape = param["shape"] - - if self.flag != 2: - dtype = self._param_buffer[kw_name].dtype - device = self._param_buffer[kw_name].device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) - else: - dtype = param["parameter"].data.dtype - device = param["parameter"].data.device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) - - if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) - - def __enter__(self): - self.enter() - - def exit(self, backward=False): - """ - Reduce scatter gradients - """ - - if not self._need_release: - return - self._need_release = False - self.block._ready = False - if backward: - for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] - - # accumulate previous gradient - if local_param.requires_grad: - if local_param.grad is None: - grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist - local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_() - else: - self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad - - current_stream = torch.cuda.current_stream() - config["load_stream"].wait_stream(current_stream) # wait for backward - - with torch.cuda.stream(config["load_stream"]): - nccl.groupStart() - for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] - - # scatter gradient - if local_param.requires_grad: - nccl.reduceScatter( - self._grad_buffer[kw], - local_param.grad.storage(), - "sum", - self.comm - ) - nccl.groupEnd() - - # set wait stream for each storage - for kw in self._grad_tensor.keys(): - # grads can not be freed until reduce ops finish - self._grad_tensor[kw].record_stream(config["load_stream"]) - - - # Release all parameters from buffer to block_storge - for param in self.block._param_info: - kw_name = param["kw_name"] - dtype = self.block._storage_params[kw_name].dtype - device = self.block._storage_params[kw_name].device - if "begin" not in param: - param["parameter"].data = torch.tensor([], dtype=dtype, device=device) - param["parameter"].grad = None - continue - begin = param["begin"] - end = param["end"] - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) - if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) - if self.flag == 1: - for i in self._param_buffer: - self.ctx_dict[i] = self._param_buffer[i] - self._grad_tensor = {} - self._param_tensor = {} - self._grad_buffer = {} - self._param_buffer = {} - - def __exit__(self, exc_type, exc_val, exc_tb): - # reduce scatter gradients - self.exit() def storage_type_cuda(storage_type): STORAGE_MAP = { @@ -450,17 +306,6 @@ def __init__(self, inner_module : torch.nn.Module): for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] -# def __call__(self, *args, **kwargs): -# # gather here -# placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) -# all_inputs = list(args) -# for kw, val in kwargs.items(): -# all_inputs.append(kw) -# all_inputs.append(val) -# outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) -# len_output = outputs[0] -# return outputs[1:1+len_output] if len_output > 0 else outputs[1] - def forward(self, *args): if config["use_checkpoint"]: with torch.no_grad(): @@ -699,46 +544,6 @@ def __repr__(self): return self._module.__repr__() -def zero_pre_forward(module, inputs): - forward_flag = 1 if config['zero_level'] == 2 else 0 - module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, forward_flag) - module._forward_block_ctx.enter() - -def zero_post_forward(module, inputs, outputs): - module._forward_block_ctx.exit() - -def zero_pre_backward(module, grad_outputs): - backward_flag = 2 if config['zero_level'] == 2 else 0 - with torch.enable_grad(): - module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag) - module._backward_block_ctxs[module._layer_id].enter(True) - if not module._is_last_layer: - module._backward_block_ctxs[module._layer_id + 1].exit(True) - module._backward_block_ctxs[module._layer_id + 1] = None - - -def zero_post_backward(module, grad_inputs, grad_outputs): - if module._layer_id == 0: - module._backward_block_ctxs[0].exit(True) - module._backward_block_ctxs[0] = None - -def checkpoint_pre_forward(module, inputs): - module._inputs = inputs - module._cuda_rng_state = torch.cuda.get_rng_state() - -def checkpoint_pre_backward(module, grad_outputs): - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): - with torch.enable_grad(): - torch.cuda.set_rng_state(module._cuda_rng_state) - out = module._module(*module._inputs) - torch.autograd.backward(out, *grad_outputs) - - if module._layer_id == 0: - module._backward_block_ctxs[0].exit(True) - module._backward_block_ctxs[0] = None - - - class TransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of CheckpointBlocks. @@ -770,16 +575,16 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module = CheckpointBlock(module) if config["zero_level"] > 0: - module.register_forward_pre_hook(zero_pre_forward) - module.register_forward_hook(zero_post_forward) - module.register_full_backward_pre_hook(zero_pre_backward) + module.register_forward_pre_hook(hook_func.zero_pre_forward) + module.register_forward_hook(hook_func.zero_post_forward) + module.register_full_backward_pre_hook(hook_func.zero_pre_backward) if config["use_checkpoint"]: - module.register_forward_pre_hook(checkpoint_pre_forward) - module.register_full_backward_pre_hook(checkpoint_pre_backward) + module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) + module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) if config["zero_level"] > 0 and not config["use_checkpoint"]: - module.register_full_backward_hook(zero_post_backward) + module.register_full_backward_hook(hook_func.zero_post_backward) module._backward_block_ctxs = self._backward_block_ctxs module._layer_id = i diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index ac6a8d4f..b0c936f3 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -2,6 +2,9 @@ from typing import Callable, TypeVar from functools import wraps from . import debug +from . import nccl +from .global_var import config +from .synchronize import wait_loader class ScopedDebugTensorList: def __init__(self) -> None: @@ -28,3 +31,155 @@ def __exit__(self, *args): self._local_list._set_hidden_states(debug.get("_inspect_hidden_states", [])) debug.set("_inspect_hidden_states", self.prev_hidden) self.prev_hidden = None + +class CheckpointBlockContext: + def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, flag : int = 0, pipe = False) -> None: + self.block = block + self.ctx_dict = ctx_dict + self._param_buffer = {} + self._grad_buffer = {} + self._param_tensor = {} + self._grad_tensor = {} + self.flag = flag + self._need_release = False + if pipe: + self.comm = config["zero_comm"] + else: + self.comm = config["comm"] + def enter(self, requires_grad=False): + """ + gather parameters + """ + if self.block._ready: + return + self.block._ready = True + self._need_release = True + + wait_loader() +#requires_grad = torch.is_grad_enabled() + with torch.cuda.stream(config["load_stream"]): + for kw, val in self.block._storage_info.items(): + assert self.block._storage_params[kw].is_cuda + assert kw not in self._grad_buffer + assert kw not in self._param_buffer + local_param = self.block._storage_params[kw] + + storage_type = local_param.storage_type() + if self.flag != 2: + self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) + self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) + + if requires_grad and local_param.requires_grad: + self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) + self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() + if self.flag != 2: + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + nccl.allGather( + self.block._storage_params[kw].storage(), + self._param_buffer[kw], + self.comm + ) + nccl.groupEnd() + + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["load_stream"]) + + # set wait stream for each storage + for kw in self.block._storage_info.keys(): + if self.flag != 2: + self._param_tensor[kw].record_stream(current_stream) + if requires_grad and kw in self._grad_tensor: + self._grad_tensor[kw].record_stream(current_stream) + + # update parameters in block + for param in self.block._param_info: + kw_name = param["kw_name"] + offset = param["offset"] + shape = param["shape"] + + if self.flag != 2: + dtype = self._param_buffer[kw_name].dtype + device = self._param_buffer[kw_name].device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) + else: + dtype = param["parameter"].data.dtype + device = param["parameter"].data.device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) + + if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) + + def __enter__(self): + self.enter() + + def exit(self, backward=False): + """ + Reduce scatter gradients + """ + + if not self._need_release: + return + self._need_release = False + self.block._ready = False + if backward: + for kw, val in self.block._storage_info.items(): + local_param = self.block._storage_params[kw] + + # accumulate previous gradient + if local_param.requires_grad: + if local_param.grad is None: + grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist + local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_() + else: + self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad + + current_stream = torch.cuda.current_stream() + config["load_stream"].wait_stream(current_stream) # wait for backward + + with torch.cuda.stream(config["load_stream"]): + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + local_param = self.block._storage_params[kw] + + # scatter gradient + if local_param.requires_grad: + nccl.reduceScatter( + self._grad_buffer[kw], + local_param.grad.storage(), + "sum", + self.comm + ) + nccl.groupEnd() + + # set wait stream for each storage + for kw in self._grad_tensor.keys(): + # grads can not be freed until reduce ops finish + self._grad_tensor[kw].record_stream(config["load_stream"]) + + + # Release all parameters from buffer to block_storge + for param in self.block._param_info: + kw_name = param["kw_name"] + dtype = self.block._storage_params[kw_name].dtype + device = self.block._storage_params[kw_name].device + if "begin" not in param: + param["parameter"].data = torch.tensor([], dtype=dtype, device=device) + param["parameter"].grad = None + continue + begin = param["begin"] + end = param["end"] + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) + if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + if self.flag == 1: + for i in self._param_buffer: + self.ctx_dict[i] = self._param_buffer[i] + self._grad_tensor = {} + self._param_tensor = {} + self._grad_buffer = {} + self._param_buffer = {} + + def __exit__(self, exc_type, exc_val, exc_tb): + # reduce scatter gradients + self.exit() diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py new file mode 100644 index 00000000..dbb6e504 --- /dev/null +++ b/bmtrain/hook_func.py @@ -0,0 +1,43 @@ +import torch +from .global_var import config +from .checkpointing import CheckpointBlockContext + +def zero_pre_forward(module, inputs): + forward_flag = 1 if config['zero_level'] == 2 else 0 + module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, forward_flag) + module._forward_block_ctx.enter() + +def zero_post_forward(module, inputs, outputs): + module._forward_block_ctx.exit() + +def zero_pre_backward(module, grad_outputs): + backward_flag = 2 if config['zero_level'] == 2 else 0 + with torch.enable_grad(): + module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag) + module._backward_block_ctxs[module._layer_id].enter(True) + if not module._is_last_layer: + module._backward_block_ctxs[module._layer_id + 1].exit(True) + module._backward_block_ctxs[module._layer_id + 1] = None + + +def zero_post_backward(module, grad_inputs, grad_outputs): + if module._layer_id == 0: + module._backward_block_ctxs[0].exit(True) + module._backward_block_ctxs[0] = None + +def checkpoint_pre_forward(module, inputs): + module._inputs = inputs + module._cuda_rng_state = torch.cuda.get_rng_state() + +def checkpoint_pre_backward(module, grad_outputs): + with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): + with torch.enable_grad(): + torch.cuda.set_rng_state(module._cuda_rng_state) + out = module._module(*module._inputs) + torch.autograd.backward(out, *grad_outputs) + + if module._layer_id == 0: + module._backward_block_ctxs[0].exit(True) + module._backward_block_ctxs[0] = None + + diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 8ee8ccaa..fe7fd378 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -5,12 +5,15 @@ from typing import Dict, Iterable, Iterator, Tuple, Union, List import torch -from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations +from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations from .global_var import config from . import nccl -from .checkpointing import ScopedTensorInspectorContext +from .checkpointing import ( + ScopedTensorInspectorContext, + CheckpointBlockContext +) from . import debug -from .block_layer import CheckpointBlockContext, CheckpointBlock, round_up, _get_param_kw +from .block_layer import CheckpointBlock, round_up, _get_param_kw class OpMicroForward(torch.autograd.Function): @staticmethod @@ -141,6 +144,7 @@ def exit_prev(prev_ctx, prev_grad): [grad_hidden_state], ) grad_hidden_state = ipt.grad + if grad_middle is not None: grad_hidden_state = grad_hidden_state + grad_middle[idx] if ctx.micro_idx == config["micros"]-1: @@ -310,6 +314,98 @@ def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle : torch.Tensor, return (None, None, None, grad) + tuple(grads) + (None,) +def zero_pre_forward(module, inputs): + if module._micro_idx == 0: + forward_flag = 1 if config['zero_level'] == 2 else 0 + module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, forward_flag, pipe=True) + module._forward_block_ctx.enter() + +def zero_post_forward(module, inputs, outputs): + if module._micro_idx == config['micros'] - 1: + module._forward_block_ctx.exit() + + +def pipe_pre_forward(module, inputs): + if not module._is_first_stage: + if module._is_first_layer: + pre_inputs = recv_activations(module.stage_id - 1, config['pipe_comm']) + pre_inputs.requires_grad_() + return (pre_inputs, ) + inputs[1:] + +def pipe_post_forward(module, inputs, outputs): + if not module._is_last_stage: + if module._is_last_layer: + send_data = outputs[0] if isinstance(outputs, tuple) else outputs + send_activations(send_data, module.stage_id + 1, config['pipe_comm']) + +def zero_pre_backward(module, grad_inputs): + if module._micro_idx == config['micros'] - 1: + backward_flag = 2 if config['zero_level'] == 2 else 0 + with torch.enable_grad(): + module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) + module._backward_block_ctxs[module._layer_id].enter(True) +# if module._micro_idx == 0: +# if not module._is_last_layer: +# module._backward_block_ctxs[module._layer_id + 1].exit(True) + +def zero_post_backward(module, grad_inputs, grad_outputs): + if module._micro_idx == 0: + module._backward_block_ctxs[module._layer_id].exit(True) +# if module._is_first_layer: +# module._backward_block_ctxs[module._layer_id].exit(True) + +def pipe_pre_backward(module, grad_inputs): + if module._is_last_layer: + module._grad_list = all_gather(grad_inputs[0], config["pipe_comm"]) + module._grad_list = module._grad_list.flatten(start_dim=0, end_dim=1).chunk(module.stages, dim=0) + + if module._is_last_layer and module._is_last_stage: + return (module._grad_list[module._micro_idx], ) + + if not module._is_last_stage: + if module._is_last_layer: + pre_grad_inputs = recv_activations(module.stage_id + 1, config['pipe_comm']) + return (pre_grad_inputs, ) + grad_inputs[1:] + + +def pipe_post_backward(module, grad_inputs, grad_outputs): + if not module._is_first_stage: + if module._is_first_layer: + send_data = grad_inputs[0] if isinstance(grad_inputs, tuple) else grad_inputs + if send_data is not None: + send_activations(send_data, module.stage_id - 1, config['pipe_comm']) + +# if module._is_first_layer: +# if module._micro_idx == config['micros'] -1: +# module._all_grads = [] +# grad = grad_inputs[0] +# module._all_grads.append(grad) +# if module._micro_idx == 0: +# grads = torch.cat(module._all_grads, dim=0) +# grad = broadcast(grads, 0, config['pipe_comm']) +# grad = grad.chunk(module.stages, dim=0) +# return (grad[module.stage_id], ) + grad_inputs[1:] + + module._micro_idx -= 1 + +def checkpoint_pre_forward(module, inputs): + if module._micro_idx == 0: + module._inputs = [inputs] + module._cuda_rng_state = [torch.cuda.get_rng_state()] + else: + module._inputs.append(inputs) + module._cuda_rng_state.append(torch.cuda.get_rng_state()) + +def checkpoint_pre_backward(module, grad_outputs): + with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): + with torch.enable_grad(): + torch.cuda.set_rng_state(module._cuda_rng_state[module._micro_idx]) + out = module._module(*module._inputs[module._micro_idx]) + torch.autograd.backward(out, *grad_outputs) + + zero_post_backward(module, None, grad_outputs) + pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) + class PipelineTransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of CheckpointBlocks. @@ -331,9 +427,9 @@ class PipelineTransformerBlockList(torch.nn.Module): """ _modules: Dict[str, CheckpointBlock] - def __init__(self, modules: Iterable[CheckpointBlock]) -> None: + def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: super().__init__() - + self.num_hidden = num_hidden self._modules = {} rank = config['rank'] topo = config['topology'] @@ -345,9 +441,39 @@ def __init__(self, modules: Iterable[CheckpointBlock]) -> None: for idx, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) + + module.register_forward_pre_hook(pipe_pre_forward) + module.register_forward_pre_hook(zero_pre_forward) + if config['use_checkpoint']: + module.register_forward_pre_hook(checkpoint_pre_forward) + + module.register_forward_hook(zero_post_forward) + module.register_forward_hook(pipe_post_forward) + + module.register_full_backward_pre_hook(pipe_pre_backward) + module.register_full_backward_pre_hook(zero_pre_backward) + if config['use_checkpoint']: + module.register_full_backward_pre_hook(checkpoint_pre_backward) + + module.register_full_backward_hook(zero_post_backward) + module.register_full_backward_hook(pipe_post_backward) + + module.stage_id = self.stage_id + module.stages = self.stages + self._modules[str(idx)] = module self.layer_ids = self.get_range_by_stage_id(self.stage_id) + + self._backward_block_ctxs = [None for _ in range(len(modules))] + for i,layer_id in enumerate(self.layer_ids): + self._modules[str(layer_id)]._layer_id = layer_id + self._modules[str(layer_id)]._backward_block_ctxs = self._backward_block_ctxs + self._modules[str(layer_id)]._is_first_stage = True if self.stage_id == 0 else False + self._modules[str(layer_id)]._is_last_stage = True if self.stage_id == self.stages-1 else False + self._modules[str(layer_id)]._is_first_layer = True if i == 0 else False + self._modules[str(layer_id)]._is_last_layer = True if i == len(self.layer_ids)-1 else False + self.partition_modules(self.layer_ids) self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = pipe_group[self.pipe_idx, self.stage_id - 1] if self.stage_id > 0 else -1 @@ -366,15 +492,52 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=False): self.return_hidden_states = return_hidden_states - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - args = list(args) - args.append(batch_related) - outputs = OpPipeTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args) - hidden_state, middle_states = outputs[:2] + if False: + placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) + args = list(args) + args.append(batch_related) + outputs = OpPipeTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args) + hidden_state, middle_states = outputs[:2] + if return_hidden_states: + return hidden_state, middle_states + else: + return hidden_state + else: + batch_size = hidden_state.shape[0] + num_micros = config["micros"] + hidden_state_list = all_gather(hidden_state, config["pipe_comm"]).flatten(0, 1).detach().requires_grad_() + + args_list = [[] for _ in range(num_micros)] + for arg in args: + if torch.is_tensor(arg): + arg_all = all_gather(arg, config['pipe_comm']) + if arg.shape[0] == batch_size: + arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) + arg_all = [tensor.detach().requires_grad_(arg.requires_grad) for tensor in arg_all] + else: + arg_all = [arg_all[0].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] + else: + arg_all = [arg for _ in range(num_micros)] + for i in range(num_micros): + args_list[i].append(arg_all[i]) + + outputs = [] + hidden_state_list = hidden_state_list.chunk(num_micros, dim=0) + for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): + for idx,layer_id in enumerate(self.layer_ids): + self._modules[str(layer_id)]._micro_idx = micro_idx + hidden_state = self._modules[str(layer_id)](hidden_state, *arg) + outputs.append(hidden_state) + + last_hidden = torch.cat(outputs, dim=0) + last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) + last_hidden = last_hidden.chunk(self.stages, dim=0) + outputs = last_hidden[self.stage_id] + if return_hidden_states: - return hidden_state, middle_states + return outputs[:2*self.num_hidden] else: - return hidden_state + return outputs[:self.num_hidden] if self.num_hidden > 1 else outputs def get_range_by_stage_id(self, stage_id : int) -> List[int]: part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)] From 76c5c265a099d6af82a7171192b41d86a648e107 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 31 Jul 2023 13:02:21 +0800 Subject: [PATCH 011/122] pipeline --- bmtrain/hook_func.py | 117 +++++++++++++++++++++++++++++++++++------- bmtrain/pipe_layer.py | 21 ++++---- 2 files changed, 109 insertions(+), 29 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index dbb6e504..77daf425 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -1,43 +1,122 @@ import torch from .global_var import config from .checkpointing import CheckpointBlockContext +from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations def zero_pre_forward(module, inputs): - forward_flag = 1 if config['zero_level'] == 2 else 0 - module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, forward_flag) - module._forward_block_ctx.enter() + enter = True + pipe = False + if config['pipe_enabled']: + enter = module._micro_idx == 0 + pipe = True + if enter: + forward_flag = 1 if config['zero_level'] == 2 else 0 + module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, forward_flag, pipe=pipe) + module._forward_block_ctx.enter() def zero_post_forward(module, inputs, outputs): - module._forward_block_ctx.exit() + exit = True + if config['pipe_enabled']: + exit = module._micro_idx == config['micros'] - 1 + + if exit: + module._forward_block_ctx.exit() def zero_pre_backward(module, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 - with torch.enable_grad(): + if not config['pipe_enabled']: module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag) module._backward_block_ctxs[module._layer_id].enter(True) - if not module._is_last_layer: - module._backward_block_ctxs[module._layer_id + 1].exit(True) - module._backward_block_ctxs[module._layer_id + 1] = None + if exit: + if not module._is_last_layer: + module._backward_block_ctxs[module._layer_id + 1].exit(True) + else: + if module._micro_idx == config['micros'] - 1: + module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) + module._backward_block_ctxs[module._layer_id].enter(True) def zero_post_backward(module, grad_inputs, grad_outputs): - if module._layer_id == 0: - module._backward_block_ctxs[0].exit(True) - module._backward_block_ctxs[0] = None + if not config['pipe_enabled']: + if module._layer_id == 0: + module._backward_block_ctxs[0].exit(True) + else: + if module._micro_idx == 0: + module._backward_block_ctxs[module._layer_id].exit(True) + +def pipe_pre_forward(module, inputs): + if not module._is_first_stage: + if module._is_first_layer: + pre_inputs = recv_activations(module.stage_id - 1, config['pipe_comm']) + pre_inputs.requires_grad_() + return (pre_inputs, ) + inputs[1:] + +def pipe_post_forward(module, inputs, outputs): + if not module._is_last_stage: + if module._is_last_layer: + send_data = outputs[0] if isinstance(outputs, tuple) else outputs + send_activations(send_data, module.stage_id + 1, config['pipe_comm']) + +def pipe_pre_backward(module, grad_inputs): + if module._is_last_layer: + module._grad_list = all_gather(grad_inputs[0], config["pipe_comm"]) + module._grad_list = module._grad_list.flatten(start_dim=0, end_dim=1).chunk(module.stages, dim=0) + + if module._is_last_layer and module._is_last_stage: + return (module._grad_list[module._micro_idx], ) + + if not module._is_last_stage: + if module._is_last_layer: + pre_grad_inputs = recv_activations(module.stage_id + 1, config['pipe_comm']) + return (pre_grad_inputs, ) + grad_inputs[1:] + + +def pipe_post_backward(module, grad_inputs, grad_outputs): + if not module._is_first_stage: + if module._is_first_layer: + send_data = grad_inputs[0] if isinstance(grad_inputs, tuple) else grad_inputs + if send_data is not None: + send_activations(send_data, module.stage_id - 1, config['pipe_comm']) + +# if module._is_first_layer: +# if module._micro_idx == config['micros'] -1: +# module._all_grads = [] +# grad = grad_inputs[0] +# module._all_grads.append(grad) +# if module._micro_idx == 0: +# grads = torch.cat(module._all_grads, dim=0) +# grad = broadcast(grads, 0, config['pipe_comm']) +# grad = grad.chunk(module.stages, dim=0) +# return (grad[module.stage_id], ) + grad_inputs[1:] + + module._micro_idx -= 1 def checkpoint_pre_forward(module, inputs): - module._inputs = inputs - module._cuda_rng_state = torch.cuda.get_rng_state() + if not config['pipe_enabled']: + module._inputs = inputs + module._cuda_rng_state = torch.cuda.get_rng_state() + else: + if module._micro_idx == 0: + module._inputs = [inputs] + module._cuda_rng_state = [torch.cuda.get_rng_state()] + else: + module._inputs.append(inputs) + module._cuda_rng_state.append(torch.cuda.get_rng_state()) def checkpoint_pre_backward(module, grad_outputs): + inputs = module._inputs if not config['pipe_enabled'] else module._inputs[module._micro_idx] + cuda_rng_state = module._cuda_rng_state if not config['pipe_enabled'] else module._cuda_rng_state[module._micro_idx] with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): with torch.enable_grad(): - torch.cuda.set_rng_state(module._cuda_rng_state) - out = module._module(*module._inputs) + torch.cuda.set_rng_state(cuda_rng_state) + out = module._module(*inputs) torch.autograd.backward(out, *grad_outputs) - if module._layer_id == 0: - module._backward_block_ctxs[0].exit(True) - module._backward_block_ctxs[0] = None + if not config['pipe_enabled']: + if module._layer_id == 0: + module._backward_block_ctxs[0].exit(True) + module._backward_block_ctxs[0] = None + else: + zero_post_backward(module, None, grad_outputs) + pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) - diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index fe7fd378..155f0d5a 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -14,6 +14,7 @@ ) from . import debug from .block_layer import CheckpointBlock, round_up, _get_param_kw +from . import hook_func class OpMicroForward(torch.autograd.Function): @staticmethod @@ -442,21 +443,21 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - module.register_forward_pre_hook(pipe_pre_forward) - module.register_forward_pre_hook(zero_pre_forward) + module.register_forward_pre_hook(hook_func.pipe_pre_forward) + module.register_forward_pre_hook(hook_func.zero_pre_forward) if config['use_checkpoint']: - module.register_forward_pre_hook(checkpoint_pre_forward) + module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) - module.register_forward_hook(zero_post_forward) - module.register_forward_hook(pipe_post_forward) + module.register_forward_hook(hook_func.zero_post_forward) + module.register_forward_hook(hook_func.pipe_post_forward) - module.register_full_backward_pre_hook(pipe_pre_backward) - module.register_full_backward_pre_hook(zero_pre_backward) + module.register_full_backward_pre_hook(hook_func.pipe_pre_backward) + module.register_full_backward_pre_hook(hook_func.zero_pre_backward) if config['use_checkpoint']: - module.register_full_backward_pre_hook(checkpoint_pre_backward) + module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) - module.register_full_backward_hook(zero_post_backward) - module.register_full_backward_hook(pipe_post_backward) + module.register_full_backward_hook(hook_func.zero_post_backward) + module.register_full_backward_hook(hook_func.pipe_post_backward) module.stage_id = self.stage_id module.stages = self.stages From 16c092243940f96dc7643ba57aa32cea767bdd9a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 31 Jul 2023 19:22:29 +0800 Subject: [PATCH 012/122] for low version --- bmtrain/block_layer.py | 34 ++++++++++++++++++++++----------- bmtrain/hook_func.py | 43 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index fd3d9f9a..50ceb1d8 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -18,6 +18,7 @@ import copy import inspect +torch_version = hook_func.torch_version # the flag is used to control the zero level , 0 means normal zero3 , 1 means forward without release parameter ,2 means backward without gather parameter class OpCheckpointBlock(torch.autograd.Function): @@ -574,17 +575,20 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - if config["zero_level"] > 0: - module.register_forward_pre_hook(hook_func.zero_pre_forward) - module.register_forward_hook(hook_func.zero_post_forward) - module.register_full_backward_pre_hook(hook_func.zero_pre_backward) + if torch_version >= '2.0.1' or i < len(modules): + if config["zero_level"] > 0: + module.register_forward_pre_hook(hook_func.zero_pre_forward) + module.register_forward_hook(hook_func.zero_post_forward) + if torch_version >= '2.0.1': + module.register_full_backward_pre_hook(hook_func.zero_pre_backward) - if config["use_checkpoint"]: - module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) - module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) + if config["use_checkpoint"]: + if torch_version >= '2.0.1': + module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) + module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) - if config["zero_level"] > 0 and not config["use_checkpoint"]: - module.register_full_backward_hook(hook_func.zero_post_backward) + if config["zero_level"] > 0 and not config["use_checkpoint"]: + module.register_full_backward_hook(hook_func.zero_post_backward) module._backward_block_ctxs = self._backward_block_ctxs module._layer_id = i @@ -592,8 +596,13 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self._modules[str(i)] = module self.add_module(str(i), module) + if i > 0: + module._pre_module = self._modules[str(i-1)] self.num_hidden = num_hidden + self.identity = hook_func.IdentityLayer() + self.identity.register_full_backward_hook(hook_func.identity_post_backward) + self.identity._pre_module = self._modules[str(len(modules)-1)] if sqrt: length = len(self) @@ -618,7 +627,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self.save_list = [(i, i) for i in range(len(self))] def __len__(self) -> int: - return len(self._modules) + return len(self._modules) - 1 def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: @@ -628,10 +637,13 @@ def forward(self, *args, return_hidden_states = False): self.return_hidden_states = return_hidden_states outputs = args[:self.num_hidden] others = args[self.num_hidden:] - for i in range(len(self._modules)): + for i in range(len(self)): outputs = self._modules[str(i)](*outputs, *others) outputs = (outputs,) + if torch_version < '2.0.1': + outputs = self.identity(outputs) + if return_hidden_states: return tuple(outputs[:2*self.num_hidden]) else: diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 77daf425..d8a14e9a 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -3,6 +3,9 @@ from .checkpointing import CheckpointBlockContext from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations +torch_version = torch.__version__ +#torch_version = '1.9.0' + def zero_pre_forward(module, inputs): enter = True pipe = False @@ -27,15 +30,13 @@ def zero_pre_backward(module, grad_outputs): if not config['pipe_enabled']: module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag) module._backward_block_ctxs[module._layer_id].enter(True) - if exit: - if not module._is_last_layer: - module._backward_block_ctxs[module._layer_id + 1].exit(True) + if not module._is_last_layer: + module._backward_block_ctxs[module._layer_id + 1].exit(True) else: if module._micro_idx == config['micros'] - 1: module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) module._backward_block_ctxs[module._layer_id].enter(True) - def zero_post_backward(module, grad_inputs, grad_outputs): if not config['pipe_enabled']: if module._layer_id == 0: @@ -44,6 +45,10 @@ def zero_post_backward(module, grad_inputs, grad_outputs): if module._micro_idx == 0: module._backward_block_ctxs[module._layer_id].exit(True) + if torch_version < '2.0.1': + if module._layer_id != 0: + zero_pre_backward(module._pre_module, grad_inputs) + def pipe_pre_forward(module, inputs): if not module._is_first_stage: if module._is_first_layer: @@ -120,3 +125,33 @@ def checkpoint_pre_backward(module, grad_outputs): zero_post_backward(module, None, grad_outputs) pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, *args): + inputs = args[0].detach() + ctx.module = module + with torch.no_grad(): + zero_pre_forward(module, args) + checkpoint_pre_forward(module, args) + outputs = module._module(inputs, *args[1:]) + outputs.requires_grad_() + zero_post_forward(module, args, outputs) + return outputs + + @staticmethod + def backward(ctx, grads): + with torch.enable_grad(): + zero_pre_backward(ctx.module, grads) + checkpoint_pre_backward(ctx.module, grads) + return None, ctx.module._inputs[0].grad, None, None + +def identity_post_backward(module, grad_inputs, grad_outputs): + zero_pre_backward(module._pre_module, grad_inputs) + +class IdentityLayer(torch.nn.Module): + def __init__(self): + super(IdentityLayer, self).__init__() + + def forward(self, x): + return x + From 2d35ba0b7607c89f16d3556f8c5cb25de67b626d Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 31 Jul 2023 19:35:17 +0800 Subject: [PATCH 013/122] for low torch version --- bmtrain/block_layer.py | 34 ++++++++++++++++++++++----------- bmtrain/hook_func.py | 43 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index fd3d9f9a..50ceb1d8 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -18,6 +18,7 @@ import copy import inspect +torch_version = hook_func.torch_version # the flag is used to control the zero level , 0 means normal zero3 , 1 means forward without release parameter ,2 means backward without gather parameter class OpCheckpointBlock(torch.autograd.Function): @@ -574,17 +575,20 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - if config["zero_level"] > 0: - module.register_forward_pre_hook(hook_func.zero_pre_forward) - module.register_forward_hook(hook_func.zero_post_forward) - module.register_full_backward_pre_hook(hook_func.zero_pre_backward) + if torch_version >= '2.0.1' or i < len(modules): + if config["zero_level"] > 0: + module.register_forward_pre_hook(hook_func.zero_pre_forward) + module.register_forward_hook(hook_func.zero_post_forward) + if torch_version >= '2.0.1': + module.register_full_backward_pre_hook(hook_func.zero_pre_backward) - if config["use_checkpoint"]: - module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) - module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) + if config["use_checkpoint"]: + if torch_version >= '2.0.1': + module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) + module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) - if config["zero_level"] > 0 and not config["use_checkpoint"]: - module.register_full_backward_hook(hook_func.zero_post_backward) + if config["zero_level"] > 0 and not config["use_checkpoint"]: + module.register_full_backward_hook(hook_func.zero_post_backward) module._backward_block_ctxs = self._backward_block_ctxs module._layer_id = i @@ -592,8 +596,13 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self._modules[str(i)] = module self.add_module(str(i), module) + if i > 0: + module._pre_module = self._modules[str(i-1)] self.num_hidden = num_hidden + self.identity = hook_func.IdentityLayer() + self.identity.register_full_backward_hook(hook_func.identity_post_backward) + self.identity._pre_module = self._modules[str(len(modules)-1)] if sqrt: length = len(self) @@ -618,7 +627,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self.save_list = [(i, i) for i in range(len(self))] def __len__(self) -> int: - return len(self._modules) + return len(self._modules) - 1 def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: @@ -628,10 +637,13 @@ def forward(self, *args, return_hidden_states = False): self.return_hidden_states = return_hidden_states outputs = args[:self.num_hidden] others = args[self.num_hidden:] - for i in range(len(self._modules)): + for i in range(len(self)): outputs = self._modules[str(i)](*outputs, *others) outputs = (outputs,) + if torch_version < '2.0.1': + outputs = self.identity(outputs) + if return_hidden_states: return tuple(outputs[:2*self.num_hidden]) else: diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 77daf425..d8a14e9a 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -3,6 +3,9 @@ from .checkpointing import CheckpointBlockContext from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations +torch_version = torch.__version__ +#torch_version = '1.9.0' + def zero_pre_forward(module, inputs): enter = True pipe = False @@ -27,15 +30,13 @@ def zero_pre_backward(module, grad_outputs): if not config['pipe_enabled']: module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag) module._backward_block_ctxs[module._layer_id].enter(True) - if exit: - if not module._is_last_layer: - module._backward_block_ctxs[module._layer_id + 1].exit(True) + if not module._is_last_layer: + module._backward_block_ctxs[module._layer_id + 1].exit(True) else: if module._micro_idx == config['micros'] - 1: module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) module._backward_block_ctxs[module._layer_id].enter(True) - def zero_post_backward(module, grad_inputs, grad_outputs): if not config['pipe_enabled']: if module._layer_id == 0: @@ -44,6 +45,10 @@ def zero_post_backward(module, grad_inputs, grad_outputs): if module._micro_idx == 0: module._backward_block_ctxs[module._layer_id].exit(True) + if torch_version < '2.0.1': + if module._layer_id != 0: + zero_pre_backward(module._pre_module, grad_inputs) + def pipe_pre_forward(module, inputs): if not module._is_first_stage: if module._is_first_layer: @@ -120,3 +125,33 @@ def checkpoint_pre_backward(module, grad_outputs): zero_post_backward(module, None, grad_outputs) pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, *args): + inputs = args[0].detach() + ctx.module = module + with torch.no_grad(): + zero_pre_forward(module, args) + checkpoint_pre_forward(module, args) + outputs = module._module(inputs, *args[1:]) + outputs.requires_grad_() + zero_post_forward(module, args, outputs) + return outputs + + @staticmethod + def backward(ctx, grads): + with torch.enable_grad(): + zero_pre_backward(ctx.module, grads) + checkpoint_pre_backward(ctx.module, grads) + return None, ctx.module._inputs[0].grad, None, None + +def identity_post_backward(module, grad_inputs, grad_outputs): + zero_pre_backward(module._pre_module, grad_inputs) + +class IdentityLayer(torch.nn.Module): + def __init__(self): + super(IdentityLayer, self).__init__() + + def forward(self, x): + return x + From bc48d83c72fddd8799788185d9c4b2e7e93c13ec Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 31 Jul 2023 20:46:46 +0800 Subject: [PATCH 014/122] for checkpoint --- bmtrain/hook_func.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index d8a14e9a..6ecbc524 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -118,9 +118,7 @@ def checkpoint_pre_backward(module, grad_outputs): torch.autograd.backward(out, *grad_outputs) if not config['pipe_enabled']: - if module._layer_id == 0: - module._backward_block_ctxs[0].exit(True) - module._backward_block_ctxs[0] = None + zero_post_backward(module, None, grad_outputs) else: zero_post_backward(module, None, grad_outputs) pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) @@ -147,6 +145,8 @@ def backward(ctx, grads): def identity_post_backward(module, grad_inputs, grad_outputs): zero_pre_backward(module._pre_module, grad_inputs) + if config['use_checkpoint']: + checkpoint_pre_backward(module.Pre_module, grad_inputs) class IdentityLayer(torch.nn.Module): def __init__(self): From bd61071e7646ca0e69799e7664127964db0799d8 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 31 Jul 2023 20:51:17 +0800 Subject: [PATCH 015/122] remove unused code --- bmtrain/block_layer.py | 2 +- bmtrain/hook_func.py | 29 ----------------------------- 2 files changed, 1 insertion(+), 30 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 50ceb1d8..3eb200d2 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -600,7 +600,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module._pre_module = self._modules[str(i-1)] self.num_hidden = num_hidden - self.identity = hook_func.IdentityLayer() + self.identity = torch.nn.Identity() self.identity.register_full_backward_hook(hook_func.identity_post_backward) self.identity._pre_module = self._modules[str(len(modules)-1)] diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 6ecbc524..2b7383b1 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -4,7 +4,6 @@ from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations torch_version = torch.__version__ -#torch_version = '1.9.0' def zero_pre_forward(module, inputs): enter = True @@ -123,35 +122,7 @@ def checkpoint_pre_backward(module, grad_outputs): zero_post_backward(module, None, grad_outputs) pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, *args): - inputs = args[0].detach() - ctx.module = module - with torch.no_grad(): - zero_pre_forward(module, args) - checkpoint_pre_forward(module, args) - outputs = module._module(inputs, *args[1:]) - outputs.requires_grad_() - zero_post_forward(module, args, outputs) - return outputs - - @staticmethod - def backward(ctx, grads): - with torch.enable_grad(): - zero_pre_backward(ctx.module, grads) - checkpoint_pre_backward(ctx.module, grads) - return None, ctx.module._inputs[0].grad, None, None - def identity_post_backward(module, grad_inputs, grad_outputs): zero_pre_backward(module._pre_module, grad_inputs) if config['use_checkpoint']: checkpoint_pre_backward(module.Pre_module, grad_inputs) - -class IdentityLayer(torch.nn.Module): - def __init__(self): - super(IdentityLayer, self).__init__() - - def forward(self, x): - return x - From de25455d4e7d3beae687a46e33906b8967797da0 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 31 Jul 2023 20:56:13 +0800 Subject: [PATCH 016/122] remove duplicate code --- bmtrain/checkpointing.py | 1 - bmtrain/pipe_layer.py | 92 ---------------------------------------- 2 files changed, 93 deletions(-) diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index b0c936f3..1775b95d 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -56,7 +56,6 @@ def enter(self, requires_grad=False): self._need_release = True wait_loader() -#requires_grad = torch.is_grad_enabled() with torch.cuda.stream(config["load_stream"]): for kw, val in self.block._storage_info.items(): assert self.block._storage_params[kw].is_cuda diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 155f0d5a..2d194843 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -315,98 +315,6 @@ def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle : torch.Tensor, return (None, None, None, grad) + tuple(grads) + (None,) -def zero_pre_forward(module, inputs): - if module._micro_idx == 0: - forward_flag = 1 if config['zero_level'] == 2 else 0 - module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, forward_flag, pipe=True) - module._forward_block_ctx.enter() - -def zero_post_forward(module, inputs, outputs): - if module._micro_idx == config['micros'] - 1: - module._forward_block_ctx.exit() - - -def pipe_pre_forward(module, inputs): - if not module._is_first_stage: - if module._is_first_layer: - pre_inputs = recv_activations(module.stage_id - 1, config['pipe_comm']) - pre_inputs.requires_grad_() - return (pre_inputs, ) + inputs[1:] - -def pipe_post_forward(module, inputs, outputs): - if not module._is_last_stage: - if module._is_last_layer: - send_data = outputs[0] if isinstance(outputs, tuple) else outputs - send_activations(send_data, module.stage_id + 1, config['pipe_comm']) - -def zero_pre_backward(module, grad_inputs): - if module._micro_idx == config['micros'] - 1: - backward_flag = 2 if config['zero_level'] == 2 else 0 - with torch.enable_grad(): - module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) - module._backward_block_ctxs[module._layer_id].enter(True) -# if module._micro_idx == 0: -# if not module._is_last_layer: -# module._backward_block_ctxs[module._layer_id + 1].exit(True) - -def zero_post_backward(module, grad_inputs, grad_outputs): - if module._micro_idx == 0: - module._backward_block_ctxs[module._layer_id].exit(True) -# if module._is_first_layer: -# module._backward_block_ctxs[module._layer_id].exit(True) - -def pipe_pre_backward(module, grad_inputs): - if module._is_last_layer: - module._grad_list = all_gather(grad_inputs[0], config["pipe_comm"]) - module._grad_list = module._grad_list.flatten(start_dim=0, end_dim=1).chunk(module.stages, dim=0) - - if module._is_last_layer and module._is_last_stage: - return (module._grad_list[module._micro_idx], ) - - if not module._is_last_stage: - if module._is_last_layer: - pre_grad_inputs = recv_activations(module.stage_id + 1, config['pipe_comm']) - return (pre_grad_inputs, ) + grad_inputs[1:] - - -def pipe_post_backward(module, grad_inputs, grad_outputs): - if not module._is_first_stage: - if module._is_first_layer: - send_data = grad_inputs[0] if isinstance(grad_inputs, tuple) else grad_inputs - if send_data is not None: - send_activations(send_data, module.stage_id - 1, config['pipe_comm']) - -# if module._is_first_layer: -# if module._micro_idx == config['micros'] -1: -# module._all_grads = [] -# grad = grad_inputs[0] -# module._all_grads.append(grad) -# if module._micro_idx == 0: -# grads = torch.cat(module._all_grads, dim=0) -# grad = broadcast(grads, 0, config['pipe_comm']) -# grad = grad.chunk(module.stages, dim=0) -# return (grad[module.stage_id], ) + grad_inputs[1:] - - module._micro_idx -= 1 - -def checkpoint_pre_forward(module, inputs): - if module._micro_idx == 0: - module._inputs = [inputs] - module._cuda_rng_state = [torch.cuda.get_rng_state()] - else: - module._inputs.append(inputs) - module._cuda_rng_state.append(torch.cuda.get_rng_state()) - -def checkpoint_pre_backward(module, grad_outputs): - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): - with torch.enable_grad(): - torch.cuda.set_rng_state(module._cuda_rng_state[module._micro_idx]) - out = module._module(*module._inputs[module._micro_idx]) - torch.autograd.backward(out, *grad_outputs) - - zero_post_backward(module, None, grad_outputs) - pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) - class PipelineTransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of CheckpointBlocks. From fde122f4e1efb862371e35e6b44a1123b7901afd Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 1 Aug 2023 16:21:36 +0800 Subject: [PATCH 017/122] fix pipeline; checkpoint support low version --- bmtrain/block_layer.py | 37 +++++++++------- bmtrain/hook_func.py | 95 +++++++++++++++++++++++++++--------------- bmtrain/pipe_layer.py | 37 ++++++++++------ 3 files changed, 108 insertions(+), 61 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 3eb200d2..d6ae4f7d 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -575,24 +575,24 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - if torch_version >= '2.0.1' or i < len(modules): - if config["zero_level"] > 0: - module.register_forward_pre_hook(hook_func.zero_pre_forward) - module.register_forward_hook(hook_func.zero_post_forward) - if torch_version >= '2.0.1': - module.register_full_backward_pre_hook(hook_func.zero_pre_backward) + if config["zero_level"] > 0: + module.register_forward_pre_hook(hook_func.zero_pre_forward) + module.register_forward_hook(hook_func.zero_post_forward) + if torch_version >= '2.0.1': + module.register_full_backward_pre_hook(hook_func.zero_pre_backward) - if config["use_checkpoint"]: - if torch_version >= '2.0.1': - module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) - module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) + if config["use_checkpoint"]: + module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) + if torch_version >= '2.0.1': + module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) - if config["zero_level"] > 0 and not config["use_checkpoint"]: - module.register_full_backward_hook(hook_func.zero_post_backward) + if config["zero_level"] > 0 and not config["use_checkpoint"]: + module.register_full_backward_hook(hook_func.zero_post_backward) module._backward_block_ctxs = self._backward_block_ctxs module._layer_id = i module._is_last_layer = True if i == len(modules) -1 else False + module._is_first_layer = True if i == 0 else False self._modules[str(i)] = module self.add_module(str(i), module) @@ -600,9 +600,10 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module._pre_module = self._modules[str(i-1)] self.num_hidden = num_hidden - self.identity = torch.nn.Identity() - self.identity.register_full_backward_hook(hook_func.identity_post_backward) - self.identity._pre_module = self._modules[str(len(modules)-1)] + if torch_version < '2.0.1': + self.identity = torch.nn.Identity() + self.identity.register_full_backward_hook(hook_func.identity_post_backward) + self.identity._pre_module = self._modules[str(len(modules)-1)] if sqrt: length = len(self) @@ -627,7 +628,11 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self.save_list = [(i, i) for i in range(len(self))] def __len__(self) -> int: - return len(self._modules) - 1 + if torch_version < '2.0.1': + return len(self._modules) - 1 + else: + return len(self._modules) + def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 2b7383b1..0460c0a1 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -3,7 +3,8 @@ from .checkpointing import CheckpointBlockContext from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations -torch_version = torch.__version__ +#torch_version = torch.__version__ +torch_version = '1.9.0' def zero_pre_forward(module, inputs): enter = True @@ -31,6 +32,7 @@ def zero_pre_backward(module, grad_outputs): module._backward_block_ctxs[module._layer_id].enter(True) if not module._is_last_layer: module._backward_block_ctxs[module._layer_id + 1].exit(True) + config['load_stream'].record_event(config['load_event']) else: if module._micro_idx == config['micros'] - 1: module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) @@ -40,61 +42,56 @@ def zero_post_backward(module, grad_inputs, grad_outputs): if not config['pipe_enabled']: if module._layer_id == 0: module._backward_block_ctxs[0].exit(True) + config['load_stream'].record_event(config['load_event']) else: if module._micro_idx == 0: module._backward_block_ctxs[module._layer_id].exit(True) + config['load_stream'].record_event(config['load_event']) - if torch_version < '2.0.1': - if module._layer_id != 0: - zero_pre_backward(module._pre_module, grad_inputs) + if torch_version < '2.0.1' and not config['pipe_enabled']: + if not module._is_first_layer: + identity_post_backward(module, grad_inputs, grad_outputs) + +class PipePreFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inputs, stage_id): + pre_inputs = recv_activations(stage_id - 1, config['pipe_comm']) + pre_inputs.requires_grad_() + return pre_inputs + + @staticmethod + def backward(ctx, grads): + return grads, None def pipe_pre_forward(module, inputs): if not module._is_first_stage: if module._is_first_layer: - pre_inputs = recv_activations(module.stage_id - 1, config['pipe_comm']) - pre_inputs.requires_grad_() - return (pre_inputs, ) + inputs[1:] + return (PipePreFunction.apply(inputs[0], module.stage_id), ) + inputs[1:] def pipe_post_forward(module, inputs, outputs): if not module._is_last_stage: if module._is_last_layer: send_data = outputs[0] if isinstance(outputs, tuple) else outputs - send_activations(send_data, module.stage_id + 1, config['pipe_comm']) + send_activations(send_data.detach(), module.stage_id + 1, config['pipe_comm']) def pipe_pre_backward(module, grad_inputs): - if module._is_last_layer: - module._grad_list = all_gather(grad_inputs[0], config["pipe_comm"]) - module._grad_list = module._grad_list.flatten(start_dim=0, end_dim=1).chunk(module.stages, dim=0) - - if module._is_last_layer and module._is_last_stage: - return (module._grad_list[module._micro_idx], ) - if not module._is_last_stage: if module._is_last_layer: pre_grad_inputs = recv_activations(module.stage_id + 1, config['pipe_comm']) return (pre_grad_inputs, ) + grad_inputs[1:] - def pipe_post_backward(module, grad_inputs, grad_outputs): if not module._is_first_stage: if module._is_first_layer: send_data = grad_inputs[0] if isinstance(grad_inputs, tuple) else grad_inputs - if send_data is not None: - send_activations(send_data, module.stage_id - 1, config['pipe_comm']) - -# if module._is_first_layer: -# if module._micro_idx == config['micros'] -1: -# module._all_grads = [] -# grad = grad_inputs[0] -# module._all_grads.append(grad) -# if module._micro_idx == 0: -# grads = torch.cat(module._all_grads, dim=0) -# grad = broadcast(grads, 0, config['pipe_comm']) -# grad = grad.chunk(module.stages, dim=0) -# return (grad[module.stage_id], ) + grad_inputs[1:] + send_activations(send_data, module.stage_id - 1, config['pipe_comm']) module._micro_idx -= 1 + if torch_version < '2.0.1': + if not module._is_first_layer: + identity_post_backward(module, grad_inputs, grad_outputs) + def checkpoint_pre_forward(module, inputs): if not config['pipe_enabled']: module._inputs = inputs @@ -117,12 +114,44 @@ def checkpoint_pre_backward(module, grad_outputs): torch.autograd.backward(out, *grad_outputs) if not config['pipe_enabled']: - zero_post_backward(module, None, grad_outputs) + zero_post_backward(module, None, (inputs[0].grad,) ) else: - zero_post_backward(module, None, grad_outputs) + zero_post_backward(module, inputs[0].grad, None) pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) def identity_post_backward(module, grad_inputs, grad_outputs): - zero_pre_backward(module._pre_module, grad_inputs) + if config['pipe_enabled']: + pipe_grad = pipe_pre_backward(module._pre_module, grad_inputs) + grad_inputs = pipe_grad if pipe_grad is not None else grad_inputs + zero_pre_backward(module._pre_module, grad_outputs) if config['use_checkpoint']: - checkpoint_pre_backward(module.Pre_module, grad_inputs) + checkpoint_pre_backward(module._pre_module, grad_outputs) + +class PipeAllGatherFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden_state): + hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"]) + hidden_state_list.requires_grad_() + return hidden_state_list + + @staticmethod + def backward(ctx, grads): + grads = broadcast(grads, 0, config['pipe_comm']) + topo = config['topology'] + return grads.chunk(topo.stages, dim=0)[topo.stage_id] + +class PipePostFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, last_hidden): + last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) + last_hidden = last_hidden.chunk(config['topology'].stages, dim=0) + outputs = last_hidden[config['topology'].stage_id] + outputs.requires_grad_() + return outputs + + @staticmethod + def backward(ctx, grads): + grad_list = all_gather(grads, config["pipe_comm"]) + grad_list = grad_list.flatten(start_dim=0, end_dim=1) + return grad_list + diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 2d194843..8e8605ff 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -16,6 +16,8 @@ from .block_layer import CheckpointBlock, round_up, _get_param_kw from . import hook_func +torch_version = hook_func.torch_version + class OpMicroForward(torch.autograd.Function): @staticmethod def forward(ctx, placeholder, self : 'PipelineTransformerBlockList', micro_idx, block_ctx_list, layers_dict, save_list, hidden_state, *args): @@ -359,10 +361,12 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: module.register_forward_hook(hook_func.zero_post_forward) module.register_forward_hook(hook_func.pipe_post_forward) - module.register_full_backward_pre_hook(hook_func.pipe_pre_backward) - module.register_full_backward_pre_hook(hook_func.zero_pre_backward) + if torch_version >= '2.0.1': + module.register_full_backward_pre_hook(hook_func.pipe_pre_backward) + module.register_full_backward_pre_hook(hook_func.zero_pre_backward) if config['use_checkpoint']: - module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) + if torch_version >= '2.0.1': + module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) module.register_full_backward_hook(hook_func.zero_post_backward) module.register_full_backward_hook(hook_func.pipe_post_backward) @@ -371,6 +375,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: module.stages = self.stages self._modules[str(idx)] = module + if idx > 0: + module._pre_module = self._modules[str(idx-1)] self.layer_ids = self.get_range_by_stage_id(self.stage_id) @@ -387,6 +393,11 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = pipe_group[self.pipe_idx, self.stage_id - 1] if self.stage_id > 0 else -1 # self.micro_batches = config['num_micro_batches'] + + if torch_version < '2.0.1': + self.identity = torch.nn.Identity() + self.identity.register_full_backward_hook(hook_func.identity_post_backward) + self.identity._pre_module = self._modules[str(self.layer_ids[-1])] self.save_list = [(i, i) for i in range(len(self.layer_ids))] @@ -414,7 +425,7 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa else: batch_size = hidden_state.shape[0] num_micros = config["micros"] - hidden_state_list = all_gather(hidden_state, config["pipe_comm"]).flatten(0, 1).detach().requires_grad_() + hidden_state_list = hook_func.PipeAllGatherFunction.apply(hidden_state) args_list = [[] for _ in range(num_micros)] for arg in args: @@ -430,23 +441,25 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa for i in range(num_micros): args_list[i].append(arg_all[i]) + hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0) outputs = [] - hidden_state_list = hidden_state_list.chunk(num_micros, dim=0) + for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): for idx,layer_id in enumerate(self.layer_ids): self._modules[str(layer_id)]._micro_idx = micro_idx hidden_state = self._modules[str(layer_id)](hidden_state, *arg) outputs.append(hidden_state) + if torch_version < '2.0.1': + outputs[-1] = self.identity(outputs[-1]) + last_hidden = torch.cat(outputs, dim=0) - last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) - last_hidden = last_hidden.chunk(self.stages, dim=0) - outputs = last_hidden[self.stage_id] + outputs = hook_func.PipePostFunction.apply(last_hidden) - if return_hidden_states: - return outputs[:2*self.num_hidden] - else: - return outputs[:self.num_hidden] if self.num_hidden > 1 else outputs + if return_hidden_states: + return outputs[:2*self.num_hidden] + else: + return outputs[:self.num_hidden] if self.num_hidden > 1 else outputs def get_range_by_stage_id(self, stage_id : int) -> List[int]: part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)] From a897ad463ddb46eacb1aa5690752f3a58b329c7a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 1 Aug 2023 16:40:37 +0800 Subject: [PATCH 018/122] fix pipeline; checkpoint support low version --- bmtrain/block_layer.py | 37 +++++++----- bmtrain/hook_func.py | 111 ++++++++++++++++++----------------- bmtrain/pipe_layer.py | 129 ++++++++--------------------------------- 3 files changed, 101 insertions(+), 176 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 50ceb1d8..d6ae4f7d 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -575,24 +575,24 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - if torch_version >= '2.0.1' or i < len(modules): - if config["zero_level"] > 0: - module.register_forward_pre_hook(hook_func.zero_pre_forward) - module.register_forward_hook(hook_func.zero_post_forward) - if torch_version >= '2.0.1': - module.register_full_backward_pre_hook(hook_func.zero_pre_backward) + if config["zero_level"] > 0: + module.register_forward_pre_hook(hook_func.zero_pre_forward) + module.register_forward_hook(hook_func.zero_post_forward) + if torch_version >= '2.0.1': + module.register_full_backward_pre_hook(hook_func.zero_pre_backward) - if config["use_checkpoint"]: - if torch_version >= '2.0.1': - module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) - module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) + if config["use_checkpoint"]: + module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) + if torch_version >= '2.0.1': + module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) - if config["zero_level"] > 0 and not config["use_checkpoint"]: - module.register_full_backward_hook(hook_func.zero_post_backward) + if config["zero_level"] > 0 and not config["use_checkpoint"]: + module.register_full_backward_hook(hook_func.zero_post_backward) module._backward_block_ctxs = self._backward_block_ctxs module._layer_id = i module._is_last_layer = True if i == len(modules) -1 else False + module._is_first_layer = True if i == 0 else False self._modules[str(i)] = module self.add_module(str(i), module) @@ -600,9 +600,10 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module._pre_module = self._modules[str(i-1)] self.num_hidden = num_hidden - self.identity = hook_func.IdentityLayer() - self.identity.register_full_backward_hook(hook_func.identity_post_backward) - self.identity._pre_module = self._modules[str(len(modules)-1)] + if torch_version < '2.0.1': + self.identity = torch.nn.Identity() + self.identity.register_full_backward_hook(hook_func.identity_post_backward) + self.identity._pre_module = self._modules[str(len(modules)-1)] if sqrt: length = len(self) @@ -627,7 +628,11 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self.save_list = [(i, i) for i in range(len(self))] def __len__(self) -> int: - return len(self._modules) - 1 + if torch_version < '2.0.1': + return len(self._modules) - 1 + else: + return len(self._modules) + def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index d8a14e9a..b34d93f4 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -4,7 +4,6 @@ from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations torch_version = torch.__version__ -#torch_version = '1.9.0' def zero_pre_forward(module, inputs): enter = True @@ -32,6 +31,7 @@ def zero_pre_backward(module, grad_outputs): module._backward_block_ctxs[module._layer_id].enter(True) if not module._is_last_layer: module._backward_block_ctxs[module._layer_id + 1].exit(True) + config['load_stream'].record_event(config['load_event']) else: if module._micro_idx == config['micros'] - 1: module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) @@ -41,61 +41,56 @@ def zero_post_backward(module, grad_inputs, grad_outputs): if not config['pipe_enabled']: if module._layer_id == 0: module._backward_block_ctxs[0].exit(True) + config['load_stream'].record_event(config['load_event']) else: if module._micro_idx == 0: module._backward_block_ctxs[module._layer_id].exit(True) + config['load_stream'].record_event(config['load_event']) - if torch_version < '2.0.1': - if module._layer_id != 0: - zero_pre_backward(module._pre_module, grad_inputs) + if torch_version < '2.0.1' and not config['pipe_enabled']: + if not module._is_first_layer: + identity_post_backward(module, grad_inputs, grad_outputs) + +class PipePreFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inputs, stage_id): + pre_inputs = recv_activations(stage_id - 1, config['pipe_comm']) + pre_inputs.requires_grad_() + return pre_inputs + + @staticmethod + def backward(ctx, grads): + return grads, None def pipe_pre_forward(module, inputs): if not module._is_first_stage: if module._is_first_layer: - pre_inputs = recv_activations(module.stage_id - 1, config['pipe_comm']) - pre_inputs.requires_grad_() - return (pre_inputs, ) + inputs[1:] + return (PipePreFunction.apply(inputs[0], module.stage_id), ) + inputs[1:] def pipe_post_forward(module, inputs, outputs): if not module._is_last_stage: if module._is_last_layer: send_data = outputs[0] if isinstance(outputs, tuple) else outputs - send_activations(send_data, module.stage_id + 1, config['pipe_comm']) + send_activations(send_data.detach(), module.stage_id + 1, config['pipe_comm']) def pipe_pre_backward(module, grad_inputs): - if module._is_last_layer: - module._grad_list = all_gather(grad_inputs[0], config["pipe_comm"]) - module._grad_list = module._grad_list.flatten(start_dim=0, end_dim=1).chunk(module.stages, dim=0) - - if module._is_last_layer and module._is_last_stage: - return (module._grad_list[module._micro_idx], ) - if not module._is_last_stage: if module._is_last_layer: pre_grad_inputs = recv_activations(module.stage_id + 1, config['pipe_comm']) return (pre_grad_inputs, ) + grad_inputs[1:] - def pipe_post_backward(module, grad_inputs, grad_outputs): if not module._is_first_stage: if module._is_first_layer: send_data = grad_inputs[0] if isinstance(grad_inputs, tuple) else grad_inputs - if send_data is not None: - send_activations(send_data, module.stage_id - 1, config['pipe_comm']) - -# if module._is_first_layer: -# if module._micro_idx == config['micros'] -1: -# module._all_grads = [] -# grad = grad_inputs[0] -# module._all_grads.append(grad) -# if module._micro_idx == 0: -# grads = torch.cat(module._all_grads, dim=0) -# grad = broadcast(grads, 0, config['pipe_comm']) -# grad = grad.chunk(module.stages, dim=0) -# return (grad[module.stage_id], ) + grad_inputs[1:] + send_activations(send_data, module.stage_id - 1, config['pipe_comm']) module._micro_idx -= 1 + if torch_version < '2.0.1': + if not module._is_first_layer: + identity_post_backward(module, grad_inputs, grad_outputs) + def checkpoint_pre_forward(module, inputs): if not config['pipe_enabled']: module._inputs = inputs @@ -118,40 +113,44 @@ def checkpoint_pre_backward(module, grad_outputs): torch.autograd.backward(out, *grad_outputs) if not config['pipe_enabled']: - if module._layer_id == 0: - module._backward_block_ctxs[0].exit(True) - module._backward_block_ctxs[0] = None + zero_post_backward(module, None, (inputs[0].grad,) ) else: - zero_post_backward(module, None, grad_outputs) + zero_post_backward(module, inputs[0].grad, None) pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) -class CheckpointFunction(torch.autograd.Function): +def identity_post_backward(module, grad_inputs, grad_outputs): + if config['pipe_enabled']: + pipe_grad = pipe_pre_backward(module._pre_module, grad_inputs) + grad_inputs = pipe_grad if pipe_grad is not None else grad_inputs + zero_pre_backward(module._pre_module, grad_outputs) + if config['use_checkpoint']: + checkpoint_pre_backward(module._pre_module, grad_outputs) + +class PipeAllGatherFunction(torch.autograd.Function): @staticmethod - def forward(ctx, module, *args): - inputs = args[0].detach() - ctx.module = module - with torch.no_grad(): - zero_pre_forward(module, args) - checkpoint_pre_forward(module, args) - outputs = module._module(inputs, *args[1:]) - outputs.requires_grad_() - zero_post_forward(module, args, outputs) - return outputs + def forward(ctx, hidden_state): + hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"]) + hidden_state_list.requires_grad_() + return hidden_state_list @staticmethod def backward(ctx, grads): - with torch.enable_grad(): - zero_pre_backward(ctx.module, grads) - checkpoint_pre_backward(ctx.module, grads) - return None, ctx.module._inputs[0].grad, None, None + grads = broadcast(grads, 0, config['pipe_comm']) + topo = config['topology'] + return grads.chunk(topo.stages, dim=0)[topo.stage_id] -def identity_post_backward(module, grad_inputs, grad_outputs): - zero_pre_backward(module._pre_module, grad_inputs) - -class IdentityLayer(torch.nn.Module): - def __init__(self): - super(IdentityLayer, self).__init__() +class PipePostFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, last_hidden): + last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) + last_hidden = last_hidden.chunk(config['topology'].stages, dim=0) + outputs = last_hidden[config['topology'].stage_id] + outputs.requires_grad_() + return outputs - def forward(self, x): - return x + @staticmethod + def backward(ctx, grads): + grad_list = all_gather(grads, config["pipe_comm"]) + grad_list = grad_list.flatten(start_dim=0, end_dim=1) + return grad_list diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 155f0d5a..8e8605ff 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -16,6 +16,8 @@ from .block_layer import CheckpointBlock, round_up, _get_param_kw from . import hook_func +torch_version = hook_func.torch_version + class OpMicroForward(torch.autograd.Function): @staticmethod def forward(ctx, placeholder, self : 'PipelineTransformerBlockList', micro_idx, block_ctx_list, layers_dict, save_list, hidden_state, *args): @@ -315,98 +317,6 @@ def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle : torch.Tensor, return (None, None, None, grad) + tuple(grads) + (None,) -def zero_pre_forward(module, inputs): - if module._micro_idx == 0: - forward_flag = 1 if config['zero_level'] == 2 else 0 - module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, forward_flag, pipe=True) - module._forward_block_ctx.enter() - -def zero_post_forward(module, inputs, outputs): - if module._micro_idx == config['micros'] - 1: - module._forward_block_ctx.exit() - - -def pipe_pre_forward(module, inputs): - if not module._is_first_stage: - if module._is_first_layer: - pre_inputs = recv_activations(module.stage_id - 1, config['pipe_comm']) - pre_inputs.requires_grad_() - return (pre_inputs, ) + inputs[1:] - -def pipe_post_forward(module, inputs, outputs): - if not module._is_last_stage: - if module._is_last_layer: - send_data = outputs[0] if isinstance(outputs, tuple) else outputs - send_activations(send_data, module.stage_id + 1, config['pipe_comm']) - -def zero_pre_backward(module, grad_inputs): - if module._micro_idx == config['micros'] - 1: - backward_flag = 2 if config['zero_level'] == 2 else 0 - with torch.enable_grad(): - module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) - module._backward_block_ctxs[module._layer_id].enter(True) -# if module._micro_idx == 0: -# if not module._is_last_layer: -# module._backward_block_ctxs[module._layer_id + 1].exit(True) - -def zero_post_backward(module, grad_inputs, grad_outputs): - if module._micro_idx == 0: - module._backward_block_ctxs[module._layer_id].exit(True) -# if module._is_first_layer: -# module._backward_block_ctxs[module._layer_id].exit(True) - -def pipe_pre_backward(module, grad_inputs): - if module._is_last_layer: - module._grad_list = all_gather(grad_inputs[0], config["pipe_comm"]) - module._grad_list = module._grad_list.flatten(start_dim=0, end_dim=1).chunk(module.stages, dim=0) - - if module._is_last_layer and module._is_last_stage: - return (module._grad_list[module._micro_idx], ) - - if not module._is_last_stage: - if module._is_last_layer: - pre_grad_inputs = recv_activations(module.stage_id + 1, config['pipe_comm']) - return (pre_grad_inputs, ) + grad_inputs[1:] - - -def pipe_post_backward(module, grad_inputs, grad_outputs): - if not module._is_first_stage: - if module._is_first_layer: - send_data = grad_inputs[0] if isinstance(grad_inputs, tuple) else grad_inputs - if send_data is not None: - send_activations(send_data, module.stage_id - 1, config['pipe_comm']) - -# if module._is_first_layer: -# if module._micro_idx == config['micros'] -1: -# module._all_grads = [] -# grad = grad_inputs[0] -# module._all_grads.append(grad) -# if module._micro_idx == 0: -# grads = torch.cat(module._all_grads, dim=0) -# grad = broadcast(grads, 0, config['pipe_comm']) -# grad = grad.chunk(module.stages, dim=0) -# return (grad[module.stage_id], ) + grad_inputs[1:] - - module._micro_idx -= 1 - -def checkpoint_pre_forward(module, inputs): - if module._micro_idx == 0: - module._inputs = [inputs] - module._cuda_rng_state = [torch.cuda.get_rng_state()] - else: - module._inputs.append(inputs) - module._cuda_rng_state.append(torch.cuda.get_rng_state()) - -def checkpoint_pre_backward(module, grad_outputs): - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): - with torch.enable_grad(): - torch.cuda.set_rng_state(module._cuda_rng_state[module._micro_idx]) - out = module._module(*module._inputs[module._micro_idx]) - torch.autograd.backward(out, *grad_outputs) - - zero_post_backward(module, None, grad_outputs) - pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) - class PipelineTransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of CheckpointBlocks. @@ -451,10 +361,12 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: module.register_forward_hook(hook_func.zero_post_forward) module.register_forward_hook(hook_func.pipe_post_forward) - module.register_full_backward_pre_hook(hook_func.pipe_pre_backward) - module.register_full_backward_pre_hook(hook_func.zero_pre_backward) + if torch_version >= '2.0.1': + module.register_full_backward_pre_hook(hook_func.pipe_pre_backward) + module.register_full_backward_pre_hook(hook_func.zero_pre_backward) if config['use_checkpoint']: - module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) + if torch_version >= '2.0.1': + module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) module.register_full_backward_hook(hook_func.zero_post_backward) module.register_full_backward_hook(hook_func.pipe_post_backward) @@ -463,6 +375,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: module.stages = self.stages self._modules[str(idx)] = module + if idx > 0: + module._pre_module = self._modules[str(idx-1)] self.layer_ids = self.get_range_by_stage_id(self.stage_id) @@ -479,6 +393,11 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = pipe_group[self.pipe_idx, self.stage_id - 1] if self.stage_id > 0 else -1 # self.micro_batches = config['num_micro_batches'] + + if torch_version < '2.0.1': + self.identity = torch.nn.Identity() + self.identity.register_full_backward_hook(hook_func.identity_post_backward) + self.identity._pre_module = self._modules[str(self.layer_ids[-1])] self.save_list = [(i, i) for i in range(len(self.layer_ids))] @@ -506,7 +425,7 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa else: batch_size = hidden_state.shape[0] num_micros = config["micros"] - hidden_state_list = all_gather(hidden_state, config["pipe_comm"]).flatten(0, 1).detach().requires_grad_() + hidden_state_list = hook_func.PipeAllGatherFunction.apply(hidden_state) args_list = [[] for _ in range(num_micros)] for arg in args: @@ -522,23 +441,25 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa for i in range(num_micros): args_list[i].append(arg_all[i]) + hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0) outputs = [] - hidden_state_list = hidden_state_list.chunk(num_micros, dim=0) + for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): for idx,layer_id in enumerate(self.layer_ids): self._modules[str(layer_id)]._micro_idx = micro_idx hidden_state = self._modules[str(layer_id)](hidden_state, *arg) outputs.append(hidden_state) + if torch_version < '2.0.1': + outputs[-1] = self.identity(outputs[-1]) + last_hidden = torch.cat(outputs, dim=0) - last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) - last_hidden = last_hidden.chunk(self.stages, dim=0) - outputs = last_hidden[self.stage_id] + outputs = hook_func.PipePostFunction.apply(last_hidden) - if return_hidden_states: - return outputs[:2*self.num_hidden] - else: - return outputs[:self.num_hidden] if self.num_hidden > 1 else outputs + if return_hidden_states: + return outputs[:2*self.num_hidden] + else: + return outputs[:self.num_hidden] if self.num_hidden > 1 else outputs def get_range_by_stage_id(self, stage_id : int) -> List[int]: part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)] From ec8385b2c8a0a78a4ce3593710d934e74071acb1 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 1 Aug 2023 20:28:44 +0800 Subject: [PATCH 019/122] fix indent --- bmtrain/block_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index e7601f2f..d6ae4f7d 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -586,8 +586,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if torch_version >= '2.0.1': module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) - if config["zero_level"] > 0 and not config["use_checkpoint"]: - module.register_full_backward_hook(hook_func.zero_post_backward) + if config["zero_level"] > 0 and not config["use_checkpoint"]: + module.register_full_backward_hook(hook_func.zero_post_backward) module._backward_block_ctxs = self._backward_block_ctxs module._layer_id = i From 9877a8159c0eab56dce606c0cab2486ccf226a57 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 2 Aug 2023 15:39:29 +0800 Subject: [PATCH 020/122] pipe support low version --- bmtrain/hook_func.py | 19 ++++++++++--------- bmtrain/pipe_layer.py | 14 +++++++------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 0460c0a1..6d79b8a3 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -43,14 +43,14 @@ def zero_post_backward(module, grad_inputs, grad_outputs): if module._layer_id == 0: module._backward_block_ctxs[0].exit(True) config['load_stream'].record_event(config['load_event']) + if torch_version < '2.0.1': + if not module._is_first_layer: + identity_post_backward(module, grad_inputs, grad_outputs) else: if module._micro_idx == 0: module._backward_block_ctxs[module._layer_id].exit(True) config['load_stream'].record_event(config['load_event']) - if torch_version < '2.0.1' and not config['pipe_enabled']: - if not module._is_first_layer: - identity_post_backward(module, grad_inputs, grad_outputs) class PipePreFunction(torch.autograd.Function): @staticmethod @@ -90,7 +90,7 @@ def pipe_post_backward(module, grad_inputs, grad_outputs): if torch_version < '2.0.1': if not module._is_first_layer: - identity_post_backward(module, grad_inputs, grad_outputs) + identity_post_backward(module, grad_outputs, grad_inputs) def checkpoint_pre_forward(module, inputs): if not config['pipe_enabled']: @@ -110,19 +110,20 @@ def checkpoint_pre_backward(module, grad_outputs): with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): with torch.enable_grad(): torch.cuda.set_rng_state(cuda_rng_state) + #inputs[0].retain_grad() out = module._module(*inputs) torch.autograd.backward(out, *grad_outputs) if not config['pipe_enabled']: - zero_post_backward(module, None, (inputs[0].grad,) ) + zero_post_backward(module, grad_outputs, (inputs[0].grad,) ) else: - zero_post_backward(module, inputs[0].grad, None) - pipe_post_backward(module, module._inputs[module._micro_idx][0].grad, None) + zero_post_backward(module, grad_outputs, (inputs[0].grad,) ) + pipe_post_backward(module, (inputs[0].grad,), grad_outputs) def identity_post_backward(module, grad_inputs, grad_outputs): if config['pipe_enabled']: - pipe_grad = pipe_pre_backward(module._pre_module, grad_inputs) - grad_inputs = pipe_grad if pipe_grad is not None else grad_inputs + pipe_grad = pipe_pre_backward(module._pre_module, grad_outputs) + grad_outputs = pipe_grad if pipe_grad is not None else grad_outputs zero_pre_backward(module._pre_module, grad_outputs) if config['use_checkpoint']: checkpoint_pre_backward(module._pre_module, grad_outputs) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 8e8605ff..1436082b 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -367,9 +367,9 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: if config['use_checkpoint']: if torch_version >= '2.0.1': module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) - - module.register_full_backward_hook(hook_func.zero_post_backward) - module.register_full_backward_hook(hook_func.pipe_post_backward) + else: + module.register_full_backward_hook(hook_func.zero_post_backward) + module.register_full_backward_hook(hook_func.pipe_post_backward) module.stage_id = self.stage_id module.stages = self.stages @@ -402,7 +402,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self.save_list = [(i, i) for i in range(len(self.layer_ids))] def __len__(self) -> int: - return len(self._modules) + return len(self._modules) def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) @@ -448,11 +448,11 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa for idx,layer_id in enumerate(self.layer_ids): self._modules[str(layer_id)]._micro_idx = micro_idx hidden_state = self._modules[str(layer_id)](hidden_state, *arg) + if torch_version < '2.0.1': + if idx == len(self.layer_ids) - 1: + hidden_state = self.identity(hidden_state) outputs.append(hidden_state) - if torch_version < '2.0.1': - outputs[-1] = self.identity(outputs[-1]) - last_hidden = torch.cat(outputs, dim=0) outputs = hook_func.PipePostFunction.apply(last_hidden) From 28993b5bcde452488b7c3b162f5b5c91996607d4 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 2 Aug 2023 20:52:18 +0800 Subject: [PATCH 021/122] custom linear for zero3 --- example/layers/linear.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/example/layers/linear.py b/example/layers/linear.py index 0aa0ab00..faf0770e 100644 --- a/example/layers/linear.py +++ b/example/layers/linear.py @@ -2,6 +2,26 @@ import torch.nn.functional as F import bmtrain as bmt +class CustomLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias=None): + ctx.save_for_backward(x, weight, bias) + return F.linear(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output): + x, weight, bias = ctx.saved_tensors + grad_x = grad_weight = grad_bias = None + if x.requires_grad: + grad_x = grad_output.matmul(weight) + if weight.requires_grad: + dim = grad_output.dim() + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1])) + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + return grad_x, grad_weight, grad_bias + class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: super().__init__() @@ -15,9 +35,9 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - return F.linear(input, self.weight, self.bias) + return CustomLinear.apply(input, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None - ) \ No newline at end of file + ) From e4eaebfa327378c667e3a761eb63b4d4a4c19158 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 3 Aug 2023 13:33:36 +0800 Subject: [PATCH 022/122] resolve conflict --- bmtrain/hook_func.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 36286900..92f55eff 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -65,7 +65,6 @@ def forward(ctx, inputs, stage_id): @staticmethod def backward(ctx, grads): return grads, None ->>>>>>> 28993b5bcde452488b7c3b162f5b5c91996607d4 class PipePreFunction(torch.autograd.Function): From cba7c55a0c4f0d35b889472bce1c3c5b075b969f Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 3 Aug 2023 13:38:45 +0800 Subject: [PATCH 023/122] resolve conflict --- bmtrain/hook_func.py | 15 --------------- bmtrain/pipe_layer.py | 3 --- 2 files changed, 18 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 92f55eff..6d79b8a3 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -51,21 +51,6 @@ def zero_post_backward(module, grad_inputs, grad_outputs): module._backward_block_ctxs[module._layer_id].exit(True) config['load_stream'].record_event(config['load_event']) - if torch_version < '2.0.1' and not config['pipe_enabled']: - if not module._is_first_layer: - identity_post_backward(module, grad_inputs, grad_outputs) - -class PipePreFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, stage_id): - pre_inputs = recv_activations(stage_id - 1, config['pipe_comm']) - pre_inputs.requires_grad_() - return pre_inputs - - @staticmethod - def backward(ctx, grads): - return grads, None - class PipePreFunction(torch.autograd.Function): @staticmethod diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 81a35d73..1436082b 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -453,9 +453,6 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa hidden_state = self.identity(hidden_state) outputs.append(hidden_state) - if torch_version < '2.0.1': - outputs[-1] = self.identity(outputs[-1]) - last_hidden = torch.cat(outputs, dim=0) outputs = hook_func.PipePostFunction.apply(last_hidden) From 839a97630a56ebdf51d79b5b84353a896dd92a72 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 3 Aug 2023 17:08:38 +0800 Subject: [PATCH 024/122] use torch.utils.checkpoint.checkpoint --- bmtrain/block_layer.py | 14 +++----------- bmtrain/hook_func.py | 43 ++++++------------------------------------ bmtrain/pipe_layer.py | 10 ++-------- 3 files changed, 11 insertions(+), 56 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index d6ae4f7d..5c0b6062 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -17,6 +17,7 @@ import copy import inspect +from torch.utils.checkpoint import checkpoint torch_version = hook_func.torch_version @@ -309,10 +310,7 @@ def __init__(self, inner_module : torch.nn.Module): def forward(self, *args): if config["use_checkpoint"]: - with torch.no_grad(): - out = self._module(*args) - out.requires_grad_() - return out + return checkpoint(self._module, *args) else: return self._module(*args) @@ -580,13 +578,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module.register_forward_hook(hook_func.zero_post_forward) if torch_version >= '2.0.1': module.register_full_backward_pre_hook(hook_func.zero_pre_backward) - - if config["use_checkpoint"]: - module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) - if torch_version >= '2.0.1': - module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) - - if config["zero_level"] > 0 and not config["use_checkpoint"]: + if config["zero_level"] > 0: module.register_full_backward_hook(hook_func.zero_post_backward) module._backward_block_ctxs = self._backward_block_ctxs diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 6d79b8a3..3a4b6879 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -3,8 +3,7 @@ from .checkpointing import CheckpointBlockContext from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations -#torch_version = torch.__version__ -torch_version = '1.9.0' +torch_version = torch.__version__ def zero_pre_forward(module, inputs): enter = True @@ -51,7 +50,6 @@ def zero_post_backward(module, grad_inputs, grad_outputs): module._backward_block_ctxs[module._layer_id].exit(True) config['load_stream'].record_event(config['load_event']) - class PipePreFunction(torch.autograd.Function): @staticmethod def forward(ctx, inputs, stage_id): @@ -90,43 +88,14 @@ def pipe_post_backward(module, grad_inputs, grad_outputs): if torch_version < '2.0.1': if not module._is_first_layer: - identity_post_backward(module, grad_outputs, grad_inputs) - -def checkpoint_pre_forward(module, inputs): - if not config['pipe_enabled']: - module._inputs = inputs - module._cuda_rng_state = torch.cuda.get_rng_state() - else: - if module._micro_idx == 0: - module._inputs = [inputs] - module._cuda_rng_state = [torch.cuda.get_rng_state()] - else: - module._inputs.append(inputs) - module._cuda_rng_state.append(torch.cuda.get_rng_state()) - -def checkpoint_pre_backward(module, grad_outputs): - inputs = module._inputs if not config['pipe_enabled'] else module._inputs[module._micro_idx] - cuda_rng_state = module._cuda_rng_state if not config['pipe_enabled'] else module._cuda_rng_state[module._micro_idx] - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): - with torch.enable_grad(): - torch.cuda.set_rng_state(cuda_rng_state) - #inputs[0].retain_grad() - out = module._module(*inputs) - torch.autograd.backward(out, *grad_outputs) - - if not config['pipe_enabled']: - zero_post_backward(module, grad_outputs, (inputs[0].grad,) ) - else: - zero_post_backward(module, grad_outputs, (inputs[0].grad,) ) - pipe_post_backward(module, (inputs[0].grad,), grad_outputs) + identity_post_backward(module, grad_inputs, grad_outputs) def identity_post_backward(module, grad_inputs, grad_outputs): if config['pipe_enabled']: - pipe_grad = pipe_pre_backward(module._pre_module, grad_outputs) - grad_outputs = pipe_grad if pipe_grad is not None else grad_outputs - zero_pre_backward(module._pre_module, grad_outputs) - if config['use_checkpoint']: - checkpoint_pre_backward(module._pre_module, grad_outputs) + pipe_grad = pipe_pre_backward(module._pre_module, grad_inputs) + grad_inputs = pipe_grad if pipe_grad is not None else grad_inputs + zero_pre_backward(module._pre_module, grad_inputs) + return grad_inputs class PipeAllGatherFunction(torch.autograd.Function): @staticmethod diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 1436082b..fddbe142 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -355,8 +355,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: module.register_forward_pre_hook(hook_func.pipe_pre_forward) module.register_forward_pre_hook(hook_func.zero_pre_forward) - if config['use_checkpoint']: - module.register_forward_pre_hook(hook_func.checkpoint_pre_forward) module.register_forward_hook(hook_func.zero_post_forward) module.register_forward_hook(hook_func.pipe_post_forward) @@ -364,12 +362,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: if torch_version >= '2.0.1': module.register_full_backward_pre_hook(hook_func.pipe_pre_backward) module.register_full_backward_pre_hook(hook_func.zero_pre_backward) - if config['use_checkpoint']: - if torch_version >= '2.0.1': - module.register_full_backward_pre_hook(hook_func.checkpoint_pre_backward) - else: - module.register_full_backward_hook(hook_func.zero_post_backward) - module.register_full_backward_hook(hook_func.pipe_post_backward) + module.register_full_backward_hook(hook_func.zero_post_backward) + module.register_full_backward_hook(hook_func.pipe_post_backward) module.stage_id = self.stage_id module.stages = self.stages From d5bbf1a86083589c83e82dcfbf52369c66d2aaf0 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 4 Aug 2023 12:22:25 +0800 Subject: [PATCH 025/122] custom hook --- bmtrain/block_layer.py | 158 +++++------------------------------------ bmtrain/hook_func.py | 48 +++++++++---- bmtrain/pipe_layer.py | 24 ------- 3 files changed, 50 insertions(+), 180 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 5c0b6062..33fdab61 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -19,124 +19,6 @@ import inspect from torch.utils.checkpoint import checkpoint -torch_version = hook_func.torch_version - -# the flag is used to control the zero level , 0 means normal zero3 , 1 means forward without release parameter ,2 means backward without gather parameter -class OpCheckpointBlock(torch.autograd.Function): - @staticmethod - def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len_args, *args): - ctx.block = block - ctx.preserve_rng_state = preserve_rng_state - - ctx.cuda_rng_state = torch.cuda.get_rng_state() if preserve_rng_state else None - tensors = [] - others = [] - for arg in args: - if torch.is_tensor(arg): - tensors.append(arg) - others.append(None) - else: - tensors.append(None) - others.append(arg) - - ctx.nontensor_inputs = others - ctx.len_args = len_args - ctx.save_for_backward(*tensors) - ctx.param_dict={} - if config['zero_level'] == 2: - flag = 1 - else: - flag = 0 - with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block, ctx.param_dict, flag): - inp_args = args[:len_args] - inp_kwargs = {} - for k, v in zip(args[len_args::2], args[len_args + 1::2]): - inp_kwargs[k] = v - outputs = ctx.block._module._call_impl(*inp_args, **inp_kwargs) - for it in inspector.hidden_states: - debug.append("_inspect_hidden_states", it) - ctx.inspect_list = inspector.hidden_states - - if not isinstance(outputs, list) and not isinstance(outputs, tuple): - outputs = [outputs] - len_outputs = 0 - else: - outputs = list(outputs) - len_outputs = len(outputs) - return tuple([len_outputs] + outputs + [hidden_state["tensor"] for hidden_state in inspector.hidden_states]) - - @staticmethod - def backward(ctx, _, *grads): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument.") - - all_inputs = [] - input_reqires_grad = [] - len_args = ctx.len_args - for tensor, other in zip(ctx.saved_tensors, ctx.nontensor_inputs): - if tensor is None: - all_inputs.append(other) - input_reqires_grad.append(False) - else: - input_reqires_grad.append( tensor.requires_grad ) - nw_tensor = tensor.detach() - nw_tensor.requires_grad = tensor.requires_grad - all_inputs.append(nw_tensor) - - - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=ctx.preserve_rng_state): - if ctx.preserve_rng_state: - torch.cuda.set_rng_state(ctx.cuda_rng_state) - if config['zero_level'] == 2: - flag = 2 - else: - flag = 0 - with torch.enable_grad(), CheckpointBlockContext(ctx.block, ctx.param_dict, flag): - inp_args = all_inputs[:len_args] - inp_kwargs = {} - for k, v in zip(all_inputs[len_args::2], all_inputs[len_args + 1::2]): - inp_kwargs[k] = v - with ScopedTensorInspectorContext() as inspector: - outputs = ctx.block._module._call_impl(*inp_args, **inp_kwargs) - if not isinstance(outputs, tuple): - outputs = (outputs,) - - assert len(outputs) + len(inspector.hidden_states) == len(grads) - - outputs_with_grad = [] - grad_of_output = [] - for i, output in enumerate(outputs): - if torch.is_tensor(output) and output.requires_grad: - outputs_with_grad.append(output) - grad_of_output.append(grads[i]) - - # calculate gradients for inputs, also for parameters - torch.autograd.backward( - outputs_with_grad + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], - grad_of_output + list(grads[len(outputs):]), - ) - assert len(ctx.inspect_list) == len(inspector.hidden_states), "Backward step changed" - for i, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.inspect_list[i]["name"], "Backward step changed" - assert it["shape"] == ctx.inspect_list[i]["shape"], "Backward step changed" - assert it["group"] == ctx.inspect_list[i]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.inspect_list[i]["tensor"] = it["tensor"] - ctx.inspect_list[i]["requires_grad"] = it["requires_grad"] - - grads = [] - for inp, requires_grad in zip(all_inputs, input_reqires_grad): - if requires_grad: - grads.append(inp.grad) - else: - grads.append(None) - return (None, None, None, None) + tuple(grads) - - def storage_type_cuda(storage_type): STORAGE_MAP = { torch.FloatStorage: torch.cuda.FloatStorage, @@ -307,12 +189,18 @@ def __init__(self, inner_module : torch.nn.Module): for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] + + self._is_first_layer = True + self._is_last_layer = True def forward(self, *args): + pre_out = hook_func.PreHookFunc.apply(self, args[0]) if config["use_checkpoint"]: - return checkpoint(self._module, *args) + out = checkpoint(self._module, pre_out, *args[1:]) else: - return self._module(*args) + out = self._module(pre_out, *args[1:]) + post_out = hook_func.PostHookFunc.apply(self, out) + return post_out def __getattr__(self,name:str): if name=="_module": @@ -573,14 +461,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - if config["zero_level"] > 0: - module.register_forward_pre_hook(hook_func.zero_pre_forward) - module.register_forward_hook(hook_func.zero_post_forward) - if torch_version >= '2.0.1': - module.register_full_backward_pre_hook(hook_func.zero_pre_backward) - if config["zero_level"] > 0: - module.register_full_backward_hook(hook_func.zero_post_backward) - module._backward_block_ctxs = self._backward_block_ctxs module._layer_id = i module._is_last_layer = True if i == len(modules) -1 else False @@ -588,14 +468,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self._modules[str(i)] = module self.add_module(str(i), module) - if i > 0: - module._pre_module = self._modules[str(i-1)] self.num_hidden = num_hidden - if torch_version < '2.0.1': - self.identity = torch.nn.Identity() - self.identity.register_full_backward_hook(hook_func.identity_post_backward) - self.identity._pre_module = self._modules[str(len(modules)-1)] if sqrt: length = len(self) @@ -620,10 +494,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self.save_list = [(i, i) for i in range(len(self))] def __len__(self) -> int: - if torch_version < '2.0.1': - return len(self._modules) - 1 - else: - return len(self._modules) + return len(self._modules) def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) @@ -634,14 +505,19 @@ def forward(self, *args, return_hidden_states = False): self.return_hidden_states = return_hidden_states outputs = args[:self.num_hidden] others = args[self.num_hidden:] + hidden_states = [] for i in range(len(self)): + hidden_states.append(outputs[0]) outputs = self._modules[str(i)](*outputs, *others) outputs = (outputs,) - if torch_version < '2.0.1': - outputs = self.identity(outputs) + if return_hidden_states: + hidden_states = [ + torch.stack(hidden_states[i::self.num_hidden], dim=0) + for i in range(self.num_hidden) + ] if return_hidden_states: - return tuple(outputs[:2*self.num_hidden]) + return outputs + tuple(hidden_states[:self.num_hidden]) else: return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 3a4b6879..70f2f143 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -3,8 +3,6 @@ from .checkpointing import CheckpointBlockContext from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations -torch_version = torch.__version__ - def zero_pre_forward(module, inputs): enter = True pipe = False @@ -42,9 +40,6 @@ def zero_post_backward(module, grad_inputs, grad_outputs): if module._layer_id == 0: module._backward_block_ctxs[0].exit(True) config['load_stream'].record_event(config['load_event']) - if torch_version < '2.0.1': - if not module._is_first_layer: - identity_post_backward(module, grad_inputs, grad_outputs) else: if module._micro_idx == 0: module._backward_block_ctxs[module._layer_id].exit(True) @@ -86,16 +81,6 @@ def pipe_post_backward(module, grad_inputs, grad_outputs): module._micro_idx -= 1 - if torch_version < '2.0.1': - if not module._is_first_layer: - identity_post_backward(module, grad_inputs, grad_outputs) - -def identity_post_backward(module, grad_inputs, grad_outputs): - if config['pipe_enabled']: - pipe_grad = pipe_pre_backward(module._pre_module, grad_inputs) - grad_inputs = pipe_grad if pipe_grad is not None else grad_inputs - zero_pre_backward(module._pre_module, grad_inputs) - return grad_inputs class PipeAllGatherFunction(torch.autograd.Function): @staticmethod @@ -125,3 +110,36 @@ def backward(ctx, grads): grad_list = grad_list.flatten(start_dim=0, end_dim=1) return grad_list +class PreHookFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, module, x): + ctx.module = module + if config['pipe_enabled']: + pipe_out = pipe_pre_forward(module, (x,)) + x = pipe_out[0] if pipe_out is not None else x + zero_pre_forward(module, x) + return x + + @staticmethod + def backward(ctx, grads): + zero_post_backward(ctx.module, grads, None) + if config['pipe_enabled']: + pipe_post_backward(ctx.module, grads, None) + return None, grads + +class PostHookFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, module, out): + ctx.module = module + zero_post_forward(module, None, out) + if config['pipe_enabled']: + pipe_post_forward(module, None, out) + return out + + @staticmethod + def backward(ctx, grads): + zero_pre_backward(ctx.module, grads) + if config['pipe_enabled']: + pipe_grads = pipe_pre_backward(ctx.module, (grads, )) + grads = pipe_grads[0] if pipe_grads is not None else grads + return None, grads diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index fddbe142..1189f52b 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -16,8 +16,6 @@ from .block_layer import CheckpointBlock, round_up, _get_param_kw from . import hook_func -torch_version = hook_func.torch_version - class OpMicroForward(torch.autograd.Function): @staticmethod def forward(ctx, placeholder, self : 'PipelineTransformerBlockList', micro_idx, block_ctx_list, layers_dict, save_list, hidden_state, *args): @@ -353,24 +351,10 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - module.register_forward_pre_hook(hook_func.pipe_pre_forward) - module.register_forward_pre_hook(hook_func.zero_pre_forward) - - module.register_forward_hook(hook_func.zero_post_forward) - module.register_forward_hook(hook_func.pipe_post_forward) - - if torch_version >= '2.0.1': - module.register_full_backward_pre_hook(hook_func.pipe_pre_backward) - module.register_full_backward_pre_hook(hook_func.zero_pre_backward) - module.register_full_backward_hook(hook_func.zero_post_backward) - module.register_full_backward_hook(hook_func.pipe_post_backward) - module.stage_id = self.stage_id module.stages = self.stages self._modules[str(idx)] = module - if idx > 0: - module._pre_module = self._modules[str(idx-1)] self.layer_ids = self.get_range_by_stage_id(self.stage_id) @@ -388,11 +372,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self.prev_rank = pipe_group[self.pipe_idx, self.stage_id - 1] if self.stage_id > 0 else -1 # self.micro_batches = config['num_micro_batches'] - if torch_version < '2.0.1': - self.identity = torch.nn.Identity() - self.identity.register_full_backward_hook(hook_func.identity_post_backward) - self.identity._pre_module = self._modules[str(self.layer_ids[-1])] - self.save_list = [(i, i) for i in range(len(self.layer_ids))] def __len__(self) -> int: @@ -442,9 +421,6 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa for idx,layer_id in enumerate(self.layer_ids): self._modules[str(layer_id)]._micro_idx = micro_idx hidden_state = self._modules[str(layer_id)](hidden_state, *arg) - if torch_version < '2.0.1': - if idx == len(self.layer_ids) - 1: - hidden_state = self.identity(hidden_state) outputs.append(hidden_state) last_hidden = torch.cat(outputs, dim=0) From e92d0efdd0bdef17f6aed2a3310d85a0c58f422b Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 4 Aug 2023 13:44:14 +0800 Subject: [PATCH 026/122] optimize code structure --- bmtrain/block_layer.py | 12 ++++++++---- bmtrain/hook_func.py | 16 ++++++++-------- bmtrain/pipe_layer.py | 5 ++--- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 33fdab61..03a02914 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -192,6 +192,12 @@ def __init__(self, inner_module : torch.nn.Module): self._is_first_layer = True self._is_last_layer = True + self._pre_module = None + self._next_module = None + + def set_pre_module(self, module): + self._pre_module = module + module._next_module = self def forward(self, *args): pre_out = hook_func.PreHookFunc.apply(self, args[0]) @@ -430,7 +436,6 @@ def eval(self): def __repr__(self): return self._module.__repr__() - class TransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of CheckpointBlocks. @@ -456,18 +461,17 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) super().__init__() self._modules = {} - self._backward_block_ctxs = [None for _ in range(len(modules))] for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - module._backward_block_ctxs = self._backward_block_ctxs - module._layer_id = i module._is_last_layer = True if i == len(modules) -1 else False module._is_first_layer = True if i == 0 else False self._modules[str(i)] = module self.add_module(str(i), module) + if i > 0: + self._modules[str(i)].set_pre_module(self._modules[str(i-1)]) self.num_hidden = num_hidden diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 70f2f143..c121ad87 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -25,24 +25,24 @@ def zero_post_forward(module, inputs, outputs): def zero_pre_backward(module, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 if not config['pipe_enabled']: - module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag) - module._backward_block_ctxs[module._layer_id].enter(True) + module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, backward_flag) + module._backward_block_ctx.enter(True) if not module._is_last_layer: - module._backward_block_ctxs[module._layer_id + 1].exit(True) + module._next_module._backward_block_ctx.exit(True) config['load_stream'].record_event(config['load_event']) else: if module._micro_idx == config['micros'] - 1: - module._backward_block_ctxs[module._layer_id] = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) - module._backward_block_ctxs[module._layer_id].enter(True) + module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) + module._backward_block_ctx.enter(True) def zero_post_backward(module, grad_inputs, grad_outputs): if not config['pipe_enabled']: - if module._layer_id == 0: - module._backward_block_ctxs[0].exit(True) + if module._is_first_layer: + module._backward_block_ctx.exit(True) config['load_stream'].record_event(config['load_event']) else: if module._micro_idx == 0: - module._backward_block_ctxs[module._layer_id].exit(True) + module._backward_block_ctx.exit(True) config['load_stream'].record_event(config['load_event']) class PipePreFunction(torch.autograd.Function): diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 1189f52b..faa92e91 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -358,14 +358,13 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self.layer_ids = self.get_range_by_stage_id(self.stage_id) - self._backward_block_ctxs = [None for _ in range(len(modules))] for i,layer_id in enumerate(self.layer_ids): - self._modules[str(layer_id)]._layer_id = layer_id - self._modules[str(layer_id)]._backward_block_ctxs = self._backward_block_ctxs self._modules[str(layer_id)]._is_first_stage = True if self.stage_id == 0 else False self._modules[str(layer_id)]._is_last_stage = True if self.stage_id == self.stages-1 else False self._modules[str(layer_id)]._is_first_layer = True if i == 0 else False self._modules[str(layer_id)]._is_last_layer = True if i == len(self.layer_ids)-1 else False + if i > 0: + self._modules[str(layer_id)].set_pre_module(self._modules[str(layer_id-1)]) self.partition_modules(self.layer_ids) self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 From 6ba753e90aff07e39ff04e12a3cb0fd19a270ae1 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 4 Aug 2023 18:29:42 +0800 Subject: [PATCH 027/122] for hidden_state --- bmtrain/block_layer.py | 9 +++++++-- bmtrain/hook_func.py | 19 +++++++++---------- bmtrain/pipe_layer.py | 3 ++- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 03a02914..355b6370 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -59,6 +59,7 @@ class CheckpointBlock(torch.nn.Module): Args: model (torch.nn.Module): The model to be checkpointed. All kinds of modules are supported. + mode (str): run in BLOCK or ZERO or PIPE Examples: >>> transformer_block = TransformerBlock(...) @@ -67,7 +68,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module): + def __init__(self, inner_module : torch.nn.Module, mode="BLOCK"): super().__init__() self._module = inner_module self._inputs = None @@ -194,6 +195,7 @@ def __init__(self, inner_module : torch.nn.Module): self._is_last_layer = True self._pre_module = None self._next_module = None + self._mode = mode #BLOCK or ZERO or PIPE def set_pre_module(self, module): self._pre_module = module @@ -206,6 +208,8 @@ def forward(self, *args): else: out = self._module(pre_out, *args[1:]) post_out = hook_func.PostHookFunc.apply(self, out) + if isinstance(post_out, list): + return tuple(post_out) return post_out def __getattr__(self,name:str): @@ -463,8 +467,9 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self._modules = {} for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): - module = CheckpointBlock(module) + module = CheckpointBlock(module, "ZERO") + module._mode = "ZERO" module._is_last_layer = True if i == len(modules) -1 else False module._is_first_layer = True if i == 0 else False diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index c121ad87..80a5511f 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -6,7 +6,7 @@ def zero_pre_forward(module, inputs): enter = True pipe = False - if config['pipe_enabled']: + if module._mode is "PIPE": enter = module._micro_idx == 0 pipe = True if enter: @@ -16,7 +16,7 @@ def zero_pre_forward(module, inputs): def zero_post_forward(module, inputs, outputs): exit = True - if config['pipe_enabled']: + if module._mode is "PIPE": exit = module._micro_idx == config['micros'] - 1 if exit: @@ -24,10 +24,10 @@ def zero_post_forward(module, inputs, outputs): def zero_pre_backward(module, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 - if not config['pipe_enabled']: + if module._mode is not "PIPE": module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, backward_flag) module._backward_block_ctx.enter(True) - if not module._is_last_layer: + if not module._is_last_layer and module._next_module is not None and module._next_module._backward_block_ctx is not None: module._next_module._backward_block_ctx.exit(True) config['load_stream'].record_event(config['load_event']) else: @@ -36,7 +36,7 @@ def zero_pre_backward(module, grad_outputs): module._backward_block_ctx.enter(True) def zero_post_backward(module, grad_inputs, grad_outputs): - if not config['pipe_enabled']: + if module._mode is not "PIPE": if module._is_first_layer: module._backward_block_ctx.exit(True) config['load_stream'].record_event(config['load_event']) @@ -81,7 +81,6 @@ def pipe_post_backward(module, grad_inputs, grad_outputs): module._micro_idx -= 1 - class PipeAllGatherFunction(torch.autograd.Function): @staticmethod def forward(ctx, hidden_state): @@ -114,7 +113,7 @@ class PreHookFunc(torch.autograd.Function): @staticmethod def forward(ctx, module, x): ctx.module = module - if config['pipe_enabled']: + if module._mode is "PIPE": pipe_out = pipe_pre_forward(module, (x,)) x = pipe_out[0] if pipe_out is not None else x zero_pre_forward(module, x) @@ -123,7 +122,7 @@ def forward(ctx, module, x): @staticmethod def backward(ctx, grads): zero_post_backward(ctx.module, grads, None) - if config['pipe_enabled']: + if ctx.module._mode is "PIPE": pipe_post_backward(ctx.module, grads, None) return None, grads @@ -132,14 +131,14 @@ class PostHookFunc(torch.autograd.Function): def forward(ctx, module, out): ctx.module = module zero_post_forward(module, None, out) - if config['pipe_enabled']: + if module._mode is "PIPE": pipe_post_forward(module, None, out) return out @staticmethod def backward(ctx, grads): zero_pre_backward(ctx.module, grads) - if config['pipe_enabled']: + if ctx.module._mode is "PIPE": pipe_grads = pipe_pre_backward(ctx.module, (grads, )) grads = pipe_grads[0] if pipe_grads is not None else grads return None, grads diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index faa92e91..c4898e2f 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -349,8 +349,9 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self.pipe_idx = topo.pipe_idx for idx, module in enumerate(modules): if not isinstance(module, CheckpointBlock): - module = CheckpointBlock(module) + module = CheckpointBlock(module, "PIPE") + module._mode = "PIPE" module.stage_id = self.stage_id module.stages = self.stages From b0a0da9592ca77d7bf20efd10376b1b5ab6a4a24 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 4 Aug 2023 19:06:28 +0800 Subject: [PATCH 028/122] for input.requires_grad is False --- bmtrain/block_layer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 355b6370..56a12606 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -202,6 +202,9 @@ def set_pre_module(self, module): module._next_module = self def forward(self, *args): + #input must be requires_grad, otherwise autograd.backward will make an error + self.input_requires_grad = args[0].requires_grad + args[0].requires_grad_() pre_out = hook_func.PreHookFunc.apply(self, args[0]) if config["use_checkpoint"]: out = checkpoint(self._module, pre_out, *args[1:]) From f4a0e0bf84bb9611909b4960f01c5f11b394b74c Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 5 Aug 2023 13:02:04 +0800 Subject: [PATCH 029/122] fix --- bmtrain/hook_func.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 80a5511f..a36728eb 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -6,7 +6,7 @@ def zero_pre_forward(module, inputs): enter = True pipe = False - if module._mode is "PIPE": + if module._mode == "PIPE": enter = module._micro_idx == 0 pipe = True if enter: @@ -16,7 +16,7 @@ def zero_pre_forward(module, inputs): def zero_post_forward(module, inputs, outputs): exit = True - if module._mode is "PIPE": + if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 if exit: @@ -24,7 +24,7 @@ def zero_post_forward(module, inputs, outputs): def zero_pre_backward(module, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 - if module._mode is not "PIPE": + if module._mode != "PIPE": module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, backward_flag) module._backward_block_ctx.enter(True) if not module._is_last_layer and module._next_module is not None and module._next_module._backward_block_ctx is not None: @@ -36,7 +36,7 @@ def zero_pre_backward(module, grad_outputs): module._backward_block_ctx.enter(True) def zero_post_backward(module, grad_inputs, grad_outputs): - if module._mode is not "PIPE": + if module._mode != "PIPE": if module._is_first_layer: module._backward_block_ctx.exit(True) config['load_stream'].record_event(config['load_event']) @@ -113,7 +113,7 @@ class PreHookFunc(torch.autograd.Function): @staticmethod def forward(ctx, module, x): ctx.module = module - if module._mode is "PIPE": + if module._mode == "PIPE": pipe_out = pipe_pre_forward(module, (x,)) x = pipe_out[0] if pipe_out is not None else x zero_pre_forward(module, x) @@ -122,7 +122,7 @@ def forward(ctx, module, x): @staticmethod def backward(ctx, grads): zero_post_backward(ctx.module, grads, None) - if ctx.module._mode is "PIPE": + if ctx.module._mode == "PIPE": pipe_post_backward(ctx.module, grads, None) return None, grads @@ -131,14 +131,14 @@ class PostHookFunc(torch.autograd.Function): def forward(ctx, module, out): ctx.module = module zero_post_forward(module, None, out) - if module._mode is "PIPE": + if module._mode == "PIPE": pipe_post_forward(module, None, out) return out @staticmethod def backward(ctx, grads): zero_pre_backward(ctx.module, grads) - if ctx.module._mode is "PIPE": + if ctx.module._mode == "PIPE": pipe_grads = pipe_pre_backward(ctx.module, (grads, )) grads = pipe_grads[0] if pipe_grads is not None else grads return None, grads From 8faff0fc545eb92e8df6b50bdd4c7de8b00bb2bf Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sun, 6 Aug 2023 15:53:09 +0800 Subject: [PATCH 030/122] pipeline support return hidden_state --- bmtrain/block_layer.py | 15 +- bmtrain/hook_func.py | 35 +--- bmtrain/pipe_layer.py | 453 +++++++++-------------------------------- 3 files changed, 107 insertions(+), 396 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 56a12606..ec4dc456 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -201,15 +201,15 @@ def set_pre_module(self, module): self._pre_module = module module._next_module = self - def forward(self, *args): + def forward(self, hidden_state, return_hidden_state=False, hidden_states=[], *args): #input must be requires_grad, otherwise autograd.backward will make an error - self.input_requires_grad = args[0].requires_grad - args[0].requires_grad_() - pre_out = hook_func.PreHookFunc.apply(self, args[0]) + self.input_requires_grad = hidden_state.requires_grad + hidden_state.requires_grad_() + pre_out = hook_func.PreHookFunc.apply(self, hidden_state, return_hidden_state, hidden_states) if config["use_checkpoint"]: - out = checkpoint(self._module, pre_out, *args[1:]) + out = checkpoint(self._module, pre_out, *args) else: - out = self._module(pre_out, *args[1:]) + out = self._module(pre_out, *args) post_out = hook_func.PostHookFunc.apply(self, out) if isinstance(post_out, list): return tuple(post_out) @@ -519,8 +519,7 @@ def forward(self, *args, return_hidden_states = False): others = args[self.num_hidden:] hidden_states = [] for i in range(len(self)): - hidden_states.append(outputs[0]) - outputs = self._modules[str(i)](*outputs, *others) + outputs = self._modules[str(i)](*outputs, return_hidden_states, hidden_states, *others) outputs = (outputs,) if return_hidden_states: diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index a36728eb..c0925f28 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -81,41 +81,16 @@ def pipe_post_backward(module, grad_inputs, grad_outputs): module._micro_idx -= 1 -class PipeAllGatherFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_state): - hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"]) - hidden_state_list.requires_grad_() - return hidden_state_list - - @staticmethod - def backward(ctx, grads): - grads = broadcast(grads, 0, config['pipe_comm']) - topo = config['topology'] - return grads.chunk(topo.stages, dim=0)[topo.stage_id] - -class PipePostFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, last_hidden): - last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) - last_hidden = last_hidden.chunk(config['topology'].stages, dim=0) - outputs = last_hidden[config['topology'].stage_id] - outputs.requires_grad_() - return outputs - - @staticmethod - def backward(ctx, grads): - grad_list = all_gather(grads, config["pipe_comm"]) - grad_list = grad_list.flatten(start_dim=0, end_dim=1) - return grad_list - class PreHookFunc(torch.autograd.Function): @staticmethod - def forward(ctx, module, x): + def forward(ctx, module, x, return_hidden_states=False, hidden_states=[]): ctx.module = module if module._mode == "PIPE": pipe_out = pipe_pre_forward(module, (x,)) x = pipe_out[0] if pipe_out is not None else x + + if return_hidden_states: + hidden_states.append(x) zero_pre_forward(module, x) return x @@ -124,7 +99,7 @@ def backward(ctx, grads): zero_post_backward(ctx.module, grads, None) if ctx.module._mode == "PIPE": pipe_post_backward(ctx.module, grads, None) - return None, grads + return None, grads, None, None class PostHookFunc(torch.autograd.Function): @staticmethod diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index c4898e2f..32b8e733 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -16,304 +16,67 @@ from .block_layer import CheckpointBlock, round_up, _get_param_kw from . import hook_func -class OpMicroForward(torch.autograd.Function): +class PipeAllGatherFunction(torch.autograd.Function): @staticmethod - def forward(ctx, placeholder, self : 'PipelineTransformerBlockList', micro_idx, block_ctx_list, layers_dict, save_list, hidden_state, *args): - with PipeContext(self, hidden_state) as pipe_input: - hidden_state = pipe_input[0].detach() - tensors = [arg if torch.is_tensor(arg) else None for arg in args] - others = [arg if not torch.is_tensor(arg) else None for arg in args] - ctx.nontensor_inputs = others - ctx.self = self - ctx.micro_idx = micro_idx - ctx.block_ctx_list = block_ctx_list - ctx.layers_dict = layers_dict - ctx.save_list = copy.deepcopy(save_list) - ctx.num_save_needed = save_list[-1][1]+1 - layer_inputs = [] - layer_inspector = [] - cuda_rng_state = [] - for idx,layer_id in enumerate(self.layer_ids): - with torch.no_grad(): - if save_list[idx][0] == idx: - layer_inputs.append(hidden_state.detach()) - cuda_rng_state.append( torch.cuda.get_rng_state() ) - # gather parameter on load stream - if ctx.micro_idx == 0: - block_ctx_list[idx] = CheckpointBlockContext(self._modules[str(layer_id)], ctx.layers_dict[idx], 1, pipe=True) - block_ctx_list[idx].enter() - # call inner module directly - with ScopedTensorInspectorContext() as inspector: - hidden_state = self._modules[str(layer_id)]._module._call_impl(hidden_state, *args) - if ctx.micro_idx == config["micros"]-1: - block_ctx_list[idx].exit() - for ith, it in enumerate(inspector.hidden_states): - it["inside_pipe"] = { - "stage_id": self.stage_id, - "stages": self.stages, - "st": (layer_id==self.layer_ids[0] and ith==0), - "ed": (layer_id==self.layer_ids[-1] and ith==len(inspector.hidden_states)-1), - } - debug.append("_inspect_hidden_states", it) - layer_inspector.append(inspector.hidden_states) - - ctx.layer_inspector = layer_inspector - ctx.cuda_rng_state = cuda_rng_state - - ctx.save_for_backward(*layer_inputs, *tensors) - pipe_input[0] = hidden_state - if self.return_hidden_states: - middle_hiddens = layer_inputs - for mid in middle_hiddens: - mid.requires_grad_() - middle_hiddens = torch.stack(middle_hiddens, dim=0) - else: - middle_hiddens = None - return tuple([pipe_input[0], middle_hiddens] + [hidden_state["tensor"] for hidden_states in ctx.layer_inspector for hidden_state in hidden_states]) + def forward(ctx, hidden_state): + hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"]) + hidden_state_list.requires_grad_() + return hidden_state_list @staticmethod - def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle : torch.Tensor, *grad_inspector): - def exit_prev(prev_ctx, prev_grad): - if prev_ctx is not None: - if prev_grad: - with torch.enable_grad(): - prev_ctx.exit(True) - config["load_stream"].record_event(config["load_event"]) - else: - with torch.no_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument.") - all_inputs = [] - input_requires_grad = [] - - layer_inputs = ctx.saved_tensors[:ctx.num_save_needed] - save_args = ctx.saved_tensors[ctx.num_save_needed:] - for tensor, other in zip(save_args, ctx.nontensor_inputs): - if tensor is None: - all_inputs.append(other) - input_requires_grad.append(False) - else: - # detach for tensor inputs - input_requires_grad.append( tensor.requires_grad ) - nw_tensor = tensor.detach() - nw_tensor.requires_grad = tensor.requires_grad - all_inputs.append(nw_tensor) - with PipeContext(ctx.self, grad_hidden_state, backward=True) as pipe_input: - grad_hidden_state = pipe_input[0] - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): - with torch.enable_grad(): - # overlap load and scatter here - prev_ctx = None - prev_grad = False - for idx, layer_id in list(enumerate(ctx.self.layer_ids))[::-1]: - torch.cuda.set_rng_state(ctx.cuda_rng_state[idx]) - ipt = layer_inputs[ctx.save_list[idx][1]].requires_grad_() - if ctx.micro_idx == 0: - ctx.block_ctx_list[idx] = CheckpointBlockContext(ctx.self._modules[str(layer_id)], ctx.layers_dict[idx], 2, pipe=True) - ctx.block_ctx_list[idx].enter(True) - if ctx.micro_idx == config["micros"]-1: - exit_prev(prev_ctx, prev_grad) - prev_ctx = ctx.block_ctx_list[idx] - prev_grad = True - - with ScopedTensorInspectorContext() as inspector: - output = ctx.self._modules[str(layer_id)]._module._call_impl(ipt, *all_inputs) - - assert len(ctx.layer_inspector[idx]) == len(inspector.hidden_states), "Backward step changed" - for j, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.layer_inspector[idx][j]["name"], "Backward step changed" - assert it["shape"] == ctx.layer_inspector[idx][j]["shape"], "Backward step changed" - assert it["group"] == ctx.layer_inspector[idx][j]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.layer_inspector[idx][j]["tensor"] = it["tensor"] - ctx.layer_inspector[idx][j]["requires_grad"] = it["requires_grad"] - if len(inspector.hidden_states) > 0: - torch.autograd.backward( - [output] + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], - [grad_hidden_state] + list(grad_inspector[-len(inspector.hidden_states):]), - ) - grad_inspector = grad_inspector[:-len(inspector.hidden_states)] - else: - torch.autograd.backward( - [output], - [grad_hidden_state], - ) - grad_hidden_state = ipt.grad - - if grad_middle is not None: - grad_hidden_state = grad_hidden_state + grad_middle[idx] - if ctx.micro_idx == config["micros"]-1: - exit_prev(prev_ctx, prev_grad) - for inspector_hiddens in ctx.layer_inspector: - for it in inspector_hiddens: - debug.append("_inspect_hidden_states", it) - - pipe_input[0] = grad_hidden_state - grads = [] - for inp, requires_grad in zip(all_inputs, input_requires_grad): - if requires_grad: - grads.append(inp.grad) - else: - grads.append(None) - return (None, None, None, None, None, None, pipe_input[0]) + tuple(grads) + def backward(ctx, grads): + grads = broadcast(grads, 0, config['pipe_comm']) + topo = config['topology'] + return grads.chunk(topo.stages, dim=0)[topo.stage_id] -class OpPipeTransformerBlockList(torch.autograd.Function): +class PipePostFunction(torch.autograd.Function): @staticmethod - def forward(ctx, placeholder, self : 'PipelineTransformerBlockList', save_list, hidden_state, *args): - num_micros = config["micros"] - ctx.self = self - ctx.num_micros = num_micros - block_ctx = [None for _ in range(len(self))] - layers_dict = [{} for _ in range(len(self))] - args_list = [[] for _ in range(num_micros)] - batch_related = args[-1] - batch_related_origin = [True if i in args[-1] else False for i in range(len(args[:-1]))] - batch_related_rule = [] - args = args[:-1] - batch_size = hidden_state.shape[0] - assert (batch_size * config["pipe_size"]) % num_micros == 0, f'The batch size {(batch_size * config["pipe_size"])} must be divisible by the number of micro_batch {num_micros}' - input_requires_grad = [] - inspector_hiddens = [] - ctx.inspector_hiddens_sep = [0] - ctx.micro_inspector = [] - with torch.enable_grad(): - for arg in args: - if torch.is_tensor(arg): - arg_all = all_gather(arg, config['pipe_comm']) - if arg.shape[0] == batch_size: - batch_related_rule.append(True) - arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) - arg_all = [tensor.detach().requires_grad_(arg.requires_grad) for tensor in arg_all] - else: - batch_related_rule.append(False) - # assert num_micros % self.stages == 0, "batch unrelated only support num_micros % stages == 0" - # arg_all = [arg_all[i // (num_micros // self.stages)].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] - arg_all = [arg_all[0].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] - input_requires_grad.append(arg.requires_grad) - else: - batch_related_rule.append(False) - arg_all = [arg for _ in range(num_micros)] - input_requires_grad.append(False) - for i in range(num_micros): - args_list[i].append(arg_all[i]) - outputs = [] - if self.return_hidden_states: - middles = [] - hidden_state_list = all_gather(hidden_state, config["pipe_comm"]).flatten(0, 1).detach().requires_grad_() - ctx.hidden_state_list = hidden_state_list - hidden_state_list = hidden_state_list.chunk(num_micros, dim=0) - for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - with ScopedTensorInspectorContext() as inspector: - micro_outputs = OpMicroForward.apply(placeholder, self, micro_idx, block_ctx, layers_dict, save_list, hidden_state, *arg) - output, middle = micro_outputs[:2] - outputs.append(output) - if self.return_hidden_states: - middles.append(middle) - for it in inspector.hidden_states: - inspector_hiddens.append(it["tensor"]) - it["tensor"] = it["tensor"].clone() - debug.append("_inspect_hidden_states", it) - ctx.inspector_hiddens_sep.append(len(inspector_hiddens)) - ctx.micro_inspector.append(inspector.hidden_states) - if len(batch_related) == 0: - ctx.batch_related = batch_related_rule - else: - ctx.batch_related = batch_related_origin - ctx.args_list = args_list - ctx.input_requires_grad = input_requires_grad - ctx.output_list = outputs - if self.return_hidden_states: - ctx.middle_list = middles - - with torch.enable_grad(): - last_hidden = torch.cat(outputs, dim=0) - last_hidden_shape = last_hidden.shape - last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) - last_hidden = last_hidden.chunk(self.stages, dim=0) - last_hidden = last_hidden[self.stage_id].clone() - - if self.return_hidden_states: + def forward(ctx, last_hidden, hidden_states=None, forward_stage_ranges=None, backward_stage_ranges=None, last_hidden_shape=None, return_hidden_states=False): + topo = config['topology'] + ctx.return_hidden_states = return_hidden_states + last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) + last_hidden = last_hidden.chunk(topo.stages, dim=0) + output = last_hidden[topo.stage_id] + output.requires_grad_() + + if return_hidden_states: + ctx.stage_id = topo.stage_id + ctx.stages = topo.stages + ctx.backward_stage_ranges = backward_stage_ranges middle_hiddens = [] - with torch.enable_grad(): - for stage_id in range(self.stages): - if self.stage_id == stage_id: - middle_hidden = torch.cat(middles, dim=1) # [(layers, micro_batch, ...), ] -> (layers, full_batch, ...) - else: - middle_shape = (self.get_part_len_by_stage_id(stage_id),)+last_hidden_shape - middle_hidden = torch.zeros(middle_shape, device=last_hidden.device, dtype=last_hidden.dtype) - middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"]) - middle_hidden = middle_hidden.chunk(self.stages, dim=1) - middle_hidden = middle_hidden[self.stage_id].clone() - middle_hiddens.append(middle_hidden) - middle_hiddens = torch.cat(middle_hiddens, dim=0) + for stage_id in range(ctx.stages): + if ctx.stage_id == stage_id: + middle_hidden = hidden_states + else: + middle_shape = (forward_stage_ranges[stage_id],) + last_hidden_shape + middle_hidden = torch.zeros(middle_shape, device=hidden_states.device, dtype=hidden_states.dtype) + middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"]) + middle_hidden = middle_hidden.chunk(ctx.stages, dim=1) + middle_hidden = middle_hidden[ctx.stage_id].clone() + middle_hiddens.append(middle_hidden) + middle_hiddens = torch.cat(middle_hiddens, dim=0) + middle_hiddens.requires_grad_() + return output, middle_hiddens else: - middle_hiddens = None - - ctx.save_for_backward(*inspector_hiddens) - return tuple([last_hidden, middle_hiddens] + [it["tensor"] for inspector_hiddens in ctx.micro_inspector for it in inspector_hiddens]) - + return output @staticmethod - def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle : torch.Tensor, *grad_inspectors): - inspector_hiddens = ctx.saved_tensors - ipt = ctx.hidden_state_list - args_list = ctx.args_list - input_requires_grad = ctx.input_requires_grad - grad_hidden_state_list = all_gather(grad_hidden_state, config["pipe_comm"]).flatten(start_dim=0, end_dim=1).chunk(ctx.num_micros, dim=0) - if ctx.self.return_hidden_states: - for stage_id in range(ctx.self.stages): - layer_range = ctx.self.get_range_by_stage_id(stage_id) + def backward(ctx, grads, grad_middle=None): + grad_list = all_gather(grads, config["pipe_comm"]) + grad_list = grad_list.flatten(start_dim=0, end_dim=1) + + if ctx.return_hidden_states: + for stage_id in range(ctx.stages): + layer_range = ctx.backward_stage_ranges[stage_id] grad_middle_state = grad_middle[layer_range] - grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"]).flatten(start_dim=0, end_dim=1).transpose(0, 1).chunk(ctx.num_micros, dim=1) # (layer, micro_batch, ...) - if ctx.self.stage_id == stage_id: - grad_middle_state_list = grad_middle_state - - for m in range(ctx.num_micros): - outputs = [ctx.output_list[m]] - grad_outputs = [grad_hidden_state_list[m]] - if ctx.self.return_hidden_states: - outputs.append(ctx.middle_list[m]) - grad_outputs.append(grad_middle_state_list[m]) - outputs += list(inspector_hiddens[ctx.inspector_hiddens_sep[m]:ctx.inspector_hiddens_sep[m+1]]) - grad_outputs += list(grad_inspectors[ctx.inspector_hiddens_sep[m]:ctx.inspector_hiddens_sep[m+1]]) - with ScopedTensorInspectorContext() as inspector: - torch.autograd.backward( - outputs, - grad_outputs, - ) - for j, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.micro_inspector[m][j]["name"], "Backward step changed" - assert it["shape"] == ctx.micro_inspector[m][j]["shape"], "Backward step changed" - assert it["group"] == ctx.micro_inspector[m][j]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.micro_inspector[m][j]["tensor"] = it["tensor"] - ctx.micro_inspector[m][j]["requires_grad"] = it["requires_grad"] - - grads = [] - for idx,requires_grad in enumerate(input_requires_grad): - if requires_grad: - grad = torch.cat([args_list[m][idx].grad for m in range(ctx.num_micros)], dim=0) - grad = all_reduce(grad, "sum", config["pipe_comm"]) - split_size = ctx.self.stages if ctx.batch_related[idx] else ctx.num_micros - grad = grad.chunk(split_size) - if ctx.batch_related[idx]: - grads.append(grad[ctx.self.stage_id]) - else: - grads.append(grad[0]) - else: - grads.append(None) - grad = broadcast(ipt.grad, 0, config["pipe_comm"]).chunk(ctx.self.stages) - grad = grad[ctx.self.stage_id] + grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"]) + grad_middle_state = grad_middle_state.flatten(start_dim=0, end_dim=1).transpose(0, 1) + if ctx.stage_id == stage_id: + grad_hidden_state_list = grad_middle_state + return grad_list, grad_hidden_state_list, None, None, None, None + else: + return grad_list - return (None, None, None, grad) + tuple(grads) + (None,) class PipelineTransformerBlockList(torch.nn.Module): r""" @@ -360,6 +123,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self.layer_ids = self.get_range_by_stage_id(self.stage_id) for i,layer_id in enumerate(self.layer_ids): + self._modules[str(layer_id)].layer_id = layer_id self._modules[str(layer_id)]._is_first_stage = True if self.stage_id == 0 else False self._modules[str(layer_id)]._is_last_stage = True if self.stage_id == self.stages-1 else False self._modules[str(layer_id)]._is_first_layer = True if i == 0 else False @@ -385,51 +149,53 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=False): self.return_hidden_states = return_hidden_states - if False: - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - args = list(args) - args.append(batch_related) - outputs = OpPipeTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args) - hidden_state, middle_states = outputs[:2] - if return_hidden_states: - return hidden_state, middle_states - else: - return hidden_state - else: - batch_size = hidden_state.shape[0] - num_micros = config["micros"] - hidden_state_list = hook_func.PipeAllGatherFunction.apply(hidden_state) - - args_list = [[] for _ in range(num_micros)] - for arg in args: - if torch.is_tensor(arg): - arg_all = all_gather(arg, config['pipe_comm']) - if arg.shape[0] == batch_size: - arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) - arg_all = [tensor.detach().requires_grad_(arg.requires_grad) for tensor in arg_all] - else: - arg_all = [arg_all[0].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] - else: - arg_all = [arg for _ in range(num_micros)] - for i in range(num_micros): - args_list[i].append(arg_all[i]) - - hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0) - outputs = [] - - for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): - for idx,layer_id in enumerate(self.layer_ids): - self._modules[str(layer_id)]._micro_idx = micro_idx - hidden_state = self._modules[str(layer_id)](hidden_state, *arg) - outputs.append(hidden_state) + batch_size = hidden_state.shape[0] + num_micros = config["micros"] + hidden_state_list = PipeAllGatherFunction.apply(hidden_state) - last_hidden = torch.cat(outputs, dim=0) - outputs = hook_func.PipePostFunction.apply(last_hidden) + args_list = [[] for _ in range(num_micros)] + for arg in args: + if torch.is_tensor(arg): + arg_all = all_gather(arg, config['pipe_comm']) + if arg.shape[0] == batch_size: + arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) + arg_all = [tensor.detach().requires_grad_(arg.requires_grad) for tensor in arg_all] + else: + arg_all = [arg_all[0].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] + else: + arg_all = [arg for _ in range(num_micros)] + for i in range(num_micros): + args_list[i].append(arg_all[i]) + + hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0) + outputs = [] + hidden_states = [] + for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): + micro_hidden_states = [] + for idx,layer_id in enumerate(self.layer_ids): + self._modules[str(layer_id)]._micro_idx = micro_idx + hidden_state = self._modules[str(layer_id)](hidden_state, return_hidden_states, micro_hidden_states, *arg) + outputs.append(hidden_state) if return_hidden_states: - return outputs[:2*self.num_hidden] - else: - return outputs[:self.num_hidden] if self.num_hidden > 1 else outputs + hidden_states.append(torch.stack(micro_hidden_states, dim=0)) + + last_hidden = torch.cat(outputs, dim=0) + last_hidden_shape = last_hidden.shape + + if return_hidden_states: + hidden_states = torch.cat(hidden_states, dim=1) + forward_stage_ranges = [] + backward_stage_ranges = [] + for stage_id in range(self.stages): + forward_stage_ranges.append(self.get_part_len_by_stage_id(stage_id)) + backward_stage_ranges.append(self.get_range_by_stage_id(stage_id)) + outputs, hidden_states = PipePostFunction.apply(last_hidden, hidden_states, forward_stage_ranges, backward_stage_ranges, last_hidden_shape, return_hidden_states) + return outputs, hidden_states + else: + outputs = PipePostFunction.apply(last_hidden) + return outputs + def get_range_by_stage_id(self, stage_id : int) -> List[int]: part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)] @@ -541,32 +307,3 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): for n, parameter in module._module.named_parameters(): destination[name+n] = recv_activations(self.get_stage_by_layer_id(idx), config['pipe_comm']) -class PipeContext: - def __init__(self, module, hidden_state, backward=False): - self.module = module - self.stage_id = module.stage_id - self.stages = module.stages - self.next_rank = module.next_rank - self.prev_rank = module.prev_rank - self.hidden_state = [hidden_state] - self.backward = backward - self.send_buffer = {} - def enter(self): - if self.backward: - if self.stage_id != self.stages -1: - self.hidden_state[0] = recv_activations(self.stage_id + 1, config['pipe_comm']) - else: - if self.stage_id != 0: - self.hidden_state[0] = recv_activations(self.stage_id - 1, config['pipe_comm']) - return self.hidden_state - def exit(self): - if self.backward: - if self.stage_id != 0: - send_activations(self.hidden_state[0], self.stage_id - 1, config['pipe_comm']) - else: - if self.stage_id != self.stages - 1: - send_activations(self.hidden_state[0], self.stage_id + 1, config['pipe_comm']) - def __enter__(self): - return self.enter() - def __exit__(self, exc_type, exc_val, exc_tb): - self.exit() From 26c8c942d18bc5a5f8fa750f3fd08394dd2269fb Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 7 Aug 2023 11:43:54 +0800 Subject: [PATCH 031/122] fix args --- bmtrain/block_layer.py | 23 +++++++++++++++-------- bmtrain/hook_func.py | 4 ++-- bmtrain/pipe_layer.py | 5 ++++- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index ec4dc456..736c4feb 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -196,21 +196,26 @@ def __init__(self, inner_module : torch.nn.Module, mode="BLOCK"): self._pre_module = None self._next_module = None self._mode = mode #BLOCK or ZERO or PIPE + self.return_hidden_states = False + self.hidden_states = [] def set_pre_module(self, module): self._pre_module = module module._next_module = self - def forward(self, hidden_state, return_hidden_state=False, hidden_states=[], *args): + def forward(self, *args): #input must be requires_grad, otherwise autograd.backward will make an error + hidden_state = args[0] self.input_requires_grad = hidden_state.requires_grad hidden_state.requires_grad_() - pre_out = hook_func.PreHookFunc.apply(self, hidden_state, return_hidden_state, hidden_states) + + pre_out = hook_func.PreHookFunc.apply(self, hidden_state) if config["use_checkpoint"]: - out = checkpoint(self._module, pre_out, *args) + out = checkpoint(self._module, pre_out, *args[1:]) else: - out = self._module(pre_out, *args) + out = self._module(pre_out, *args[1:]) post_out = hook_func.PostHookFunc.apply(self, out) + if isinstance(post_out, list): return tuple(post_out) return post_out @@ -515,12 +520,14 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: def forward(self, *args, return_hidden_states = False): self.return_hidden_states = return_hidden_states - outputs = args[:self.num_hidden] - others = args[self.num_hidden:] hidden_states = [] for i in range(len(self)): - outputs = self._modules[str(i)](*outputs, return_hidden_states, hidden_states, *others) - outputs = (outputs,) + if return_hidden_states: + self._modules[str(i)].return_hidden_states = return_hidden_states + self._modules[str(i)].hidden_states = hidden_states + outputs = self._modules[str(i)]._call_impl(*args) + outputs = (outputs, ) + args = outputs + args[1:] if return_hidden_states: hidden_states = [ diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index c0925f28..e86a4f62 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -89,8 +89,8 @@ def forward(ctx, module, x, return_hidden_states=False, hidden_states=[]): pipe_out = pipe_pre_forward(module, (x,)) x = pipe_out[0] if pipe_out is not None else x - if return_hidden_states: - hidden_states.append(x) + if module.return_hidden_states: + module.hidden_states.append(x) zero_pre_forward(module, x) return x diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 32b8e733..2dbd7e00 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -175,7 +175,10 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa micro_hidden_states = [] for idx,layer_id in enumerate(self.layer_ids): self._modules[str(layer_id)]._micro_idx = micro_idx - hidden_state = self._modules[str(layer_id)](hidden_state, return_hidden_states, micro_hidden_states, *arg) + if return_hidden_states: + self._modules[str(layer_id)].return_hidden_states = return_hidden_states + self._modules[str(layer_id)].hidden_states = micro_hidden_states + hidden_state = self._modules[str(layer_id)](hidden_state, *arg) outputs.append(hidden_state) if return_hidden_states: hidden_states.append(torch.stack(micro_hidden_states, dim=0)) From b7d1c8c9609bcbc4948a83749ea41477704f2800 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 7 Aug 2023 20:58:41 +0800 Subject: [PATCH 032/122] fix test --- bmtrain/block_layer.py | 23 +++++++++++++---------- bmtrain/hook_func.py | 22 ++++++++++++---------- tests/test_has_inf_nan.py | 6 +++--- tests/test_model_wrapper.py | 4 ++-- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 736c4feb..9a1730f5 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -205,19 +205,21 @@ def set_pre_module(self, module): def forward(self, *args): #input must be requires_grad, otherwise autograd.backward will make an error - hidden_state = args[0] - self.input_requires_grad = hidden_state.requires_grad - hidden_state.requires_grad_() - - pre_out = hook_func.PreHookFunc.apply(self, hidden_state) + args[0].requires_grad_() + pre_out = hook_func.PreHookFunc.apply(self, *args) if config["use_checkpoint"]: - out = checkpoint(self._module, pre_out, *args[1:]) + out = checkpoint(self._module, *pre_out) else: - out = self._module(pre_out, *args[1:]) - post_out = hook_func.PostHookFunc.apply(self, out) + out = self._module(*pre_out) + tuple_out = (out, ) if isinstance(out, torch.Tensor) else out + post_out = hook_func.PostHookFunc.apply(self, *tuple_out) + if isinstance(out, torch.Tensor): + return post_out[0] if isinstance(post_out, list): return tuple(post_out) +# if isinstance(post_out, tuple) and len(post_out) == 1: +# return post_out[0] return post_out def __getattr__(self,name:str): @@ -526,8 +528,9 @@ def forward(self, *args, return_hidden_states = False): self._modules[str(i)].return_hidden_states = return_hidden_states self._modules[str(i)].hidden_states = hidden_states outputs = self._modules[str(i)]._call_impl(*args) - outputs = (outputs, ) - args = outputs + args[1:] + if not isinstance(outputs, tuple): + outputs = (outputs, ) + args = outputs + args[self.num_hidden:] if return_hidden_states: hidden_states = [ diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index e86a4f62..e17a86a4 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -83,27 +83,27 @@ def pipe_post_backward(module, grad_inputs, grad_outputs): class PreHookFunc(torch.autograd.Function): @staticmethod - def forward(ctx, module, x, return_hidden_states=False, hidden_states=[]): + def forward(ctx, module, *x): ctx.module = module if module._mode == "PIPE": - pipe_out = pipe_pre_forward(module, (x,)) - x = pipe_out[0] if pipe_out is not None else x + pipe_out = pipe_pre_forward(module, x) + x = pipe_out if pipe_out is not None else x if module.return_hidden_states: - module.hidden_states.append(x) + module.hidden_states.append(x[0]) zero_pre_forward(module, x) return x @staticmethod - def backward(ctx, grads): + def backward(ctx, *grads): zero_post_backward(ctx.module, grads, None) if ctx.module._mode == "PIPE": pipe_post_backward(ctx.module, grads, None) - return None, grads, None, None + return None, *grads class PostHookFunc(torch.autograd.Function): @staticmethod - def forward(ctx, module, out): + def forward(ctx, module, *out): ctx.module = module zero_post_forward(module, None, out) if module._mode == "PIPE": @@ -111,9 +111,11 @@ def forward(ctx, module, out): return out @staticmethod - def backward(ctx, grads): + def backward(ctx, *grads): zero_pre_backward(ctx.module, grads) if ctx.module._mode == "PIPE": - pipe_grads = pipe_pre_backward(ctx.module, (grads, )) + pipe_grads = pipe_pre_backward(ctx.module, grads) grads = pipe_grads[0] if pipe_grads is not None else grads - return None, grads + if not isinstance(grads, tuple): + return None, grads + return None, *grads diff --git a/tests/test_has_inf_nan.py b/tests/test_has_inf_nan.py index b1b9b4a9..fda85515 100644 --- a/tests/test_has_inf_nan.py +++ b/tests/test_has_inf_nan.py @@ -1,12 +1,12 @@ from utils import * import torch -import bmtrain.optim._cuda as G +import bmtrain.loss._function as F import random def check(x, v): out = torch.zeros(1, dtype=torch.uint8, device="cuda")[0] - G.f_has_inf_nan(x, out) + F.has_inf_nan(x, out) assert_eq(out.item(), v) def test_main(): @@ -29,4 +29,4 @@ def test_main(): check(x, 1) if __name__ == "__main__": - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_model_wrapper.py b/tests/test_model_wrapper.py index 409107e3..6f913d3c 100644 --- a/tests/test_model_wrapper.py +++ b/tests/test_model_wrapper.py @@ -164,7 +164,7 @@ def forward(self, out = input_emb for layer in self.transformers: - out = layer(out, position_bias=None, mask=mask_2d) + out = layer(out, mask_2d) out = self.layernorm(out) logits = F.linear(out, self.word_emb.weight) / math.sqrt(self.dim_model) @@ -218,4 +218,4 @@ def test_main(): if __name__ == '__main__': bmt.init_distributed(seed=0) - test_main() \ No newline at end of file + test_main() From 4303575dbf9929914d2c1f1fde2e05d0469b8d5c Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 8 Aug 2023 16:49:29 +0800 Subject: [PATCH 033/122] CheckpointBlock -> BMTBlock --- bmtrain/__init__.py | 4 ++- bmtrain/block_layer.py | 35 +++++++++++++------------ bmtrain/pipe_layer.py | 58 +++++++++++++++++++++--------------------- 3 files changed, 49 insertions(+), 48 deletions(-) diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index 7c7d6c2c..71d3c13a 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -6,7 +6,7 @@ from .param_init import init_parameters, grouped_parameters from .utils import print_block, print_dict, print_rank, see_memory from .synchronize import synchronize, sum_loss, wait_loader, gather_result -from .block_layer import CheckpointBlock, TransformerBlockList +from .block_layer import ZeroBlock, TransformerBlockList from .wrapper import BMTrainModelWrapper from .pipe_layer import PipelineTransformerBlockList from . import debug @@ -18,3 +18,5 @@ from . import lr_scheduler from . import loss from . import distributed + +CheckpointBlock = BMTBlock diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 9a1730f5..f64f70a1 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -50,8 +50,8 @@ def _get_param_kw(param : DistributedParameter): group_name = "_g_" + param.group return type_name + grad_name + group_name -class CheckpointBlock(torch.nn.Module): - """ Checkpoint a model or part of the model. +class BMTBlock(torch.nn.Module): + """ A bmtrain block containing two memory-saving methods of ZeRO-2/3 and checkpoint. Checkpoint block is used to save the occupation of GPU memory in training. @@ -59,16 +59,16 @@ class CheckpointBlock(torch.nn.Module): Args: model (torch.nn.Module): The model to be checkpointed. All kinds of modules are supported. - mode (str): run in BLOCK or ZERO or PIPE + use_checkpoint (boolean): use checkpoint or not. Default True. Examples: >>> transformer_block = TransformerBlock(...) - >>> checkpoint_block = CheckpointBlock(transformer_block) + >>> checkpoint_block = BMTBlock(transformer_block) >>> y1, ... = checkpoint_block(x) >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, mode="BLOCK"): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): super().__init__() self._module = inner_module self._inputs = None @@ -191,11 +191,12 @@ def __init__(self, inner_module : torch.nn.Module, mode="BLOCK"): for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] + self.use_checkpoint = use_checkpoint self._is_first_layer = True self._is_last_layer = True self._pre_module = None self._next_module = None - self._mode = mode #BLOCK or ZERO or PIPE + self._mode = "BLOCK" #BLOCK or ZERO or PIPE self.return_hidden_states = False self.hidden_states = [] @@ -207,19 +208,17 @@ def forward(self, *args): #input must be requires_grad, otherwise autograd.backward will make an error args[0].requires_grad_() pre_out = hook_func.PreHookFunc.apply(self, *args) - if config["use_checkpoint"]: + if self.use_checkpoint: out = checkpoint(self._module, *pre_out) else: out = self._module(*pre_out) tuple_out = (out, ) if isinstance(out, torch.Tensor) else out post_out = hook_func.PostHookFunc.apply(self, *tuple_out) - if isinstance(out, torch.Tensor): + if isinstance(out, torch.Tensor) and isinstance(post_out, tuple): return post_out[0] if isinstance(post_out, list): return tuple(post_out) -# if isinstance(post_out, tuple) and len(post_out) == 1: -# return post_out[0] return post_out def __getattr__(self,name:str): @@ -239,7 +238,7 @@ def __delattr__(self, name): object.__delattr__(self, name) def _save_to_state_dict(self, destination, prefix, keep_vars): - raise RuntimeError("._save_to_state_dict() of CheckpointBlock should not be called") + raise RuntimeError("._save_to_state_dict() of BMTBlock should not be called") def state_dict(self, destination=None, prefix='', keep_vars=False): # gather here @@ -452,7 +451,7 @@ def __repr__(self): class TransformerBlockList(torch.nn.Module): r""" - TransformerBlockList is a list of CheckpointBlocks. + TransformerBlockList is a list of BMTBlocks. This is designed to reduce the communication overhead by overlapping the computation and reduce_scatter operation during backward pass. @@ -469,15 +468,15 @@ class TransformerBlockList(torch.nn.Module): >>> hidden_state = transformer_module_list(hidden_state, ...) """ - _modules: Dict[str, CheckpointBlock] + _modules: Dict[str, BMTBlock] - def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) -> None: + def __init__(self, modules: Iterable[BMTBlock], num_hidden=1, sqrt=False) -> None: super().__init__() self._modules = {} for i, module in enumerate(modules): - if not isinstance(module, CheckpointBlock): - module = CheckpointBlock(module, "ZERO") + if not isinstance(module, BMTBlock): + module = BMTBlock(module) module._mode = "ZERO" module._is_last_layer = True if i == len(modules) -1 else False @@ -515,9 +514,9 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) def __len__(self) -> int: return len(self._modules) - def __iter__(self) -> Iterator[CheckpointBlock]: + def __iter__(self) -> Iterator[BMTBlock]: return iter(self._modules.values()) - def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: + def __getitem__(self, index: Union[int, str]) -> BMTBlock: return self._modules[str(index)] def forward(self, *args, return_hidden_states = False): diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 2dbd7e00..f357b066 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -13,21 +13,37 @@ CheckpointBlockContext ) from . import debug -from .block_layer import CheckpointBlock, round_up, _get_param_kw -from . import hook_func +from .block_layer import BMTBlock, round_up, _get_param_kw -class PipeAllGatherFunction(torch.autograd.Function): +class PipePreFunction(torch.autograd.Function): @staticmethod - def forward(ctx, hidden_state): + def forward(ctx, hidden_state, *args): hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"]) hidden_state_list.requires_grad_() - return hidden_state_list + + batch_size = hidden_state.shape[0] + num_micros = config["micros"] + args_list = [[] for _ in range(num_micros)] + for arg in args: + if torch.is_tensor(arg): + arg_all = all_gather(arg, config['pipe_comm']) + if arg.shape[0] == batch_size: + arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) + arg_all = [tensor.detach().requires_grad_(arg.requires_grad) for tensor in arg_all] + else: + arg_all = [arg_all[0].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] + else: + arg_all = [arg for _ in range(num_micros)] + for i in range(num_micros): + args_list[i].append(arg_all[i]) + + return hidden_state_list, args_list @staticmethod def backward(ctx, grads): grads = broadcast(grads, 0, config['pipe_comm']) topo = config['topology'] - return grads.chunk(topo.stages, dim=0)[topo.stage_id] + return grads.chunk(topo.stages, dim=0)[topo.stage_id], None class PipePostFunction(torch.autograd.Function): @staticmethod @@ -77,7 +93,6 @@ def backward(ctx, grads, grad_middle=None): else: return grad_list - class PipelineTransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of CheckpointBlocks. @@ -97,9 +112,9 @@ class PipelineTransformerBlockList(torch.nn.Module): >>> hidden_state = transformer_module_list(hidden_state, ...) """ - _modules: Dict[str, CheckpointBlock] + _modules: Dict[str, BMTBlock] - def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: + def __init__(self, modules: Iterable[BMTBlock], num_hidden=1) -> None: super().__init__() self.num_hidden = num_hidden self._modules = {} @@ -111,8 +126,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self.stage_id = topo.stage_id self.pipe_idx = topo.pipe_idx for idx, module in enumerate(modules): - if not isinstance(module, CheckpointBlock): - module = CheckpointBlock(module, "PIPE") + if not isinstance(module, BMTBlock): + module = BMTBlock(module) module._mode = "PIPE" module.stage_id = self.stage_id @@ -141,32 +156,18 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: def __len__(self) -> int: return len(self._modules) - def __iter__(self) -> Iterator[CheckpointBlock]: + def __iter__(self) -> Iterator[BMTBlock]: return iter(self._modules.values()) - def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: + def __getitem__(self, index: Union[int, str]) -> BMTBlock: return self._modules[str(index)] def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=False): self.return_hidden_states = return_hidden_states batch_size = hidden_state.shape[0] num_micros = config["micros"] - hidden_state_list = PipeAllGatherFunction.apply(hidden_state) + hidden_state_list, args_list = PipePreFunction.apply(hidden_state) - args_list = [[] for _ in range(num_micros)] - for arg in args: - if torch.is_tensor(arg): - arg_all = all_gather(arg, config['pipe_comm']) - if arg.shape[0] == batch_size: - arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) - arg_all = [tensor.detach().requires_grad_(arg.requires_grad) for tensor in arg_all] - else: - arg_all = [arg_all[0].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] - else: - arg_all = [arg for _ in range(num_micros)] - for i in range(num_micros): - args_list[i].append(arg_all[i]) - hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0) outputs = [] hidden_states = [] @@ -199,7 +200,6 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa outputs = PipePostFunction.apply(last_hidden) return outputs - def get_range_by_stage_id(self, stage_id : int) -> List[int]: part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)] start = sum(part_lens[:stage_id+1]) From 8061b668868ea3eaf312c1a77d312146123c5362 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 8 Aug 2023 20:24:00 +0800 Subject: [PATCH 034/122] reset block name --- bmtrain/__init__.py | 3 +-- bmtrain/block_layer.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index 71d3c13a..35078b7e 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -6,7 +6,7 @@ from .param_init import init_parameters, grouped_parameters from .utils import print_block, print_dict, print_rank, see_memory from .synchronize import synchronize, sum_loss, wait_loader, gather_result -from .block_layer import ZeroBlock, TransformerBlockList +from .block_layer import CheckpointBlock, TransformerBlockList from .wrapper import BMTrainModelWrapper from .pipe_layer import PipelineTransformerBlockList from . import debug @@ -19,4 +19,3 @@ from . import loss from . import distributed -CheckpointBlock = BMTBlock diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index f64f70a1..6d58c65c 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -50,7 +50,7 @@ def _get_param_kw(param : DistributedParameter): group_name = "_g_" + param.group return type_name + grad_name + group_name -class BMTBlock(torch.nn.Module): +class CheckpointBlock(torch.nn.Module): """ A bmtrain block containing two memory-saving methods of ZeRO-2/3 and checkpoint. Checkpoint block is used to save the occupation of GPU memory in training. @@ -63,7 +63,7 @@ class BMTBlock(torch.nn.Module): Examples: >>> transformer_block = TransformerBlock(...) - >>> checkpoint_block = BMTBlock(transformer_block) + >>> checkpoint_block = CheckpointBlock(transformer_block) >>> y1, ... = checkpoint_block(x) >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) @@ -238,7 +238,7 @@ def __delattr__(self, name): object.__delattr__(self, name) def _save_to_state_dict(self, destination, prefix, keep_vars): - raise RuntimeError("._save_to_state_dict() of BMTBlock should not be called") + raise RuntimeError("._save_to_state_dict() of CheckpointBlock should not be called") def state_dict(self, destination=None, prefix='', keep_vars=False): # gather here @@ -451,7 +451,7 @@ def __repr__(self): class TransformerBlockList(torch.nn.Module): r""" - TransformerBlockList is a list of BMTBlocks. + TransformerBlockList is a list of CheckpointBlocks. This is designed to reduce the communication overhead by overlapping the computation and reduce_scatter operation during backward pass. @@ -468,15 +468,15 @@ class TransformerBlockList(torch.nn.Module): >>> hidden_state = transformer_module_list(hidden_state, ...) """ - _modules: Dict[str, BMTBlock] + _modules: Dict[str, CheckpointBlock] - def __init__(self, modules: Iterable[BMTBlock], num_hidden=1, sqrt=False) -> None: + def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) -> None: super().__init__() self._modules = {} for i, module in enumerate(modules): - if not isinstance(module, BMTBlock): - module = BMTBlock(module) + if not isinstance(module, CheckpointBlock): + module = CheckpointBlock(module) module._mode = "ZERO" module._is_last_layer = True if i == len(modules) -1 else False @@ -514,9 +514,9 @@ def __init__(self, modules: Iterable[BMTBlock], num_hidden=1, sqrt=False) -> Non def __len__(self) -> int: return len(self._modules) - def __iter__(self) -> Iterator[BMTBlock]: + def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) - def __getitem__(self, index: Union[int, str]) -> BMTBlock: + def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: return self._modules[str(index)] def forward(self, *args, return_hidden_states = False): From 845f210ecc1a54d29a693c09a10da924f1750dcf Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 8 Aug 2023 20:24:19 +0800 Subject: [PATCH 035/122] pipeline support batch_related --- bmtrain/pipe_layer.py | 60 +++++++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index f357b066..2759f662 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -13,7 +13,7 @@ CheckpointBlockContext ) from . import debug -from .block_layer import BMTBlock, round_up, _get_param_kw +from .block_layer import CheckpointBlock, round_up, _get_param_kw class PipePreFunction(torch.autograd.Function): @staticmethod @@ -21,29 +21,60 @@ def forward(ctx, hidden_state, *args): hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"]) hidden_state_list.requires_grad_() + batch_related = args[-1] + batch_related_origin = [True if i in args[-1] else False for i in range(len(args[:-1]))] + batch_related_rule = [] + args = args[:-1] + batch_size = hidden_state.shape[0] num_micros = config["micros"] args_list = [[] for _ in range(num_micros)] + input_requires_grad = [] for arg in args: if torch.is_tensor(arg): arg_all = all_gather(arg, config['pipe_comm']) - if arg.shape[0] == batch_size: + if arg.dim() == hidden_state.dim() and arg.shape[0] == batch_size: + batch_related_rule.append(True) arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) - arg_all = [tensor.detach().requires_grad_(arg.requires_grad) for tensor in arg_all] + arg_all = [tensor.requires_grad_(arg.requires_grad) for tensor in arg_all] else: - arg_all = [arg_all[0].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] + batch_related_rule.append(False) + arg_all = [arg_all[0].requires_grad_(arg.requires_grad) for i in range(num_micros)] + input_requires_grad.append(arg.requires_grad) else: + batch_related_rule.append(False) arg_all = [arg for _ in range(num_micros)] + input_requires_grad.append(False) for i in range(num_micros): args_list[i].append(arg_all[i]) - + ctx.input_requires_grad = input_requires_grad + ctx.args_list = args_list + if len(batch_related) == 0: + ctx.batch_related = batch_related_rule + else: + ctx.batch_related = batch_related_origin return hidden_state_list, args_list @staticmethod - def backward(ctx, grads): + def backward(ctx, grads, arg_grads): grads = broadcast(grads, 0, config['pipe_comm']) topo = config['topology'] - return grads.chunk(topo.stages, dim=0)[topo.stage_id], None + arg_grads = [] + num_micros = config['micros'] + for idx,requires_grad in enumerate(ctx.input_requires_grad): + if requires_grad: + grad = torch.cat([ctx.args_list[m][idx].grad for m in range(num_micros)], dim=0) + grad = all_reduce(grad, "sum", config["pipe_comm"]) + split_size = topo.stages if ctx.batch_related[idx] else num_micros + grad = grad.chunk(split_size) + if ctx.batch_related[idx]: + arg_grads.append(grad[topo.stage_id]) + else: + arg_grads.append(grad[0]) + else: + arg_grads.append(None) + arg_grads.append(None) #for append(batch_related) + return grads.chunk(topo.stages, dim=0)[topo.stage_id], *arg_grads class PipePostFunction(torch.autograd.Function): @staticmethod @@ -112,9 +143,9 @@ class PipelineTransformerBlockList(torch.nn.Module): >>> hidden_state = transformer_module_list(hidden_state, ...) """ - _modules: Dict[str, BMTBlock] + _modules: Dict[str, CheckpointBlock] - def __init__(self, modules: Iterable[BMTBlock], num_hidden=1) -> None: + def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: super().__init__() self.num_hidden = num_hidden self._modules = {} @@ -126,8 +157,8 @@ def __init__(self, modules: Iterable[BMTBlock], num_hidden=1) -> None: self.stage_id = topo.stage_id self.pipe_idx = topo.pipe_idx for idx, module in enumerate(modules): - if not isinstance(module, BMTBlock): - module = BMTBlock(module) + if not isinstance(module, CheckpointBlock): + module = CheckpointBlock(module) module._mode = "PIPE" module.stage_id = self.stage_id @@ -156,17 +187,18 @@ def __init__(self, modules: Iterable[BMTBlock], num_hidden=1) -> None: def __len__(self) -> int: return len(self._modules) - def __iter__(self) -> Iterator[BMTBlock]: + def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) - def __getitem__(self, index: Union[int, str]) -> BMTBlock: + def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: return self._modules[str(index)] def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=False): self.return_hidden_states = return_hidden_states batch_size = hidden_state.shape[0] num_micros = config["micros"] - hidden_state_list, args_list = PipePreFunction.apply(hidden_state) + args = args + (batch_related, ) + hidden_state_list, args_list = PipePreFunction.apply(hidden_state, *args) hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0) outputs = [] From 0b14fe5c8f946f821d79a8aa00cfd1e453de3bf3 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 9 Aug 2023 11:19:34 +0800 Subject: [PATCH 036/122] remove use_checkpoint from init_distributed --- bmtrain/init.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 4cbdafd2..5c3006d2 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -13,7 +13,6 @@ def init_distributed( zero_level: int = 3, pipe_size: int = -1, num_micro_batches: int = None, - use_checkpoint: bool = True, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -71,7 +70,6 @@ def init_distributed( config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() config["zero_level"] = zero_level - config["use_checkpoint"] = use_checkpoint config["topology"] = topology(config) config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] cpus_this_worker = None From 12e51e1a183da643cfdbcfed7c81a0c83bf75dd7 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 10 Aug 2023 10:22:25 +0800 Subject: [PATCH 037/122] test --- example/layers/attention.py | 240 +++++++--- example/layers/feedforward.py | 6 +- example/layers/flash_triton.py | 830 +++++++++++++++++++++++++++++++++ example/layers/test_attn.py | 45 ++ example/layers/test_linear.py | 152 ++++++ 5 files changed, 1199 insertions(+), 74 deletions(-) create mode 100644 example/layers/flash_triton.py create mode 100644 example/layers/test_attn.py create mode 100644 example/layers/test_linear.py diff --git a/example/layers/attention.py b/example/layers/attention.py index 4a0eec11..c5485575 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -5,82 +5,178 @@ import math class Attention(bmt.DistributedModule): - def __init__(self, - dim_model : int, dim_head : int, - num_heads : int, bias : bool = True, - dtype = None - ) -> None: + def __init__( + self, + dim_model: int, + num_heads: int, + dim_head: int, + dtype: torch.dtype = torch.half, + dropout_p: Optional[float] = None, + scale: bool = True, + use_flash_attn: bool = False, + ) -> None: super().__init__() - self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - - self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) - - self.softmax = torch.nn.Softmax(dim=-1) + self.dim_model = dim_model self.num_heads = num_heads + self.num_kv_heads = num_heads + self.head_groups = num_heads // num_kv_heads self.dim_head = dim_head - self.dim_model = dim_model - - def forward(self, - hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model) - hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model) - mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv) - position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv) - ) -> torch.Tensor: - batch_size, seq_q, dim_model = hidden_q.size() - seq_kv = hidden_kv.size(1) - - h_q : torch.Tensor = self.project_q(hidden_q) - h_k : torch.Tensor = self.project_k(hidden_kv) - h_v : torch.Tensor = self.project_v(hidden_kv) - - h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) - h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) - h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) - - h_q = h_q.permute(0, 2, 1, 3).contiguous() - h_k = h_k.permute(0, 2, 1, 3).contiguous() - h_v = h_v.permute(0, 2, 1, 3).contiguous() - - h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) - h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) - h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) - - score = torch.bmm( - h_q, h_k.transpose(1, 2) - ) - score = score / math.sqrt(self.dim_head) - - score = score.view(batch_size, self.num_heads, seq_q, seq_kv) - - if position_bias is not None: - score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) - - score = torch.where( - mask.view(batch_size, 1, seq_q, seq_kv), - score, - torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype) - ) - - score = torch.where( - mask.view(batch_size, 1, seq_q, seq_kv), - self.softmax(score), - torch.scalar_tensor(0, device=score.device, dtype=score.dtype) - ) - - score = score.view(batch_size * self.num_heads, seq_q, seq_kv) - - h_out = torch.bmm( - score, h_v - ) - h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) - h_out = h_out.permute(0, 2, 1, 3).contiguous() - h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) - - attn_out = self.project_out(h_out) - return attn_out + + self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype, scale=scale) + self.project_k = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale) + self.project_v = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale) + + self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype, scale=scale) + + + if dropout_p is not None: + self.dropout = torch.nn.Dropout(p=dropout_p) + self.dropout_p = dropout_p + else: + self.dropout = None + + # if use_flash_attn: + # self.core_attention_flash = FlashSelfAttention(causal=False, attention_dropout=0.0) + self.use_flash_attn = use_flash_attn + + def forward( + self, + hidden_q: torch.Tensor, + hidden_kv: torch.Tensor, + attention_mask: torch.BoolTensor, + position_bias: torch.Tensor, + ): + """This model inherits from bmt.DistributedModule. + Args: + hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix. + hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding. + attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices. + position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`. + Return: + out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output. + """ # noqa: E501 + + batch_size = hidden_q.size(0) + len_q = hidden_q.size(1) + len_k = hidden_kv.size(1) + + h_q = self.project_q(hidden_q) + h_k = self.project_k(hidden_kv) + h_v = self.project_v(hidden_kv) + + # + # if self.head_groups != 1: + # h_k = h_k[:, :, :, None, :].expand(batch_size, len_k, self.num_kv_heads, self.head_groups, self.dim_head).reshape(batch_size, len_k, self.num_heads, self.dim_head) + # h_v = h_v[:, :, :, None, :].expand(batch_size, len_k, self.num_kv_heads, self.head_groups, self.dim_head).reshape(batch_size, len_k, self.num_heads, self.dim_head) + + + # h_q = h_q.permute(0, 2, 1, 3).contiguous() + # h_k = h_k.permute(0, 2, 1, 3).contiguous() + # h_v = h_v.permute(0, 2, 1, 3).contiguous() + + # B, S, H, D + # score = self.core_attention_flash( + # h_q, h_k, h_v, attention_mask=attention_mask, length_mask=length_mask, context_mask=context_mask + # ) + if attention_mask is not None: + assert pos_bias_type == "rotary" + h_q, h_k = position_bias(h_q, h_k, -3) + h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head) # .permute(0, 2, 1, 3) + h_k = h_k.view(batch_size, len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3) + h_v = h_v.view(batch_size, len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3) + mask = attention_mask.unsqueeze(dim=1).contiguous() + mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16) # 创建与mask形状相同的全零张量 + mask_bias[mask == False] -= torch.inf + score = FlashAttnFunc.apply(h_q, h_k, h_v, mask_bias, False, None) + + score = score.view(batch_size, len_q, self.num_heads * self.dim_head) + + + score = self.attention_out(score) + + if use_cache: + return score, (h_k, h_v) + else: + return score + +#class Attention(bmt.DistributedModule): +# def __init__(self, +# dim_model : int, dim_head : int, +# num_heads : int, bias : bool = True, +# dtype = None +# ) -> None: +# super().__init__() +# +# self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) +# self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) +# self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) +# +# self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) +# +# self.softmax = torch.nn.Softmax(dim=-1) +# self.num_heads = num_heads +# self.dim_head = dim_head +# self.dim_model = dim_model +# +# def forward(self, +# hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model) +# hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model) +# mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv) +# position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv) +# ) -> torch.Tensor: +# batch_size, seq_q, dim_model = hidden_q.size() +# seq_kv = hidden_kv.size(1) +# +# h_q : torch.Tensor = self.project_q(hidden_q) +# h_k : torch.Tensor = self.project_k(hidden_kv) +# h_v : torch.Tensor = self.project_v(hidden_kv) +# +# h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) +# h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) +# h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) +# +# h_q = h_q.permute(0, 2, 1, 3).contiguous() +# h_k = h_k.permute(0, 2, 1, 3).contiguous() +# h_v = h_v.permute(0, 2, 1, 3).contiguous() +# +# h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) +# h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) +# h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) +# +# score = torch.bmm( +# h_q, h_k.transpose(1, 2) +# ) +# score = score / math.sqrt(self.dim_head) +# +# score = score.view(batch_size, self.num_heads, seq_q, seq_kv) +# +# if position_bias is not None: +# score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) +# +# score = torch.where( +# mask.view(batch_size, 1, seq_q, seq_kv), +# score, +# torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype) +# ) +# +# score = torch.where( +# mask.view(batch_size, 1, seq_q, seq_kv), +# self.softmax(score), +# torch.scalar_tensor(0, device=score.device, dtype=score.dtype) +# ) +# +# score = score.view(batch_size * self.num_heads, seq_q, seq_kv) +# +# h_out = torch.bmm( +# score, h_v +# ) +# h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) +# h_out = h_out.permute(0, 2, 1, 3).contiguous() +# h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) +# +# attn_out = self.project_out(h_out) +# return attn_out diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index 3fe935bf..b0bfd94c 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -8,9 +8,11 @@ def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = No self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype) + self.gate = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) self.relu = torch.nn.ReLU() def forward(self, input : torch.Tensor) -> torch.Tensor: - - return self.w_out(self.relu(self.w_in(input))) +#return self.w_out(self.relu(self.w_in(input))) + gate_out = self.relu(self.gate(input)) + return self.w_out(self.w_in(input) * gate_out) diff --git a/example/layers/flash_triton.py b/example/layers/flash_triton.py new file mode 100644 index 00000000..1d687378 --- /dev/null +++ b/example/layers/flash_triton.py @@ -0,0 +1,830 @@ +""" +*Experimental* implementation of FlashAttention in Triton. + +We use the FlashAttention implementation from Phil Tillet a starting point. +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +Changes: +- Implement both causal and non-causal attention. +- Implement both self-attention and cross-attention. +- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. +- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. +- Support attention bias. +- Speed up the forward pass a bit, and only store the LSE instead of m and l. +- Make the backward for d=128 much faster by reducing register spilling. +- Optionally parallelize the backward pass across seqlen_k, to deal with the case of +small batch size * nheads. + +Caution: +- This is an *experimental* implementation. The forward pass should be quite robust but +I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). +- This implementation has only been tested on A100. +- If you plan to use headdim other than 64 and 128, you should test for race conditions +(due to the Triton compiler), as done in tests/test_flash_attn.py +"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions +for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident +that there are none left for other head dimensions. + +Differences between this Triton version and the CUDA version: +- Triton version doesn't support dropout. +- Triton forward is generally faster than CUDA forward, while Triton backward is +generally slower than CUDA backward. Overall Triton forward + backward is slightly slower +than CUDA forward + backward. +- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). +- Triton version supports attention bias, while CUDA version doesn't. +""" + +import math + +import torch + +import triton +import triton.language as tl + + +# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), +# # This config has a race condition when EVEN_M == False, disabling it for now. +# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), +# ], +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# ) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, K, V, Bias, Out, + Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_bb, stride_bh, stride_bm, + stride_ob, stride_oh, stride_om, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, + CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # off_b = tl.program_id(1) + # off_h = tl.program_id(2) + # off_hb = off_b * nheads + off_h + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + if BIAS_TYPE == 'vector': + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + elif BIAS_TYPE == 'matrix': + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0) + else: + k = tl.load(k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if BIAS_TYPE != 'none': + if BIAS_TYPE == 'vector': + if EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == 'matrix': + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0).to(tl.float32) + # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler + # can then fuse the mult and add into an fma instruction. But if we have bias we need to + # to multiply with softmax_scale here. + qk = qk * softmax_scale + bias + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + + m_ij = tl.where(m_ij==float("-inf"),0,m_ij) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # p = tl.where(p==float("-inf"), 0, p) + # l_ij = tl.maximum(tl.sum(p, 1),-1e16) + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + # mask_sum = tl.sum(bias == float("-inf"), axis=1) == BLOCK_M + # acc_o_scale = tl.where(mask_sum, 0, acc_o_scale) + # # -- update output accumulator -- + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0) + else: + v = tl.load(v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + lse_i = tl.where(lse_i == float("-inf"), 0, lse_i) + o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store(out_ptrs, acc_o, + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, DO, Delta, + stride_ob, stride_oh, stride_om, + stride_dob, stride_doh, stride_dom, + nheads, seqlen_q, seqlen_q_rounded, headdim, + BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) + do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD: tl.constexpr, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + if BIAS_TYPE == 'vector': + b_ptrs = Bias + offs_n + elif BIAS_TYPE == 'matrix': + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop + # may have zero step, and pipelining with the bias matrix could screw it up. + # So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) + return + # k and v stay in SRAM throughout + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_d[None, :] < headdim), other=0.0) + # recompute p = softmax(qk, dim=-1).T + qk = tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) + if IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + if BIAS_TYPE != 'none': + tl.debug_barrier() # Race condition otherwise + if BIAS_TYPE == 'vector': + if EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == 'matrix': + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_n[None, :] < seqlen_k), + other=0.0).to(tl.float32) + qk = qk * softmax_scale + bias + # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. + # Also wrong for headdim=64. + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + lse_i = tl.load(LSE + offs_m_curr) + if BIAS_TYPE == 'none': + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + else: + p = tl.exp(qk - lse_i[:, None]) + # compute dv + # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs + # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, + # the output is correct. + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_d[None, :] < headdim), other=0.0) + # if EVEN_M: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs) + # else: + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + # else: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + # else: + # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + # & (offs_d[None, :] < headdim), other=0.0) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + dp = tl.dot(do, v, trans_b=True) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() + # compute ds = p * (dp - delta[:, None]) + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(ds, q, trans_a=True) + # compute dq + if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix' + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, + eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last") + else: + dq = tl.load(dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last") + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add(dq_ptrs, dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + # increment pointers + dq_ptrs += BLOCK_M * stride_dqm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + if BIAS_TYPE == 'matrix': + b_ptrs += BLOCK_M * stride_bm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_bb, stride_bh, stride_bm, + stride_dob, stride_doh, stride_dom, + stride_dqb, stride_dqh, stride_dqm, + stride_dkb, stride_dkh, stride_dkn, + stride_dvb, stride_dvh, stride_dvn, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, + CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + if BIAS_TYPE != 'none': + Bias += off_b * stride_bb + off_h * stride_bh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD=False, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD=True, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + + +def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, 'FlashAttention only support head dimensions up to 128' + assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' + assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' + assert q.is_cuda and k.is_cuda and v.is_cuda + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + + has_bias = bias is not None + bias_type = 'none' + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + if bias.stride(-1) != 1: + bias = bias.contiguous() + if bias.shape[2:] == (1, seqlen_k): + bias_type = 'vector' + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = 'matrix' + else: + raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' + ' or (seqlen_q, seqlen_k)') + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, k, v, bias, o, + lse, tmp, + softmax_scale, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + *bias_strides, + o.stride(0), o.stride(2), o.stride(1), + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, + seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, causal, BLOCK_HEADDIM, + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o, lse, softmax_scale # softmax_scale could have been updated + + +def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, do, delta, + o.stride(0), o.stride(2), o.stride(1), + do.stride(0), do.stride(2), do.stride(1), + nheads, seqlen_q, seqlen_q_rounded, d, + BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + has_bias = bias is not None + bias_type = 'none' + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + if bias.shape[2:] == (1, seqlen_k): + bias_type = 'vector' + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = 'matrix' + else: + raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' + ' or (seqlen_q, seqlen_k)') + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads) + _bwd_kernel[grid]( + q, k, v, bias, + do, dq_accum, dk, dv, + lse, delta, + softmax_scale, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + *bias_strides, + do.stride(0), do.stride(2), do.stride(1), + dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), + dk.stride(0), dk.stride(2), dk.stride(1), + dv.stride(0), dv.stride(2), dv.stride(1), + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, + seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, causal, BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): + """ + qkv: (batch, seqlen, 3, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). + ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) + """ + # Make sure that the last dimension is contiguous + if qkv.stride(-1) != 1: + qkv = qkv.contiguous() + o, lse, ctx.softmax_scale = _flash_attn_forward( + qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, + softmax_scale=softmax_scale + ) + ctx.save_for_backward(qkv, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + qkv, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet' + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dqkv = torch.empty_like(qkv) + _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, + dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], + bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dqkv, None, None, None + + +flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply + + +class FlashAttnKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): + """ + q: (batch, seqlen_q, nheads, headdim) + kv: (batch, seqlen_k, 2, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). + ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) + """ + # Make sure that the last dimension is contiguous + q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, kv, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, kv, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet' + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, + dq, dkv[:, :, 0], dkv[:, :, 1], + bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dq, dkv, None, None, None + + +flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply + + +class FlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): + """ + q: (batch_size, seqlen_q, nheads, headdim) + k, v: (batch_size, seqlen_k, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). + ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) + """ + # Make sure that the last dimension is contiguous + q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, k, v, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet' + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, + bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dq, dk, dv, None, None, None + + +flash_attn_func = FlashAttnFunc.apply diff --git a/example/layers/test_attn.py b/example/layers/test_attn.py new file mode 100644 index 00000000..a185e793 --- /dev/null +++ b/example/layers/test_attn.py @@ -0,0 +1,45 @@ +import torch +import torch.nn.functional as F +import bmtrain as bmt +from bmtrain.global_var import config +from . import Attention + + +gb = 1024.0 * 1024.0 * 1024.0 + +bmt.init_distributed(zero_level=2) + +linears = [] +for i in range(10), : + linears.append(bmt.CheckpointBlock(Attention( + dim_model=8192, + dim_head=128, + num_head=64 + dropout_p=0.0, + use_flash_attn=True, + dtype=torch.half + ), + use_checkpoint=False) + ) + +linears = bmt.TransformerBlockList(linears) + +device = torch.device('cuda') +bmt.synchronize() +if config['rank'] == 0: + print('before forward', torch.cuda.memory_allocated(device) / gb) + +x = torch.randn(4096, 8192, dtype=torch.float16, device=device).requires_grad_() +bmt.synchronize() +if config['rank'] == 0: + print('init input', torch.cuda.memory_allocated(device) / gb) + +y = linears(x) +bmt.synchronize() +if config['rank'] == 0: + print('after forward', torch.cuda.memory_allocated(device) / gb) + +y.sum().backward() +bmt.synchronize() +if config['rank'] == 0: + print('after backward', torch.cuda.memory_allocated(device) / gb) diff --git a/example/layers/test_linear.py b/example/layers/test_linear.py new file mode 100644 index 00000000..27568e12 --- /dev/null +++ b/example/layers/test_linear.py @@ -0,0 +1,152 @@ +import torch +import torch.nn.functional as F +import bmtrain as bmt +from bmtrain.global_var import config +from . import TransformerEncoder + + +gb = 1024.0 * 1024.0 * 1024.0 + +class CustomLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias=None): + ctx.save_for_backward(x, weight, bias) + return F.linear(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output): + x, weight, bias = ctx.saved_tensors + grad_x = grad_weight = grad_bias = None + if x.requires_grad: + grad_x = grad_output.matmul(weight) + if weight.requires_grad: + dim = grad_output.dim() + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1])) + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + return grad_x, grad_weight, grad_bias + + +class LinearFunctionForZeroStage3(torch.autograd.Function): + # Note that both forward and backward are @staticmethods + @staticmethod + #@autocast_custom_fwd + # bias is an optional argument + def forward(ctx, input, weight, bias=None): + + ctx.save_for_backward(input, weight, bias) + + if input.dim() == 2 and bias is not None: + # fused op is marginally faster + ret = torch.addmm(bias, input, weight.t()) + else: + output = input.matmul(weight.t()) + if bias is not None: + output += bias + ret = output + + return ret + + # This function has only a single output, so it gets only one gradient + @staticmethod + #@autocast_custom_bwd + def backward(ctx, grad_output): + # This is a pattern that is very convenient - at the top of backward + # unpack saved_tensors and initialize all gradients w.r.t. inputs to + # None. Thanks to the fact that additional trailing Nones are + # ignored, the return statement is simple even when the function has + # optional inputs. + input, weight, bias = ctx.saved_tensors + + grad_input = grad_weight = grad_bias = None + + #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}") + # These needs_input_grad checks are optional and there only to + # improve efficiency. If you want to make your code simpler, you can + # skip them. Returning gradients for inputs that don't require it is + # not an error. + if ctx.needs_input_grad[0]: + #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}") + grad_input = grad_output.matmul(weight) + #print(f"Computed grad input {grad_input.shape}") + if ctx.needs_input_grad[1]: + #print("Computing grad weight") + dim = grad_output.dim() + if dim > 2: + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) + else: + grad_weight = grad_output.t().matmul(input) + #print(f"Computed grad weight grad_weight {grad_weight.shape}") + if bias is not None and ctx.needs_input_grad[2]: + #print("Computing grad bias") + grad_bias = grad_output.sum(0) + #print("Done computing grad bias") + #print("needs bias") + #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}") + return grad_input, grad_weight, grad_bias + + +class Linear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = False, dtype = torch.float16) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_) + if bias: + self.bias = bmt.DistributedParameter(torch.empty((1, out_features), dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_) + else: + self.register_parameter('bias', None) + + def forward(self, input): + #return CustomLinear.apply(input, self.weight, self.bias) + return LinearFunctionForZeroStage3.apply(input, self.weight, self.bias) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) + +class Feedforward(bmt.DistributedModule): + def __init__(self, dim_model : int, dim_ff : int, bias : bool = False, dtype = torch.float16) -> None: + super().__init__() + + self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) + self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype) + self.gate = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) + + self.relu = torch.nn.ReLU() + + def forward(self, input : torch.Tensor) -> torch.Tensor: + gate_out = self.relu(self.gate(input)) + return self.w_out(self.w_in(input) * gate_out) + +bmt.init_distributed(zero_level=2) + +linears = [] +for i in range(10): + linears.append(bmt.CheckpointBlock(TransformerEncoder(8192, 20480), use_checkpoint=False)) + +linears = bmt.TransformerBlockList(linears) + +device = torch.device('cuda') +bmt.synchronize() +if config['rank'] == 0: + print('before forward', torch.cuda.memory_allocated(device) / gb) + +x = torch.randn(4096, 8192, dtype=torch.float16, device=device).requires_grad_() +bmt.synchronize() +if config['rank'] == 0: + print('init input', torch.cuda.memory_allocated(device) / gb) + +y = linears(x) +bmt.synchronize() +if config['rank'] == 0: + print('after forward', torch.cuda.memory_allocated(device) / gb) + +y.sum().backward() +bmt.synchronize() +if config['rank'] == 0: + print('after backward', torch.cuda.memory_allocated(device) / gb) From 726aa2f5eda6d25ea51170c6467186e9f285dc5d Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 10 Aug 2023 10:54:49 +0800 Subject: [PATCH 038/122] test for transformer and attn --- example/layers/attention.py | 29 +++++++------------ example/layers/feedforward.py | 1 - example/layers/transformer.py | 2 +- example/run.sh | 4 +-- example/{layers => }/test_attn.py | 18 +++++++----- example/test_block.py | 48 +++++++++++++++++++++++++++++++ 6 files changed, 72 insertions(+), 30 deletions(-) rename example/{layers => }/test_attn.py (69%) create mode 100644 example/test_block.py diff --git a/example/layers/attention.py b/example/layers/attention.py index c5485575..062798aa 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -3,31 +3,30 @@ import bmtrain as bmt from layers import Linear import math - +from .flash_triton import FlashAttnFunc class Attention(bmt.DistributedModule): def __init__( self, dim_model: int, - num_heads: int, dim_head: int, + num_heads: int, + bias: bool = False, dtype: torch.dtype = torch.half, dropout_p: Optional[float] = None, - scale: bool = True, - use_flash_attn: bool = False, + use_flash_attn: bool = True, ) -> None: super().__init__() self.dim_model = dim_model self.num_heads = num_heads self.num_kv_heads = num_heads - self.head_groups = num_heads // num_kv_heads self.dim_head = dim_head - self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype, scale=scale) - self.project_k = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale) - self.project_v = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale) + self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype, ) + self.project_k = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, ) + self.project_v = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, ) - self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype, scale=scale) + self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype, ) if dropout_p is not None: @@ -45,7 +44,6 @@ def forward( hidden_q: torch.Tensor, hidden_kv: torch.Tensor, attention_mask: torch.BoolTensor, - position_bias: torch.Tensor, ): """This model inherits from bmt.DistributedModule. Args: @@ -80,13 +78,11 @@ def forward( # h_q, h_k, h_v, attention_mask=attention_mask, length_mask=length_mask, context_mask=context_mask # ) if attention_mask is not None: - assert pos_bias_type == "rotary" - h_q, h_k = position_bias(h_q, h_k, -3) h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head) # .permute(0, 2, 1, 3) h_k = h_k.view(batch_size, len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3) h_v = h_v.view(batch_size, len_k, self.num_kv_heads, self.dim_head) # .permute(0, 2, 1, 3) - mask = attention_mask.unsqueeze(dim=1).contiguous() - mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16) # 创建与mask形状相同的全零张量 + mask = attention_mask + mask_bias = torch.zeros_like(attention_mask, device="cuda", dtype=torch.float16) # 创建与mask形状相同的全零张量 mask_bias[mask == False] -= torch.inf score = FlashAttnFunc.apply(h_q, h_k, h_v, mask_bias, False, None) @@ -95,10 +91,7 @@ def forward( score = self.attention_out(score) - if use_cache: - return score, (h_k, h_v) - else: - return score + return score #class Attention(bmt.DistributedModule): # def __init__(self, diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index b0bfd94c..74bae8d3 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -13,6 +13,5 @@ def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = No self.relu = torch.nn.ReLU() def forward(self, input : torch.Tensor) -> torch.Tensor: -#return self.w_out(self.relu(self.w_in(input))) gate_out = self.relu(self.gate(input)) return self.w_out(self.w_in(input) * gate_out) diff --git a/example/layers/transformer.py b/example/layers/transformer.py index 7cda1bb9..1e1cffe0 100644 --- a/example/layers/transformer.py +++ b/example/layers/transformer.py @@ -23,7 +23,7 @@ def forward(self, ): bmt.inspect.record_tensor(hidden, "hidden") x = self.ln_attn(hidden) - x = self.attn(x, x, mask, position_bias) + x = self.attn(x, x, mask) hidden = hidden + x x = self.ln_ff(hidden) diff --git a/example/run.sh b/example/run.sh index 542e5252..1beb71bd 100644 --- a/example/run.sh +++ b/example/run.sh @@ -1,3 +1 @@ -export NCCL_P2P_DISABLE=1 -export CUDA_LAUNCH_BLOCKING=1 -torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost train.py +torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost $1 diff --git a/example/layers/test_attn.py b/example/test_attn.py similarity index 69% rename from example/layers/test_attn.py rename to example/test_attn.py index a185e793..642f2a8b 100644 --- a/example/layers/test_attn.py +++ b/example/test_attn.py @@ -2,19 +2,19 @@ import torch.nn.functional as F import bmtrain as bmt from bmtrain.global_var import config -from . import Attention +from layers import Attention gb = 1024.0 * 1024.0 * 1024.0 -bmt.init_distributed(zero_level=2) +bmt.init_distributed(zero_level=3) linears = [] for i in range(10), : linears.append(bmt.CheckpointBlock(Attention( dim_model=8192, dim_head=128, - num_head=64 + num_heads=64, dropout_p=0.0, use_flash_attn=True, dtype=torch.half @@ -28,13 +28,17 @@ bmt.synchronize() if config['rank'] == 0: print('before forward', torch.cuda.memory_allocated(device) / gb) - -x = torch.randn(4096, 8192, dtype=torch.float16, device=device).requires_grad_() +batch_size=1 +seq_len=4096 +x = torch.randn(batch_size, seq_len, 8192, dtype=torch.float16, device=device).requires_grad_() bmt.synchronize() if config['rank'] == 0: print('init input', torch.cuda.memory_allocated(device) / gb) - -y = linears(x) +enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() +mask = torch.arange(seq_len).unsqueeze(0) <= torch.arange(seq_len).unsqueeze(1) +mask = mask.unsqueeze(0).unsqueeze(0) +print(mask.shape) +y = linears(x,x,mask) bmt.synchronize() if config['rank'] == 0: print('after forward', torch.cuda.memory_allocated(device) / gb) diff --git a/example/test_block.py b/example/test_block.py new file mode 100644 index 00000000..4730a0a9 --- /dev/null +++ b/example/test_block.py @@ -0,0 +1,48 @@ +import torch +import torch.nn.functional as F +import bmtrain as bmt +from bmtrain.global_var import config +from layers import TransformerEncoder + + +gb = 1024.0 * 1024.0 * 1024.0 + +bmt.init_distributed(zero_level=3) + +linears = [] +for i in range(10), : + linears.append(bmt.CheckpointBlock(TransformerEncoder( + dim_model=8192, + dim_head=128, + num_heads=64, + dim_ff=20480, + bias=False, + dtype=torch.half + ), + use_checkpoint=False) + ) + +linears = bmt.TransformerBlockList(linears) + +device = torch.device('cuda') +bmt.synchronize() +if config['rank'] == 0: + print('before forward', torch.cuda.memory_allocated(device) / gb) +batch_size=1 +seq_len=4096 +x = torch.randn(batch_size, seq_len, 8192, dtype=torch.float16, device=device).requires_grad_() +bmt.synchronize() +if config['rank'] == 0: + print('init input', torch.cuda.memory_allocated(device) / gb) +enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() +mask = torch.arange(seq_len).unsqueeze(0) <= torch.arange(seq_len).unsqueeze(1) +mask = mask.unsqueeze(0).unsqueeze(0) +y = linears(x,mask) +bmt.synchronize() +if config['rank'] == 0: + print('after forward', torch.cuda.memory_allocated(device) / gb) + +y.sum().backward() +bmt.synchronize() +if config['rank'] == 0: + print('after backward', torch.cuda.memory_allocated(device) / gb) From ae56de809b28e0b6b3b9f2bb93281ccb84b5775d Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 10 Aug 2023 15:46:20 +0800 Subject: [PATCH 039/122] for requires_grad --- bmtrain/block_layer.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 6d58c65c..4c3c218d 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -205,13 +205,24 @@ def set_pre_module(self, module): module._next_module = self def forward(self, *args): - #input must be requires_grad, otherwise autograd.backward will make an error - args[0].requires_grad_() - pre_out = hook_func.PreHookFunc.apply(self, *args) + grad_tensors = [] + grad_index = [] + arg_list = list(args) + for i, arg in enumerate(args): + if arg is not None and arg.requires_grad: + grad_tensors.append(arg) + grad_index.append(i) + grad_tensors = tuple(grad_tensors) + + pre_out = hook_func.PreHookFunc.apply(self, *grad_tensors) + for i in range(len(grad_index)): + arg_list[grad_index[i]] = grad_tensors[i] + if self.use_checkpoint: - out = checkpoint(self._module, *pre_out) + out = checkpoint(self._module, *arg_list) else: - out = self._module(*pre_out) + out = self._module(*arg_list) + tuple_out = (out, ) if isinstance(out, torch.Tensor) else out post_out = hook_func.PostHookFunc.apply(self, *tuple_out) if isinstance(out, torch.Tensor) and isinstance(post_out, tuple): From 27ae2b719b8ff35097283decf18140f51620a2ad Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 10 Aug 2023 15:52:22 +0800 Subject: [PATCH 040/122] for requires_grad --- bmtrain/block_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 4c3c218d..38117b0e 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -216,7 +216,7 @@ def forward(self, *args): pre_out = hook_func.PreHookFunc.apply(self, *grad_tensors) for i in range(len(grad_index)): - arg_list[grad_index[i]] = grad_tensors[i] + arg_list[grad_index[i]] = pre_out[i] if self.use_checkpoint: out = checkpoint(self._module, *arg_list) From fdc823139ca8e1916e57c470157107883187d7d8 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 10 Aug 2023 16:54:20 +0800 Subject: [PATCH 041/122] fix for arg is not tensor --- bmtrain/block_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 38117b0e..81c1b068 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -209,7 +209,7 @@ def forward(self, *args): grad_index = [] arg_list = list(args) for i, arg in enumerate(args): - if arg is not None and arg.requires_grad: + if arg is not None and isinstance(arg, torch.Tensor) and arg.requires_grad: grad_tensors.append(arg) grad_index.append(i) grad_tensors = tuple(grad_tensors) From b0f71547f014b5bf7f3ac6c58a939a24097e3bdd Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 10 Aug 2023 16:59:04 +0800 Subject: [PATCH 042/122] fix for arg is not a tensor --- bmtrain/block_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 38117b0e..81c1b068 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -209,7 +209,7 @@ def forward(self, *args): grad_index = [] arg_list = list(args) for i, arg in enumerate(args): - if arg is not None and arg.requires_grad: + if arg is not None and isinstance(arg, torch.Tensor) and arg.requires_grad: grad_tensors.append(arg) grad_index.append(i) grad_tensors = tuple(grad_tensors) From 420b626ffe1da33c6c5f865422083e7e1215dfb6 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 10 Aug 2023 20:03:07 +0800 Subject: [PATCH 043/122] add test --- tests/test_requires_grad.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_requires_grad.py b/tests/test_requires_grad.py index 83fe8d17..41cabbe0 100644 --- a/tests/test_requires_grad.py +++ b/tests/test_requires_grad.py @@ -25,19 +25,23 @@ def __init__(self, in_features : int, out_features: int, init_weight = None, ini else: self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_) - def forward(self, input): + def forward(self, input, other_bias): ret = F.linear(input, self.weight, self.bias) + ret += other_bias return ret def run(m, a, b): inp = torch.rand((1, 10, 256)).cuda()*100 - logits = m(inp) + inp.requires_grad_() + bias = torch.rand(256).cuda()*100 + logits = m(inp, bias) loss = logits.sum() loss.backward() bmt.synchronize() sm = bmt.inspect.format_summary( bmt.inspect.inspect_model(m, '*') ) + assert_eq(bias.requires_grad, False) return a.weight.grad is None, a.bias.grad is None, sm def test_main(): @@ -100,4 +104,4 @@ def test_main_pipe(): bmt.init_distributed(pipe_size=1) test_main() - test_main_pipe() \ No newline at end of file + test_main_pipe() From ebc269f513c41acafb72e7e5e1191d2d433af22a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 11 Aug 2023 16:34:10 +0800 Subject: [PATCH 044/122] merge enhance_ckp --- bmtrain/block_layer.py | 50 ++++++++++++++++++++++++++++++++++------ bmtrain/checkpointing.py | 17 +++++++------- bmtrain/hook_func.py | 40 +++++++++++++++++++++----------- bmtrain/init.py | 5 ++++ bmtrain/pipe_layer.py | 4 ++-- 5 files changed, 84 insertions(+), 32 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 81c1b068..27636e18 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -50,6 +50,28 @@ def _get_param_kw(param : DistributedParameter): group_name = "_g_" + param.group return type_name + grad_name + group_name +class BMTBlockContext: + def __init__(self): + self._pre_module = None + self._first = True + + def link_module(self, module): + if not self._first and module._ref_count == -1: + self._pre_module = module + module._ref_count = 1 + return + + if self._pre_module is None: + module._ref_count = 1 + module._is_first_layer = True + else: + if module._ref_count == 0: + module._is_first_layer = False + self._pre_module.set_next_module(module) + self._pre_module._is_last_layer = False + self._pre_module = module + self._first = False + class CheckpointBlock(torch.nn.Module): """ A bmtrain block containing two memory-saving methods of ZeRO-2/3 and checkpoint. @@ -68,12 +90,14 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_context=None): super().__init__() self._module = inner_module self._inputs = None self._layer_dict = {} + self._forward_block_ctx = None self._backward_block_ctx = None + self._forward_enter_count = 0 # build large parameter&grad here self._param_info = [] self._storage_params : Dict[str, torch.nn.Parameter] = {} @@ -194,17 +218,31 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): self.use_checkpoint = use_checkpoint self._is_first_layer = True self._is_last_layer = True - self._pre_module = None - self._next_module = None + self._pre_module = [] + self._next_module = [] + self._ref_count = 0 self._mode = "BLOCK" #BLOCK or ZERO or PIPE self.return_hidden_states = False self.hidden_states = [] + self.block_context = block_context + if block_context is None: + self.block_context = config['block_context'][config['rank']] def set_pre_module(self, module): - self._pre_module = module - module._next_module = self + self._ref_count += 1 + if module is not None: + module._next_module.append(self) + self._is_first_layer = False + module._is_last_layer = False + + def set_next_module(self, module): + self._next_module.append(module) + module._ref_count += 1 def forward(self, *args): + if self._mode != "PIPE": + self.block_context.link_module(self) + grad_tensors = [] grad_index = [] arg_list = list(args) @@ -495,8 +533,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self._modules[str(i)] = module self.add_module(str(i), module) - if i > 0: - self._modules[str(i)].set_pre_module(self._modules[str(i-1)]) self.num_hidden = num_hidden diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index 1775b95d..c13f0963 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -33,20 +33,19 @@ def __exit__(self, *args): self.prev_hidden = None class CheckpointBlockContext: - def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, flag : int = 0, pipe = False) -> None: + def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = False) -> None: self.block = block self.ctx_dict = ctx_dict self._param_buffer = {} self._grad_buffer = {} self._param_tensor = {} self._grad_tensor = {} - self.flag = flag self._need_release = False if pipe: self.comm = config["zero_comm"] else: self.comm = config["comm"] - def enter(self, requires_grad=False): + def enter(self, flag=0, requires_grad=False): """ gather parameters """ @@ -64,14 +63,14 @@ def enter(self, requires_grad=False): local_param = self.block._storage_params[kw] storage_type = local_param.storage_type() - if self.flag != 2: + if flag != 2: self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) if requires_grad and local_param.requires_grad: self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() - if self.flag != 2: + if flag != 2: nccl.groupStart() for kw, val in self.block._storage_info.items(): nccl.allGather( @@ -86,7 +85,7 @@ def enter(self, requires_grad=False): # set wait stream for each storage for kw in self.block._storage_info.keys(): - if self.flag != 2: + if flag != 2: self._param_tensor[kw].record_stream(current_stream) if requires_grad and kw in self._grad_tensor: self._grad_tensor[kw].record_stream(current_stream) @@ -97,7 +96,7 @@ def enter(self, requires_grad=False): offset = param["offset"] shape = param["shape"] - if self.flag != 2: + if flag != 2: dtype = self._param_buffer[kw_name].dtype device = self._param_buffer[kw_name].device param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) @@ -112,7 +111,7 @@ def enter(self, requires_grad=False): def __enter__(self): self.enter() - def exit(self, backward=False): + def exit(self, flag=0, backward=False): """ Reduce scatter gradients """ @@ -171,7 +170,7 @@ def exit(self, backward=False): param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) - if self.flag == 1: + if flag == 1: for i in self._param_buffer: self.ctx_dict[i] = self._param_buffer[i] self._grad_tensor = {} diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index e17a86a4..b26c357b 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -10,39 +10,51 @@ def zero_pre_forward(module, inputs): enter = module._micro_idx == 0 pipe = True if enter: - forward_flag = 1 if config['zero_level'] == 2 else 0 - module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, forward_flag, pipe=pipe) - module._forward_block_ctx.enter() + zero_level = config['zero_level'] + forward_flag = 1 if zero_level == 2 else 0 + if zero_level == 2 and module._ref_count > 1: + forward_flag = 2 # repeating forward in same layer + module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=pipe) + module._forward_block_ctx.enter(forward_flag) def zero_post_forward(module, inputs, outputs): + forward_flag = 1 if config['zero_level'] == 2 else 0 exit = True if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 if exit: - module._forward_block_ctx.exit() + module._forward_block_ctx.exit(forward_flag) def zero_pre_backward(module, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 if module._mode != "PIPE": - module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, backward_flag) - module._backward_block_ctx.enter(True) - if not module._is_last_layer and module._next_module is not None and module._next_module._backward_block_ctx is not None: - module._next_module._backward_block_ctx.exit(True) - config['load_stream'].record_event(config['load_event']) + module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) + module._backward_block_ctx.enter(backward_flag, True) + if not module._is_last_layer and len(module._next_module) > 0 and module._next_module[-1]._backward_block_ctx is not None: + if module._next_module[-1]._ref_count == 1: + module._next_module[-1]._ref_count = 0 + module._next_module.pop()._backward_block_ctx.exit(backward_flag, True) + config['load_stream'].record_event(config['load_event']) + else: + module._next_module[-1]._ref_count -= 1 + else: if module._micro_idx == config['micros'] - 1: - module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, backward_flag, pipe=True) - module._backward_block_ctx.enter(True) + module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=True) + module._backward_block_ctx.enter(backward_flag, True) def zero_post_backward(module, grad_inputs, grad_outputs): + backward_flag = 2 if config['zero_level'] == 2 else 0 if module._mode != "PIPE": - if module._is_first_layer: - module._backward_block_ctx.exit(True) + if module._is_first_layer and module._ref_count == 1: + module._backward_block_ctx.exit(backward_flag, True) + module._ref_count = -1 config['load_stream'].record_event(config['load_event']) else: if module._micro_idx == 0: - module._backward_block_ctx.exit(True) + module._ref_count = -1 if module._is_first_layer else 0 + module._backward_block_ctx.exit(backward_flag, True) config['load_stream'].record_event(config['load_event']) class PipePreFunction(torch.autograd.Function): diff --git a/bmtrain/init.py b/bmtrain/init.py index 5c3006d2..5772f963 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -7,6 +7,8 @@ from .global_var import config from . import nccl from .synchronize import synchronize +from .block_layer import BMTBlockContext + def init_distributed( init_method : str = "env://", seed : int = 0, @@ -72,6 +74,9 @@ def init_distributed( config["zero_level"] = zero_level config["topology"] = topology(config) config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] + config["block_context"] = [] + for i in range(world_size): + config["block_context"].append(BMTBlockContext()) cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 2759f662..af62aecc 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -174,8 +174,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self._modules[str(layer_id)]._is_last_stage = True if self.stage_id == self.stages-1 else False self._modules[str(layer_id)]._is_first_layer = True if i == 0 else False self._modules[str(layer_id)]._is_last_layer = True if i == len(self.layer_ids)-1 else False - if i > 0: - self._modules[str(layer_id)].set_pre_module(self._modules[str(layer_id-1)]) +#if i > 0: +#self._modules[str(layer_id)].set_pre_module(self._modules[str(layer_id-1)]) self.partition_modules(self.layer_ids) self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 From 2f1e766a7d16e7986004e02f9d848aeaa3447a36 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 11 Aug 2023 18:15:53 +0800 Subject: [PATCH 045/122] enhance ckp --- bmtrain/block_layer.py | 4 ++++ tests/test_training.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 27636e18..de176da2 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -72,6 +72,10 @@ def link_module(self, module): self._pre_module = module self._first = False + def clear(self): + self._pre_module = None + self._first = True + class CheckpointBlock(torch.nn.Module): """ A bmtrain block containing two memory-saving methods of ZeRO-2/3 and checkpoint. diff --git a/tests/test_training.py b/tests/test_training.py index 7342fe6c..b58701d4 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -396,6 +396,7 @@ def pipe_model(): def add_to_check_list(m, l, o): key, value = train((m, models[m]), (l, loss_funcs[l]), (o, optimizers[o])) ret[key] = value + config['block_context'][config['rank']].clear() if test_fp16: kwargs["dtype"] = torch.half @@ -442,4 +443,4 @@ def check_param(info1, info2): if __name__ == '__main__': bmt.init_distributed(pipe_size=2) - test_main(test_fp16=True, test_fp32=True) \ No newline at end of file + test_main(test_fp16=True, test_fp32=True) From 683707d02ae3beb91e4a5be9028ddba358eafebc Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 11 Aug 2023 18:41:49 +0800 Subject: [PATCH 046/122] test --- example/test_block.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/example/test_block.py b/example/test_block.py index 4730a0a9..03bc258b 100644 --- a/example/test_block.py +++ b/example/test_block.py @@ -6,6 +6,15 @@ gb = 1024.0 * 1024.0 * 1024.0 +device = torch.device('cuda') + +def reserved(device): + return torch.cuda.memory_reserved(device) / gb +def allocated(device): + return torch.cuda.memory_allocated(device) / gb +def max_allocated(device): + return torch.cuda.max_memory_allocated(device) / gb + bmt.init_distributed(zero_level=3) @@ -24,25 +33,24 @@ linears = bmt.TransformerBlockList(linears) -device = torch.device('cuda') bmt.synchronize() if config['rank'] == 0: - print('before forward', torch.cuda.memory_allocated(device) / gb) + print('before forward', allocated(device), reserved(device), max_allocated(device)) batch_size=1 seq_len=4096 x = torch.randn(batch_size, seq_len, 8192, dtype=torch.float16, device=device).requires_grad_() bmt.synchronize() if config['rank'] == 0: - print('init input', torch.cuda.memory_allocated(device) / gb) + print('init input', allocated(device), reserved(device), max_allocated(device)) enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() mask = torch.arange(seq_len).unsqueeze(0) <= torch.arange(seq_len).unsqueeze(1) -mask = mask.unsqueeze(0).unsqueeze(0) +mask = mask.unsqueeze(0).unsqueeze(0).to(device) y = linears(x,mask) bmt.synchronize() if config['rank'] == 0: - print('after forward', torch.cuda.memory_allocated(device) / gb) + print('after forward', allocated(device), reserved(device), max_allocated(device)) y.sum().backward() bmt.synchronize() if config['rank'] == 0: - print('after backward', torch.cuda.memory_allocated(device) / gb) + print('after backward', allocated(device), reserved(device), max_allocated(device)) From 4013502c0aa5f971010b8ecdf4284afb329d3524 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 11 Aug 2023 18:58:24 +0800 Subject: [PATCH 047/122] test --- example/layers/transformer.py | 16 ++++++--- example/test_block.py | 62 ++++++++++++++++++++++------------- 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/example/layers/transformer.py b/example/layers/transformer.py index 1e1cffe0..3390327d 100644 --- a/example/layers/transformer.py +++ b/example/layers/transformer.py @@ -10,11 +10,17 @@ def __init__(self, ) -> None: super().__init__() - self.ln_attn = Layernorm(dim_model, dtype=dtype) - self.attn = Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype) + self.ln_attn = bmt.CheckpointBlock(Layernorm(dim_model, dtype=dtype), use_checkpoint=False) + self.attn = bmt.CheckpointBlock( + Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype), + use_checkpoint=False + ) - self.ln_ff = Layernorm(dim_model, dtype=dtype) - self.ff = Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype) + self.ln_ff = bmt.CheckpointBlock(Layernorm(dim_model, dtype=dtype), use_checkpoint=False) + self.ff = bmt.CheckpointBlock( + Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype), + use_checkpoint=True + ) def forward(self, hidden : torch.Tensor, # (batch, seq_len, dim_model) @@ -23,7 +29,7 @@ def forward(self, ): bmt.inspect.record_tensor(hidden, "hidden") x = self.ln_attn(hidden) - x = self.attn(x, x, mask) + x = self.attn(x, x, mask, position_bias) hidden = hidden + x x = self.ln_ff(hidden) diff --git a/example/test_block.py b/example/test_block.py index 03bc258b..6493e1d7 100644 --- a/example/test_block.py +++ b/example/test_block.py @@ -6,8 +6,6 @@ gb = 1024.0 * 1024.0 * 1024.0 -device = torch.device('cuda') - def reserved(device): return torch.cuda.memory_reserved(device) / gb def allocated(device): @@ -15,42 +13,60 @@ def allocated(device): def max_allocated(device): return torch.cuda.max_memory_allocated(device) / gb - bmt.init_distributed(zero_level=3) linears = [] for i in range(10), : - linears.append(bmt.CheckpointBlock(TransformerEncoder( + linears.append(TransformerEncoder( dim_model=8192, dim_head=128, num_heads=64, dim_ff=20480, bias=False, dtype=torch.half - ), - use_checkpoint=False) + ) ) -linears = bmt.TransformerBlockList(linears) +#linears = bmt.TransformerBlockList(linears) +linears = torch.nn.ModuleList(linears) + +optimizer = bmt.optim.AdamOffloadOptimizer(linears.parameters(), weight_decay=1e-2) +lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) +optim_manager = bmt.optim.OptimManager(loss_scale=2**20) +optim_manager.add_optimizer(optimizer, lr_scheduler) + +bmt.synchronize() + +device = torch.device('cuda') bmt.synchronize() if config['rank'] == 0: - print('before forward', allocated(device), reserved(device), max_allocated(device)) + print('before init input', allocated(device), reserved(device)) batch_size=1 seq_len=4096 -x = torch.randn(batch_size, seq_len, 8192, dtype=torch.float16, device=device).requires_grad_() -bmt.synchronize() -if config['rank'] == 0: - print('init input', allocated(device), reserved(device), max_allocated(device)) -enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() -mask = torch.arange(seq_len).unsqueeze(0) <= torch.arange(seq_len).unsqueeze(1) -mask = mask.unsqueeze(0).unsqueeze(0).to(device) -y = linears(x,mask) -bmt.synchronize() -if config['rank'] == 0: - print('after forward', allocated(device), reserved(device), max_allocated(device)) -y.sum().backward() -bmt.synchronize() -if config['rank'] == 0: - print('after backward', allocated(device), reserved(device), max_allocated(device)) +for i in range(4): + x = torch.randn(batch_size, seq_len, 8192, dtype=torch.float16, device=device).requires_grad_() + bmt.synchronize() + if config['rank'] == 0: + print('init input', allocated(device), reserved(device)) + enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() + mask = torch.arange(seq_len).unsqueeze(0) <= torch.arange(seq_len).unsqueeze(1) + mask = mask.unsqueeze(0).unsqueeze(0).to(device) +#y = linears(x,mask) + y = x + for encoder in linears: + y = encoder(y, mask) + bmt.synchronize() + if config['rank'] == 0: + print('after forward', allocated(device), reserved(device), max_allocated(device)) + + y.sum().backward() + bmt.synchronize() + if config['rank'] == 0: + print('after backward', allocated(device), reserved(device), max_allocated(device)) + optim_manager.step() + if config['rank'] == 0: + print('after optimizer', allocated(device), reserved(device)) +#torch.cuda.empty_cache() + optim_manager.zero_grad() From 1c532d4f4a3194189a6e79d88f679818a61d75d2 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 12 Aug 2023 11:16:24 +0800 Subject: [PATCH 048/122] test --- example/layers/transformer.py | 47 ++++++++++++++++++++++++++--------- example/models/gpt.py | 14 +++++------ 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/example/layers/transformer.py b/example/layers/transformer.py index 3390327d..2fb9614e 100644 --- a/example/layers/transformer.py +++ b/example/layers/transformer.py @@ -3,25 +3,17 @@ import bmtrain as bmt from layers import Layernorm, Feedforward, Attention -class TransformerEncoder(bmt.DistributedModule): +class SubBlock(bmt.DistributedModule): def __init__(self, dim_model : int, dim_head : int, num_heads : int, dim_ff : int, bias : bool = True, dtype = None ) -> None: super().__init__() + self.ln_attn = Layernorm(dim_model, dtype=dtype) + self.attn = Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype) - self.ln_attn = bmt.CheckpointBlock(Layernorm(dim_model, dtype=dtype), use_checkpoint=False) - self.attn = bmt.CheckpointBlock( - Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype), - use_checkpoint=False - ) + self.ln_ff = Layernorm(dim_model, dtype=dtype) - self.ln_ff = bmt.CheckpointBlock(Layernorm(dim_model, dtype=dtype), use_checkpoint=False) - self.ff = bmt.CheckpointBlock( - Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype), - use_checkpoint=True - ) - def forward(self, hidden : torch.Tensor, # (batch, seq_len, dim_model) mask : torch.BoolTensor, # (batch, seq_len, dim_model) @@ -33,6 +25,37 @@ def forward(self, hidden = hidden + x x = self.ln_ff(hidden) +#x = self.ff(x) +#hidden = hidden + x + + return hidden, x + + +class TransformerEncoder(bmt.DistributedModule): + def __init__(self, + dim_model : int, dim_head : int, num_heads : int, dim_ff : int, + bias : bool = True, dtype = None + ) -> None: + super().__init__() + + self.attn = bmt.CheckpointBlock(SubBlock(dim_model, dim_head, num_heads, dim_ff, bias, dtype), use_checkpoint=True) + self.ff = bmt.CheckpointBlock( + Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype), + use_checkpoint=True + ) + + def forward(self, + hidden : torch.Tensor, # (batch, seq_len, dim_model) + mask : torch.BoolTensor, # (batch, seq_len, dim_model) + position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len) + ): + bmt.inspect.record_tensor(hidden, "hidden") +# x = self.ln_attn(hidden) +# x = self.attn(x, x, mask, position_bias) +# hidden = hidden + x +# +# x = self.ln_ff(hidden) + hidden, x = self.attn(hidden, mask, position_bias) x = self.ff(x) hidden = hidden + x diff --git a/example/models/gpt.py b/example/models/gpt.py index 78d77a7d..042ae91d 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -16,11 +16,9 @@ def __init__(self, self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) - self.transformers = bmt.TransformerBlockList([ - bmt.CheckpointBlock( - TransformerEncoder( - dim_model, dim_head, num_heads, dim_ff, bias, dtype - ) + self.transformers = torch.nn.ModuleList([ + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype ) for _ in range(num_layers) ]) @@ -39,10 +37,12 @@ def forward(self, out = self.pos_emb(pos) + self.word_emb(input) # for layer in self.transformers: - out = self.transformers(out, mask_2d, None) +#out = self.transformers(out, mask_2d, None) + for trans in self.transformers: + out = trans(out, mask_2d) out = self.layernorm(out) logits = self.word_emb(out, projection=True) bmt.inspect.record_tensor(logits, "logits") - return logits \ No newline at end of file + return logits From 1e993c6a9590ce982db302efa2eb334201c499b7 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 12 Aug 2023 14:23:49 +0800 Subject: [PATCH 049/122] refactor code --- bmtrain/block_layer.py | 22 ++++++++++++++-------- bmtrain/pipe_layer.py | 2 -- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index de176da2..60919938 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -242,11 +242,10 @@ def set_pre_module(self, module): def set_next_module(self, module): self._next_module.append(module) module._ref_count += 1 - - def forward(self, *args): + + def pre_hook(self, *args): if self._mode != "PIPE": self.block_context.link_module(self) - grad_tensors = [] grad_index = [] arg_list = list(args) @@ -259,12 +258,9 @@ def forward(self, *args): pre_out = hook_func.PreHookFunc.apply(self, *grad_tensors) for i in range(len(grad_index)): arg_list[grad_index[i]] = pre_out[i] + return arg_list - if self.use_checkpoint: - out = checkpoint(self._module, *arg_list) - else: - out = self._module(*arg_list) - + def post_hook(self, out): tuple_out = (out, ) if isinstance(out, torch.Tensor) else out post_out = hook_func.PostHookFunc.apply(self, *tuple_out) if isinstance(out, torch.Tensor) and isinstance(post_out, tuple): @@ -274,6 +270,16 @@ def forward(self, *args): return tuple(post_out) return post_out + def forward(self, *args): + arg_list = self.pre_hook(*args) + + if self.use_checkpoint: + out = checkpoint(self._module, *arg_list) + else: + out = self._module(*arg_list) + + return self.post_hook(out) + def __getattr__(self,name:str): if name=="_module": return self._module diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index af62aecc..99b60750 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -174,8 +174,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self._modules[str(layer_id)]._is_last_stage = True if self.stage_id == self.stages-1 else False self._modules[str(layer_id)]._is_first_layer = True if i == 0 else False self._modules[str(layer_id)]._is_last_layer = True if i == len(self.layer_ids)-1 else False -#if i > 0: -#self._modules[str(layer_id)].set_pre_module(self._modules[str(layer_id-1)]) self.partition_modules(self.layer_ids) self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 From 24d0f5999052561d114de74d121e07720244f1f4 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 12 Aug 2023 14:45:45 +0800 Subject: [PATCH 050/122] mv linear to bmt.nn.linear --- bmtrain/__init__.py | 2 +- bmtrain/nn/__init__.py | 1 + {example/layers => bmtrain/nn}/linear.py | 0 example/layers/__init__.py | 3 +-- example/layers/attention.py | 2 +- example/layers/feedforward.py | 2 +- 6 files changed, 5 insertions(+), 5 deletions(-) create mode 100644 bmtrain/nn/__init__.py rename {example/layers => bmtrain/nn}/linear.py (100%) diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index 35078b7e..3c158f92 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -18,4 +18,4 @@ from . import lr_scheduler from . import loss from . import distributed - +from . import nn diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py new file mode 100644 index 00000000..67f9fdee --- /dev/null +++ b/bmtrain/nn/__init__.py @@ -0,0 +1 @@ +from .linear import Linear diff --git a/example/layers/linear.py b/bmtrain/nn/linear.py similarity index 100% rename from example/layers/linear.py rename to bmtrain/nn/linear.py diff --git a/example/layers/__init__.py b/example/layers/__init__.py index 425d0a1b..ef4617c0 100644 --- a/example/layers/__init__.py +++ b/example/layers/__init__.py @@ -1,6 +1,5 @@ -from .linear import Linear from .embedding import Embedding from .feedforward import Feedforward from .layernorm import Layernorm from .attention import Attention -from .transformer import TransformerEncoder \ No newline at end of file +from .transformer import TransformerEncoder diff --git a/example/layers/attention.py b/example/layers/attention.py index 4a0eec11..243df3ea 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -1,7 +1,7 @@ from typing import Optional import torch import bmtrain as bmt -from layers import Linear +from bmtrain.nn import Linear import math class Attention(bmt.DistributedModule): diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index 3fe935bf..99d2dc3b 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -1,6 +1,6 @@ import torch import bmtrain as bmt -from layers import Linear +from bmtrain.nn import Linear class Feedforward(bmt.DistributedModule): def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: From ff72e665f922eefe690a489a6ddd9b096af1bb51 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 12 Aug 2023 16:27:10 +0800 Subject: [PATCH 051/122] for enhance_ckp --- tests/test_middle_hidden.py | 5 ++++- tests/test_requires_grad_multi_gpu.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_middle_hidden.py b/tests/test_middle_hidden.py index f0d5c559..86b9e552 100644 --- a/tests/test_middle_hidden.py +++ b/tests/test_middle_hidden.py @@ -146,6 +146,7 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ret += bmt.inspect.format_summary( bmt.inspect.inspect_model(m, '*') ) + config['block_context'][config['rank']].clear() if only_middle: logits, hidden_states = m(inp, return_hidden_states=True) loss = sum([ @@ -157,6 +158,7 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ret += bmt.inspect.format_summary( bmt.inspect.inspect_model(m, '*') ) + config['block_context'][config['rank']].clear() if mix_test: logits, hidden_states = m(inp, return_hidden_states=True) loss = sum([ @@ -168,6 +170,7 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ret += bmt.inspect.format_summary( bmt.inspect.inspect_model(m, '*') ) + config['block_context'][config['rank']].clear() return ret + "\n" # replace for matching None grad with zero_grad def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256): @@ -209,4 +212,4 @@ def test_main(): if __name__ == "__main__": bmt.init_distributed(pipe_size=4) - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_requires_grad_multi_gpu.py b/tests/test_requires_grad_multi_gpu.py index ebea096e..3fe7ba2e 100644 --- a/tests/test_requires_grad_multi_gpu.py +++ b/tests/test_requires_grad_multi_gpu.py @@ -31,6 +31,7 @@ def forward(self, input): def run(m, a, b): inp = torch.rand((1, 10, 256)).cuda()*100 + inp.requires_grad_() logits = m(inp) loss = logits.sum() loss.backward() @@ -93,4 +94,4 @@ def test_main_pipe(): bmt.init_distributed(pipe_size=2) test_main() - test_main_pipe() \ No newline at end of file + test_main_pipe() From 1fbf3b29f28e67725cf6702610e7f4f2a811de24 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 14 Aug 2023 12:20:31 +0800 Subject: [PATCH 052/122] fix for all input not grad --- bmtrain/block_layer.py | 10 +++++++++- bmtrain/checkpointing.py | 2 ++ tests/test_requires_grad.py | 1 - 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 60919938..080633f5 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -231,6 +231,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_co self.block_context = block_context if block_context is None: self.block_context = config['block_context'][config['rank']] + self.all_input_no_grad = False def set_pre_module(self, module): self._ref_count += 1 @@ -258,6 +259,13 @@ def pre_hook(self, *args): pre_out = hook_func.PreHookFunc.apply(self, *grad_tensors) for i in range(len(grad_index)): arg_list[grad_index[i]] = pre_out[i] + + if len(grad_tensors) == 0: + for param in self.parameters(): + if param.requires_grad: + param.register_hook(lambda grad: hook_func.zero_post_backward(self, grad, None)) + break + self.all_input_no_grad = True return arg_list def post_hook(self, out): @@ -274,7 +282,7 @@ def forward(self, *args): arg_list = self.pre_hook(*args) if self.use_checkpoint: - out = checkpoint(self._module, *arg_list) + out = checkpoint(self._module, *arg_list, use_reentrant=False) else: out = self._module(*arg_list) diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index c13f0963..456a4c5c 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -170,6 +170,8 @@ def exit(self, flag=0, backward=False): param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + if self.block.all_input_no_grad: + param['parameter'].grad.data = param['parameter'].grad.data.view(param['shape']) if flag == 1: for i in self._param_buffer: self.ctx_dict[i] = self._param_buffer[i] diff --git a/tests/test_requires_grad.py b/tests/test_requires_grad.py index 41cabbe0..104b8125 100644 --- a/tests/test_requires_grad.py +++ b/tests/test_requires_grad.py @@ -32,7 +32,6 @@ def forward(self, input, other_bias): def run(m, a, b): inp = torch.rand((1, 10, 256)).cuda()*100 - inp.requires_grad_() bias = torch.rand(256).cuda()*100 logits = m(inp, bias) loss = logits.sum() From ace5216f5f8683f412a1f5c95ec4be1e280c743b Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 14 Aug 2023 13:00:51 +0800 Subject: [PATCH 053/122] fix pre_module --- bmtrain/block_layer.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 080633f5..88c07de1 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -233,15 +233,9 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_co self.block_context = config['block_context'][config['rank']] self.all_input_no_grad = False - def set_pre_module(self, module): - self._ref_count += 1 - if module is not None: - module._next_module.append(self) - self._is_first_layer = False - module._is_last_layer = False - def set_next_module(self, module): self._next_module.append(module) + module._pre_module.append(self) module._ref_count += 1 def pre_hook(self, *args): From 52cd4e2b0eecce9e3204957f6bac7f99582fe5ee Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 14 Aug 2023 18:46:44 +0800 Subject: [PATCH 054/122] fix pre_module --- bmtrain/hook_func.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index b26c357b..78738d40 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -51,6 +51,8 @@ def zero_post_backward(module, grad_inputs, grad_outputs): module._backward_block_ctx.exit(backward_flag, True) module._ref_count = -1 config['load_stream'].record_event(config['load_event']) + if not module._is_first_layer and len(module._pre_module) > 0: + module._pre_module.pop() else: if module._micro_idx == 0: module._ref_count = -1 if module._is_first_layer else 0 From 0b0bd0b013e14b31c78abfe8913a41a6786f16cb Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 14 Aug 2023 19:00:47 +0800 Subject: [PATCH 055/122] fix for all input no grad --- bmtrain/block_layer.py | 10 ++++++---- bmtrain/checkpointing.py | 12 ++++++++---- bmtrain/pipe_layer.py | 1 + tests/test_requires_grad_multi_gpu.py | 2 +- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 88c07de1..dc5748a9 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -254,12 +254,14 @@ def pre_hook(self, *args): for i in range(len(grad_index)): arg_list[grad_index[i]] = pre_out[i] - if len(grad_tensors) == 0: - for param in self.parameters(): - if param.requires_grad: - param.register_hook(lambda grad: hook_func.zero_post_backward(self, grad, None)) + if self._mode != "PIPE" and len(grad_tensors) == 0: + for param in self._param_info: + if param['parameter'].requires_grad: + param['parameter'].register_hook(lambda grad: hook_func.zero_post_backward(self, grad, None)) break self.all_input_no_grad = True + else: + self.all_input_no_grad = False return arg_list def post_hook(self, out): diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index 456a4c5c..3ea50e03 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -163,15 +163,18 @@ def exit(self, flag=0, backward=False): device = self.block._storage_params[kw_name].device if "begin" not in param: param["parameter"].data = torch.tensor([], dtype=dtype, device=device) - param["parameter"].grad = None + if not param['parameter'].requires_grad: + param["parameter"].grad = None continue begin = param["begin"] end = param["end"] param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) - if self.block.all_input_no_grad: - param['parameter'].grad.data = param['parameter'].grad.data.view(param['shape']) + if config['world_size'] > 1 and not self.block.all_input_no_grad: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + if config['world_size'] == 1 and self.block.all_input_no_grad: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + param["parameter"].grad.data = param['parameter'].grad.data.view(param['shape']) if flag == 1: for i in self._param_buffer: self.ctx_dict[i] = self._param_buffer[i] @@ -179,6 +182,7 @@ def exit(self, flag=0, backward=False): self._param_tensor = {} self._grad_buffer = {} self._param_buffer = {} + def __exit__(self, exc_type, exc_val, exc_tb): # reduce scatter gradients diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 99b60750..01b7160f 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -196,6 +196,7 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa batch_size = hidden_state.shape[0] num_micros = config["micros"] args = args + (batch_related, ) + hidden_state.requires_grad_() hidden_state_list, args_list = PipePreFunction.apply(hidden_state, *args) hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0) diff --git a/tests/test_requires_grad_multi_gpu.py b/tests/test_requires_grad_multi_gpu.py index 3fe7ba2e..eb5de5e4 100644 --- a/tests/test_requires_grad_multi_gpu.py +++ b/tests/test_requires_grad_multi_gpu.py @@ -31,10 +31,10 @@ def forward(self, input): def run(m, a, b): inp = torch.rand((1, 10, 256)).cuda()*100 - inp.requires_grad_() logits = m(inp) loss = logits.sum() loss.backward() + config['block_context'][config['rank']].clear() sm = bmt.inspect.format_summary( bmt.inspect.inspect_model(m, '*') From 05b49f80c9bd24fd5ca3374d605e330f958aca7f Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 14 Aug 2023 19:04:14 +0800 Subject: [PATCH 056/122] fix for all input no grad --- bmtrain/checkpointing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index 3ea50e03..66249cd3 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -163,8 +163,7 @@ def exit(self, flag=0, backward=False): device = self.block._storage_params[kw_name].device if "begin" not in param: param["parameter"].data = torch.tensor([], dtype=dtype, device=device) - if not param['parameter'].requires_grad: - param["parameter"].grad = None + param["parameter"].grad = None continue begin = param["begin"] end = param["end"] From 98d5b32d1872878e26e0ebc01057af7440a94289 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 15 Aug 2023 09:25:59 +0800 Subject: [PATCH 057/122] activation offloading --- bmtrain/block_layer.py | 18 +++--- bmtrain/hook_func.py | 112 ++++++++++++++++++++++++++++++++- bmtrain/init.py | 1 + example/layers/attention.py | 1 + example/layers/flash_triton.py | 6 +- example/layers/linear.py | 6 +- example/layers/transformer.py | 41 ++---------- example/models/gpt.py | 17 ++--- example/test_block.py | 4 +- example/train.py | 18 +++--- 10 files changed, 153 insertions(+), 71 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index de176da2..4712e884 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -94,7 +94,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_context=None): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_context=None, use_offload=False): super().__init__() self._module = inner_module self._inputs = None @@ -107,9 +107,9 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_co self._storage_params : Dict[str, torch.nn.Parameter] = {} self._storage_info = {} self._ready = False - # sort parameters by name + # sort parameters by nam_next_modulee ordered_parameters = list(self._module.named_parameters()) - + assert not (use_checkpoint and use_offload) # calc total number of parameters for name, param in ordered_parameters: if not isinstance(param, DistributedParameter): @@ -226,22 +226,20 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_co self._next_module = [] self._ref_count = 0 self._mode = "BLOCK" #BLOCK or ZERO or PIPE + if use_offload: + self._mode = "OFFLOAD" + self._on_device = False self.return_hidden_states = False self.hidden_states = [] self.block_context = block_context if block_context is None: self.block_context = config['block_context'][config['rank']] - def set_pre_module(self, module): - self._ref_count += 1 - if module is not None: - module._next_module.append(self) - self._is_first_layer = False - module._is_last_layer = False def set_next_module(self, module): self._next_module.append(module) module._ref_count += 1 + module._pre_module.append(self) def forward(self, *args): if self._mode != "PIPE": @@ -531,10 +529,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - module._mode = "ZERO" module._is_last_layer = True if i == len(modules) -1 else False module._is_first_layer = True if i == 0 else False - self._modules[str(i)] = module self.add_module(str(i), module) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index b26c357b..65fbb739 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -2,10 +2,87 @@ from .global_var import config from .checkpointing import CheckpointBlockContext from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations +from collections import deque + +def wrapper(module_name,act_cuda_dict): + def fn(m,inps,out): + inp = inps[0] + act_cuda_dict[module_name] = {} + act_cuda_dict[module_name]['shape'] = tuple(inp.shape) + act_cuda_dict[module_name]['numel'] = inp.numel() + act_cuda_dict[module_name]['inp'] = inp + act_cuda_dict[module_name]['dtype'] = inp.dtype + return fn + +def nearest_offload_module(module): + if module._mode == "OFFLOAD": + return [module] + queue = deque([(module, 0)]) # 使用队列来进行广度优先搜索 + nearest_modules = [] + nearest_depth = float('inf') + + while queue: + curr_module, curr_depth = queue.popleft() + + if curr_depth > nearest_depth: + break + + for m in curr_module._pre_module: + if m._mode == "OFFLOAD": + if curr_depth < nearest_depth: + nearest_modules = [m] + nearest_depth = curr_depth + elif curr_depth == nearest_depth: + nearest_modules.append(m) + else: + queue.append((m, curr_depth + 1)) + + return nearest_modules + +def make_cpu_storage(_act_cuda_dict, _offload_dict): + fp16_total = sum([v['numel'] for v in _act_cuda_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['numel'] for v in _act_cuda_dict.values() if v['dtype'] == torch.float32]) + fp16_storage = torch.HalfStorage(fp16_total).pin_memory() + fp32_storage = torch.FloatStorage(fp32_total).pin_memory() + fp16_offset = 0 + fp32_offset = 0 + for key,val in _act_cuda_dict.items(): + if val['dtype'] == torch.float16: + _offload_dict[key] = {} + _offload_dict[key]['inp'] = torch.tensor([], dtype=torch.float16, device="cpu") \ + .set_(fp16_storage, fp16_offset, val['shape']) + + fp16_offset += _act_cuda_dict[key]['numel'] + elif val['dtype'] == torch.float32: + _offload_dict[key]['inp'] = torch.tensor([], dtype=torch.float32, device="cpu") \ + .set_(fp32_storage, fp32_offset, val['shape']) + + fp32_offset += _act_cuda_dict[key]['numel'] +def d2h_memcpy(_act_cuda_dict, _offload_dict): + for key,val in _act_cuda_dict.items(): + shape, inp = val['shape'],val['inp'] + cpu_inp = _offload_dict[key]['inp'] + _offload_dict[key]['inp'] = cpu_inp.copy_(inp, non_blocking=True) + +def h2d_memcpy(_act_cuda_dict, _offload_dict): + for key,val in _act_cuda_dict.items(): + shape, cuda_inp = val['shape'],val['inp'] + cpu_inp = _offload_dict[key]['inp'] + cuda_stor = cuda_inp.storage_type()(val['numel']) + cuda_inp.set_(cuda_stor, 0, shape) + cuda_inp.copy_(cpu_inp, non_blocking=True) def zero_pre_forward(module, inputs): enter = True pipe = False + + if module._mode == "OFFLOAD": + act_dict = {} + for name, sub_module in module.named_modules(): + if sub_module.__class__.__name__ == "Linear": + fn = wrapper(name, act_dict) + sub_module.register_forward_hook(fn) + module._act_cuda_dict = act_dict if module._mode == "PIPE": enter = module._micro_idx == 0 pipe = True @@ -22,13 +99,35 @@ def zero_post_forward(module, inputs, outputs): exit = True if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 - + elif module._mode == "OFFLOAD": + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["calc_stream"]) + if not hasattr(module, "_offload_dict"): + module._offload_dict = {} + make_cpu_storage(module._act_cuda_dict, module._offload_dict) + with torch.cuda.stream(config["offload_stream"]): + d2h_memcpy(module._act_cuda_dict, module._offload_dict) + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["offload_stream"]) + for key,val in module._act_cuda_dict.items(): + module._act_cuda_dict[key]['inp'].data = torch.tensor([],dtype=val['inp'].dtype,device=val['inp'].device) if exit: module._forward_block_ctx.exit(forward_flag) def zero_pre_backward(module, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 if module._mode != "PIPE": + if module._mode != "OFFLOAD": + count = len([m for m in module._pre_module if m._mode=="OFFLOAD"]) + if module._is_last_layer or module._next_module[0]._mode == "OFFLOAD": + for pre_module in nearest_offload_module(module): + if pre_module._mode == "OFFLOAD": + pre_module._on_device = True + with torch.cuda.stream(config["offload_stream"]): + h2d_memcpy(pre_module._act_cuda_dict, pre_module._offload_dict) + else: + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["offload_stream"]) module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) if not module._is_last_layer and len(module._next_module) > 0 and module._next_module[-1]._backward_block_ctx is not None: @@ -38,7 +137,6 @@ def zero_pre_backward(module, grad_outputs): config['load_stream'].record_event(config['load_event']) else: module._next_module[-1]._ref_count -= 1 - else: if module._micro_idx == config['micros'] - 1: module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=True) @@ -47,6 +145,16 @@ def zero_pre_backward(module, grad_outputs): def zero_post_backward(module, grad_inputs, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 if module._mode != "PIPE": + if not module._is_first_layer and len(module._pre_module) > 0: + module._pre_module.pop() + if module._mode == "OFFLOAD": + module._on_device = False + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["calc_stream"]) + with torch.cuda.stream(config["offload_stream"]): + for key,val in module._act_cuda_dict.items(): + inp = val['inp'] + inp.data = torch.tensor([], dtype=inp.dtype, device=inp.device) if module._is_first_layer and module._ref_count == 1: module._backward_block_ctx.exit(backward_flag, True) module._ref_count = -1 diff --git a/bmtrain/init.py b/bmtrain/init.py index 5772f963..40acbb93 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -69,6 +69,7 @@ def init_distributed( config["world_size"] = world_size config["calc_stream"] = torch.cuda.current_stream() config["load_stream"] = torch.cuda.Stream(priority=-1) + config["offload_stream"] = torch.cuda.Stream() config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() config["zero_level"] = zero_level diff --git a/example/layers/attention.py b/example/layers/attention.py index 062798aa..32c36578 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -84,6 +84,7 @@ def forward( mask = attention_mask mask_bias = torch.zeros_like(attention_mask, device="cuda", dtype=torch.float16) # 创建与mask形状相同的全零张量 mask_bias[mask == False] -= torch.inf + mask_bias = mask_bias.unsqueeze(1) score = FlashAttnFunc.apply(h_q, h_k, h_v, mask_bias, False, None) score = score.view(batch_size, len_q, self.num_heads * self.dim_head) diff --git a/example/layers/flash_triton.py b/example/layers/flash_triton.py index 1d687378..bba3ecee 100644 --- a/example/layers/flash_triton.py +++ b/example/layers/flash_triton.py @@ -808,13 +808,15 @@ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): o, lse, ctx.softmax_scale = _flash_attn_forward( q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale ) - ctx.save_for_backward(q, k, v, o, lse, bias) + ctx.output = {"output":o} + ctx.save_for_backward(q, k, v, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): - q, k, v, o, lse, bias = ctx.saved_tensors + q, k, v, lse, bias = ctx.saved_tensors + o = ctx.output["output"] assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet' # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. diff --git a/example/layers/linear.py b/example/layers/linear.py index faf0770e..0afd2f40 100644 --- a/example/layers/linear.py +++ b/example/layers/linear.py @@ -5,12 +5,14 @@ class CustomLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias=None): - ctx.save_for_backward(x, weight, bias) + ctx.input = {"input":x} + ctx.save_for_backward( weight, bias) return F.linear(x, weight, bias) @staticmethod def backward(ctx, grad_output): - x, weight, bias = ctx.saved_tensors + x = ctx.input["input"] + weight, bias = ctx.saved_tensors grad_x = grad_weight = grad_bias = None if x.requires_grad: grad_x = grad_output.matmul(weight) diff --git a/example/layers/transformer.py b/example/layers/transformer.py index 2fb9614e..1e1cffe0 100644 --- a/example/layers/transformer.py +++ b/example/layers/transformer.py @@ -3,17 +3,19 @@ import bmtrain as bmt from layers import Layernorm, Feedforward, Attention -class SubBlock(bmt.DistributedModule): +class TransformerEncoder(bmt.DistributedModule): def __init__(self, dim_model : int, dim_head : int, num_heads : int, dim_ff : int, bias : bool = True, dtype = None ) -> None: super().__init__() + self.ln_attn = Layernorm(dim_model, dtype=dtype) self.attn = Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype) - self.ln_ff = Layernorm(dim_model, dtype=dtype) - + self.ln_ff = Layernorm(dim_model, dtype=dtype) + self.ff = Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype) + def forward(self, hidden : torch.Tensor, # (batch, seq_len, dim_model) mask : torch.BoolTensor, # (batch, seq_len, dim_model) @@ -21,41 +23,10 @@ def forward(self, ): bmt.inspect.record_tensor(hidden, "hidden") x = self.ln_attn(hidden) - x = self.attn(x, x, mask, position_bias) + x = self.attn(x, x, mask) hidden = hidden + x x = self.ln_ff(hidden) -#x = self.ff(x) -#hidden = hidden + x - - return hidden, x - - -class TransformerEncoder(bmt.DistributedModule): - def __init__(self, - dim_model : int, dim_head : int, num_heads : int, dim_ff : int, - bias : bool = True, dtype = None - ) -> None: - super().__init__() - - self.attn = bmt.CheckpointBlock(SubBlock(dim_model, dim_head, num_heads, dim_ff, bias, dtype), use_checkpoint=True) - self.ff = bmt.CheckpointBlock( - Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype), - use_checkpoint=True - ) - - def forward(self, - hidden : torch.Tensor, # (batch, seq_len, dim_model) - mask : torch.BoolTensor, # (batch, seq_len, dim_model) - position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len) - ): - bmt.inspect.record_tensor(hidden, "hidden") -# x = self.ln_attn(hidden) -# x = self.attn(x, x, mask, position_bias) -# hidden = hidden + x -# -# x = self.ln_ff(hidden) - hidden, x = self.attn(hidden, mask, position_bias) x = self.ff(x) hidden = hidden + x diff --git a/example/models/gpt.py b/example/models/gpt.py index 042ae91d..89b32fc3 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -15,10 +15,13 @@ def __init__(self, self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) - - self.transformers = torch.nn.ModuleList([ - TransformerEncoder( - dim_model, dim_head, num_heads, dim_ff, bias, dtype + ckpt_mask = [True for i in range(num_layers)] + offload_mask = [False for i in range(num_layers)] + self.transformers = bmt.TransformerBlockList([ + bmt.CheckpointBlock( + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype + ),use_checkpoint=True, ) for _ in range(num_layers) ]) @@ -37,12 +40,10 @@ def forward(self, out = self.pos_emb(pos) + self.word_emb(input) # for layer in self.transformers: -#out = self.transformers(out, mask_2d, None) - for trans in self.transformers: - out = trans(out, mask_2d) + out = self.transformers(out, mask_2d, None) out = self.layernorm(out) logits = self.word_emb(out, projection=True) bmt.inspect.record_tensor(logits, "logits") - return logits + return logits \ No newline at end of file diff --git a/example/test_block.py b/example/test_block.py index 6493e1d7..90f57182 100644 --- a/example/test_block.py +++ b/example/test_block.py @@ -27,8 +27,8 @@ def max_allocated(device): ) ) -#linears = bmt.TransformerBlockList(linears) -linears = torch.nn.ModuleList(linears) +linears = bmt.TransformerBlockList(linears) +# linears = torch.nn.ModuleList(linears) optimizer = bmt.optim.AdamOffloadOptimizer(linears.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) diff --git a/example/train.py b/example/train.py index 7bc92400..8b071445 100644 --- a/example/train.py +++ b/example/train.py @@ -6,18 +6,18 @@ def main(): bmt.init_distributed( seed=0, - zero_level=2, + zero_level=3, ) model = GPT( - num_layers=8, - vocab_size=10240, - dim_model=2560, - dim_head=80, + num_layers=48, + vocab_size=80000, + dim_model=4096, + dim_head=128, num_heads=32, - dim_ff=8192, + dim_ff=10240, max_distance=1024, - bias=True, + bias=False, dtype=torch.half ) @@ -51,7 +51,7 @@ def main(): break loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) - optimizer = bmt.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) + optimizer = bmt.optim.AdamOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) optim_manager = bmt.optim.OptimManager(loss_scale=2**20) @@ -62,7 +62,7 @@ def main(): avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - for iteration in range(1000): + for iteration in range(30): # load data st = time.time() From c16127a274d0ef75a32d398fe6b0305c4f755610 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Wed, 16 Aug 2023 10:53:06 +0800 Subject: [PATCH 058/122] offload new version --- bmtrain/block_layer.py | 2 ++ bmtrain/hook_func.py | 67 +++++++++++++++++++++++++++++++----------- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 4712e884..35704110 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -19,6 +19,8 @@ import inspect from torch.utils.checkpoint import checkpoint + + def storage_type_cuda(storage_type): STORAGE_MAP = { torch.FloatStorage: torch.cuda.FloatStorage, diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 65fbb739..4bee2ab4 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -2,18 +2,35 @@ from .global_var import config from .checkpointing import CheckpointBlockContext from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations -from collections import deque +from collections import deque,OrderedDict +class Offload_Dict: + def __init__(self): + self._offload_dict = OrderedDict() + self.offset = 0 + + def add(self, tensor): + self._offload_dict[id(tensor)] = {} + self._offload_dict[id(tensor)]["offset"] = self.offset + self._offload_dict[id(tensor)]["numel"] = tensor.numel() + self._offload_dict[module_name]['dtype'] = inp.dtype + + def make_cpu_storage(self, _cuda_dict): + with torch.cuda.stream(config["offload_stream"]): + def wrapper(module_name,act_cuda_dict): def fn(m,inps,out): + if module_name not in act_cuda_dict: + act_cuda_dict[module_name] = m._inp_dict inp = inps[0] - act_cuda_dict[module_name] = {} act_cuda_dict[module_name]['shape'] = tuple(inp.shape) act_cuda_dict[module_name]['numel'] = inp.numel() act_cuda_dict[module_name]['inp'] = inp - act_cuda_dict[module_name]['dtype'] = inp.dtype + act_cuda_dict[module_name]['dtype'] = inp.dtype + m.riginp_dict = act_cuda_dict[module_name] return fn - + + def nearest_offload_module(module): if module._mode == "OFFLOAD": return [module] @@ -28,7 +45,7 @@ def nearest_offload_module(module): break for m in curr_module._pre_module: - if m._mode == "OFFLOAD": + if m._mode == "OFFLOAD" and not m._on_device: if curr_depth < nearest_depth: nearest_modules = [m] nearest_depth = curr_depth @@ -40,6 +57,10 @@ def nearest_offload_module(module): return nearest_modules def make_cpu_storage(_act_cuda_dict, _offload_dict): + for key,val in _act_cuda_dict.items(): + if "dtype" not in val: + print(key) + print(val) fp16_total = sum([v['numel'] for v in _act_cuda_dict.values() if v['dtype'] == torch.float16]) fp32_total = sum([v['numel'] for v in _act_cuda_dict.values() if v['dtype'] == torch.float32]) fp16_storage = torch.HalfStorage(fp16_total).pin_memory() @@ -69,20 +90,27 @@ def h2d_memcpy(_act_cuda_dict, _offload_dict): shape, cuda_inp = val['shape'],val['inp'] cpu_inp = _offload_dict[key]['inp'] cuda_stor = cuda_inp.storage_type()(val['numel']) + cuda_inp.record_stream(config["offload_stream"]) cuda_inp.set_(cuda_stor, 0, shape) cuda_inp.copy_(cpu_inp, non_blocking=True) +def pack_hook(tensor): + _offload_tensor(id_tensor) = {} def zero_pre_forward(module, inputs): enter = True pipe = False if module._mode == "OFFLOAD": - act_dict = {} - for name, sub_module in module.named_modules(): - if sub_module.__class__.__name__ == "Linear": - fn = wrapper(name, act_dict) - sub_module.register_forward_hook(fn) - module._act_cuda_dict = act_dict + if not hasattr(module,"_act_cuda_dict"): + torch._C._autograd._push_saved_tensors_default_hooks( + pack_hook, unpack_hook + ) + module._act_cuda_dict = {} + for name, sub_module in module.named_modules(): + if sub_module.__class__.__name__ == "Linear": + sub_module.offload = True + fn = wrapper(name, module._act_cuda_dict) + sub_module.register_forward_hook(fn) if module._mode == "PIPE": enter = module._micro_idx == 0 pipe = True @@ -100,17 +128,20 @@ def zero_post_forward(module, inputs, outputs): if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 elif module._mode == "OFFLOAD": + torch._C._autograd._pop_saved_tensors_default_hooks() current_stream = torch.cuda.current_stream() current_stream.wait_stream(config["calc_stream"]) - if not hasattr(module, "_offload_dict"): - module._offload_dict = {} - make_cpu_storage(module._act_cuda_dict, module._offload_dict) with torch.cuda.stream(config["offload_stream"]): - d2h_memcpy(module._act_cuda_dict, module._offload_dict) + if not hasattr(module, "_offload_dict"): + module._offload_dict = {} + make_cpu_storage(module._act_cuda_dict, module._offload_dict) + d2h_memcpy(module._act_cuda_dict, module._offload_dict) current_stream = torch.cuda.current_stream() current_stream.wait_stream(config["offload_stream"]) + cuda_stor = torch.UntypedStorage(1).cuda() for key,val in module._act_cuda_dict.items(): - module._act_cuda_dict[key]['inp'].data = torch.tensor([],dtype=val['inp'].dtype,device=val['inp'].device) + module._act_cuda_dict[key]['inp'].set_(cuda_stor, 0, (1,)) + if exit: module._forward_block_ctx.exit(forward_flag) @@ -151,10 +182,12 @@ def zero_post_backward(module, grad_inputs, grad_outputs): module._on_device = False current_stream = torch.cuda.current_stream() current_stream.wait_stream(config["calc_stream"]) + cuda_stor = torch.UntypedStorage(1).cuda() with torch.cuda.stream(config["offload_stream"]): for key,val in module._act_cuda_dict.items(): inp = val['inp'] - inp.data = torch.tensor([], dtype=inp.dtype, device=inp.device) + inp.record_stream(config["offload_stream"]) + inp.set_(cuda_stor, 0, (1,)) if module._is_first_layer and module._ref_count == 1: module._backward_block_ctx.exit(backward_flag, True) module._ref_count = -1 From 4861ec85fddeb1a06e3761b3d7389addc20334bd Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Wed, 16 Aug 2023 13:30:24 +0800 Subject: [PATCH 059/122] save_for_backward hook --- bmtrain/hook_func.py | 170 +++++++++++++++++++++---------------------- 1 file changed, 85 insertions(+), 85 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 4bee2ab4..fa895b17 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -7,29 +7,62 @@ class Offload_Dict: def __init__(self): self._offload_dict = OrderedDict() - self.offset = 0 def add(self, tensor): - self._offload_dict[id(tensor)] = {} - self._offload_dict[id(tensor)]["offset"] = self.offset - self._offload_dict[id(tensor)]["numel"] = tensor.numel() - self._offload_dict[module_name]['dtype'] = inp.dtype + tensor_id = id(tensor) + self._offload_dict[tensor_id] = {} + self._offload_dict[tensor_id]["numel"] = tensor.numel() + self._offload_dict[tensor_id]['dtype'] = tensor.dtype + self._offload_dict[tensor_id]["tensor"] = tensor + self._offload_dict[tensor_id]["shape"] = tensor.shape + self._device = "cuda" + return tensor_id - def make_cpu_storage(self, _cuda_dict): - with torch.cuda.stream(config["offload_stream"]): - -def wrapper(module_name,act_cuda_dict): - def fn(m,inps,out): - if module_name not in act_cuda_dict: - act_cuda_dict[module_name] = m._inp_dict - inp = inps[0] - act_cuda_dict[module_name]['shape'] = tuple(inp.shape) - act_cuda_dict[module_name]['numel'] = inp.numel() - act_cuda_dict[module_name]['inp'] = inp - act_cuda_dict[module_name]['dtype'] = inp.dtype - m.riginp_dict = act_cuda_dict[module_name] - return fn + def make_cpu_storage(self): + fp16_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + fp16_storage = torch.HalfStorage(fp16_total).pin_memory() + fp32_storage = torch.FloatStorage(fp32_total).pin_memory() + self.fp16_storage = fp16_storage + self.fp32_storage = fp32_storage + self.fp16_total = fp16_total + self.fp32_total = fp32_total + + def get(self, key): + return self._offload_dict[key]["tensor"] + + def pop_all(self): + self._offload_dict = OrderedDict() + + def h2d_memcpy(self): + for key,val in self._offload_dict.items(): + self._offload_dict[key]['tensor'] = self._offload_dict[key]['tensor'].cuda(non_blocking=True) + def record_stream(self, stream): + for key, val in self._offload_dict.items(): + self._offload_dict[key]['tensor'].record_stream(stream) + + def d2h_memcpy(self): + fp16_offset = 0 + fp32_offset = 0 + fp16_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + assert fp16_total <= self.fp16_total + assert fp32_total <= self.fp32_total + fp16_storage = self.fp16_storage + fp32_storage = self.fp32_storage + for key,val in self._offload_dict.items(): + assert val['dtype'] in [torch.float16, torch.float32] + storage = fp16_storage if val['dtype'] == torch.float16 else fp32_storage + offset = fp16_offset if val['dtype'] == torch.float16 else fp32_offset + cpu_tensor = torch.tensor([], dtype=val['dtype'], device="cpu") \ + .set_(storage, offset, val['shape']) + self._offload_dict[key]['tensor'].record_stream(config['offload_stream']) + self._offload_dict[key]['tensor'] = cpu_tensor.copy_(self._offload_dict[key]['tensor'], non_blocking=True) + if val['dtype'] == torch.float16: + fp16_offset += self._offload_dict[key]['numel'] + else: + fp32_offset += self._offload_dict[key]['numel'] def nearest_offload_module(module): if module._mode == "OFFLOAD": @@ -56,61 +89,38 @@ def nearest_offload_module(module): return nearest_modules -def make_cpu_storage(_act_cuda_dict, _offload_dict): - for key,val in _act_cuda_dict.items(): - if "dtype" not in val: - print(key) - print(val) - fp16_total = sum([v['numel'] for v in _act_cuda_dict.values() if v['dtype'] == torch.float16]) - fp32_total = sum([v['numel'] for v in _act_cuda_dict.values() if v['dtype'] == torch.float32]) - fp16_storage = torch.HalfStorage(fp16_total).pin_memory() - fp32_storage = torch.FloatStorage(fp32_total).pin_memory() - fp16_offset = 0 - fp32_offset = 0 - for key,val in _act_cuda_dict.items(): - if val['dtype'] == torch.float16: - _offload_dict[key] = {} - _offload_dict[key]['inp'] = torch.tensor([], dtype=torch.float16, device="cpu") \ - .set_(fp16_storage, fp16_offset, val['shape']) - fp16_offset += _act_cuda_dict[key]['numel'] - elif val['dtype'] == torch.float32: - _offload_dict[key]['inp'] = torch.tensor([], dtype=torch.float32, device="cpu") \ - .set_(fp32_storage, fp32_offset, val['shape']) - - fp32_offset += _act_cuda_dict[key]['numel'] -def d2h_memcpy(_act_cuda_dict, _offload_dict): - for key,val in _act_cuda_dict.items(): - shape, inp = val['shape'],val['inp'] - cpu_inp = _offload_dict[key]['inp'] - _offload_dict[key]['inp'] = cpu_inp.copy_(inp, non_blocking=True) - -def h2d_memcpy(_act_cuda_dict, _offload_dict): - for key,val in _act_cuda_dict.items(): - shape, cuda_inp = val['shape'],val['inp'] - cpu_inp = _offload_dict[key]['inp'] - cuda_stor = cuda_inp.storage_type()(val['numel']) - cuda_inp.record_stream(config["offload_stream"]) - cuda_inp.set_(cuda_stor, 0, shape) - cuda_inp.copy_(cpu_inp, non_blocking=True) +def offload_wrapper(offload_dict): + def pack_hook(tensor): + if isinstance(tensor, torch.nn.Parameter): + return (tensor,) + else: + key = offload_dict.add(tensor) + return (tensor.device, key) + def unpack_hook(packed): + if len(packed) == 2: + device, key = packed + tensor = offload_dict.get(key) + assert tensor.device == device + return tensor + else: + tensor, = packed + return tensor + return pack_hook, unpack_hook -def pack_hook(tensor): - _offload_tensor(id_tensor) = {} def zero_pre_forward(module, inputs): enter = True pipe = False - if module._mode == "OFFLOAD": - if not hasattr(module,"_act_cuda_dict"): - torch._C._autograd._push_saved_tensors_default_hooks( - pack_hook, unpack_hook - ) - module._act_cuda_dict = {} - for name, sub_module in module.named_modules(): - if sub_module.__class__.__name__ == "Linear": - sub_module.offload = True - fn = wrapper(name, module._act_cuda_dict) - sub_module.register_forward_hook(fn) + module._offload_dict = Offload_Dict() + pack_hook, unpack_hook = offload_wrapper(module._offload_dict) + for n, m in module.named_modules(): + if m.__class__.__name__ == "Linear": + m._offload_hook = (pack_hook, unpack_hook) + # torch._C._autograd._push_saved_tensors_default_hooks( + # pack_hook, unpack_hook + # ) + if module._mode == "PIPE": enter = module._micro_idx == 0 pipe = True @@ -128,20 +138,13 @@ def zero_post_forward(module, inputs, outputs): if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 elif module._mode == "OFFLOAD": - torch._C._autograd._pop_saved_tensors_default_hooks() + # torch._C._autograd._pop_saved_tensors_default_hooks() current_stream = torch.cuda.current_stream() current_stream.wait_stream(config["calc_stream"]) with torch.cuda.stream(config["offload_stream"]): - if not hasattr(module, "_offload_dict"): - module._offload_dict = {} - make_cpu_storage(module._act_cuda_dict, module._offload_dict) - d2h_memcpy(module._act_cuda_dict, module._offload_dict) - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config["offload_stream"]) - cuda_stor = torch.UntypedStorage(1).cuda() - for key,val in module._act_cuda_dict.items(): - module._act_cuda_dict[key]['inp'].set_(cuda_stor, 0, (1,)) - + if not hasattr(module._offload_dict, "fp16_storage"): + module._offload_dict.make_cpu_storage() + module._offload_dict.d2h_memcpy() if exit: module._forward_block_ctx.exit(forward_flag) @@ -155,10 +158,11 @@ def zero_pre_backward(module, grad_outputs): if pre_module._mode == "OFFLOAD": pre_module._on_device = True with torch.cuda.stream(config["offload_stream"]): - h2d_memcpy(pre_module._act_cuda_dict, pre_module._offload_dict) + pre_module._offload_dict.h2d_memcpy() else: current_stream = torch.cuda.current_stream() current_stream.wait_stream(config["offload_stream"]) + module._offload_dict.record_stream(config["calc_stream"]) module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) if not module._is_last_layer and len(module._next_module) > 0 and module._next_module[-1]._backward_block_ctx is not None: @@ -182,12 +186,8 @@ def zero_post_backward(module, grad_inputs, grad_outputs): module._on_device = False current_stream = torch.cuda.current_stream() current_stream.wait_stream(config["calc_stream"]) - cuda_stor = torch.UntypedStorage(1).cuda() with torch.cuda.stream(config["offload_stream"]): - for key,val in module._act_cuda_dict.items(): - inp = val['inp'] - inp.record_stream(config["offload_stream"]) - inp.set_(cuda_stor, 0, (1,)) + module._offload_dict.pop_all() if module._is_first_layer and module._ref_count == 1: module._backward_block_ctx.exit(backward_flag, True) module._ref_count = -1 From fc819713edd7a9df4078610619f23911d243a7b6 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 17 Aug 2023 18:36:54 +0800 Subject: [PATCH 060/122] offloading bug fix --- bmtrain/hook_func.py | 26 ++++++-------- example/layers/attention.py | 8 ++++- example/layers/flash_triton.py | 6 ++-- example/layers/linear.py | 17 +++++---- example/layers/transformer.py | 4 +-- example/models/gpt.py | 16 +++++---- example/train.py | 65 +++++++++++++++++----------------- 7 files changed, 75 insertions(+), 67 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index fa895b17..10e7d50b 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -32,7 +32,7 @@ def get(self, key): return self._offload_dict[key]["tensor"] def pop_all(self): - self._offload_dict = OrderedDict() + self._offload_dict.clear() def h2d_memcpy(self): for key,val in self._offload_dict.items(): @@ -65,8 +65,6 @@ def d2h_memcpy(self): fp32_offset += self._offload_dict[key]['numel'] def nearest_offload_module(module): - if module._mode == "OFFLOAD": - return [module] queue = deque([(module, 0)]) # 使用队列来进行广度优先搜索 nearest_modules = [] nearest_depth = float('inf') @@ -112,7 +110,8 @@ def zero_pre_forward(module, inputs): enter = True pipe = False if module._mode == "OFFLOAD": - module._offload_dict = Offload_Dict() + if not hasattr(module, "_offload_dict"): + module._offload_dict = Offload_Dict() pack_hook, unpack_hook = offload_wrapper(module._offload_dict) for n, m in module.named_modules(): if m.__class__.__name__ == "Linear": @@ -137,14 +136,14 @@ def zero_post_forward(module, inputs, outputs): exit = True if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 - elif module._mode == "OFFLOAD": + elif module._mode != "OFFLOAD" and ((not module._is_first_layer) and module._pre_module[0]._mode == "OFFLOAD"): + for pre_module in module._pre_module: + if pre_module._mode == "OFFLOAD": # torch._C._autograd._pop_saved_tensors_default_hooks() - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config["calc_stream"]) - with torch.cuda.stream(config["offload_stream"]): - if not hasattr(module._offload_dict, "fp16_storage"): - module._offload_dict.make_cpu_storage() - module._offload_dict.d2h_memcpy() + with torch.cuda.stream(config["offload_stream"]): + if not hasattr(pre_module._offload_dict, "fp16_storage"): + pre_module._offload_dict.make_cpu_storage() + pre_module._offload_dict.d2h_memcpy() if exit: module._forward_block_ctx.exit(forward_flag) @@ -184,10 +183,7 @@ def zero_post_backward(module, grad_inputs, grad_outputs): module._pre_module.pop() if module._mode == "OFFLOAD": module._on_device = False - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config["calc_stream"]) - with torch.cuda.stream(config["offload_stream"]): - module._offload_dict.pop_all() + module._offload_dict.pop_all() if module._is_first_layer and module._ref_count == 1: module._backward_block_ctx.exit(backward_flag, True) module._ref_count = -1 diff --git a/example/layers/attention.py b/example/layers/attention.py index 32c36578..1435e499 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -85,8 +85,14 @@ def forward( mask_bias = torch.zeros_like(attention_mask, device="cuda", dtype=torch.float16) # 创建与mask形状相同的全零张量 mask_bias[mask == False] -= torch.inf mask_bias = mask_bias.unsqueeze(1) + # if hasattr(self, "_offload_hook"): + # pack, unpack = self._offload_hook + # torch._C._autograd._push_saved_tensors_default_hooks( + # pack, unpack + # ) score = FlashAttnFunc.apply(h_q, h_k, h_v, mask_bias, False, None) - + # if hasattr(self, "_offload_hook"): + # torch._C._autograd._pop_saved_tensors_default_hooks() score = score.view(batch_size, len_q, self.num_heads * self.dim_head) diff --git a/example/layers/flash_triton.py b/example/layers/flash_triton.py index bba3ecee..1d687378 100644 --- a/example/layers/flash_triton.py +++ b/example/layers/flash_triton.py @@ -808,15 +808,13 @@ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): o, lse, ctx.softmax_scale = _flash_attn_forward( q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale ) - ctx.output = {"output":o} - ctx.save_for_backward(q, k, v, lse, bias) + ctx.save_for_backward(q, k, v, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): - q, k, v, lse, bias = ctx.saved_tensors - o = ctx.output["output"] + q, k, v, o, lse, bias = ctx.saved_tensors assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet' # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. diff --git a/example/layers/linear.py b/example/layers/linear.py index 0afd2f40..a4d5381c 100644 --- a/example/layers/linear.py +++ b/example/layers/linear.py @@ -5,14 +5,12 @@ class CustomLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias=None): - ctx.input = {"input":x} - ctx.save_for_backward( weight, bias) + ctx.save_for_backward(x, weight, bias) return F.linear(x, weight, bias) @staticmethod def backward(ctx, grad_output): - x = ctx.input["input"] - weight, bias = ctx.saved_tensors + x, weight, bias = ctx.saved_tensors grad_x = grad_weight = grad_bias = None if x.requires_grad: grad_x = grad_output.matmul(weight) @@ -37,8 +35,15 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - return CustomLinear.apply(input, self.weight, self.bias) - + if hasattr(self, "_offload_hook"): + pack, unpack = self._offload_hook + torch._C._autograd._push_saved_tensors_default_hooks( + pack, unpack + ) + res = CustomLinear.apply(input, self.weight, self.bias) + if hasattr(self, "_offload_hook"): + torch._C._autograd._pop_saved_tensors_default_hooks() + return res def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None diff --git a/example/layers/transformer.py b/example/layers/transformer.py index 1e1cffe0..4b867e0a 100644 --- a/example/layers/transformer.py +++ b/example/layers/transformer.py @@ -20,8 +20,8 @@ def forward(self, hidden : torch.Tensor, # (batch, seq_len, dim_model) mask : torch.BoolTensor, # (batch, seq_len, dim_model) position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len) - ): - bmt.inspect.record_tensor(hidden, "hidden") + ): + # bmt.inspect.record_tensor(hidden, "hidden") x = self.ln_attn(hidden) x = self.attn(x, x, mask) hidden = hidden + x diff --git a/example/models/gpt.py b/example/models/gpt.py index 89b32fc3..2778d549 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -7,7 +7,7 @@ def __init__(self, num_layers : int, vocab_size : int, dim_model : int, dim_head : int, num_heads : int, dim_ff : int, max_distance : int, - bias : bool = True, dtype = None + bias : bool = True, dtype = None, offload = False, ) -> None: super().__init__() @@ -15,15 +15,19 @@ def __init__(self, self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) - ckpt_mask = [True for i in range(num_layers)] - offload_mask = [False for i in range(num_layers)] + if offload: + offload_mask = [True if i%4 == 0 else False for i in range(num_layers)] + ckpt_mask = [not offload_mask[i] for i in range(num_layers)] + else: + ckpt_mask = [ True for i in range(num_layers) ] + offload_mask = [ False for i in range(num_layers) ] self.transformers = bmt.TransformerBlockList([ bmt.CheckpointBlock( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype - ),use_checkpoint=True, + ),use_checkpoint=ckpt_mask[i],use_offload=offload_mask[i] ) - for _ in range(num_layers) + for i in range(num_layers) ]) self.layernorm = Layernorm(dim_model, dtype=dtype) @@ -44,6 +48,6 @@ def forward(self, out = self.layernorm(out) logits = self.word_emb(out, projection=True) - bmt.inspect.record_tensor(logits, "logits") + # bmt.inspect.record_tensor(logits, "logits") return logits \ No newline at end of file diff --git a/example/train.py b/example/train.py index 8b071445..f6c32104 100644 --- a/example/train.py +++ b/example/train.py @@ -8,17 +8,19 @@ def main(): seed=0, zero_level=3, ) - + offload = False + seq_len = True model = GPT( num_layers=48, - vocab_size=80000, + vocab_size=80000, dim_model=4096, dim_head=128, num_heads=32, dim_ff=10240, - max_distance=1024, + max_distance=seq_len, bias=False, - dtype=torch.half + dtype=torch.half, + offload=offload ) bmt.init_parameters(model) @@ -31,10 +33,7 @@ def main(): # data # generate dummy data for each rank torch.manual_seed(1234) - - batch_size = 2 - seq_len = 512 - + batch_size = 4 for i in range(bmt.world_size()): sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() @@ -50,7 +49,7 @@ def main(): if i == bmt.rank(): break - loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=100) optimizer = bmt.optim.AdamOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) @@ -66,36 +65,36 @@ def main(): # load data st = time.time() - with bmt.inspect.inspect_tensor() as inspector: - pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) - logits = model( - enc_input, - pos, - pos < enc_length[:, None] - ) - batch, seq_len, vocab_out_size = logits.size() + # with bmt.inspect.inspect_tensor() as inspector: + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + logits = model( + enc_input, + pos, + pos < enc_length[:, None] + ) + batch, seq_len, vocab_out_size = logits.size() - loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) - - global_loss = bmt.sum_loss(loss).item() + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + + global_loss = bmt.sum_loss(loss).item() - optim_manager.zero_grad() + optim_manager.zero_grad() - optim_manager.backward(loss) + optim_manager.backward(loss) # print inspected tensors in the forward & backward pass # print parameters of the model - if iteration % 100 == 0: - bmt.print_rank( - bmt.inspect.format_summary( - inspector.get_summary() - ) - ) - bmt.print_rank( - bmt.inspect.format_summary( - bmt.inspect.inspect_model(model, "*") - ) - ) + # if iteration % 100 == 0: + # bmt.print_rank( + # bmt.inspect.format_summary( + # inspector.get_summary() + # ) + # ) + # bmt.print_rank( + # bmt.inspect.format_summary( + # bmt.inspect.inspect_model(model, "*") + # ) + # ) optim_manager.step() From 88b5bd3c384e0ed7d2704505b18087d48254df46 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 17 Aug 2023 20:31:11 +0800 Subject: [PATCH 061/122] fix reentrant --- bmtrain/block_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index dc5748a9..693908ef 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -278,7 +278,7 @@ def forward(self, *args): arg_list = self.pre_hook(*args) if self.use_checkpoint: - out = checkpoint(self._module, *arg_list, use_reentrant=False) + out = checkpoint(self._module, *arg_list, use_reentrant=not self.all_input_no_grad) else: out = self._module(*arg_list) From fd4931170e9bff4f9282073a78edef932adc5ff7 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sun, 20 Aug 2023 14:40:19 +0800 Subject: [PATCH 062/122] refactor CheckpointBlock --- bmtrain/block_layer.py | 76 +++++++++++++++++++----------------------- bmtrain/hook_func.py | 21 ++++-------- bmtrain/init.py | 4 --- 3 files changed, 40 insertions(+), 61 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 693908ef..d920f4f9 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -7,7 +7,6 @@ from .synchronize import wait_loader from .parameter import DistributedParameter, OpAllGather from .checkpointing import ( - ScopedTensorInspectorContext, CheckpointBlockContext ) @@ -50,32 +49,6 @@ def _get_param_kw(param : DistributedParameter): group_name = "_g_" + param.group return type_name + grad_name + group_name -class BMTBlockContext: - def __init__(self): - self._pre_module = None - self._first = True - - def link_module(self, module): - if not self._first and module._ref_count == -1: - self._pre_module = module - module._ref_count = 1 - return - - if self._pre_module is None: - module._ref_count = 1 - module._is_first_layer = True - else: - if module._ref_count == 0: - module._is_first_layer = False - self._pre_module.set_next_module(module) - self._pre_module._is_last_layer = False - self._pre_module = module - self._first = False - - def clear(self): - self._pre_module = None - self._first = True - class CheckpointBlock(torch.nn.Module): """ A bmtrain block containing two memory-saving methods of ZeRO-2/3 and checkpoint. @@ -94,7 +67,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_context=None): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): super().__init__() self._module = inner_module self._inputs = None @@ -222,25 +195,35 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_co self.use_checkpoint = use_checkpoint self._is_first_layer = True self._is_last_layer = True - self._pre_module = [] - self._next_module = [] - self._ref_count = 0 + self._release_list = [True] + self._next_module = [] #save the next module of self + self._pre_module = [] #save the pre module of self + self._ref_count = 0 #incremental in forward and decreasing in backward self._mode = "BLOCK" #BLOCK or ZERO or PIPE self.return_hidden_states = False self.hidden_states = [] - self.block_context = block_context - if block_context is None: - self.block_context = config['block_context'][config['rank']] self.all_input_no_grad = False + self.all_param_no_grad = False - def set_next_module(self, module): - self._next_module.append(module) - module._pre_module.append(self) - module._ref_count += 1 + def set_pre_module(self, pre_module): + if pre_module is not None: + self._pre_module.append(pre_module) + pre_module._next_module.append(self) + + def pre_module(self): + return self._pre_module[self._ref_count-1] + + def next_module(self): + assert len(self._next_module) == self._ref_count, "{} != {}".format(len(self._next_module), self._ref_count) + return self._next_module[self._ref_count-1] + + def backward_release(self, flag): + if self._ref_count == 1: + self._backward_block_ctx.exit(flag, True) + config['load_stream'].record_event(config['load_event']) + self._ref_count -= 1 def pre_hook(self, *args): - if self._mode != "PIPE": - self.block_context.link_module(self) grad_tensors = [] grad_index = [] arg_list = list(args) @@ -255,9 +238,11 @@ def pre_hook(self, *args): arg_list[grad_index[i]] = pre_out[i] if self._mode != "PIPE" and len(grad_tensors) == 0: + self.all_param_no_grad = True for param in self._param_info: if param['parameter'].requires_grad: param['parameter'].register_hook(lambda grad: hook_func.zero_post_backward(self, grad, None)) + self.all_param_no_grad = False break self.all_input_no_grad = True else: @@ -537,16 +522,23 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) super().__init__() self._modules = {} + release_list = [] + pre_module = None for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) module._mode = "ZERO" - module._is_last_layer = True if i == len(modules) -1 else False - module._is_first_layer = True if i == 0 else False + module.set_pre_module(pre_module) + pre_module = module + self._is_first_layer = False + self._is_last_layer = False self._modules[str(i)] = module self.add_module(str(i), module) + + self._modules[str(0)]._is_first_layer = True + self._modules[str(len(modules)-1)]._is_last_layer = True self.num_hidden = num_hidden diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 78738d40..0af5239d 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -25,20 +25,16 @@ def zero_post_forward(module, inputs, outputs): if exit: module._forward_block_ctx.exit(forward_flag) + if module._mode != "PIPE": + module._ref_count += 1 def zero_pre_backward(module, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 if module._mode != "PIPE": module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) - if not module._is_last_layer and len(module._next_module) > 0 and module._next_module[-1]._backward_block_ctx is not None: - if module._next_module[-1]._ref_count == 1: - module._next_module[-1]._ref_count = 0 - module._next_module.pop()._backward_block_ctx.exit(backward_flag, True) - config['load_stream'].record_event(config['load_event']) - else: - module._next_module[-1]._ref_count -= 1 - + if not module._is_last_layer: + module.next_module().backward_release(backward_flag) else: if module._micro_idx == config['micros'] - 1: module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=True) @@ -47,15 +43,10 @@ def zero_pre_backward(module, grad_outputs): def zero_post_backward(module, grad_inputs, grad_outputs): backward_flag = 2 if config['zero_level'] == 2 else 0 if module._mode != "PIPE": - if module._is_first_layer and module._ref_count == 1: - module._backward_block_ctx.exit(backward_flag, True) - module._ref_count = -1 - config['load_stream'].record_event(config['load_event']) - if not module._is_first_layer and len(module._pre_module) > 0: - module._pre_module.pop() + if module._is_first_layer: + module.backward_release(backward_flag) else: if module._micro_idx == 0: - module._ref_count = -1 if module._is_first_layer else 0 module._backward_block_ctx.exit(backward_flag, True) config['load_stream'].record_event(config['load_event']) diff --git a/bmtrain/init.py b/bmtrain/init.py index 5772f963..bc48b78e 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -7,7 +7,6 @@ from .global_var import config from . import nccl from .synchronize import synchronize -from .block_layer import BMTBlockContext def init_distributed( init_method : str = "env://", @@ -74,9 +73,6 @@ def init_distributed( config["zero_level"] = zero_level config["topology"] = topology(config) config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] - config["block_context"] = [] - for i in range(world_size): - config["block_context"].append(BMTBlockContext()) cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) From 221bdc3687f9904f05ce77f8de7377c5834efd4d Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sun, 20 Aug 2023 15:15:35 +0800 Subject: [PATCH 063/122] refactor pipe --- bmtrain/block_layer.py | 4 +-- bmtrain/hook_func.py | 55 ++----------------------------------- bmtrain/pipe_layer.py | 62 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 59 insertions(+), 62 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index d920f4f9..8abf2859 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -577,7 +577,7 @@ def forward(self, *args, return_hidden_states = False): hidden_states = [] for i in range(len(self)): if return_hidden_states: - self._modules[str(i)].return_hidden_states = return_hidden_states + return_hidden_states.append(outputs) self._modules[str(i)].hidden_states = hidden_states outputs = self._modules[str(i)]._call_impl(*args) if not isinstance(outputs, tuple): @@ -591,6 +591,6 @@ def forward(self, *args, return_hidden_states = False): ] if return_hidden_states: - return outputs + tuple(hidden_states[:self.num_hidden]) + return outputs + tuple(hidden_states) else: return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 0af5239d..7e6211a3 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -25,7 +25,6 @@ def zero_post_forward(module, inputs, outputs): if exit: module._forward_block_ctx.exit(forward_flag) - if module._mode != "PIPE": module._ref_count += 1 def zero_pre_backward(module, grad_outputs): @@ -47,63 +46,20 @@ def zero_post_backward(module, grad_inputs, grad_outputs): module.backward_release(backward_flag) else: if module._micro_idx == 0: - module._backward_block_ctx.exit(backward_flag, True) - config['load_stream'].record_event(config['load_event']) - -class PipePreFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, stage_id): - pre_inputs = recv_activations(stage_id - 1, config['pipe_comm']) - pre_inputs.requires_grad_() - return pre_inputs - - @staticmethod - def backward(ctx, grads): - return grads, None - -def pipe_pre_forward(module, inputs): - if not module._is_first_stage: - if module._is_first_layer: - return (PipePreFunction.apply(inputs[0], module.stage_id), ) + inputs[1:] - -def pipe_post_forward(module, inputs, outputs): - if not module._is_last_stage: - if module._is_last_layer: - send_data = outputs[0] if isinstance(outputs, tuple) else outputs - send_activations(send_data.detach(), module.stage_id + 1, config['pipe_comm']) - -def pipe_pre_backward(module, grad_inputs): - if not module._is_last_stage: - if module._is_last_layer: - pre_grad_inputs = recv_activations(module.stage_id + 1, config['pipe_comm']) - return (pre_grad_inputs, ) + grad_inputs[1:] - -def pipe_post_backward(module, grad_inputs, grad_outputs): - if not module._is_first_stage: - if module._is_first_layer: - send_data = grad_inputs[0] if isinstance(grad_inputs, tuple) else grad_inputs - send_activations(send_data, module.stage_id - 1, config['pipe_comm']) + module.backward_release(backward_flag) + module._micro_idx -= 1 - module._micro_idx -= 1 class PreHookFunc(torch.autograd.Function): @staticmethod def forward(ctx, module, *x): ctx.module = module - if module._mode == "PIPE": - pipe_out = pipe_pre_forward(module, x) - x = pipe_out if pipe_out is not None else x - - if module.return_hidden_states: - module.hidden_states.append(x[0]) zero_pre_forward(module, x) return x @staticmethod def backward(ctx, *grads): zero_post_backward(ctx.module, grads, None) - if ctx.module._mode == "PIPE": - pipe_post_backward(ctx.module, grads, None) return None, *grads class PostHookFunc(torch.autograd.Function): @@ -111,16 +67,9 @@ class PostHookFunc(torch.autograd.Function): def forward(ctx, module, *out): ctx.module = module zero_post_forward(module, None, out) - if module._mode == "PIPE": - pipe_post_forward(module, None, out) return out @staticmethod def backward(ctx, *grads): zero_pre_backward(ctx.module, grads) - if ctx.module._mode == "PIPE": - pipe_grads = pipe_pre_backward(ctx.module, grads) - grads = pipe_grads[0] if pipe_grads is not None else grads - if not isinstance(grads, tuple): - return None, grads return None, *grads diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 01b7160f..913b727e 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -124,6 +124,44 @@ def backward(ctx, grads, grad_middle=None): else: return grad_list +class StagePreFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, stage_id): + ctx.stage_id = stage_id + ctx.is_first_stage = stage_id == 0 + ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + if not ctx.is_first_stage: + input = recv_activations(stage_id - 1, config['pipe_comm']) + input.requires_grad_() + return input + return input + + @staticmethod + def backward(ctx, grad_outputs): + if not ctx.is_first_stage: + send_data = grad_outputs[0] if isinstance(grad_outputs, tuple) else grad_outputs + send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) + return grad_outputs, None + +class StagePostFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, outputs, stage_id): + ctx.stage_id = stage_id + ctx.is_first_stage = stage_id == 0 + ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + if not ctx.is_last_stage: + send_data = outputs[0] if isinstance(outputs, tuple) else outputs + send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) + return outputs + + @staticmethod + def backward(ctx, grad_outputs): + if not ctx.is_last_stage: + pre_grad_inputs = recv_activations(ctx.stage_id + 1, config['pipe_comm']) + return pre_grad_inputs, None + return grad_outputs, None + + class PipelineTransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of CheckpointBlocks. @@ -168,12 +206,18 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self.layer_ids = self.get_range_by_stage_id(self.stage_id) + pre_module = None for i,layer_id in enumerate(self.layer_ids): - self._modules[str(layer_id)].layer_id = layer_id - self._modules[str(layer_id)]._is_first_stage = True if self.stage_id == 0 else False - self._modules[str(layer_id)]._is_last_stage = True if self.stage_id == self.stages-1 else False - self._modules[str(layer_id)]._is_first_layer = True if i == 0 else False - self._modules[str(layer_id)]._is_last_layer = True if i == len(self.layer_ids)-1 else False + module = self._modules[str(layer_id)] + module.set_pre_module(pre_module) + pre_module = module + + module._is_first_stage = True if self.stage_id == 0 else False + module._is_last_stage = True if self.stage_id == self.stages-1 else False + module._is_first_layer = False + module._is_last_layer = False + self._modules[str(self.layer_ids[0])]._is_first_layer = True + self._modules[str(self.layer_ids[-1])]._is_last_layer = True self.partition_modules(self.layer_ids) self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 @@ -205,12 +249,16 @@ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=Fa for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): micro_hidden_states = [] + + hidden_state = StagePreFunction.apply(hidden_state, self.stage_id) + for idx,layer_id in enumerate(self.layer_ids): self._modules[str(layer_id)]._micro_idx = micro_idx if return_hidden_states: - self._modules[str(layer_id)].return_hidden_states = return_hidden_states - self._modules[str(layer_id)].hidden_states = micro_hidden_states + micro_hidden_states.append(hidden_state) hidden_state = self._modules[str(layer_id)](hidden_state, *arg) + hidden_state = StagePostFunction.apply(hidden_state, self.stage_id) + outputs.append(hidden_state) if return_hidden_states: hidden_states.append(torch.stack(micro_hidden_states, dim=0)) From 9c634070a88de6b00da1b9402de8daac3ba4bb53 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sun, 20 Aug 2023 18:59:34 +0800 Subject: [PATCH 064/122] fix all input no grad --- bmtrain/__init__.py | 11 ++++---- bmtrain/block_layer.py | 13 +++++----- bmtrain/checkpointing.py | 8 ++---- bmtrain/hook_func.py | 36 +++++++++++++++++++++++++++ tests/test_middle_hidden.py | 4 --- tests/test_requires_grad_multi_gpu.py | 2 -- tests/test_training.py | 2 -- 7 files changed, 50 insertions(+), 26 deletions(-) diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index ad2a17da..7c7d6c2c 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -1,14 +1,10 @@ -from .utils import print_block, print_dict, print_rank, see_memory, load_nccl_pypi -try: - from . import nccl -except: - load_nccl_pypi() from .global_var import config, world_size, rank from .init import init_distributed from .parameter import DistributedParameter, ParameterInitializer from .layer import DistributedModule from .param_init import init_parameters, grouped_parameters +from .utils import print_block, print_dict, print_rank, see_memory from .synchronize import synchronize, sum_loss, wait_loader, gather_result from .block_layer import CheckpointBlock, TransformerBlockList from .wrapper import BMTrainModelWrapper @@ -16,6 +12,9 @@ from . import debug from .store import save, load +from . import benchmark +from . import optim +from . import inspect +from . import lr_scheduler from . import loss from . import distributed -from . import nn diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 8abf2859..da8984eb 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -211,6 +211,7 @@ def set_pre_module(self, pre_module): pre_module._next_module.append(self) def pre_module(self): + assert len(self._next_module) == self._ref_count, "{} != {}".format(len(self._next_module), self._ref_count) return self._pre_module[self._ref_count-1] def next_module(self): @@ -241,7 +242,6 @@ def pre_hook(self, *args): self.all_param_no_grad = True for param in self._param_info: if param['parameter'].requires_grad: - param['parameter'].register_hook(lambda grad: hook_func.zero_post_backward(self, grad, None)) self.all_param_no_grad = False break self.all_input_no_grad = True @@ -254,14 +254,16 @@ def post_hook(self, out): post_out = hook_func.PostHookFunc.apply(self, *tuple_out) if isinstance(out, torch.Tensor) and isinstance(post_out, tuple): return post_out[0] - - if isinstance(post_out, list): - return tuple(post_out) + post_out = tuple(post_out) return post_out def forward(self, *args): arg_list = self.pre_hook(*args) + if self.all_input_no_grad: + placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) + return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) + if self.use_checkpoint: out = checkpoint(self._module, *arg_list, use_reentrant=not self.all_input_no_grad) else: @@ -577,8 +579,7 @@ def forward(self, *args, return_hidden_states = False): hidden_states = [] for i in range(len(self)): if return_hidden_states: - return_hidden_states.append(outputs) - self._modules[str(i)].hidden_states = hidden_states + hidden_states.append(outputs) outputs = self._modules[str(i)]._call_impl(*args) if not isinstance(outputs, tuple): outputs = (outputs, ) diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index 66249cd3..008bab5f 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -4,7 +4,7 @@ from . import debug from . import nccl from .global_var import config -from .synchronize import wait_loader +from .synchronize import wait_loader, synchronize class ScopedDebugTensorList: def __init__(self) -> None: @@ -169,11 +169,7 @@ def exit(self, flag=0, backward=False): end = param["end"] param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: - if config['world_size'] > 1 and not self.block.all_input_no_grad: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) - if config['world_size'] == 1 and self.block.all_input_no_grad: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) - param["parameter"].grad.data = param['parameter'].grad.data.view(param['shape']) + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) if flag == 1: for i in self._param_buffer: self.ctx_dict[i] = self._param_buffer[i] diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 7e6211a3..a1bb4cf8 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -14,11 +14,15 @@ def zero_pre_forward(module, inputs): forward_flag = 1 if zero_level == 2 else 0 if zero_level == 2 and module._ref_count > 1: forward_flag = 2 # repeating forward in same layer + if module.all_param_no_grad: #only forward + forward_flag = 0 module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=pipe) module._forward_block_ctx.enter(forward_flag) def zero_post_forward(module, inputs, outputs): forward_flag = 1 if config['zero_level'] == 2 else 0 + if module.all_param_no_grad: + forward_flag = 0 exit = True if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 @@ -49,6 +53,38 @@ def zero_post_backward(module, grad_inputs, grad_outputs): module.backward_release(backward_flag) module._micro_idx -= 1 +class OneStepNoGradFunc(torch.autograd.Function): + """ + requires_grad = False for all inputs + """ + @staticmethod + def forward(ctx, module, placeholder, *x): + ctx.x = x + ctx.module = module + ctx.rng_state = torch.cuda.get_rng_state() + + with torch.no_grad(): + out = module._module(*x) + zero_post_forward(module, None, out) + if not isinstance(out, torch.Tensor): + return tuple(out) + return out + + @staticmethod + def backward(ctx, grads): + zero_pre_backward(ctx.module, grads) + with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): + torch.cuda.set_rng_state(ctx.rng_state) + x = ctx.x + with torch.enable_grad(): + out = ctx.module._module(*x) + torch.autograd.backward(out, grads) + zero_post_backward(ctx.module, grads, None) + grads = [] + for _ in x: + grads.append(None) + return None, None, *grads + class PreHookFunc(torch.autograd.Function): @staticmethod diff --git a/tests/test_middle_hidden.py b/tests/test_middle_hidden.py index 86b9e552..cd50351e 100644 --- a/tests/test_middle_hidden.py +++ b/tests/test_middle_hidden.py @@ -3,7 +3,6 @@ import bmtrain as bmt import random import torch -from bmtrain import config from bmtrain.block_layer import CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F @@ -146,7 +145,6 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ret += bmt.inspect.format_summary( bmt.inspect.inspect_model(m, '*') ) - config['block_context'][config['rank']].clear() if only_middle: logits, hidden_states = m(inp, return_hidden_states=True) loss = sum([ @@ -158,7 +156,6 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ret += bmt.inspect.format_summary( bmt.inspect.inspect_model(m, '*') ) - config['block_context'][config['rank']].clear() if mix_test: logits, hidden_states = m(inp, return_hidden_states=True) loss = sum([ @@ -170,7 +167,6 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ret += bmt.inspect.format_summary( bmt.inspect.inspect_model(m, '*') ) - config['block_context'][config['rank']].clear() return ret + "\n" # replace for matching None grad with zero_grad def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256): diff --git a/tests/test_requires_grad_multi_gpu.py b/tests/test_requires_grad_multi_gpu.py index eb5de5e4..2b2ed95e 100644 --- a/tests/test_requires_grad_multi_gpu.py +++ b/tests/test_requires_grad_multi_gpu.py @@ -2,7 +2,6 @@ import bmtrain as bmt import torch -from bmtrain import config from bmtrain.block_layer import CheckpointBlockContext, CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList from typing import List @@ -34,7 +33,6 @@ def run(m, a, b): logits = m(inp) loss = logits.sum() loss.backward() - config['block_context'][config['rank']].clear() sm = bmt.inspect.format_summary( bmt.inspect.inspect_model(m, '*') diff --git a/tests/test_training.py b/tests/test_training.py index b58701d4..4fd8f934 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -6,7 +6,6 @@ import math import torch.nn.functional as F import bmtrain as bmt -from bmtrain.global_var import config import os class Attention(torch.nn.Module): @@ -396,7 +395,6 @@ def pipe_model(): def add_to_check_list(m, l, o): key, value = train((m, models[m]), (l, loss_funcs[l]), (o, optimizers[o])) ret[key] = value - config['block_context'][config['rank']].clear() if test_fp16: kwargs["dtype"] = torch.half From f72fcfcfd906d127eec887b3079bc41fb0c67300 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sun, 20 Aug 2023 20:31:50 +0800 Subject: [PATCH 065/122] fix hiddenstate --- bmtrain/block_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index da8984eb..b75d6b92 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -579,7 +579,7 @@ def forward(self, *args, return_hidden_states = False): hidden_states = [] for i in range(len(self)): if return_hidden_states: - hidden_states.append(outputs) + hidden_states.append(args[0]) outputs = self._modules[str(i)]._call_impl(*args) if not isinstance(outputs, tuple): outputs = (outputs, ) From ebdf519d20be84ab431d645ca1fe990019d704ec Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 10:25:21 +0800 Subject: [PATCH 066/122] fix test --- tests/test_middle_hidden.py | 13 +++++++------ tests/test_optim.py | 5 +++-- tests/test_optim_state.py | 9 +++++---- tests/test_other_hidden.py | 13 +++++++------ tests/test_requires_grad.py | 5 +++-- tests/test_requires_grad_multi_gpu.py | 5 +++-- tests/test_training.py | 5 +++-- 7 files changed, 31 insertions(+), 24 deletions(-) diff --git a/tests/test_middle_hidden.py b/tests/test_middle_hidden.py index cd50351e..688cdfe5 100644 --- a/tests/test_middle_hidden.py +++ b/tests/test_middle_hidden.py @@ -6,6 +6,7 @@ from bmtrain.block_layer import CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -142,8 +143,8 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid loss = (logits * last_weight).sum() loss.backward() ret += f"========================only last========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if only_middle: logits, hidden_states = m(inp, return_hidden_states=True) @@ -153,8 +154,8 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ]) loss.backward() ret += f"========================only middle========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if mix_test: logits, hidden_states = m(inp, return_hidden_states=True) @@ -164,8 +165,8 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ]) + (logits * last_weight).sum() loss.backward() ret += f"========================mix========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) return ret + "\n" # replace for matching None grad with zero_grad diff --git a/tests/test_optim.py b/tests/test_optim.py index 81356ede..fdb64521 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,5 +1,6 @@ import torch import bmtrain as bmt +from bmtrain import optim class TestModule(torch.nn.Module): def __init__(self): @@ -29,8 +30,8 @@ def main(): model2 = model2.cuda() model3 = model3.cuda() - opt1 = bmt.optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) - opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) + opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) + opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3) for _ in range(100): diff --git a/tests/test_optim_state.py b/tests/test_optim_state.py index df697f49..cef06734 100644 --- a/tests/test_optim_state.py +++ b/tests/test_optim_state.py @@ -2,6 +2,7 @@ import bmtrain as bmt import os from copy import deepcopy +from bmtrain import optim class TestSubModule(bmt.DistributedModule): def __init__(self): @@ -67,10 +68,10 @@ def main(): bmt.load(model2, f"test_optim_state_model1.pt") bmt.load(model3, f"test_optim_state_model1.pt") - opt1 = bmt.optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) - opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) + opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) + opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3) - optim_manager = bmt.optim.OptimManager(loss_scale=256) + optim_manager = optim.OptimManager(loss_scale=256) optim_manager.add_optimizer(opt1) optim_manager.add_optimizer(opt2) optim_manager.add_optimizer(opt3) @@ -121,4 +122,4 @@ def main(): if __name__ == "__main__": bmt.init_distributed() - main() \ No newline at end of file + main() diff --git a/tests/test_other_hidden.py b/tests/test_other_hidden.py index d1e317ad..1f6c8c65 100644 --- a/tests/test_other_hidden.py +++ b/tests/test_other_hidden.py @@ -7,6 +7,7 @@ from bmtrain.block_layer import CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -142,22 +143,22 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_pre=False, only_post loss = (pre.weight * last_weight).sum() loss.backward() ret += f"========================only last========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if only_post: loss = (post.weight * last_weight).sum() loss.backward() ret += f"========================only middle========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if mix_test: loss = (pre.weight * last_weight).sum() + (post.weight * last_weight).sum() loss.backward() ret += f"========================mix========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) return ret + "\n" # replace for matching None grad with zero_grad diff --git a/tests/test_requires_grad.py b/tests/test_requires_grad.py index 104b8125..943275c3 100644 --- a/tests/test_requires_grad.py +++ b/tests/test_requires_grad.py @@ -7,6 +7,7 @@ from bmtrain.pipe_layer import PipelineTransformerBlockList from typing import List import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -37,8 +38,8 @@ def run(m, a, b): loss = logits.sum() loss.backward() bmt.synchronize() - sm = bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + sm = inspect.format_summary( + inspect.inspect_model(m, '*') ) assert_eq(bias.requires_grad, False) return a.weight.grad is None, a.bias.grad is None, sm diff --git a/tests/test_requires_grad_multi_gpu.py b/tests/test_requires_grad_multi_gpu.py index 2b2ed95e..4a2670ae 100644 --- a/tests/test_requires_grad_multi_gpu.py +++ b/tests/test_requires_grad_multi_gpu.py @@ -6,6 +6,7 @@ from bmtrain.pipe_layer import PipelineTransformerBlockList from typing import List import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -34,8 +35,8 @@ def run(m, a, b): loss = logits.sum() loss.backward() - sm = bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + sm = inspect.format_summary( + inspect.inspect_model(m, '*') ) return sm diff --git a/tests/test_training.py b/tests/test_training.py index 4fd8f934..f6432101 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -7,6 +7,7 @@ import torch.nn.functional as F import bmtrain as bmt import os +from bmtrain import inspect class Attention(torch.nn.Module): def __init__(self, @@ -249,7 +250,7 @@ def sub_train_torch(model, loss_func_cls, optimizer_cls): )) logs.append(global_loss) - summary = bmt.inspect.inspect_model(model, "*") + summary = inspect.inspect_model(model, "*") return logs, summary def sub_train(model, loss_func_cls, optimizer_cls): @@ -310,7 +311,7 @@ def sub_train(model, loss_func_cls, optimizer_cls): )) logs.append(global_loss) - summary = bmt.inspect.inspect_model(model, "*") + summary = inspect.inspect_model(model, "*") return logs, summary def train(model, loss_func, optimizer): From 780ca20755421c98c323b54d285abdca45d50697 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 10:49:09 +0800 Subject: [PATCH 067/122] fix --- bmtrain/__init__.py | 12 ++++++++---- bmtrain/block_layer.py | 7 +++---- example/train.py | 14 ++++++++------ tests/test_all.py | 1 + tests/test_inspector_hidden.py | 29 +++++++++++++++-------------- 5 files changed, 35 insertions(+), 28 deletions(-) diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index 7c7d6c2c..ae243e65 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -1,10 +1,14 @@ +from .utils import print_block, print_dict, print_rank, see_memory, load_nccl_pypi +try: + from . import nccl +except: + load_nccl_pypi() from .global_var import config, world_size, rank from .init import init_distributed from .parameter import DistributedParameter, ParameterInitializer from .layer import DistributedModule from .param_init import init_parameters, grouped_parameters -from .utils import print_block, print_dict, print_rank, see_memory from .synchronize import synchronize, sum_loss, wait_loader, gather_result from .block_layer import CheckpointBlock, TransformerBlockList from .wrapper import BMTrainModelWrapper @@ -12,9 +16,9 @@ from . import debug from .store import save, load -from . import benchmark +from . import loss +from . import distributed +from . import nn from . import optim from . import inspect from . import lr_scheduler -from . import loss -from . import distributed diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index b75d6b92..01de1cce 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -200,8 +200,6 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): self._pre_module = [] #save the pre module of self self._ref_count = 0 #incremental in forward and decreasing in backward self._mode = "BLOCK" #BLOCK or ZERO or PIPE - self.return_hidden_states = False - self.hidden_states = [] self.all_input_no_grad = False self.all_param_no_grad = False @@ -260,7 +258,7 @@ def post_hook(self, out): def forward(self, *args): arg_list = self.pre_hook(*args) - if self.all_input_no_grad: + if self.all_input_no_grad and not self.all_param_no_grad: placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) @@ -579,7 +577,8 @@ def forward(self, *args, return_hidden_states = False): hidden_states = [] for i in range(len(self)): if return_hidden_states: - hidden_states.append(args[0]) + for hidden_state in args[:self.num_hidden]: + hidden_states.append(hidden_state) outputs = self._modules[str(i)]._call_impl(*args) if not isinstance(outputs, tuple): outputs = (outputs, ) diff --git a/example/train.py b/example/train.py index 7bc92400..1a744e20 100644 --- a/example/train.py +++ b/example/train.py @@ -2,6 +2,8 @@ import bmtrain as bmt from models import GPT import time +from bmtrain import optim +from bmtrain import inspect def main(): bmt.init_distributed( @@ -51,10 +53,10 @@ def main(): break loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) - optimizer = bmt.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) - optim_manager = bmt.optim.OptimManager(loss_scale=2**20) + optim_manager = optim.OptimManager(loss_scale=2**20) optim_manager.add_optimizer(optimizer, lr_scheduler) bmt.synchronize() @@ -66,7 +68,7 @@ def main(): # load data st = time.time() - with bmt.inspect.inspect_tensor() as inspector: + with inspect.inspect_tensor() as inspector: pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) logits = model( enc_input, @@ -87,13 +89,13 @@ def main(): # print parameters of the model if iteration % 100 == 0: bmt.print_rank( - bmt.inspect.format_summary( + inspect.format_summary( inspector.get_summary() ) ) bmt.print_rank( - bmt.inspect.format_summary( - bmt.inspect.inspect_model(model, "*") + inspect.format_summary( + inspect.inspect_model(model, "*") ) ) diff --git a/tests/test_all.py b/tests/test_all.py index b614d3eb..6682aa93 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -24,6 +24,7 @@ ("send_recv", 4), ("nccl_backward", 4), + ("no_grad", 1), ("training", 4), ]) diff --git a/tests/test_inspector_hidden.py b/tests/test_inspector_hidden.py index 731884ad..c39de5fb 100644 --- a/tests/test_inspector_hidden.py +++ b/tests/test_inspector_hidden.py @@ -7,6 +7,7 @@ from bmtrain.block_layer import CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -48,7 +49,7 @@ def __init__(self, dim : int): def forward(self, x): x = self.m1(x) - bmt.inspect.record_tensor(x, "hidden") + inspect.record_tensor(x, "hidden") x = self.m2(x) return x @@ -160,10 +161,10 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len): bmt.init_parameters(m) m = cls(pre, [m for m in ms], post) ret = "" - with bmt.inspect.inspect_tensor() as inspector: + with inspect.inspect_tensor() as inspector: logits = m(inp) - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" @@ -171,32 +172,32 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len): for i in range(len(ms)//2): loss = loss + (inspector.summary[i]['tensor'] * middle_weight[i]).sum() - with bmt.inspect.inspect_tensor(): + with inspect.inspect_tensor(): loss.backward() - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) + "\n" - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" - with bmt.inspect.inspect_tensor() as inspector: + with inspect.inspect_tensor() as inspector: logits = m(inp) - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" loss = (logits * last_weight).sum() - with bmt.inspect.inspect_tensor(): + with inspect.inspect_tensor(): loss.backward() - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) + "\n" - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" @@ -237,4 +238,4 @@ def test_main(): if __name__ == "__main__": bmt.init_distributed(pipe_size=2) - test_main() \ No newline at end of file + test_main() From 6df85e712f0ca2dec3b0facc65b6c0c8e36947ca Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 11:13:42 +0800 Subject: [PATCH 068/122] remove unused import --- bmtrain/block_layer.py | 4 ---- bmtrain/checkpointing.py | 4 +--- bmtrain/hook_func.py | 1 - bmtrain/pipe_layer.py | 1 - 4 files changed, 1 insertion(+), 9 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 01de1cce..6723ce1a 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -4,7 +4,6 @@ from .global_var import config import torch from . import nccl -from .synchronize import wait_loader from .parameter import DistributedParameter, OpAllGather from .checkpointing import ( CheckpointBlockContext @@ -74,12 +73,10 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): self._layer_dict = {} self._forward_block_ctx = None self._backward_block_ctx = None - self._forward_enter_count = 0 # build large parameter&grad here self._param_info = [] self._storage_params : Dict[str, torch.nn.Parameter] = {} self._storage_info = {} - self._ready = False # sort parameters by name ordered_parameters = list(self._module.named_parameters()) @@ -522,7 +519,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) super().__init__() self._modules = {} - release_list = [] pre_module = None for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index 008bab5f..b2c9ec07 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -1,10 +1,8 @@ import torch -from typing import Callable, TypeVar -from functools import wraps from . import debug from . import nccl from .global_var import config -from .synchronize import wait_loader, synchronize +from .synchronize import wait_loader class ScopedDebugTensorList: def __init__(self) -> None: diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index a1bb4cf8..6a56300e 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -1,7 +1,6 @@ import torch from .global_var import config from .checkpointing import CheckpointBlockContext -from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations def zero_pre_forward(module, inputs): enter = True diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 913b727e..0a34ac46 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -9,7 +9,6 @@ from .global_var import config from . import nccl from .checkpointing import ( - ScopedTensorInspectorContext, CheckpointBlockContext ) from . import debug From bb482d64129563defaa61ff8266caa945514a6f8 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 11:19:07 +0800 Subject: [PATCH 069/122] fix pre_module --- bmtrain/block_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 6723ce1a..b1ceb742 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -206,7 +206,7 @@ def set_pre_module(self, pre_module): pre_module._next_module.append(self) def pre_module(self): - assert len(self._next_module) == self._ref_count, "{} != {}".format(len(self._next_module), self._ref_count) + assert len(self._pre_module) == self._ref_count, "{} != {}".format(len(self._pre_module), self._ref_count) return self._pre_module[self._ref_count-1] def next_module(self): From 1010d26efd7b5bd8d8e2046ce55d0cfc5f37852b Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 11:26:32 +0800 Subject: [PATCH 070/122] recovery some code --- bmtrain/block_layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index b1ceb742..8e52c68a 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -77,6 +77,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): self._param_info = [] self._storage_params : Dict[str, torch.nn.Parameter] = {} self._storage_info = {} + self._ready = False # sort parameters by name ordered_parameters = list(self._module.named_parameters()) From b580530b56a413049c26ca8ff33453633f91aac0 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 12:31:18 +0800 Subject: [PATCH 071/122] add test_no_grad.py --- tests/test_no_grad.py | 46 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/test_no_grad.py diff --git a/tests/test_no_grad.py b/tests/test_no_grad.py new file mode 100644 index 00000000..3629921b --- /dev/null +++ b/tests/test_no_grad.py @@ -0,0 +1,46 @@ +import torch +import bmtrain as bmt + +class Layer(torch.nn.Module): + def __init__(self): + super(Layer, self).__init__() + self.linear = bmt.nn.Linear(32, 32) + self.count = 0 + def forward(self, x): + self.count += 1 + return self.linear(x) + +def test_no_grad(): + x = torch.randn(32, 32, device='cuda') + + layer1 = bmt.CheckpointBlock(Layer()) + layer2 = bmt.CheckpointBlock(Layer()) + layer1.linear.weight.requires_grad_(False) + layer1.linear.bias.requires_grad_(False) + y = layer1(x) + assert y.requires_grad == False + y = layer2(y) + y.sum().backward() + assert layer1.count == 1 + assert layer2.count == 2 + +def test_all_input_no_grad(): + linear1 = bmt.nn.Linear(32, 32) + linear2 = bmt.nn.Linear(32, 32) + + x = torch.randn(32,32, device='cuda') + + linear1 = bmt.CheckpointBlock(linear1) + linear2 = bmt.CheckpointBlock(linear2) + y = linear1(x) + y = linear2(y) + y.sum().backward() + assert linear1.weight.grad is not None + assert linear1.bias.grad is not None + assert x.grad is None + +if __name__ == '__main__': + bmt.init_distributed() + + test_no_grad() + test_all_input_no_grad() From 767a875ffc4a7b179cfc86aeef820a71de9e8196 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 12:52:28 +0800 Subject: [PATCH 072/122] test unroll block list --- tests/test_training.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/test_training.py b/tests/test_training.py index f6432101..1d6481c9 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -151,6 +151,7 @@ def __init__(self, ) for _ in range(num_layers) ]) + self.run_unroll = False self.layernorm = torch.nn.LayerNorm(dim_model, dtype=dtype) @@ -166,7 +167,7 @@ def forward(self, input_emb = self.pos_emb(pos) + self.word_emb(input) out = input_emb - if isinstance(self.transformers, torch.nn.ModuleList): + if isinstance(self.transformers, torch.nn.ModuleList) or self.run_unroll: for layer in self.transformers: out = layer(out, mask_2d, None) else: @@ -376,11 +377,20 @@ def pipe_model(): bmt.load(pipe_model, ckpt_path) return model + def unroll_list_model(): + model = GPT(**kwargs) + list_model = bmt.BMTrainModelWrapper(model) + list_model.transformers = bmt.TransformerBlockList([m for m in list_model.transformers]) + bmt.load(list_model, ckpt_path) + model.run_unroll = True + return model + models = { "torch": torch_model, "wrapper": wrap_model, "blocklist": list_model, "pipelist": pipe_model, + "unroll_blocklist": unroll_list_model, } loss_funcs = { "bmt_entropy": bmt.loss.FusedCrossEntropy, @@ -406,6 +416,7 @@ def add_to_check_list(m, l, o): add_to_check_list("pipelist", "bmt_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "bmt_entropy", "bmt_adam_offload") + add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") if bmt.rank() == 0: os.remove(ckpt_path) check(ret) @@ -419,6 +430,7 @@ def add_to_check_list(m, l, o): add_to_check_list("pipelist", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam_offload") add_to_check_list("blocklist", "torch_entropy", "torch_adam") + add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") if bmt.rank() == 0: os.remove(ckpt_path) check(ret) From d19a627347d4874f75ecc0f83b621802a6e822b3 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 13:57:48 +0800 Subject: [PATCH 073/122] fix test_fp32 --- tests/test_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_training.py b/tests/test_training.py index 1d6481c9..5e385ae3 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -430,7 +430,7 @@ def add_to_check_list(m, l, o): add_to_check_list("pipelist", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam_offload") add_to_check_list("blocklist", "torch_entropy", "torch_adam") - add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") + add_to_check_list("unroll_blocklist", "torch_entropy", "bmt_adam") if bmt.rank() == 0: os.remove(ckpt_path) check(ret) @@ -454,4 +454,4 @@ def check_param(info1, info2): if __name__ == '__main__': bmt.init_distributed(pipe_size=2) - test_main(test_fp16=True, test_fp32=True) + test_main(test_fp16=False, test_fp32=True) From bf986a73ae37ee046a441aef244cd1fe76f2aed7 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 14:49:20 +0800 Subject: [PATCH 074/122] cross_entropy support fp32 --- bmtrain/loss/cross_entropy.py | 9 +++++++++ tests/test_training.py | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 160ef421..31223640 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -185,6 +185,15 @@ def __init__(self, self.inplace = inplace def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if input.dtype == torch.float32: + return torch.nn.functional.cross_entropy( + input, + target.long(), + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing) + if self.inplace: ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor else: diff --git a/tests/test_training.py b/tests/test_training.py index 5e385ae3..1d6481c9 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -430,7 +430,7 @@ def add_to_check_list(m, l, o): add_to_check_list("pipelist", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam_offload") add_to_check_list("blocklist", "torch_entropy", "torch_adam") - add_to_check_list("unroll_blocklist", "torch_entropy", "bmt_adam") + add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") if bmt.rank() == 0: os.remove(ckpt_path) check(ret) @@ -454,4 +454,4 @@ def check_param(info1, info2): if __name__ == '__main__': bmt.init_distributed(pipe_size=2) - test_main(test_fp16=False, test_fp32=True) + test_main(test_fp16=True, test_fp32=True) From b28cb3fa9186e096ae8b59909dccd35924b54eed Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 21 Aug 2023 15:13:16 +0800 Subject: [PATCH 075/122] offload context --- bmtrain/hook_func.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 10e7d50b..ec4041e9 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -3,6 +3,8 @@ from .checkpointing import CheckpointBlockContext from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations from collections import deque,OrderedDict +from contextlib import contextmanager + class Offload_Dict: def __init__(self): @@ -106,6 +108,17 @@ def unpack_hook(packed): return tensor return pack_hook, unpack_hook +@contextmanager +def offload_context(module): + if hasattr(module, "_offload_hook"): + pack_hook, unpack_hook = module._offload_hook + torch._C._autograd._push_saved_tensors_default_hooks( + pack_hook, unpack_hook + ) + yield + if hasattr(module, "_offload_hook"): + torch._C._autograd._pop_saved_tensors_default_hooks() + def zero_pre_forward(module, inputs): enter = True pipe = False From f94afa2f221b755ac7a4e63dc5a377ae5007203f Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 21 Aug 2023 18:58:24 +0800 Subject: [PATCH 076/122] cpm live for offloading test --- bmtrain/block_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 9d10fe6d..c4db4e61 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -68,7 +68,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offload=False): super().__init__() self._module = inner_module self._inputs = None From bc65a2eeacf4b5c24596a56522090d7de2fe3322 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 22 Aug 2023 13:27:11 +0800 Subject: [PATCH 077/122] Better hack for offload --- bmtrain/hook_func.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 8c0f0169..ae8a025a 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -119,6 +119,17 @@ def offload_context(module): if hasattr(module, "_offload_hook"): torch._C._autograd._pop_saved_tensors_default_hooks() +def offload_pre_hook(module, input): + if hasattr(module, "_offload_hook"): + pack_hook, unpack_hook = module._offload_hook + torch._C._autograd._push_saved_tensors_default_hooks( + pack_hook, unpack_hook + ) + +def offload_post_hook(module, input, output): + if hasattr(module, "_offload_hook"): + torch._C._autograd._pop_saved_tensors_default_hooks() + def zero_pre_forward(module, inputs): enter = True pipe = False @@ -126,9 +137,21 @@ def zero_pre_forward(module, inputs): if not hasattr(module, "_offload_dict"): module._offload_dict = Offload_Dict() pack_hook, unpack_hook = offload_wrapper(module._offload_dict) - for n, m in module.named_modules(): - if m.__class__.__name__ == "Linear": - m._offload_hook = (pack_hook, unpack_hook) + if module.offload_level == 1: + match_module = ["Linear"] + for n, m in module.named_modules(): + if m.__class__.__name__ in match_module and not hasattr(m, "_offload_hook"): + print("register hook") + m._offload_hook = (pack_hook, unpack_hook) + m.register_forward_pre_hook(offload_pre_hook) + m.register_forward_hook(offload_post_hook) + elif module.offload_level == 2: + if not hasattr(module, "_offload_hook"): + print("register HOOK for CheckpointBlock") + module._offload_hook = (pack_hook, unpack_hook) + module.register_forward_pre_hook(offload_pre_hook) + module.register_forward_hook(offload_post_hook) + # torch._C._autograd._push_saved_tensors_default_hooks( # pack_hook, unpack_hook # ) From 76f816213a356c73b7c60607ecaab823b150a3ae Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 22 Aug 2023 14:19:17 +0800 Subject: [PATCH 078/122] fix OFFLOAD _mode bug --- bmtrain/block_layer.py | 5 +++-- bmtrain/hook_func.py | 2 -- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index c4db4e61..7f96b7e7 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -68,7 +68,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offload=False): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offload=False, offload_level=0): super().__init__() self._module = inner_module self._inputs = None @@ -203,6 +203,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offl if use_offload: self._mode = "OFFLOAD" self._on_device = False + self.offload_level = offload_level self.all_input_no_grad = False self.all_param_no_grad = False @@ -531,7 +532,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - module._mode = "ZERO" + module._mode = "ZERO" if module._mode == "BLOCK" else module._mode module.set_pre_module(pre_module) pre_module = module self._is_first_layer = False diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index ae8a025a..19176d2e 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -141,13 +141,11 @@ def zero_pre_forward(module, inputs): match_module = ["Linear"] for n, m in module.named_modules(): if m.__class__.__name__ in match_module and not hasattr(m, "_offload_hook"): - print("register hook") m._offload_hook = (pack_hook, unpack_hook) m.register_forward_pre_hook(offload_pre_hook) m.register_forward_hook(offload_post_hook) elif module.offload_level == 2: if not hasattr(module, "_offload_hook"): - print("register HOOK for CheckpointBlock") module._offload_hook = (pack_hook, unpack_hook) module.register_forward_pre_hook(offload_pre_hook) module.register_forward_hook(offload_post_hook) From 0d4ea37fabd2f539964eb0238fe4e3086ca97a6e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 22 Aug 2023 15:35:52 +0800 Subject: [PATCH 079/122] fix is_first_layer --- bmtrain/block_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 8e52c68a..18438b8c 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -528,8 +528,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module._mode = "ZERO" module.set_pre_module(pre_module) pre_module = module - self._is_first_layer = False - self._is_last_layer = False + module._is_first_layer = False + module._is_last_layer = False self._modules[str(i)] = module self.add_module(str(i), module) From 6ffcf5cbf79e0c0acad097e2fcbfce286ba5afdd Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Wed, 23 Aug 2023 16:17:23 +0800 Subject: [PATCH 080/122] Fix async bug --- .gitignore | 4 ++-- bmtrain/block_layer.py | 10 ++++---- bmtrain/hook_func.py | 53 +++++++++++++++++++++++++++++++----------- 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index 0222862f..9c9a3f28 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class - +*nsys-rep # C extensions *.so @@ -150,4 +150,4 @@ log .vscode !bmtrain/dist -tests/test_log.txt \ No newline at end of file +tests/test_log.txt diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 7f96b7e7..e0a1bfdd 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -82,7 +82,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offl self._ready = False # sort parameters by nam_next_modulee ordered_parameters = list(self._module.named_parameters()) - assert not (use_checkpoint and use_offload) + assert not (use_checkpoint and use_offload), "It does not make sense to use offload and checkpointing at the same time" # calc total number of parameters for name, param in ordered_parameters: if not isinstance(param, DistributedParameter): @@ -531,14 +531,14 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - module._mode = "ZERO" if module._mode == "BLOCK" else module._mode module.set_pre_module(pre_module) pre_module = module - self._is_first_layer = False - self._is_last_layer = False + module._is_first_layer = False + module._is_last_layer = False self._modules[str(i)] = module + module._idx = i self.add_module(str(i), module) self._modules[str(0)]._is_first_layer = True @@ -567,7 +567,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self.save_list = save_list else: self.save_list = [(i, i) for i in range(len(self))] - + def __len__(self) -> int: return len(self._modules) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 19176d2e..0d841e74 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -19,6 +19,11 @@ def add(self, tensor): self._offload_dict[tensor_id]["shape"] = tensor.shape self._device = "cuda" return tensor_id + + def get_total(self): + fp16_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + return fp16_total,fp32_total def make_cpu_storage(self): fp16_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) @@ -94,6 +99,8 @@ def offload_wrapper(offload_dict): def pack_hook(tensor): if isinstance(tensor, torch.nn.Parameter): return (tensor,) + elif tensor.dtype not in [torch.float16]: + return (tensor,) else: key = offload_dict.add(tensor) return (tensor.device, key) @@ -131,6 +138,13 @@ def offload_post_hook(module, input, output): torch._C._autograd._pop_saved_tensors_default_hooks() def zero_pre_forward(module, inputs): + def find_pre_module_helper(m): + if m._mode == "OFFLOAD": + return m + elif m._is_first_layer: + return None + else: + return find_pre_module_helper(m._pre_module[0]) enter = True pipe = False if module._mode == "OFFLOAD": @@ -138,18 +152,33 @@ def zero_pre_forward(module, inputs): module._offload_dict = Offload_Dict() pack_hook, unpack_hook = offload_wrapper(module._offload_dict) if module.offload_level == 1: - match_module = ["Linear"] for n, m in module.named_modules(): - if m.__class__.__name__ in match_module and not hasattr(m, "_offload_hook"): + if m.__class__.__name__ == "Linear" and not hasattr(m, "_offload_hook"): m._offload_hook = (pack_hook, unpack_hook) m.register_forward_pre_hook(offload_pre_hook) m.register_forward_hook(offload_post_hook) elif module.offload_level == 2: if not hasattr(module, "_offload_hook"): module._offload_hook = (pack_hook, unpack_hook) - module.register_forward_pre_hook(offload_pre_hook) - module.register_forward_hook(offload_post_hook) - + torch._C._autograd._push_saved_tensors_default_hooks( + pack_hook, unpack_hook + ) + elif module._mode != "OFFLOAD" and ((len(module._pre_module) > 0) and module._pre_module[0]._mode == "OFFLOAD"): + for pre_module in module._pre_module: + if len(pre_module._pre_module) == 0: + pre_offload_module = None + else: + pre_offload_module = find_pre_module_helper(pre_module._pre_module[0]) + if pre_offload_module is not None: + torch.cuda.current_stream().wait_event(pre_offload_module.offload_event) + if pre_module._mode == "OFFLOAD": + with torch.cuda.stream(config["offload_stream"]): + config["offload_stream"].wait_event(pre_module.calc_event) + if not hasattr(pre_module._offload_dict, "fp16_storage"): + pre_module._offload_dict.make_cpu_storage() + pre_module._offload_dict.d2h_memcpy() + pre_module.offload_event = torch.cuda.Event() + config["offload_stream"].record_event(pre_module.offload_event) # torch._C._autograd._push_saved_tensors_default_hooks( # pack_hook, unpack_hook # ) @@ -174,14 +203,10 @@ def zero_post_forward(module, inputs, outputs): exit = True if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 - elif module._mode != "OFFLOAD" and ((not module._is_first_layer) and module._pre_module[0]._mode == "OFFLOAD"): - for pre_module in module._pre_module: - if pre_module._mode == "OFFLOAD": - # torch._C._autograd._pop_saved_tensors_default_hooks() - with torch.cuda.stream(config["offload_stream"]): - if not hasattr(pre_module._offload_dict, "fp16_storage"): - pre_module._offload_dict.make_cpu_storage() - pre_module._offload_dict.d2h_memcpy() + elif module._mode == "OFFLOAD" and module.offload_level == 2: + module.calc_event = torch.cuda.Event() + torch.cuda.current_stream().record_event(module.calc_event) + torch._C._autograd._pop_saved_tensors_default_hooks() if exit: module._forward_block_ctx.exit(forward_flag) module._ref_count += 1 @@ -191,7 +216,7 @@ def zero_pre_backward(module, grad_outputs): if module._mode != "PIPE": if module._mode != "OFFLOAD": count = len([m for m in module._pre_module if m._mode=="OFFLOAD"]) - if module._is_last_layer or module._next_module[0]._mode == "OFFLOAD": + if (len(module._next_module) == 0) or module._next_module[0]._mode == "OFFLOAD": for pre_module in nearest_offload_module(module): if pre_module._mode == "OFFLOAD": pre_module._on_device = True From 3063afb129f3a97b3f13ed84b2ca983886cba61e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 23 Aug 2023 19:36:55 +0800 Subject: [PATCH 081/122] tensor parallel --- bmtrain/block_layer.py | 49 ++++++--- bmtrain/checkpointing.py | 12 +-- bmtrain/init.py | 63 ++++++++--- bmtrain/loss/cross_entropy.py | 38 ++++--- bmtrain/nn/__init__.py | 4 + bmtrain/nn/column_parallel_linear.py | 35 ++++++ bmtrain/nn/cross_entropy.py | 135 ++++++++++++++++++++++++ bmtrain/nn/embedding.py | 87 +++++++++++++++ bmtrain/nn/parallel_linear_hook_func.py | 91 ++++++++++++++++ bmtrain/nn/row_parallel_linear.py | 37 +++++++ bmtrain/param_init.py | 19 +++- bmtrain/parameter.py | 71 ++++++++++--- example/layers/attention.py | 52 ++++++--- example/layers/embedding.py | 9 +- example/layers/feedforward.py | 15 ++- example/models/gpt.py | 18 +++- example/train.py | 15 ++- 17 files changed, 659 insertions(+), 91 deletions(-) create mode 100644 bmtrain/nn/column_parallel_linear.py create mode 100644 bmtrain/nn/cross_entropy.py create mode 100644 bmtrain/nn/embedding.py create mode 100644 bmtrain/nn/parallel_linear_hook_func.py create mode 100644 bmtrain/nn/row_parallel_linear.py diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 18438b8c..4bc1cc2d 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -94,7 +94,8 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): "total": 0, "storage_type": storage_type, "requires_grad": param.requires_grad, - "group": param.group + "group": param.group, + "zero_comm" : param._zero_comm } param_shape = param._original_shape @@ -108,11 +109,14 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): offsets = {} # intialize storage buffers for kw, val in self._storage_info.items(): - val["world_size"] = config["world_size"] + comm = val['zero_comm'] + world_size = nccl.commCount(comm) + rank = nccl.commRank(comm) + val["world_size"] = world_size #config["dp_size"] partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] val["partition_size"] = partition_size - val["begin"] = config['rank'] * partition_size - val["end"] = (config['rank'] + 1) * partition_size + val["begin"] = rank * partition_size #config['zero_rank'] * partition_size + val["end"] = (rank+1) * partition_size #(config['zero_rank'] + 1) * partition_size offsets[kw] = 0 @@ -301,13 +305,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if key in state_dict: # load here input_param = state_dict[key] + param = it['parameter'] + tp_mode = param._tp_mode if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - if input_param.shape != it["shape"]: + if not tp_mode and input_param.shape != it["shape"]: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' .format(key, input_param.shape, it["shape"])) continue + param_st = it["offset"] param_end = it["offset"] + it["size"] kw_name = it["kw_name"] @@ -321,8 +328,18 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, continue # copy to buffer - assert input_param.numel() == it["size"] + if not tp_mode: + assert input_param.numel() == it["size"] + contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() + + if tp_mode: + tp_split_dim = param._tp_split_dim + if tp_split_dim >= 0: + param_list = contiguous_param.chunk(config['tp_size'], dim=tp_split_dim) + sub_tensor = param_list[config['topology'].tp_id] + contiguous_param = torch.empty(sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype) + contiguous_param.copy_(sub_tensor) offset_st = max(storage_st - param_st, 0) offset_end = min(storage_end - param_st, contiguous_param.numel()) @@ -330,7 +347,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, to_offset_st = offset_st + param_st - storage_st to_offset_end = offset_end + param_st - storage_st - + # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype @@ -397,7 +414,7 @@ def init_parameters(self): param = it["parameter"] if isinstance(param, DistributedParameter) and param._init_method is not None: # initialzie here - tmp_tensor = torch.empty(it["shape"], device=param.device, dtype=param.dtype) + tmp_tensor = torch.empty(param._tp_original_shape, device=param.device, dtype=param.dtype) param._init_method(tmp_tensor) param_st = it["offset"] param_end = it["offset"] + it["size"] @@ -411,16 +428,18 @@ def init_parameters(self): if param_end <= storage_st: continue + if param._tp_mode and param._tp_split_dim >= 0: + tensor_list = tmp_tensor.chunk(config['tp_size'], dim=param._tp_split_dim) + sub_tensor = tensor_list[config['topology'].tp_id].contiguous() + tmp_tensor = torch.empty(sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype) + tmp_tensor.copy_(sub_tensor) # copy to buffer assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel() - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, tmp_tensor.numel()) + offset_st = max(storage_st - param_st, 0) + offset_end = min(storage_end - param_st, tmp_tensor.numel()) assert offset_st < offset_end - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st - # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype @@ -528,8 +547,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module._mode = "ZERO" module.set_pre_module(pre_module) pre_module = module - module._is_first_layer = False - module._is_last_layer = False + self._is_first_layer = False + self._is_last_layer = False self._modules[str(i)] = module self.add_module(str(i), module) diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index b2c9ec07..3adbc105 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -39,10 +39,8 @@ def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = Fal self._param_tensor = {} self._grad_tensor = {} self._need_release = False - if pipe: - self.comm = config["zero_comm"] - else: - self.comm = config["comm"] + #self.comm = config["zero_comm"] + def enter(self, flag=0, requires_grad=False): """ gather parameters @@ -74,7 +72,8 @@ def enter(self, flag=0, requires_grad=False): nccl.allGather( self.block._storage_params[kw].storage(), self._param_buffer[kw], - self.comm + #self.comm + val['zero_comm'] ) nccl.groupEnd() @@ -144,7 +143,8 @@ def exit(self, flag=0, backward=False): self._grad_buffer[kw], local_param.grad.storage(), "sum", - self.comm + #self.comm + val['zero_comm'] ) nccl.groupEnd() diff --git a/bmtrain/init.py b/bmtrain/init.py index 1fa0712d..a403caf7 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -10,12 +10,12 @@ from . import nccl from .synchronize import synchronize - def init_distributed( init_method : str = "env://", seed : int = 0, zero_level: int = 3, pipe_size: int = -1, + tp_size = 1, num_micro_batches: int = None, ): """Initialize distributed training. @@ -75,8 +75,10 @@ def init_distributed( config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() config["zero_level"] = zero_level + config["tp_size"] = tp_size if tp_size > 0 else 1 config["topology"] = topology(config) - config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] + config["zero_rank"] = config['topology'].get_group_rank("zero") + config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) @@ -105,6 +107,8 @@ def init_distributed( unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode()) config['comm'] = nccl.commInitRank(unique_id, world_size, rank) + config['zero_comm'] = config['comm'] + if config['pipe_enabled']: config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] topo = config['topology'] @@ -113,13 +117,24 @@ def init_distributed( store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) - if topo.zero_id == 0: + + if config['tp_size'] > 1: + topo = config['topology'] + if topo.tp_id == 0: unique_id = nccl.getUniqueId() - store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode()) - config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//pipe_size, topo.zero_id) + store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) + config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + + if not config['pipe_enabled'] and config['tp_size'] <= 1: + config['tp_zero_comm'] = config['comm'] else: - config['zero_comm'] = config['comm'] + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config ['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + for i in range(world_size): if i == rank: print_dict("Initialization", { @@ -132,24 +147,40 @@ def init_distributed( "cpus": cpus_this_worker }) synchronize() + class topology: def __init__(self,config): # pipe_idx is the idx of the pipeline in the group self.rank = config['rank'] pp_size = config["pipe_size"] + tp_size = config["tp_size"] world_size = config["world_size"] - assert world_size % pp_size == 0, "The nums of GPUs must be divisible by the pipeline parallel size" + assert world_size % (pp_size * tp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size" - dp_size = world_size // pp_size - topo=torch.tensor(range(dp_size*pp_size),dtype=torch.int,device='cuda') - topo=topo.view(pp_size,dp_size) + dp_size = world_size // (pp_size * tp_size) + config['tp_zero_size'] = dp_size + config['zero_size'] = world_size // pp_size + topo=torch.tensor(range(dp_size*tp_size*pp_size),dtype=torch.int,device='cuda') + topo=topo.view(pp_size,dp_size*tp_size) self.pp_group=topo.transpose(0,1).reshape(-1,pp_size) - self.dp_group=topo self.stage_id = (self.pp_group == self.rank).nonzero()[0,-1].item() self.stages = config['pipe_size'] self.pipe_idx = (self.pp_group == self.rank).nonzero()[0, 0].item() # x axes self.zero_id = self.pipe_idx self.zero_idx = self.stage_id + + self.tp_group = topo.reshape(pp_size, dp_size, tp_size) + self.tp_id = (self.tp_group == self.rank).nonzero()[0,2].item() + self.tp_idx = (self.tp_group == self.rank).nonzero()[0,1 if dp_size > 1 else 0].item() + + if pp_size == 1 and tp_size == 1: + self.tp_zero_id = self.rank + self.tp_zero_idx = 0 + else: + self.dp_group = self.tp_group.permute(0,2,1) + self.tp_zero_id = (self.dp_group == self.rank).nonzero()[0,2 if tp_size > 1 else 0].item() + self.tp_zero_idx = (self.dp_group == self.rank).nonzero()[0,1 if tp_size > 1 else 2].item() + self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 self.tails = self.pp_group[self.pipe_idx, self.stage_id:].tolist() @@ -160,12 +191,20 @@ def get_group_id(self,group_name): return self.pipe_idx elif group_name == "zero": return self.zero_idx + elif group_name == "tp_zero": + return self.tp_zero_idx + elif group_name == "tp": + return self.tp_idx def get_group_rank(self,group_name): if group_name == "pipe": return self.stage_id elif group_name == "zero": return self.zero_id + elif group_name == "tp_zero": + return self.tp_zero_id + elif group_name == "tp": + return self.tp_id def is_initialized() -> bool: return config["initialized"] diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 31223640..a5d43a55 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -1,6 +1,10 @@ from typing import Optional import torch from . import _function as F +from bmtrain.nn import fused_cross_entropy +from bmtrain.global_var import config +from bmtrain.distributed import all_gather + class OpFusedCrossEntropy(torch.autograd.Function): """ CrossEntropy dim = 1 @@ -185,19 +189,23 @@ def __init__(self, self.inplace = inplace def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - if input.dtype == torch.float32: - return torch.nn.functional.cross_entropy( - input, - target.long(), - weight=self.weight, - ignore_index=self.ignore_index, - reduction=self.reduction, - label_smoothing=self.label_smoothing) - - if self.inplace: - ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor + if config['tp_size'] > 1: + target = all_gather(target, comm=config['tp_comm']).flatten(0,1) + ret = fused_cross_entropy(input, target.long(), self.ignore_index) else: - ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor + if input.dtype == torch.float32: + return torch.nn.functional.cross_entropy( + input, + target.long(), + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing) + + if self.inplace: + ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor + else: + ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor if self.weight is not None: if self.weight.dim() != 1 or self.weight.size(0) != input.size(1): @@ -208,6 +216,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = (target != self.ignore_index).int() ret = w * ret + + if config['tp_size'] > 1: + ret_list = ret.chunk(config['tp_size'], dim=0) + ret = ret_list[config['topology'].tp_id] + w_list = w.chunk(config['tp_size'], dim=0) + w = w_list[config['topology'].tp_id] if self.reduction == "none": return ret diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 67f9fdee..12540b57 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1 +1,5 @@ from .linear import Linear +from .column_parallel_linear import ColumnParallelLinear +from .row_parallel_linear import RowParallelLinear +from .embedding import Embedding +from .cross_entropy import fused_cross_entropy diff --git a/bmtrain/nn/column_parallel_linear.py b/bmtrain/nn/column_parallel_linear.py new file mode 100644 index 00000000..9d6444b4 --- /dev/null +++ b/bmtrain/nn/column_parallel_linear.py @@ -0,0 +1,35 @@ +import torch +from torch.nn.parameter import Parameter + +import bmtrain as bmt +from bmtrain.global_var import config +from .parallel_linear_hook_func import ( + LinearHookFunc, + ReduceType) + +class ColumnParallelLinear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + tp_size = config['tp_size'] + assert out_features % tp_size == 0 + self.out_features_per_partition = out_features // tp_size + self.weight = bmt.DistributedParameter(torch.empty(self.out_features_per_partition, in_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_, tp_split_dim=0, tp_mode=True) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(self.out_features_per_partition, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_, tp_split_dim=0, tp_mode=True) + else: + self.register_parameter('bias', None) + + def forward(self, input): + gather_input = True + split_input = False + reduce_output_type = None + return LinearHookFunc.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features_per_partitions, self.bias is not None + ) + diff --git a/bmtrain/nn/cross_entropy.py b/bmtrain/nn/cross_entropy.py new file mode 100644 index 00000000..8643c03b --- /dev/null +++ b/bmtrain/nn/cross_entropy.py @@ -0,0 +1,135 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +from bmtrain.distributed import all_reduce, all_gather + +class FusedCrossEntropyFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): + comm = config['tp_comm'] + rank = config['topology'].tp_id + world_size = config['tp_size'] + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + logits_max = all_reduce(logits_max, op="max", comm=comm) + # Subtract the maximum value. + vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) + + # Get the partition's vocab indecies + #get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + vocab_start_index = rank * partition_vocab_size + vocab_end_index = (rank + 1) * partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], + device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + predicted_logits = all_reduce(predicted_logits, op="sum", comm=comm) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + #if config['rank'] == 0: + #print("before", sum_exp_logits.shape, predicted_logits.shape) + sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) + #if config['rank'] == 0: + # print(sum_exp_logits.shape) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits + + # Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + vocab_size = exp_logits.size(-1) + if label_smoothing > 0: + """ + We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. + = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) + = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i + = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K + From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py + """ + assert 1.0 > label_smoothing > 0.0 + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + + # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. + log_probs = torch.log(exp_logits) + mean_log_probs = log_probs.mean(dim=-1) + loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs + + ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + # Store softmax, target-mask and masked-target for backward pass. + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], + device=grad_2d.device) + + softmax_update = 1.0 - target_mask.view(-1).float() + + if label_smoothing > 0: + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update + average_grad = 1 / vocab_size + grad_2d[arange_1d, :] -= smoothing * average_grad + else: + grad_2d[arange_1d, masked_target_1d] -= softmax_update + + # Finally elementwise multiplication with the output gradients. + #grad_input.mul_(grad_output.unsqueeze(dim=-1)) + grad_input.mul_(grad_output.flatten(0,1).unsqueeze(dim=-1)) + + return grad_input, None, None + + +def fused_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): + """ + Performs cross entropy loss when logits are split across tensor parallel ranks + + Arguments: + vocab_parallel_logits: logits split across tensor parallel ranks + dimension is [sequence_length, batch_size, hidden_size] + + target: correct vocab ids of dimseion [sequence_length, micro_batch_size] + + lobal_smoothing: smoothing factor, must be in range [0.0, 1.0) + default is no smoothing (=0.0) + """ + out = FusedCrossEntropyFunc.apply(vocab_parallel_logits.to(torch.float32), target, label_smoothing) + return out + diff --git a/bmtrain/nn/embedding.py b/bmtrain/nn/embedding.py new file mode 100644 index 00000000..52be1d72 --- /dev/null +++ b/bmtrain/nn/embedding.py @@ -0,0 +1,87 @@ +import torch +from torch.nn.parameter import Parameter +import torch.nn.functional as F +import math + +import bmtrain as bmt +from bmtrain.global_var import config +from bmtrain.distributed import all_reduce, all_gather +from .parallel_linear_hook_func import LinearHookFunc + +class Embedding(bmt.DistributedModule): + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + scale: bool = True, + init_mean: float = 0.0, + init_std: float = 1, + ): + super().__init__() + + self.dim_model = embedding_size + assert vocab_size % config['tp_size'] == 0 + self.vocab_size_per_partition = vocab_size // config['tp_size'] + self.start_index = config['topology'].tp_id * self.vocab_size_per_partition + self.end_index = (config['topology'].tp_id+1) * self.vocab_size_per_partition + self.weight = bmt.DistributedParameter( + torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + tp_mode=True, + tp_split_dim=0, + ) + self.scale = scale + + def forward(self, ids: torch.Tensor): + """ + Args: + ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens. + Return: + :obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output. + """ # noqa: E501 + + if config['tp_size'] > 1: + ids = all_gather(ids, comm=config['tp_comm']) + input_mask = (ids < self.start_index) | (ids >= self.end_index) + ids = ids.clone() - self.start_index + ids[input_mask] = 0 + + #if self.scale: + #embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model) + #else: + embeds = F.embedding(ids, self.weight) + + if config['tp_size'] > 1: + embeds[input_mask, :] = 0.0 + embeds = all_reduce(embeds, op="sum", comm=config['tp_comm']) + embed_list = embeds.chunk(config['tp_size'], dim=0) + embeds = embed_list[config['topology'].tp_id].flatten(0,1) + #print(embeds.sum()) + + return embeds.clone() + + def projection(self, x: torch.Tensor): + """ + Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. + Args: + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection + Returns: + :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. + """ # noqa: E501 + gather_input = True + split_input = False + reduce_output_type = None + gather_output = False + if self.scale: + #out = LinearHookFunc.apply(x / math.sqrt(self.dim_model), self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + #print(x.sum()) + #print(x.shape, self.weight.shape) + out = LinearHookFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) / math.sqrt(self.dim_model) + #out_list = out.chunk(config['tp_size'], dim=0) + #out = out_list[config['topology'].tp_id] + #print(out.sum()) + return out + else: + return LinearHookFunc.apply(x, self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + return logits diff --git a/bmtrain/nn/parallel_linear_hook_func.py b/bmtrain/nn/parallel_linear_hook_func.py new file mode 100644 index 00000000..af7bea0a --- /dev/null +++ b/bmtrain/nn/parallel_linear_hook_func.py @@ -0,0 +1,91 @@ +import torch +import torch.nn.functional as F +from bmtrain.global_var import config +from ..distributed import all_gather, all_reduce +from .. import nccl +import bmtrain as bmt +from enum import Enum + +class ReduceType(Enum): + ALL_REDUCE = 1 + REDUCE_SCATTER = 2 + +def preprocess_input(input, gather_input, split_input): + if gather_input: + input = all_gather(input, config['tp_comm']) + input = input.flatten(0, 1) + + if split_input: + all_input_list = input.chunk(config['tp_size'], dim=1) + input = all_input_list[config['topology'].tp_id] + return input + +class LinearHookFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None): + ctx.save_for_backward(input, weight, bias) + ctx.gather_output = gather_output + ctx.split_input = split_input + ctx.gather_input = gather_input + ctx.reduce_output_type = reduce_output_type + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + out = F.linear(all_input, weight, bias) + if gather_output: + all_output_list = all_gather(out, config['tp_comm']) + all_output_list = all_output_list.chunk(config['tp_size'], dim=0) + out = torch.cat(all_output_list, dim=all_output_list[0].dim()-1).flatten(0,1) + + if reduce_output_type is None: + return out + + if reduce_output_type == ReduceType.ALL_REDUCE: + nccl.allReduce(out.storage(), out.storage(), "sum", config['tp_comm']) + return out + elif reduce_output_type == ReduceType.REDUCE_SCATTER: + shape = list(out.shape) + shape[0] = shape[0] // config['tp_size'] + reduce_out = torch.empty(shape, dtype=out.dtype, device=out.device) + nccl.reduceScatter(out.storage(), reduce_out.storage(), "sum", config['tp_comm']) + return reduce_out + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + gather_output = ctx.gather_output + + if ctx.reduce_output_type == ReduceType.REDUCE_SCATTER: + grad_output = all_gather(grad_output, config['tp_comm']) + grad_output = grad_output.flatten(0, 1) + + if gather_output: + tp_size = config['tp_size'] + tp_id = config['topology'].tp_id + grad_output_list = grad_output.chunk(tp_size, dim=1) + grad_output = grad_output_list[tp_id] + + grad_input = grad_weight = grad_bias = None + + if input.requires_grad or weight.requires_grad: + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + + if input.requires_grad: + #gather can async with grad_out.matmul(weight) + #TODO: gather on load_stream + grad_all_input = grad_output.matmul(weight) + grad_input = torch.empty_like(input) + if ctx.gather_input: + nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) + else: + grad_input = grad_all_input + + if ctx.split_input: + grad_input = all_gather(grad_input, config['tp_comm']) + + if weight.requires_grad: + dim = grad_output.dim() + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(all_input.reshape(-1, all_input.shape[-1])) + + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + return grad_input, grad_weight, grad_bias, None, None, None, None diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py new file mode 100644 index 00000000..a6877b7c --- /dev/null +++ b/bmtrain/nn/row_parallel_linear.py @@ -0,0 +1,37 @@ +import torch +from torch.nn.parameter import Parameter + +import bmtrain as bmt +from bmtrain.global_var import config +from .parallel_linear_hook_func import ( + LinearHookFunc, + ReduceType) + +class RowParallelLinear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, split_input=False, all_reduce_output=False) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.split_input = split_input + self.all_reduce_output = all_reduce_output + tp_size = config['tp_size'] + assert in_features % tp_size == 0 + self.in_features_per_partition = in_features // tp_size + self.weight = bmt.DistributedParameter(torch.empty(self.out_features, self.in_features_per_partition, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_, tp_split_dim=1, tp_mode=True) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(self.out_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_, tp_split_dim=-1, tp_mode=True) + else: + self.register_parameter('bias', None) + + def forward(self, input): + gather_input = self.split_input + gather_output = False + reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER + out = LinearHookFunc.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) + out = out + self.bias + return out + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features_per_partition, self.out_features, self.bias is not None + ) diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 8b74c580..5e165fe1 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -3,6 +3,7 @@ from .block_layer import CheckpointBlock from .parameter import DistributedParameter from .global_var import config +from . import nccl def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): @@ -13,17 +14,27 @@ def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): continue with torch.no_grad(): partition_size = param.storage().size() - global_size = partition_size * config['world_size'] - + global_size = partition_size * config['tp_zero_size'] * config['tp_size'] tmp_storage = param.storage_type()(global_size) tmp_tensor = torch.tensor([], dtype=param.dtype, device="cuda") - tmp_tensor.set_(tmp_storage, 0, param._original_shape) + tmp_tensor.set_(tmp_storage, 0, param._tp_original_shape) param._init_method(tmp_tensor) + if param._tp_mode and param._tp_split_dim >= 0: + tensor_list = tmp_tensor.chunk(config['tp_size'], dim=param._tp_split_dim) + sub_tensor = tensor_list[config['topology'].tp_id].contiguous() + tmp_tensor = torch.empty(sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype) + tmp_tensor.copy_(sub_tensor) + + if param._tp_mode: + begin = config['tp_zero_rank'] + else: + begin = config['zero_rank'] + end = begin + 1 # Pytorch 1.11 changed the API of storage.__getitem__ torch.tensor([], dtype=param.dtype, device=param.device).set_(param.storage())[:] = \ - torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_storage)[partition_size * config['rank'] : partition_size * (config['rank'] + 1)] + torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_tensor.storage())[partition_size * begin : partition_size * end] # param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)]) def iterate_parameters(model : torch.nn.Module): diff --git a/bmtrain/parameter.py b/bmtrain/parameter.py index f965cdac..f68d53ac 100644 --- a/bmtrain/parameter.py +++ b/bmtrain/parameter.py @@ -3,6 +3,7 @@ from .utils import round_up from .global_var import config from . import nccl +from .distributed import all_gather class DistributedParameter(torch.nn.Parameter): r""" @@ -31,7 +32,9 @@ def __new__(cls, data : torch.Tensor, requires_grad : bool = True, init_method : Optional[Callable[['DistributedParameter'], None]] = None, - group : Optional[str] = None + group : Optional[str] = None, + tp_mode : bool = False, + tp_split_dim : int = -1, ): if not config["initialized"]: raise RuntimeError("BMTrain is not initialized") @@ -39,14 +42,25 @@ def __new__(cls, num_of_elements = data.numel() cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda") - cuda_storage_size = round_up(num_of_elements, config["world_size"]) // config["world_size"] + if tp_mode: + comm = config['tp_zero_comm'] + else: + comm = config['zero_comm'] + world_size = nccl.commCount(comm) + rank = nccl.commRank(comm) + cuda_storage_size = round_up(num_of_elements, world_size) // world_size original_shape = data.size() + tp_original_shape = original_shape + if tp_mode and tp_split_dim >= 0: + list_shape = list(original_shape) + list_shape[tp_split_dim] *= config['tp_size'] + tp_original_shape = list_shape cuda_storage = cuda_tensor.storage_type()(cuda_storage_size) - start_of_partition = cuda_storage_size * config["rank"] - end_of_partition = min(num_of_elements, cuda_storage_size * (config["rank"] + 1)) + start_of_partition = cuda_storage_size * rank + end_of_partition = min(num_of_elements, cuda_storage_size * (rank + 1)) # FX: cuda_tensor_size < 0 if num_of_elements is too small cuda_tensor_size = max(end_of_partition - start_of_partition, 0) @@ -60,7 +74,12 @@ def __new__(cls, setattr(ret, "_end_partition", end_of_partition) setattr(ret, "_init_method", init_method) setattr(ret, "_in_checkpoint_block", False) - setattr(ret, "_group", group) + setattr(ret, "_group", group if not tp_mode else "tp") + + setattr(ret, "_tp_mode", tp_mode) + setattr(ret, "_zero_comm", comm) + setattr(ret, "_tp_split_dim", tp_split_dim) + setattr(ret, "_tp_original_shape", tp_original_shape) return ret @property @@ -83,24 +102,52 @@ def gather(self) -> torch.Tensor: current_stream.wait_stream(config['load_stream']) return output_tensor + def gather_all(self) -> torch.tensor: + zero_param = self.gather() + if config['tp_size'] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(zero_param, config['tp_comm']) + if self._tp_split_dim == 1: + output_list = output_tensor.chunk(config['tp_size'], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim()-1).flatten(0,1) + return output + else: + return output_tensor.flatten(0,1) + else: + return zero_param + + def tp_gather(self) -> torch.tensor: + if config['tp_size'] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(self, config['tp_comm']) + if self._tp_split_dim == 1: + output_list = output_tensor.chunk(config['tp_size'], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim()-1).flatten(0,1) + return output + else: + return output_tensor.flatten(0,1) + else: + return self + def _copy_data(self, data : torch.Tensor): self.data.copy_(data.view(-1)[self._start_partition : self._end_partition]) - class OpAllGather(torch.autograd.Function): @staticmethod def forward(ctx, value : DistributedParameter): assert isinstance(value, DistributedParameter) + comm = value._zero_comm #config['zero_comm'] + world_size = nccl.commCount(comm) + ctx.comm = comm + ctx.world_size = world_size partition_size = value.storage().size() - global_size = partition_size * config['world_size'] + global_size = partition_size * world_size storage = value.storage_type()(global_size) nccl.allGather( value.storage(), storage, - config['comm'] + comm ) output_tensor = torch.tensor([], dtype=value.dtype, device="cuda") @@ -117,19 +164,19 @@ def backward(ctx, grad_output : torch.Tensor): grad_storage = grad_output.storage_type()(ctx.partition_size) grad_output_storage = grad_output.storage() - if grad_output_storage.size() == ctx.partition_size * config['world_size']: + if grad_output_storage.size() == ctx.partition_size * ctx.world_size: pass else: - grad_output_storage.resize_(ctx.partition_size * config['world_size']) + grad_output_storage.resize_(ctx.partition_size * ctx.world_size) nccl.reduceScatter( grad_output_storage, grad_storage, 'sum', - config['comm'] + ctx.comm ) grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda") grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,)) - return grad_tensor + return grad_tensor, None class ParameterInitializer: """ diff --git a/example/layers/attention.py b/example/layers/attention.py index 243df3ea..b2ddebb9 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -1,8 +1,14 @@ from typing import Optional import torch import bmtrain as bmt -from bmtrain.nn import Linear +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear, +) import math +from bmtrain.global_var import config +from bmtrain.distributed import all_gather class Attention(bmt.DistributedModule): def __init__(self, @@ -12,11 +18,17 @@ def __init__(self, ) -> None: super().__init__() - self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + if config['tp_size'] <= 1: + self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + else: + self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) self.softmax = torch.nn.Softmax(dim=-1) self.num_heads = num_heads @@ -35,29 +47,35 @@ def forward(self, h_q : torch.Tensor = self.project_q(hidden_q) h_k : torch.Tensor = self.project_k(hidden_kv) h_v : torch.Tensor = self.project_v(hidden_kv) + if config['tp_size'] > 1: + #batch_size will changed in TensorParallel + batch_size = h_v.shape[0] - h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) - h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) - h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) + h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) + h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) + h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) h_q = h_q.permute(0, 2, 1, 3).contiguous() h_k = h_k.permute(0, 2, 1, 3).contiguous() h_v = h_v.permute(0, 2, 1, 3).contiguous() - h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) - h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) - h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) + h_q = h_q.view(-1, seq_q, self.dim_head) + h_k = h_k.view(-1, seq_kv, self.dim_head) + h_v = h_v.view(-1, seq_kv, self.dim_head) score = torch.bmm( h_q, h_k.transpose(1, 2) ) score = score / math.sqrt(self.dim_head) - score = score.view(batch_size, self.num_heads, seq_q, seq_kv) + score = score.view(batch_size, -1, seq_q, seq_kv) if position_bias is not None: - score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) - + score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) + + if config['tp_size'] > 1: + mask = all_gather(mask, config['tp_comm']) + score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, @@ -70,14 +88,14 @@ def forward(self, torch.scalar_tensor(0, device=score.device, dtype=score.dtype) ) - score = score.view(batch_size * self.num_heads, seq_q, seq_kv) + score = score.view(-1, seq_q, seq_kv) h_out = torch.bmm( score, h_v ) - h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) + h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) h_out = h_out.permute(0, 2, 1, 3).contiguous() - h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) + h_out = h_out.view(batch_size, seq_q, -1) attn_out = self.project_out(h_out) return attn_out diff --git a/example/layers/embedding.py b/example/layers/embedding.py index 13c47384..4faaa133 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -77,11 +77,14 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: - return F.embedding( + out = F.embedding( input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) + return out else: - return F.linear(input, self.weight) / math.sqrt(self.embedding_dim) + out = F.linear(input, self.weight) / math.sqrt(self.embedding_dim) + #print(out.sum()) + return out def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}' @@ -97,4 +100,4 @@ def extra_repr(self) -> str: s += ', sparse=True' return s.format(**self.__dict__) - \ No newline at end of file + diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index 99d2dc3b..e88d2495 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -1,16 +1,23 @@ import torch import bmtrain as bmt -from bmtrain.nn import Linear +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear) +from bmtrain.global_var import config class Feedforward(bmt.DistributedModule): def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: super().__init__() - self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) - self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype) + if config['tp_size'] > 1: + self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype) + self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype) + else: + self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype) + self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype) self.relu = torch.nn.ReLU() def forward(self, input : torch.Tensor) -> torch.Tensor: - return self.w_out(self.relu(self.w_in(input))) diff --git a/example/models/gpt.py b/example/models/gpt.py index 78d77a7d..d218b391 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -1,6 +1,7 @@ import torch import bmtrain as bmt from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder +from bmtrain.global_var import config class GPT(bmt.DistributedModule): def __init__(self, @@ -13,14 +14,17 @@ def __init__(self, self.max_distance = max_distance - self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) + if config['tp_size'] > 1: + self.word_emb = bmt.nn.Embedding(vocab_size, dim_model, dtype=dtype) + else: + self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) self.transformers = bmt.TransformerBlockList([ bmt.CheckpointBlock( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype - ) + ), use_checkpoint=False ) for _ in range(num_layers) ]) @@ -37,12 +41,18 @@ def forward(self, mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) out = self.pos_emb(pos) + self.word_emb(input) + bmt.synchronize() # for layer in self.transformers: out = self.transformers(out, mask_2d, None) out = self.layernorm(out) - logits = self.word_emb(out, projection=True) + bmt.synchronize() + if config['tp_size'] > 1: + logits = self.word_emb.projection(out)#self.word_emb(out, projection=True) + else: + logits = self.word_emb(out, projection=True) + bmt.synchronize() bmt.inspect.record_tensor(logits, "logits") - return logits \ No newline at end of file + return logits diff --git a/example/train.py b/example/train.py index 1a744e20..50036675 100644 --- a/example/train.py +++ b/example/train.py @@ -3,12 +3,14 @@ from models import GPT import time from bmtrain import optim +from bmtrain.global_var import config from bmtrain import inspect def main(): bmt.init_distributed( seed=0, zero_level=2, + tp_size=4, ) model = GPT( @@ -24,11 +26,13 @@ def main(): ) bmt.init_parameters(model) + #bmt.load(model, "example_model.pt") # print_inspect(model, "*") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() + #bmt.save(model, "example_model.pt") # data # generate dummy data for each rank @@ -52,7 +56,11 @@ def main(): if i == bmt.rank(): break - loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + if config['tp_size'] > 1: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + else: + loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) @@ -77,7 +85,10 @@ def main(): ) batch, seq_len, vocab_out_size = logits.size() - loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + if config['tp_size'] > 1: + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets) + else: + loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) global_loss = bmt.sum_loss(loss).item() From 8648f5baea9b599e668a92ba1fd788c2942363cf Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 23 Aug 2023 20:20:56 +0800 Subject: [PATCH 082/122] rm unused code --- bmtrain/block_layer.py | 6 +++--- bmtrain/checkpointing.py | 3 --- bmtrain/nn/cross_entropy.py | 5 ----- bmtrain/nn/embedding.py | 20 ++------------------ example/layers/attention.py | 12 ++++++------ example/layers/embedding.py | 3 +-- example/models/gpt.py | 7 ++----- example/train.py | 5 +---- 8 files changed, 15 insertions(+), 46 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index f30b13f6..d01a01cb 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -112,11 +112,11 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): comm = val['zero_comm'] world_size = nccl.commCount(comm) rank = nccl.commRank(comm) - val["world_size"] = world_size #config["dp_size"] + val["world_size"] = world_size partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] val["partition_size"] = partition_size - val["begin"] = rank * partition_size #config['zero_rank'] * partition_size - val["end"] = (rank+1) * partition_size #(config['zero_rank'] + 1) * partition_size + val["begin"] = rank * partition_size + val["end"] = (rank+1) * partition_size offsets[kw] = 0 diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index 3adbc105..550225be 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -39,7 +39,6 @@ def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = Fal self._param_tensor = {} self._grad_tensor = {} self._need_release = False - #self.comm = config["zero_comm"] def enter(self, flag=0, requires_grad=False): """ @@ -72,7 +71,6 @@ def enter(self, flag=0, requires_grad=False): nccl.allGather( self.block._storage_params[kw].storage(), self._param_buffer[kw], - #self.comm val['zero_comm'] ) nccl.groupEnd() @@ -143,7 +141,6 @@ def exit(self, flag=0, backward=False): self._grad_buffer[kw], local_param.grad.storage(), "sum", - #self.comm val['zero_comm'] ) nccl.groupEnd() diff --git a/bmtrain/nn/cross_entropy.py b/bmtrain/nn/cross_entropy.py index 8643c03b..49460b5b 100644 --- a/bmtrain/nn/cross_entropy.py +++ b/bmtrain/nn/cross_entropy.py @@ -45,11 +45,7 @@ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) - #if config['rank'] == 0: - #print("before", sum_exp_logits.shape, predicted_logits.shape) sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) - #if config['rank'] == 0: - # print(sum_exp_logits.shape) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits @@ -111,7 +107,6 @@ def backward(ctx, grad_output): grad_2d[arange_1d, masked_target_1d] -= softmax_update # Finally elementwise multiplication with the output gradients. - #grad_input.mul_(grad_output.unsqueeze(dim=-1)) grad_input.mul_(grad_output.flatten(0,1).unsqueeze(dim=-1)) return grad_input, None, None diff --git a/bmtrain/nn/embedding.py b/bmtrain/nn/embedding.py index 52be1d72..c8acafae 100644 --- a/bmtrain/nn/embedding.py +++ b/bmtrain/nn/embedding.py @@ -14,7 +14,6 @@ def __init__( vocab_size: int, embedding_size: int, dtype: torch.dtype = torch.half, - scale: bool = True, init_mean: float = 0.0, init_std: float = 1, ): @@ -31,7 +30,6 @@ def __init__( tp_mode=True, tp_split_dim=0, ) - self.scale = scale def forward(self, ids: torch.Tensor): """ @@ -47,9 +45,6 @@ def forward(self, ids: torch.Tensor): ids = ids.clone() - self.start_index ids[input_mask] = 0 - #if self.scale: - #embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model) - #else: embeds = F.embedding(ids, self.weight) if config['tp_size'] > 1: @@ -57,7 +52,6 @@ def forward(self, ids: torch.Tensor): embeds = all_reduce(embeds, op="sum", comm=config['tp_comm']) embed_list = embeds.chunk(config['tp_size'], dim=0) embeds = embed_list[config['topology'].tp_id].flatten(0,1) - #print(embeds.sum()) return embeds.clone() @@ -73,15 +67,5 @@ def projection(self, x: torch.Tensor): split_input = False reduce_output_type = None gather_output = False - if self.scale: - #out = LinearHookFunc.apply(x / math.sqrt(self.dim_model), self.weight, None, gather_input, gather_output, split_input, reduce_output_type) - #print(x.sum()) - #print(x.shape, self.weight.shape) - out = LinearHookFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) / math.sqrt(self.dim_model) - #out_list = out.chunk(config['tp_size'], dim=0) - #out = out_list[config['topology'].tp_id] - #print(out.sum()) - return out - else: - return LinearHookFunc.apply(x, self.weight, None, gather_input, gather_output, split_input, reduce_output_type) - return logits + out = LinearHookFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + return out diff --git a/example/layers/attention.py b/example/layers/attention.py index b2ddebb9..a49edabb 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -18,17 +18,17 @@ def __init__(self, ) -> None: super().__init__() - if config['tp_size'] <= 1: + if config['tp_size'] > 1: + self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + else: self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) - else: - self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) self.softmax = torch.nn.Softmax(dim=-1) self.num_heads = num_heads diff --git a/example/layers/embedding.py b/example/layers/embedding.py index 4faaa133..f62151c4 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -82,8 +82,7 @@ def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tenso self.norm_type, self.scale_grad_by_freq, self.sparse) return out else: - out = F.linear(input, self.weight) / math.sqrt(self.embedding_dim) - #print(out.sum()) + out = F.linear(input, self.weight) return out def extra_repr(self) -> str: diff --git a/example/models/gpt.py b/example/models/gpt.py index d218b391..d2dff467 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -24,7 +24,7 @@ def __init__(self, bmt.CheckpointBlock( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype - ), use_checkpoint=False + ) ) for _ in range(num_layers) ]) @@ -41,18 +41,15 @@ def forward(self, mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) out = self.pos_emb(pos) + self.word_emb(input) - bmt.synchronize() # for layer in self.transformers: out = self.transformers(out, mask_2d, None) out = self.layernorm(out) - bmt.synchronize() if config['tp_size'] > 1: - logits = self.word_emb.projection(out)#self.word_emb(out, projection=True) + logits = self.word_emb.projection(out) else: logits = self.word_emb(out, projection=True) - bmt.synchronize() bmt.inspect.record_tensor(logits, "logits") return logits diff --git a/example/train.py b/example/train.py index 50036675..44a7d5d2 100644 --- a/example/train.py +++ b/example/train.py @@ -10,7 +10,7 @@ def main(): bmt.init_distributed( seed=0, zero_level=2, - tp_size=4, + tp_size=2, ) model = GPT( @@ -26,13 +26,10 @@ def main(): ) bmt.init_parameters(model) - #bmt.load(model, "example_model.pt") - # print_inspect(model, "*") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) bmt.synchronize() - #bmt.save(model, "example_model.pt") # data # generate dummy data for each rank From 763b4080816fb6b9e8d541ca352ea12ad12dd88a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 09:46:38 +0800 Subject: [PATCH 083/122] refactor nccl group; remove partition_modules in pipe_layer.py --- bmtrain/init.py | 42 ++++++++++------------- bmtrain/pipe_layer.py | 77 +------------------------------------------ 2 files changed, 19 insertions(+), 100 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index a403caf7..915e30c7 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -107,7 +107,6 @@ def init_distributed( unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode()) config['comm'] = nccl.commInitRank(unique_id, world_size, rank) - config['zero_comm'] = config['comm'] if config['pipe_enabled']: config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] @@ -126,15 +125,18 @@ def init_distributed( unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) - if not config['pipe_enabled'] and config['tp_size'] <= 1: - config['tp_zero_comm'] = config['comm'] - else: if topo.tp_zero_id == 0: unique_id = nccl.getUniqueId() - store.set(f"ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) config ['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + if topo.zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode()) + config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size']), topo.zero_id) + for i in range(world_size): if i == rank: print_dict("Initialization", { @@ -162,29 +164,21 @@ def __init__(self,config): config['zero_size'] = world_size // pp_size topo=torch.tensor(range(dp_size*tp_size*pp_size),dtype=torch.int,device='cuda') topo=topo.view(pp_size,dp_size*tp_size) - self.pp_group=topo.transpose(0,1).reshape(-1,pp_size) - self.stage_id = (self.pp_group == self.rank).nonzero()[0,-1].item() self.stages = config['pipe_size'] - self.pipe_idx = (self.pp_group == self.rank).nonzero()[0, 0].item() # x axes - self.zero_id = self.pipe_idx - self.zero_idx = self.stage_id - self.tp_group = topo.reshape(pp_size, dp_size, tp_size) - self.tp_id = (self.tp_group == self.rank).nonzero()[0,2].item() - self.tp_idx = (self.tp_group == self.rank).nonzero()[0,1 if dp_size > 1 else 0].item() - - if pp_size == 1 and tp_size == 1: - self.tp_zero_id = self.rank - self.tp_zero_idx = 0 - else: - self.dp_group = self.tp_group.permute(0,2,1) - self.tp_zero_id = (self.dp_group == self.rank).nonzero()[0,2 if tp_size > 1 else 0].item() - self.tp_zero_idx = (self.dp_group == self.rank).nonzero()[0,1 if tp_size > 1 else 2].item() + for i in range(world_size): + self.pipe_idx = self.rank % pp_size + self.stage_id = self.rank // pp_size + self.tp_id = self.rank % tp_size + self.tp_idx = self.rank // tp_size + self.zero_idx = self.stage_id if pp_size > 1 else 0 + self.zero_id = self.pipe_idx if pp_size > 1 else self.rank + self.tp_zero_idx = self.tp_id + self.tp_zero_id = self.tp_idx if dp_size > 1 else 0 self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 - self.tails = self.pp_group[self.pipe_idx, self.stage_id:].tolist() - self.heads = self.pp_group[self.pipe_idx, :self.stage_id + 1].tolist() + def get_group_id(self,group_name): if group_name == "pipe": diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 0a34ac46..efce4049 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -189,7 +189,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: rank = config['rank'] topo = config['topology'] self.layer_ids = [] - pipe_group = topo.pp_group self.stages = topo.stages self.stage_id = topo.stage_id self.pipe_idx = topo.pipe_idx @@ -218,11 +217,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self._modules[str(self.layer_ids[0])]._is_first_layer = True self._modules[str(self.layer_ids[-1])]._is_last_layer = True - self.partition_modules(self.layer_ids) - self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 - self.prev_rank = pipe_group[self.pipe_idx, self.stage_id - 1] if self.stage_id > 0 else -1 - # self.micro_batches = config['num_micro_batches'] - self.save_list = [(i, i) for i in range(len(self.layer_ids))] def __len__(self) -> int: @@ -295,77 +289,8 @@ def get_stage_by_layer_id(self, layer_id : int) -> int: else: return rest + (layer_id - rest * (part_len+1)) // part_len - def partition_modules(self, idxs) -> None: - for i in range(len(self)): - contiguous_params = {} - for kw, val in self[i]._storage_info.items(): - storage_type = val["storage_type"] - contiguous_params[kw] = storage_type(round_up(val["total"], config["world_size"] // config["pipe_size"])) - nccl.allGather( - self[i]._storage_params[kw].storage(), - contiguous_params[kw], - config["comm"] - ) - - if i not in idxs: - for name, param in self[i]._module.named_parameters(): - param.data = torch.tensor([], dtype = param.dtype, device = param.device) - for kw, val in self[i]._storage_info.items(): - val["begin"] = self.stage_id - val["end"] = self.stage_id + 1 - val["partition_size"] = 1 - val["total"] = val["world_size"] - dtype = self[i]._storage_params[kw].dtype - device = self[i]._storage_params[kw].device - self[i]._storage_params[kw] = \ - torch.nn.Parameter(torch.tensor([0], dtype = dtype, device=device)) - else: - for kw, val in self[i]._storage_info.items(): - storage_type = val["storage_type"] - val["world_size"] = config["world_size"] // config["pipe_size"] - partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] - val["partition_size"] = partition_size - val["begin"] = config['zero_rank'] * partition_size - val["end"] = (config['zero_rank'] + 1) * partition_size - storage_param_buffer = storage_type(partition_size) - dtype = storage_param_buffer.dtype - device = storage_param_buffer.device - self[i]._storage_params[kw] = torch.nn.Parameter( - torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer) - ) - if val["requires_grad"]: - self[i]._storage_params[kw].requires_grad_(True) - else: - self[i]._storage_params[kw].requires_grad_(False) - ordered_parameters = list(self[i]._module.named_parameters()) - for idx, named_param in enumerate(ordered_parameters): - name, param = named_param - param_info = self[i]._param_info[idx] - kw_name = _get_param_kw(param) - storage_info = self[i]._storage_info[kw_name] - storage_st = storage_info["begin"] - storage_end = storage_info["end"] - param_st = param_info["offset"] - param_end = param_st + param_info["size"] - if not (param_st >= storage_end or param_end <= storage_st): - # copy offset in parameter storage - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, param_info["size"]) - assert offset_st < offset_end - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st - d_dtype = self[i]._storage_params[kw_name].dtype - d_device = self[i]._storage_params[kw_name].device - param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self[i]._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) - param_info["begin"] = to_offset_st - param_info["end"] = (to_offset_end - to_offset_st,) - param.data[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_params[kw], storage_st+to_offset_st, (to_offset_end - to_offset_st,))[:] - else: - param.data = torch.tensor([], dtype=param.dtype, device=param.device) - del contiguous_params - def _save_to_state_dict(self, destination, prefix, keep_vars): + print('call _save_to_state_dict') for name, module in self._modules.items(): idx = int(name) name = prefix + name + '.' From 4c50567a94b7143cceecfe2f90977a659e0e8dc0 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 10:43:51 +0800 Subject: [PATCH 084/122] fix by review comment --- bmtrain/block_layer.py | 26 +++++++------------ bmtrain/init.py | 2 +- bmtrain/loss/cross_entropy.py | 4 +-- bmtrain/nn/__init__.py | 4 +-- bmtrain/nn/column_parallel_linear.py | 11 ++++---- ...ropy.py => parallel_cross_entropy_func.py} | 6 ++--- .../{embedding.py => parallel_embedding.py} | 12 +++++---- ...r_hook_func.py => parallel_linear_func.py} | 2 +- bmtrain/nn/row_parallel_linear.py | 6 ++--- bmtrain/param_init.py | 1 - bmtrain/parameter.py | 7 +++-- bmtrain/pipe_layer.py | 1 - bmtrain/utils.py | 7 +++++ 13 files changed, 45 insertions(+), 44 deletions(-) rename bmtrain/nn/{cross_entropy.py => parallel_cross_entropy_func.py} (95%) rename bmtrain/nn/{embedding.py => parallel_embedding.py} (84%) rename bmtrain/nn/{parallel_linear_hook_func.py => parallel_linear_func.py} (98%) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index d01a01cb..c479a7e5 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, Iterator, Union, List -from .utils import round_up +from .utils import (round_up, tp_split_tensor) from .global_var import config import torch from . import nccl @@ -309,10 +309,12 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, tp_mode = param._tp_mode if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - if not tp_mode and input_param.shape != it["shape"]: + + verify_shape = it["shape"] if not tp_mode else param._tp_original_shape + if input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' - .format(key, input_param.shape, it["shape"])) + .format(key, input_param.shape, verify_shape)) continue param_st = it["offset"] @@ -328,18 +330,13 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, continue # copy to buffer - if not tp_mode: - assert input_param.numel() == it["size"] + verify_size = verify_shape.numel() + assert input_param.numel() == verify_size contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() - if tp_mode: - tp_split_dim = param._tp_split_dim - if tp_split_dim >= 0: - param_list = contiguous_param.chunk(config['tp_size'], dim=tp_split_dim) - sub_tensor = param_list[config['topology'].tp_id] - contiguous_param = torch.empty(sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype) - contiguous_param.copy_(sub_tensor) + if tp_mode and tp_split_dim >= 0: + contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim) offset_st = max(storage_st - param_st, 0) offset_end = min(storage_end - param_st, contiguous_param.numel()) @@ -429,10 +426,7 @@ def init_parameters(self): continue if param._tp_mode and param._tp_split_dim >= 0: - tensor_list = tmp_tensor.chunk(config['tp_size'], dim=param._tp_split_dim) - sub_tensor = tensor_list[config['topology'].tp_id].contiguous() - tmp_tensor = torch.empty(sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype) - tmp_tensor.copy_(sub_tensor) + tmp_tensor = tp_split_tensor(tmp_tensor, param._tp_split_dim) # copy to buffer assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel() diff --git a/bmtrain/init.py b/bmtrain/init.py index 915e30c7..23999e39 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -15,7 +15,7 @@ def init_distributed( seed : int = 0, zero_level: int = 3, pipe_size: int = -1, - tp_size = 1, + tp_size : int = 1, num_micro_batches: int = None, ): """Initialize distributed training. diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index a5d43a55..962cfdeb 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -1,7 +1,7 @@ from typing import Optional import torch from . import _function as F -from bmtrain.nn import fused_cross_entropy +from bmtrain.nn import parallel_cross_entropy_func from bmtrain.global_var import config from bmtrain.distributed import all_gather @@ -191,7 +191,7 @@ def __init__(self, def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if config['tp_size'] > 1: target = all_gather(target, comm=config['tp_comm']).flatten(0,1) - ret = fused_cross_entropy(input, target.long(), self.ignore_index) + ret = parallel_cross_entropy_func(input, target.long(), self.ignore_index) else: if input.dtype == torch.float32: return torch.nn.functional.cross_entropy( diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 12540b57..05026738 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1,5 +1,5 @@ from .linear import Linear from .column_parallel_linear import ColumnParallelLinear from .row_parallel_linear import RowParallelLinear -from .embedding import Embedding -from .cross_entropy import fused_cross_entropy +from .parallel_embedding import ParallelEmbedding +from .parallel_cross_entropy_func import parallel_cross_entropy_func diff --git a/bmtrain/nn/column_parallel_linear.py b/bmtrain/nn/column_parallel_linear.py index 9d6444b4..2baf79c2 100644 --- a/bmtrain/nn/column_parallel_linear.py +++ b/bmtrain/nn/column_parallel_linear.py @@ -3,16 +3,17 @@ import bmtrain as bmt from bmtrain.global_var import config -from .parallel_linear_hook_func import ( - LinearHookFunc, +from .parallel_linear_func import ( + ParallelLinearFunc, ReduceType) class ColumnParallelLinear(bmt.DistributedModule): - def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False) -> None: + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False, gather_input=True) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.gather_input = gather_input tp_size = config['tp_size'] assert out_features % tp_size == 0 self.out_features_per_partition = out_features // tp_size @@ -23,10 +24,10 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - gather_input = True + gather_input = self.gather_input split_input = False reduce_output_type = None - return LinearHookFunc.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) + return ParallelLinearFunc.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( diff --git a/bmtrain/nn/cross_entropy.py b/bmtrain/nn/parallel_cross_entropy_func.py similarity index 95% rename from bmtrain/nn/cross_entropy.py rename to bmtrain/nn/parallel_cross_entropy_func.py index 49460b5b..55aa78cb 100644 --- a/bmtrain/nn/cross_entropy.py +++ b/bmtrain/nn/parallel_cross_entropy_func.py @@ -3,7 +3,7 @@ from bmtrain.global_var import config from bmtrain.distributed import all_reduce, all_gather -class FusedCrossEntropyFunc(torch.autograd.Function): +class ParallelCrossEntropyFunc(torch.autograd.Function): @staticmethod def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): @@ -112,7 +112,7 @@ def backward(ctx, grad_output): return grad_input, None, None -def fused_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): +def parallel_cross_entropy_func(vocab_parallel_logits, target, label_smoothing=0.0): """ Performs cross entropy loss when logits are split across tensor parallel ranks @@ -125,6 +125,6 @@ def fused_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): lobal_smoothing: smoothing factor, must be in range [0.0, 1.0) default is no smoothing (=0.0) """ - out = FusedCrossEntropyFunc.apply(vocab_parallel_logits.to(torch.float32), target, label_smoothing) + out = ParallelCrossEntropyFunc.apply(vocab_parallel_logits.to(torch.float32), target, label_smoothing) return out diff --git a/bmtrain/nn/embedding.py b/bmtrain/nn/parallel_embedding.py similarity index 84% rename from bmtrain/nn/embedding.py rename to bmtrain/nn/parallel_embedding.py index c8acafae..1d7a1330 100644 --- a/bmtrain/nn/embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -6,9 +6,9 @@ import bmtrain as bmt from bmtrain.global_var import config from bmtrain.distributed import all_reduce, all_gather -from .parallel_linear_hook_func import LinearHookFunc +from .parallel_linear_func import ParallelLinearFunc -class Embedding(bmt.DistributedModule): +class ParallelEmbedding(bmt.DistributedModule): def __init__( self, vocab_size: int, @@ -31,16 +31,18 @@ def __init__( tp_split_dim=0, ) - def forward(self, ids: torch.Tensor): + def forward(self, ids: torch.Tensor, gather_input=True): """ Args: ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens. + gather_input (bool) : whether gather input is required between tensor parallel group) Return: :obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output. """ # noqa: E501 if config['tp_size'] > 1: - ids = all_gather(ids, comm=config['tp_comm']) + if gather_input: + ids = all_gather(ids, comm=config['tp_comm']) input_mask = (ids < self.start_index) | (ids >= self.end_index) ids = ids.clone() - self.start_index ids[input_mask] = 0 @@ -67,5 +69,5 @@ def projection(self, x: torch.Tensor): split_input = False reduce_output_type = None gather_output = False - out = LinearHookFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + out = ParallelLineakFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) return out diff --git a/bmtrain/nn/parallel_linear_hook_func.py b/bmtrain/nn/parallel_linear_func.py similarity index 98% rename from bmtrain/nn/parallel_linear_hook_func.py rename to bmtrain/nn/parallel_linear_func.py index af7bea0a..5c7a30bc 100644 --- a/bmtrain/nn/parallel_linear_hook_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -20,7 +20,7 @@ def preprocess_input(input, gather_input, split_input): input = all_input_list[config['topology'].tp_id] return input -class LinearHookFunc(torch.autograd.Function): +class ParallelLinearFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None): ctx.save_for_backward(input, weight, bias) diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py index a6877b7c..acc9378d 100644 --- a/bmtrain/nn/row_parallel_linear.py +++ b/bmtrain/nn/row_parallel_linear.py @@ -3,8 +3,8 @@ import bmtrain as bmt from bmtrain.global_var import config -from .parallel_linear_hook_func import ( - LinearHookFunc, +from .parallel_linear_func import ( + ParallelLinearFunc, ReduceType) class RowParallelLinear(bmt.DistributedModule): @@ -27,7 +27,7 @@ def forward(self, input): gather_input = self.split_input gather_output = False reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER - out = LinearHookFunc.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) + out = ParallelLinearFunc.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) out = out + self.bias return out diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 5e165fe1..d5c86225 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -3,7 +3,6 @@ from .block_layer import CheckpointBlock from .parameter import DistributedParameter from .global_var import config -from . import nccl def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): diff --git a/bmtrain/parameter.py b/bmtrain/parameter.py index f68d53ac..b5e5e9ae 100644 --- a/bmtrain/parameter.py +++ b/bmtrain/parameter.py @@ -53,9 +53,8 @@ def __new__(cls, original_shape = data.size() tp_original_shape = original_shape if tp_mode and tp_split_dim >= 0: - list_shape = list(original_shape) - list_shape[tp_split_dim] *= config['tp_size'] - tp_original_shape = list_shape + tp_original_shape = list(original_shape) + tp_original_shape[tp_split_dim] *= config['tp_size'] cuda_storage = cuda_tensor.storage_type()(cuda_storage_size) @@ -176,7 +175,7 @@ def backward(ctx, grad_output : torch.Tensor): ) grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda") grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,)) - return grad_tensor, None + return grad_tensor class ParameterInitializer: """ diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index efce4049..a21bb9af 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -290,7 +290,6 @@ def get_stage_by_layer_id(self, layer_id : int) -> int: return rest + (layer_id - rest * (part_len+1)) // part_len def _save_to_state_dict(self, destination, prefix, keep_vars): - print('call _save_to_state_dict') for name, module in self._modules.items(): idx = int(name) name = prefix + name + '.' diff --git a/bmtrain/utils.py b/bmtrain/utils.py index a5687c7d..8cb87808 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -107,6 +107,13 @@ def see_memory(message, detail=False): """) torch.cuda.reset_peak_memory_stats() +def tp_split_tensor(tensor, split_dim): + tensor_list = tensor.chunk(config['tp_size'], dim=split_dim) + sub_tensor = tensor_list[config['topology'].tp_id].contiguous() + tmp_tensor = torch.empty(sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype) + tmp_tensor.copy_(sub_tensor) + return tmp_tensor + class AverageRecorder: """A utility class to record the average value of a quantity over time. From 825139c5f6c48afca9ac24e09dcfb6b426522f4e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 11:26:56 +0800 Subject: [PATCH 085/122] fix topology --- bmtrain/init.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 23999e39..17d217b1 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -15,8 +15,8 @@ def init_distributed( seed : int = 0, zero_level: int = 3, pipe_size: int = -1, - tp_size : int = 1, num_micro_batches: int = None, + tp_size : int = 1, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -25,6 +25,9 @@ def init_distributed( Args: seed (int): The random seed. zero_level (int): The ZeRO optimization level. 2 for stage-2, 3 for stage-3. + pipe_size (int) : pipe_size means that all processes will be divided into pipe_size groups + num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode. + tp_size (int) : tp_size means the size of each of tensor parallel group **init_distributed** reads the following environment variables: @@ -107,24 +110,23 @@ def init_distributed( unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode()) config['comm'] = nccl.commInitRank(unique_id, world_size, rank) + topo = config['topology'] if config['pipe_enabled']: config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] - topo = config['topology'] if topo.stage_id == 0: unique_id = nccl.getUniqueId() store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) - if config['tp_size'] > 1: - topo = config['topology'] - if topo.tp_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) - unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) - config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + if topo.tp_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) + config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + if config['tp_size'] > 1: if topo.tp_zero_id == 0: unique_id = nccl.getUniqueId() store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) @@ -166,13 +168,14 @@ def __init__(self,config): topo=topo.view(pp_size,dp_size*tp_size) self.stages = config['pipe_size'] + stage_size = world_size // pp_size for i in range(world_size): - self.pipe_idx = self.rank % pp_size - self.stage_id = self.rank // pp_size + self.pipe_idx = self.rank % stage_size + self.stage_id = self.rank // stage_size self.tp_id = self.rank % tp_size self.tp_idx = self.rank // tp_size - self.zero_idx = self.stage_id if pp_size > 1 else 0 - self.zero_id = self.pipe_idx if pp_size > 1 else self.rank + self.zero_idx = self.stage_id + self.zero_id = self.pipe_idx self.tp_zero_idx = self.tp_id self.tp_zero_id = self.tp_idx if dp_size > 1 else 0 From f08bc832cc8142fea509d4b05485cba4cc525de3 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 24 Aug 2023 11:42:55 +0800 Subject: [PATCH 086/122] offload event wait --- bmtrain/block_layer.py | 10 ++- bmtrain/hook_func.py | 16 ++-- example/layers/attention.py | 158 ++++++++++++++++++------------------ example/models/gpt.py | 5 +- example/train.py | 18 ++-- 5 files changed, 109 insertions(+), 98 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index e0a1bfdd..b8aa1a4c 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -200,7 +200,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offl self._pre_module = [] #save the pre module of self self._ref_count = 0 #incremental in forward and decreasing in backward self._mode = "BLOCK" #BLOCK or ZERO or PIPE - if use_offload: + if use_offload and offload_level != 0: self._mode = "OFFLOAD" self._on_device = False self.offload_level = offload_level @@ -528,6 +528,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self._modules = {} pre_module = None + offload = 0 for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) @@ -536,11 +537,14 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) pre_module = module module._is_first_layer = False module._is_last_layer = False - + if module._mode == "OFFLOAD": + offload+=1 + module.calc_event = torch.cuda.Event() + module.offload_event = torch.cuda.Event() self._modules[str(i)] = module module._idx = i self.add_module(str(i), module) - + print(f"offload layer: {offload}") self._modules[str(0)]._is_first_layer = True self._modules[str(len(modules)-1)]._is_last_layer = True diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 0d841e74..03f0be36 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -177,7 +177,7 @@ def find_pre_module_helper(m): if not hasattr(pre_module._offload_dict, "fp16_storage"): pre_module._offload_dict.make_cpu_storage() pre_module._offload_dict.d2h_memcpy() - pre_module.offload_event = torch.cuda.Event() + # if len(module._next_module) > 0: config["offload_stream"].record_event(pre_module.offload_event) # torch._C._autograd._push_saved_tensors_default_hooks( # pack_hook, unpack_hook @@ -203,10 +203,10 @@ def zero_post_forward(module, inputs, outputs): exit = True if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 - elif module._mode == "OFFLOAD" and module.offload_level == 2: - module.calc_event = torch.cuda.Event() + elif module._mode == "OFFLOAD": + if module.offload_level == 2: + torch._C._autograd._pop_saved_tensors_default_hooks() torch.cuda.current_stream().record_event(module.calc_event) - torch._C._autograd._pop_saved_tensors_default_hooks() if exit: module._forward_block_ctx.exit(forward_flag) module._ref_count += 1 @@ -221,11 +221,14 @@ def zero_pre_backward(module, grad_outputs): if pre_module._mode == "OFFLOAD": pre_module._on_device = True with torch.cuda.stream(config["offload_stream"]): + if (len(module._next_module) != 0): + torch.cuda.current_stream().wait_event(module._next_module[0].calc_event) pre_module._offload_dict.h2d_memcpy() + torch.cuda.current_stream().record_event(pre_module.offload_event) else: current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config["offload_stream"]) - module._offload_dict.record_stream(config["calc_stream"]) + current_stream.wait_event(module.offload_event) + module._offload_dict.record_stream(current_stream) module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) if not module._is_last_layer: @@ -241,6 +244,7 @@ def zero_post_backward(module, grad_inputs, grad_outputs): if module._mode == "OFFLOAD": module._on_device = False module._offload_dict.pop_all() + torch.cuda.current_stream().record_event(module.calc_event) if module._is_first_layer: module.backward_release(backward_flag) else: diff --git a/example/layers/attention.py b/example/layers/attention.py index 54c1ceca..b5d337dc 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -3,8 +3,8 @@ import bmtrain as bmt from bmtrain.nn import Linear import math -from .flash_triton import FlashAttnFunc -class Attention(bmt.DistributedModule): +# from .flash_triton import FlashAttnFunc +class FlashAttention(bmt.DistributedModule): def __init__( self, dim_model: int, @@ -100,83 +100,83 @@ def forward( return score -#class Attention(bmt.DistributedModule): -# def __init__(self, -# dim_model : int, dim_head : int, -# num_heads : int, bias : bool = True, -# dtype = None -# ) -> None: -# super().__init__() -# -# self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) -# self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) -# self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) -# -# self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) -# -# self.softmax = torch.nn.Softmax(dim=-1) -# self.num_heads = num_heads -# self.dim_head = dim_head -# self.dim_model = dim_model -# -# def forward(self, -# hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model) -# hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model) -# mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv) -# position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv) -# ) -> torch.Tensor: -# batch_size, seq_q, dim_model = hidden_q.size() -# seq_kv = hidden_kv.size(1) -# -# h_q : torch.Tensor = self.project_q(hidden_q) -# h_k : torch.Tensor = self.project_k(hidden_kv) -# h_v : torch.Tensor = self.project_v(hidden_kv) -# -# h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) -# h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) -# h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) -# -# h_q = h_q.permute(0, 2, 1, 3).contiguous() -# h_k = h_k.permute(0, 2, 1, 3).contiguous() -# h_v = h_v.permute(0, 2, 1, 3).contiguous() -# -# h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) -# h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) -# h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) -# -# score = torch.bmm( -# h_q, h_k.transpose(1, 2) -# ) -# score = score / math.sqrt(self.dim_head) -# -# score = score.view(batch_size, self.num_heads, seq_q, seq_kv) -# -# if position_bias is not None: -# score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) -# -# score = torch.where( -# mask.view(batch_size, 1, seq_q, seq_kv), -# score, -# torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype) -# ) -# -# score = torch.where( -# mask.view(batch_size, 1, seq_q, seq_kv), -# self.softmax(score), -# torch.scalar_tensor(0, device=score.device, dtype=score.dtype) -# ) -# -# score = score.view(batch_size * self.num_heads, seq_q, seq_kv) -# -# h_out = torch.bmm( -# score, h_v -# ) -# h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) -# h_out = h_out.permute(0, 2, 1, 3).contiguous() -# h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) -# -# attn_out = self.project_out(h_out) -# return attn_out +class Attention(bmt.DistributedModule): + def __init__(self, + dim_model : int, dim_head : int, + num_heads : int, bias : bool = True, + dtype = None + ) -> None: + super().__init__() + + self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + + self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + + self.softmax = torch.nn.Softmax(dim=-1) + self.num_heads = num_heads + self.dim_head = dim_head + self.dim_model = dim_model + + def forward(self, + hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model) + hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model) + mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv) + position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv) + ) -> torch.Tensor: + batch_size, seq_q, dim_model = hidden_q.size() + seq_kv = hidden_kv.size(1) + + h_q : torch.Tensor = self.project_q(hidden_q) + h_k : torch.Tensor = self.project_k(hidden_kv) + h_v : torch.Tensor = self.project_v(hidden_kv) + + h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) + h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) + h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) + + h_q = h_q.permute(0, 2, 1, 3).contiguous() + h_k = h_k.permute(0, 2, 1, 3).contiguous() + h_v = h_v.permute(0, 2, 1, 3).contiguous() + + h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) + h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) + h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) + + score = torch.bmm( + h_q, h_k.transpose(1, 2) + ) + score = score / math.sqrt(self.dim_head) + + score = score.view(batch_size, self.num_heads, seq_q, seq_kv) + + if position_bias is not None: + score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) + + score = torch.where( + mask.view(batch_size, 1, seq_q, seq_kv), + score, + torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype) + ) + + score = torch.where( + mask.view(batch_size, 1, seq_q, seq_kv), + self.softmax(score), + torch.scalar_tensor(0, device=score.device, dtype=score.dtype) + ) + + score = score.view(batch_size * self.num_heads, seq_q, seq_kv) + + h_out = torch.bmm( + score, h_v + ) + h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) + h_out = h_out.permute(0, 2, 1, 3).contiguous() + h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) + + attn_out = self.project_out(h_out) + return attn_out diff --git a/example/models/gpt.py b/example/models/gpt.py index 2778d549..f31aabbc 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -7,7 +7,7 @@ def __init__(self, num_layers : int, vocab_size : int, dim_model : int, dim_head : int, num_heads : int, dim_ff : int, max_distance : int, - bias : bool = True, dtype = None, offload = False, + bias : bool = True, dtype = None, offload = False, offload_level = 0 ) -> None: super().__init__() @@ -18,6 +18,7 @@ def __init__(self, if offload: offload_mask = [True if i%4 == 0 else False for i in range(num_layers)] ckpt_mask = [not offload_mask[i] for i in range(num_layers)] + offload_level = offload_level else: ckpt_mask = [ True for i in range(num_layers) ] offload_mask = [ False for i in range(num_layers) ] @@ -25,7 +26,7 @@ def __init__(self, bmt.CheckpointBlock( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype - ),use_checkpoint=ckpt_mask[i],use_offload=offload_mask[i] + ),use_checkpoint=ckpt_mask[i],use_offload=offload_mask[i],offload_level=offload_level ) for i in range(num_layers) ]) diff --git a/example/train.py b/example/train.py index 3e115d71..44451c9e 100644 --- a/example/train.py +++ b/example/train.py @@ -10,19 +10,21 @@ def main(): seed=0, zero_level=3, ) - offload = False - seq_len = True + offload = True + seq_len = 4096 + offload_level = 0 model = GPT( - num_layers=48, + num_layers=24, vocab_size=80000, - dim_model=4096, - dim_head=128, - num_heads=32, - dim_ff=10240, + dim_model=1024, + dim_head=64, + num_heads=16, + dim_ff=4096, max_distance=seq_len, bias=False, dtype=torch.half, - offload=offload + offload=offload, + offload_level=offload_level ) bmt.init_parameters(model) From 4ff0f419f5428ff38d9863b2afed103babe8baac Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 13:58:52 +0800 Subject: [PATCH 087/122] fix topology --- bmtrain/init.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 17d217b1..8f4f7063 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -176,8 +176,8 @@ def __init__(self,config): self.tp_idx = self.rank // tp_size self.zero_idx = self.stage_id self.zero_id = self.pipe_idx - self.tp_zero_idx = self.tp_id - self.tp_zero_id = self.tp_idx if dp_size > 1 else 0 + self.tp_zero_idx = self.stage_id * tp_size + self.tp_id + self.tp_zero_id = self.pipe_idx // tp_size self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 From a5d7ba63d90009576deaf299e64af20f944bf9b0 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 15:43:56 +0800 Subject: [PATCH 088/122] fix --- bmtrain/nn/parallel_cross_entropy_func.py | 46 ++++++++++------------- bmtrain/nn/parallel_embedding.py | 2 +- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/bmtrain/nn/parallel_cross_entropy_func.py b/bmtrain/nn/parallel_cross_entropy_func.py index 55aa78cb..cd1f63bf 100644 --- a/bmtrain/nn/parallel_cross_entropy_func.py +++ b/bmtrain/nn/parallel_cross_entropy_func.py @@ -6,19 +6,19 @@ class ParallelCrossEntropyFunc(torch.autograd.Function): @staticmethod - def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): + def forward(ctx, logits, target, label_smoothing=0.0): comm = config['tp_comm'] rank = config['topology'].tp_id world_size = config['tp_size'] - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - logits_max = all_reduce(logits_max, op="max", comm=comm) - # Subtract the maximum value. - vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) - - # Get the partition's vocab indecies - #get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size - partition_vocab_size = vocab_parallel_logits.size()[-1] + + # local max + max_logits = torch.max(logits, dim=-1)[0] + # global max + max_logits = all_reduce(max_logits, op="max", comm=comm) + + logits = logits - max_logits.unsqueeze(dim=-1) + + partition_vocab_size = logits.size()[-1] vocab_start_index = rank * partition_vocab_size vocab_end_index = (rank + 1) * partition_vocab_size @@ -27,10 +27,7 @@ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): masked_target = target.clone() - vocab_start_index masked_target[target_mask] = 0 - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + logits_2d = logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) @@ -38,12 +35,12 @@ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. predicted_logits = all_reduce(predicted_logits, op="sum", comm=comm) - # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) + exp_logits = logits + torch.exp(logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) @@ -106,25 +103,20 @@ def backward(ctx, grad_output): else: grad_2d[arange_1d, masked_target_1d] -= softmax_update - # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.flatten(0,1).unsqueeze(dim=-1)) return grad_input, None, None -def parallel_cross_entropy_func(vocab_parallel_logits, target, label_smoothing=0.0): +def parallel_cross_entropy_func(logits, target, label_smoothing=0.0): """ Performs cross entropy loss when logits are split across tensor parallel ranks Arguments: - vocab_parallel_logits: logits split across tensor parallel ranks - dimension is [sequence_length, batch_size, hidden_size] - - target: correct vocab ids of dimseion [sequence_length, micro_batch_size] - - lobal_smoothing: smoothing factor, must be in range [0.0, 1.0) - default is no smoothing (=0.0) + logits: logits split across tensor parallel ranks dimension is [batch * seq_len, hidden_size]. + target: correct vocab ids of dimseion [batch * seq_len]. + lobal_smoothing: smoothing factor, must be in range [0.0, 1.0). default is 0.0. """ - out = ParallelCrossEntropyFunc.apply(vocab_parallel_logits.to(torch.float32), target, label_smoothing) + out = ParallelCrossEntropyFunc.apply(logits.to(torch.float32), target, label_smoothing) return out diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index 1d7a1330..7ffb74ef 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -69,5 +69,5 @@ def projection(self, x: torch.Tensor): split_input = False reduce_output_type = None gather_output = False - out = ParallelLineakFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + out = ParallelLinearFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) return out From 2951d70a2f430d5b54b2f6ad5a6a6b73d2c17846 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 16:01:32 +0800 Subject: [PATCH 089/122] use ParallelEmbedding --- example/models/gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/models/gpt.py b/example/models/gpt.py index d2dff467..64474ba8 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -15,7 +15,7 @@ def __init__(self, self.max_distance = max_distance if config['tp_size'] > 1: - self.word_emb = bmt.nn.Embedding(vocab_size, dim_model, dtype=dtype) + self.word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) else: self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) From 2f4ca8ac2941294fd2333ef12c3a4db91f3fc9b7 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 24 Aug 2023 16:36:56 +0800 Subject: [PATCH 090/122] Offload Correct Version --- bmtrain/hook_func.py | 73 ++++++++++++++++++++---------------------- bmtrain/nn/__init__.py | 2 +- bmtrain/nn/linear.py | 2 +- 3 files changed, 36 insertions(+), 41 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 03f0be36..487a7f25 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -11,7 +11,7 @@ def __init__(self): self._offload_dict = OrderedDict() def add(self, tensor): - tensor_id = id(tensor) + tensor_id = tensor.data_ptr() self._offload_dict[tensor_id] = {} self._offload_dict[tensor_id]["numel"] = tensor.numel() self._offload_dict[tensor_id]['dtype'] = tensor.dtype @@ -54,8 +54,8 @@ def d2h_memcpy(self): fp32_offset = 0 fp16_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) fp32_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) - assert fp16_total <= self.fp16_total - assert fp32_total <= self.fp32_total + assert fp16_total == self.fp16_total + assert fp32_total == self.fp32_total fp16_storage = self.fp16_storage fp32_storage = self.fp32_storage for key,val in self._offload_dict.items(): @@ -64,7 +64,6 @@ def d2h_memcpy(self): offset = fp16_offset if val['dtype'] == torch.float16 else fp32_offset cpu_tensor = torch.tensor([], dtype=val['dtype'], device="cpu") \ .set_(storage, offset, val['shape']) - self._offload_dict[key]['tensor'].record_stream(config['offload_stream']) self._offload_dict[key]['tensor'] = cpu_tensor.copy_(self._offload_dict[key]['tensor'], non_blocking=True) if val['dtype'] == torch.float16: fp16_offset += self._offload_dict[key]['numel'] @@ -115,16 +114,6 @@ def unpack_hook(packed): return tensor return pack_hook, unpack_hook -@contextmanager -def offload_context(module): - if hasattr(module, "_offload_hook"): - pack_hook, unpack_hook = module._offload_hook - torch._C._autograd._push_saved_tensors_default_hooks( - pack_hook, unpack_hook - ) - yield - if hasattr(module, "_offload_hook"): - torch._C._autograd._pop_saved_tensors_default_hooks() def offload_pre_hook(module, input): if hasattr(module, "_offload_hook"): @@ -164,24 +153,22 @@ def find_pre_module_helper(m): pack_hook, unpack_hook ) elif module._mode != "OFFLOAD" and ((len(module._pre_module) > 0) and module._pre_module[0]._mode == "OFFLOAD"): - for pre_module in module._pre_module: - if len(pre_module._pre_module) == 0: - pre_offload_module = None - else: - pre_offload_module = find_pre_module_helper(pre_module._pre_module[0]) - if pre_offload_module is not None: - torch.cuda.current_stream().wait_event(pre_offload_module.offload_event) - if pre_module._mode == "OFFLOAD": - with torch.cuda.stream(config["offload_stream"]): - config["offload_stream"].wait_event(pre_module.calc_event) - if not hasattr(pre_module._offload_dict, "fp16_storage"): - pre_module._offload_dict.make_cpu_storage() - pre_module._offload_dict.d2h_memcpy() - # if len(module._next_module) > 0: + pre_module = module._pre_module[0] + if len(pre_module._pre_module) == 0: + pre_offload_module = None + else: + pre_offload_module = find_pre_module_helper(pre_module._pre_module[0]) + if pre_offload_module is not None: + torch.cuda.current_stream().wait_event(pre_offload_module.offload_event) + if pre_module._mode == "OFFLOAD": + with torch.cuda.stream(config["offload_stream"]): + config["offload_stream"].wait_event(pre_module.calc_event) + if not hasattr(pre_module._offload_dict, "fp16_storage"): + pre_module._offload_dict.make_cpu_storage() + pre_module._offload_dict.record_stream(config["offload_stream"]) + pre_module._offload_dict.d2h_memcpy() + if len(module._next_module) > 0: config["offload_stream"].record_event(pre_module.offload_event) - # torch._C._autograd._push_saved_tensors_default_hooks( - # pack_hook, unpack_hook - # ) if module._mode == "PIPE": enter = module._micro_idx == 0 @@ -212,19 +199,27 @@ def zero_post_forward(module, inputs, outputs): module._ref_count += 1 def zero_pre_backward(module, grad_outputs): + def find_pre_module_helper(m): + if m._mode == "OFFLOAD": + return m + else: + if len(m._pre_module) != 0: + return find_pre_module_helper(m._pre_module[0]) + else: + return None backward_flag = 2 if config['zero_level'] == 2 else 0 if module._mode != "PIPE": if module._mode != "OFFLOAD": count = len([m for m in module._pre_module if m._mode=="OFFLOAD"]) if (len(module._next_module) == 0) or module._next_module[0]._mode == "OFFLOAD": - for pre_module in nearest_offload_module(module): - if pre_module._mode == "OFFLOAD": - pre_module._on_device = True - with torch.cuda.stream(config["offload_stream"]): - if (len(module._next_module) != 0): - torch.cuda.current_stream().wait_event(module._next_module[0].calc_event) - pre_module._offload_dict.h2d_memcpy() - torch.cuda.current_stream().record_event(pre_module.offload_event) + pre_module = find_pre_module_helper(module) + if pre_module is not None: + pre_module._on_device = True + with torch.cuda.stream(config["offload_stream"]): + if (len(module._next_module) != 0): + torch.cuda.current_stream().wait_event(module._next_module[0].calc_event) + pre_module._offload_dict.h2d_memcpy() + torch.cuda.current_stream().record_event(pre_module.offload_event) else: current_stream = torch.cuda.current_stream() current_stream.wait_event(module.offload_event) diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 67f9fdee..7002ffe4 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1 +1 @@ -from .linear import Linear +from .linear import Linear,OpLinear \ No newline at end of file diff --git a/bmtrain/nn/linear.py b/bmtrain/nn/linear.py index faf0770e..e2c9cd65 100644 --- a/bmtrain/nn/linear.py +++ b/bmtrain/nn/linear.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import bmtrain as bmt -class CustomLinear(torch.autograd.Function): +class OpLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias=None): ctx.save_for_backward(x, weight, bias) From 39319e117e7b17aece40330a056956331cf40856 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 19:55:42 +0800 Subject: [PATCH 091/122] overlap parallel linear backward --- bmtrain/nn/parallel_linear_func.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index 5c7a30bc..252162b2 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -73,13 +73,17 @@ def backward(ctx, grad_output): #TODO: gather on load_stream grad_all_input = grad_output.matmul(weight) grad_input = torch.empty_like(input) + current_stream = torch.cuda.current_stream() + config['tp_comm_stream'].wait_stream(current_stream) if ctx.gather_input: - nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) + with torch.cuda.stream(config['tp_comm_stream']): + nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) else: grad_input = grad_all_input if ctx.split_input: - grad_input = all_gather(grad_input, config['tp_comm']) + with torch.cuda.stream(config['tp_comm_stream']): + grad_input = all_gather(grad_input, config['tp_comm']) if weight.requires_grad: dim = grad_output.dim() @@ -88,4 +92,7 @@ def backward(ctx, grad_output): if bias is not None and bias.requires_grad: grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config['tp_comm_stream']) return grad_input, grad_weight, grad_bias, None, None, None, None From df3fd8f8452c8f3eea982d375aa5847c64bd0551 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 24 Aug 2023 20:28:46 +0800 Subject: [PATCH 092/122] add tp_comm_stream --- bmtrain/init.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index 8f4f7063..9126f8d8 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -75,6 +75,7 @@ def init_distributed( config["world_size"] = world_size config["calc_stream"] = torch.cuda.current_stream() config["load_stream"] = torch.cuda.Stream(priority=-1) + config["tp_comm_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() config["zero_level"] = zero_level @@ -126,12 +127,12 @@ def init_distributed( unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) - if config['tp_size'] > 1: - if topo.tp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) - config ['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + #if config['tp_size'] > 1: + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config ['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) if topo.zero_id == 0: unique_id = nccl.getUniqueId() From 99efba3b307122f381db31ab243f154ff46124f9 Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Thu, 24 Aug 2023 21:25:08 +0800 Subject: [PATCH 093/122] fix tp --- bmtrain/block_layer.py | 3 ++- bmtrain/init.py | 12 ++++++------ bmtrain/loss/cross_entropy.py | 4 +++- bmtrain/nn/__init__.py | 3 ++- bmtrain/nn/column_parallel_linear.py | 4 ++-- bmtrain/nn/linear.py | 4 ++-- bmtrain/nn/parallel_embedding.py | 5 ++--- bmtrain/nn/parallel_linear_func.py | 5 ++++- bmtrain/nn/row_parallel_linear.py | 4 ++-- setup.py | 2 +- 10 files changed, 26 insertions(+), 20 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index c479a7e5..e0a2c3e6 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -310,7 +310,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - verify_shape = it["shape"] if not tp_mode else param._tp_original_shape + verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape) if input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' @@ -335,6 +335,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() + tp_split_dim = param._tp_split_dim if tp_mode and tp_split_dim >= 0: contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim) diff --git a/bmtrain/init.py b/bmtrain/init.py index 8f4f7063..eb176ad4 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -81,6 +81,7 @@ def init_distributed( config["tp_size"] = tp_size if tp_size > 0 else 1 config["topology"] = topology(config) config["zero_rank"] = config['topology'].get_group_rank("zero") + config["tp_rank"] = config['topology'].get_group_rank("tp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") cpus_this_worker = None @@ -126,12 +127,11 @@ def init_distributed( unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) - if config['tp_size'] > 1: - if topo.tp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) - config ['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) if topo.zero_id == 0: unique_id = nccl.getUniqueId() diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 962cfdeb..c4b36a7f 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -180,6 +180,7 @@ def __init__(self, reduction: str = 'mean', label_smoothing: float = 0.0, # TODO not supported yet inplace: bool = False, + parallel: bool = False, ) -> None: super().__init__() self.weight = weight @@ -187,9 +188,10 @@ def __init__(self, self.reduction = reduction self.label_smoothing = label_smoothing self.inplace = inplace + self.parallel = parallel def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - if config['tp_size'] > 1: + if self.parallel: target = all_gather(target, comm=config['tp_comm']).flatten(0,1) ret = parallel_cross_entropy_func(input, target.long(), self.ignore_index) else: diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 05026738..e22d8c55 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1,5 +1,6 @@ -from .linear import Linear +from .linear import Linear, OpLinear from .column_parallel_linear import ColumnParallelLinear from .row_parallel_linear import RowParallelLinear from .parallel_embedding import ParallelEmbedding from .parallel_cross_entropy_func import parallel_cross_entropy_func +from .parallel_linear_func import OpParallelLinear \ No newline at end of file diff --git a/bmtrain/nn/column_parallel_linear.py b/bmtrain/nn/column_parallel_linear.py index 2baf79c2..e8f554c8 100644 --- a/bmtrain/nn/column_parallel_linear.py +++ b/bmtrain/nn/column_parallel_linear.py @@ -4,7 +4,7 @@ import bmtrain as bmt from bmtrain.global_var import config from .parallel_linear_func import ( - ParallelLinearFunc, + OpParallelLinear, ReduceType) class ColumnParallelLinear(bmt.DistributedModule): @@ -27,7 +27,7 @@ def forward(self, input): gather_input = self.gather_input split_input = False reduce_output_type = None - return ParallelLinearFunc.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) + return OpParallelLinear.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( diff --git a/bmtrain/nn/linear.py b/bmtrain/nn/linear.py index faf0770e..cb04863a 100644 --- a/bmtrain/nn/linear.py +++ b/bmtrain/nn/linear.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import bmtrain as bmt -class CustomLinear(torch.autograd.Function): +class OpLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias=None): ctx.save_for_backward(x, weight, bias) @@ -35,7 +35,7 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - return CustomLinear.apply(input, self.weight, self.bias) + return OpLinear.apply(input, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index 7ffb74ef..cd567b4e 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -6,7 +6,7 @@ import bmtrain as bmt from bmtrain.global_var import config from bmtrain.distributed import all_reduce, all_gather -from .parallel_linear_func import ParallelLinearFunc +from .parallel_linear_func import OpParallelLinear class ParallelEmbedding(bmt.DistributedModule): def __init__( @@ -35,7 +35,6 @@ def forward(self, ids: torch.Tensor, gather_input=True): """ Args: ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens. - gather_input (bool) : whether gather input is required between tensor parallel group) Return: :obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output. """ # noqa: E501 @@ -69,5 +68,5 @@ def projection(self, x: torch.Tensor): split_input = False reduce_output_type = None gather_output = False - out = ParallelLinearFunc.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + out = OpParallelLinear.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) return out diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index 5c7a30bc..dc2d3816 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -20,9 +20,12 @@ def preprocess_input(input, gather_input, split_input): input = all_input_list[config['topology'].tp_id] return input -class ParallelLinearFunc(torch.autograd.Function): +class OpParallelLinear(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None): + if reduce_output_type is not None: + reduce_output_type = ReduceType(reduce_output_type) + ctx.save_for_backward(input, weight, bias) ctx.gather_output = gather_output ctx.split_input = split_input diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py index acc9378d..71a4297d 100644 --- a/bmtrain/nn/row_parallel_linear.py +++ b/bmtrain/nn/row_parallel_linear.py @@ -4,7 +4,7 @@ import bmtrain as bmt from bmtrain.global_var import config from .parallel_linear_func import ( - ParallelLinearFunc, + OpParallelLinear, ReduceType) class RowParallelLinear(bmt.DistributedModule): @@ -27,7 +27,7 @@ def forward(self, input): gather_input = self.split_input gather_output = False reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER - out = ParallelLinearFunc.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) + out = OpParallelLinear.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) out = out + self.bias return out diff --git a/setup.py b/setup.py index 2bbb55d8..ad1c8905 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ def build_extension(self, ext): ] setup( name='bmtrain', - version='0.2.3.post2', + version='0.2.3.post3', author="Guoyang Zeng", author_email="qbjooo@qq.com", description="A toolkit for training big models", From 9f8a5b437c2cadd22a5e65114693e6e3f9fa2edb Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 25 Aug 2023 10:42:42 +0800 Subject: [PATCH 094/122] new hook storage --- bmtrain/hook_func.py | 58 +++++++++++++++++++++++++++----------------- bmtrain/nn/linear.py | 2 +- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 487a7f25..3196e3ff 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -11,23 +11,31 @@ def __init__(self): self._offload_dict = OrderedDict() def add(self, tensor): - tensor_id = tensor.data_ptr() - self._offload_dict[tensor_id] = {} - self._offload_dict[tensor_id]["numel"] = tensor.numel() - self._offload_dict[tensor_id]['dtype'] = tensor.dtype - self._offload_dict[tensor_id]["tensor"] = tensor - self._offload_dict[tensor_id]["shape"] = tensor.shape + tensor_id = id(tensor) + data_ptr = tensor.data_ptr() + if data_ptr not in self._offload_dict: + self._offload_dict[data_ptr] = {} + self._offload_dict[data_ptr]["stor"] = tensor.storage() + self._offload_dict[data_ptr]["size"] = tensor.storage().size() + self._offload_dict[data_ptr]["dtype"] = tensor.storage().dtype + self._offload_dict[data_ptr]["tensors"] = {} + self._offload_dict[data_ptr]["tensors"][id(tensor)] = {} + self._offload_dict[data_ptr]["tensors"][id(tensor)]["numel"] = tensor.numel() + self._offload_dict[data_ptr]["tensors"][id(tensor)]['dtype'] = tensor.dtype + self._offload_dict[data_ptr]["tensors"][id(tensor)]['offset'] = tensor.storage_offset() + self._offload_dict[data_ptr]["tensors"][id(tensor)]['tensor'] = tensor + self._offload_dict[data_ptr]["tensors"][id(tensor)]["shape"] = tensor.shape self._device = "cuda" - return tensor_id + return (data_ptr,tensor_id) def get_total(self): - fp16_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) - fp32_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) return fp16_total,fp32_total def make_cpu_storage(self): - fp16_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) - fp32_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) fp16_storage = torch.HalfStorage(fp16_total).pin_memory() fp32_storage = torch.FloatStorage(fp32_total).pin_memory() self.fp16_storage = fp16_storage @@ -36,24 +44,29 @@ def make_cpu_storage(self): self.fp32_total = fp32_total def get(self, key): - return self._offload_dict[key]["tensor"] + data_ptr, tensor_id = key + return self._offload_dict[data_ptr]['tensors'][tensor_id]["tensor"] def pop_all(self): self._offload_dict.clear() def h2d_memcpy(self): for key,val in self._offload_dict.items(): - self._offload_dict[key]['tensor'] = self._offload_dict[key]['tensor'].cuda(non_blocking=True) + val['stor'] = val['stor'].cuda(non_blocking=True) + for id_val in val['tensors'].values(): + id_val['tensor'] = torch.tensor([], dtype=id_val['dtype'],device=val['stor'].device) + id_val['tensor'].set_(val['stor'], id_val['offset'], id_val['shape']) def record_stream(self, stream): for key, val in self._offload_dict.items(): - self._offload_dict[key]['tensor'].record_stream(stream) + for id_val in val['tensors'].values(): + id_val['tensor'].record_stream(stream) def d2h_memcpy(self): fp16_offset = 0 fp32_offset = 0 - fp16_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) - fp32_total = sum([v['numel'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) assert fp16_total == self.fp16_total assert fp32_total == self.fp32_total fp16_storage = self.fp16_storage @@ -62,16 +75,17 @@ def d2h_memcpy(self): assert val['dtype'] in [torch.float16, torch.float32] storage = fp16_storage if val['dtype'] == torch.float16 else fp32_storage offset = fp16_offset if val['dtype'] == torch.float16 else fp32_offset - cpu_tensor = torch.tensor([], dtype=val['dtype'], device="cpu") \ - .set_(storage, offset, val['shape']) - self._offload_dict[key]['tensor'] = cpu_tensor.copy_(self._offload_dict[key]['tensor'], non_blocking=True) + for id_val in val['tensors'].values(): + cpu_tensor = torch.tensor([], dtype=id_val['dtype'], device="cpu") \ + .set_(storage, offset+id_val['offset'], id_val['shape']) + id_val['tensor'] = cpu_tensor.copy_(id_val['tensor'], non_blocking=True) if val['dtype'] == torch.float16: - fp16_offset += self._offload_dict[key]['numel'] + fp16_offset += val['size'] else: - fp32_offset += self._offload_dict[key]['numel'] + fp32_offset += val['size'] def nearest_offload_module(module): - queue = deque([(module, 0)]) # 使用队列来进行广度优先搜索 + queue = deque([(module, 0)]) nearest_modules = [] nearest_depth = float('inf') diff --git a/bmtrain/nn/linear.py b/bmtrain/nn/linear.py index e2c9cd65..cb04863a 100644 --- a/bmtrain/nn/linear.py +++ b/bmtrain/nn/linear.py @@ -35,7 +35,7 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - return CustomLinear.apply(input, self.weight, self.bias) + return OpLinear.apply(input, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( From 725fe573ba23581cb291db0253c215d0e3be17c8 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 25 Aug 2023 10:51:40 +0800 Subject: [PATCH 095/122] Offload storage function fix --- bmtrain/hook_func.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 3196e3ff..196c8c24 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -11,8 +11,9 @@ def __init__(self): self._offload_dict = OrderedDict() def add(self, tensor): + tensor = tensor.contiguous() tensor_id = id(tensor) - data_ptr = tensor.data_ptr() + data_ptr = id(tensor.storage()) if data_ptr not in self._offload_dict: self._offload_dict[data_ptr] = {} self._offload_dict[data_ptr]["stor"] = tensor.storage() From ec63e1b3a608b67116c349f306fe3ce4c0c96d89 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 25 Aug 2023 11:19:44 +0800 Subject: [PATCH 096/122] storage dont release fix --- bmtrain/hook_func.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 196c8c24..09082805 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -52,11 +52,16 @@ def pop_all(self): self._offload_dict.clear() def h2d_memcpy(self): + fp16_storage_cuda = self.fp16_storage.cuda(non_blocking=True) + fp32_storage_cuda = self.fp32_storage.cuda(non_blocking=True) for key,val in self._offload_dict.items(): - val['stor'] = val['stor'].cuda(non_blocking=True) for id_val in val['tensors'].values(): - id_val['tensor'] = torch.tensor([], dtype=id_val['dtype'],device=val['stor'].device) - id_val['tensor'].set_(val['stor'], id_val['offset'], id_val['shape']) + id_val['tensor'] = torch.tensor([], dtype=id_val['dtype'],device=fp16_storage_cuda.device) + if id_val['dtype'] == torch.float16: + id_val['tensor'].set_(fp16_storage_cuda, id_val['abs_offset'], id_val['shape']) + elif id_val['dtype'] == torch.float32: + id_val['tensor'].set_(fp32_storage_cuda, id_val['abs_offset'], id_val['shape']) + def record_stream(self, stream): for key, val in self._offload_dict.items(): @@ -68,8 +73,8 @@ def d2h_memcpy(self): fp32_offset = 0 fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) - assert fp16_total == self.fp16_total - assert fp32_total == self.fp32_total + assert fp16_total <= self.fp16_total + assert fp32_total <= self.fp32_total fp16_storage = self.fp16_storage fp32_storage = self.fp32_storage for key,val in self._offload_dict.items(): @@ -79,11 +84,13 @@ def d2h_memcpy(self): for id_val in val['tensors'].values(): cpu_tensor = torch.tensor([], dtype=id_val['dtype'], device="cpu") \ .set_(storage, offset+id_val['offset'], id_val['shape']) + id_val["abs_offset"] = offset+id_val['offset'] id_val['tensor'] = cpu_tensor.copy_(id_val['tensor'], non_blocking=True) if val['dtype'] == torch.float16: fp16_offset += val['size'] else: fp32_offset += val['size'] + val['stor'] = None def nearest_offload_module(module): queue = deque([(module, 0)]) From f1b4fd7127124eb98d9ec32f16424f81312e32b5 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 12:50:32 +0800 Subject: [PATCH 097/122] fix load_state_dict --- bmtrain/layer.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/bmtrain/layer.py b/bmtrain/layer.py index ebbef815..7de83e5e 100644 --- a/bmtrain/layer.py +++ b/bmtrain/layer.py @@ -1,6 +1,8 @@ import torch from .parameter import DistributedParameter +from .global_var import config import itertools +from .utils import tp_split_tensor class DistributedModule(torch.nn.Module): """ @@ -11,7 +13,7 @@ class DistributedModule(torch.nn.Module): def __getattr__(self, name: str): ret = super().__getattr__(name) # gather distributed parameters if not in CheckpointBlock - if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block: + if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block: return ret.gather() return ret @@ -30,8 +32,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - if isinstance(param, DistributedParameter) and not param._in_checkpoint_block: - destination[prefix + name] = param.gather().detach().cpu() # sync operation + if isinstance(param, DistributedParameter):#and not param._in_checkpoint_block: + if param._in_checkpoint_block: + destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation + else: + destination[prefix + name] = param.gather_all().detach().cpu() # sync operation else: destination[prefix + name] = param if keep_vars else param.detach().cpu() for name, buf in self._buffers.items(): @@ -81,6 +86,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, for name, param in local_state.items(): key = prefix + name if key in state_dict: + tp_mode = param._tp_mode input_param = state_dict[key] if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() @@ -98,13 +104,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 'the shape in current model is {}.' .format(key, input_param.shape, param.shape)) continue - if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != param._original_shape: + verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape) + if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' - .format(key, input_param.shape, param.shape)) + .format(key, input_param.shape, verify_shape)) try: with torch.no_grad(): if isinstance(param, DistributedParameter): + tp_split_dim = param._tp_split_dim + if tp_mode and tp_split_dim >= 0: + input_param = tp_split_tensor(input_param, tp_split_dim) param._copy_data(input_param) else: param.copy_(input_param) From 677a316228fd623d5b7e4e71a01676f2927ffc0e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 15:23:08 +0800 Subject: [PATCH 098/122] test parallel linear --- tests/test_all.py | 2 + tests/test_column_parallel_linear.py | 55 ++++++++++++++++++++++++++++ tests/test_row_parallel_linear.py | 54 +++++++++++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 tests/test_column_parallel_linear.py create mode 100644 tests/test_row_parallel_linear.py diff --git a/tests/test_all.py b/tests/test_all.py index 6682aa93..aa382676 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -25,6 +25,8 @@ ("send_recv", 4), ("nccl_backward", 4), ("no_grad", 1), + ("column_parallel_linear", 2), + ("row_parallel_linear", 2), ("training", 4), ]) diff --git a/tests/test_column_parallel_linear.py b/tests/test_column_parallel_linear.py new file mode 100644 index 00000000..1c49570a --- /dev/null +++ b/tests/test_column_parallel_linear.py @@ -0,0 +1,55 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np + +def run_bmt(x, gather_output, ckp_path, tp_size=2): + linear = bmt.nn.ColumnParallelLinear(8,8, gather_output=gather_output) + linear = bmt.CheckpointBlock(linear) + bmt.init_parameters(linear) + y = linear(x[config['topology'].tp_id]) + y.sum().backward() + bmt.save(linear, ckp_path) + bmt.synchronize() + return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad + +def run_torch(x, ckp_path): + linear = torch.nn.Linear(8, 8) + linear_dict = torch.load(ckp_path) + linear.load_state_dict(linear_dict) + linear = linear.cuda() + linear.weight.requires_grad_() + y = linear(x) + y.sum().backward() + return y, linear.weight.grad, linear.bias.grad + +def run(gather_output, ckp_path): + tp_size = bmt.config['tp_size'] + torch.cuda.manual_seed(100) + x = torch.randn(tp_size, 8,8, device='cuda').requires_grad_() + y1, weight_grad1, bias_grad1 = run_bmt(x, gather_output, ckp_path) + y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) + tp_rank = config['topology'].tp_id + if gather_output: + assert np.allclose(y1.detach().cpu().numpy(), y2.flatten(0,1).detach().cpu().numpy()) + else: + torch_out_list = torch.split(y2, y2.size()[-1] // tp_size, dim=y2.dim()-1) + assert np.allclose(y1.detach().cpu().numpy(), torch_out_list[tp_rank].flatten(0,1).detach().cpu().numpy()) + + weight_grad_list = weight_grad2.chunk(tp_size, dim=0) + assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) + + bias_grad_list = bias_grad2.chunk(tp_size, dim=0) + assert np.allclose(bias_grad1.reshape(bias_grad_list[tp_rank].shape).cpu().numpy(), bias_grad_list[tp_rank].cpu().numpy()) + +def test_gather_output(): + run(True, 'linear.ckp') + +def test_no_gather_output(): + run(False, 'linear_no_gather.ckp') + +if __name__ == "__main__": + bmt.init_distributed(tp_size=2) + test_gather_output() + test_no_gather_output() + diff --git a/tests/test_row_parallel_linear.py b/tests/test_row_parallel_linear.py new file mode 100644 index 00000000..f89b6dc5 --- /dev/null +++ b/tests/test_row_parallel_linear.py @@ -0,0 +1,54 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np + +def run_bmt(x, ckp_path, split_input=True, use_checkpoint_block=True): + linear = bmt.nn.RowParallelLinear(8,8, split_input=split_input, all_reduce_output=True) + if use_checkpoint_block: + linear = bmt.CheckpointBlock(linear) + bmt.init_parameters(linear) + y = linear(x) + y.sum().backward() + bmt.save(linear, ckp_path) + bmt.synchronize() + return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad + +def run_torch(x, ckp_path): + linear = torch.nn.Linear(8, 8) + linear_dict = torch.load(ckp_path) + linear.load_state_dict(linear_dict) + linear = linear.cuda() + linear.weight.requires_grad_() + y = linear(x) + y.sum().backward() + return y, linear.weight.grad, linear.bias.grad + +def run(split_input, use_checkpoint_block, ckp_path): + tp_size = bmt.config['tp_size'] + torch.cuda.manual_seed(100) + tp_rank = config['topology'].tp_id + x = torch.randn(8,8, device='cuda').requires_grad_() + rank_x = x.chunk(tp_size, dim=0 if split_input else 1)[tp_rank] + y1, weight_grad1, bias_grad1 = run_bmt(rank_x, ckp_path, split_input, use_checkpoint_block) + y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) + np.testing.assert_allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy(), atol=1e-5) + + weight_grad_list = weight_grad2.chunk(tp_size, dim=1) + assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) + + assert np.allclose(bias_grad1.cpu().numpy(), bias_grad2.cpu().numpy()) + +def test_split_input(): + run(True, False, 'row_parallel_linear.ckp') + run(True, True, 'row_parallel_linear.ckp') + +def test_no_split_input(): + run(False, False, 'row_parallel_linear_no_split.ckp') + run(False, True, 'row_parallel_linear_no_split.ckp') + +if __name__ == "__main__": + bmt.init_distributed(tp_size=2) + test_no_split_input() + test_split_input() + From 743253e8f31745000c6f3c73c0cb599cc6837166 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 15:39:10 +0800 Subject: [PATCH 099/122] mv zero_level to CheckpointBlock --- bmtrain/block_layer.py | 3 ++- bmtrain/hook_func.py | 11 +++++++---- bmtrain/init.py | 3 --- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 18438b8c..4e9a6c01 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -66,7 +66,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3): super().__init__() self._module = inner_module self._inputs = None @@ -200,6 +200,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): self._mode = "BLOCK" #BLOCK or ZERO or PIPE self.all_input_no_grad = False self.all_param_no_grad = False + self._zero_level = zero_level def set_pre_module(self, pre_module): if pre_module is not None: diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 6a56300e..08d37b5d 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -9,7 +9,7 @@ def zero_pre_forward(module, inputs): enter = module._micro_idx == 0 pipe = True if enter: - zero_level = config['zero_level'] + zero_level = module._zero_level #config['zero_level'] forward_flag = 1 if zero_level == 2 else 0 if zero_level == 2 and module._ref_count > 1: forward_flag = 2 # repeating forward in same layer @@ -19,7 +19,8 @@ def zero_pre_forward(module, inputs): module._forward_block_ctx.enter(forward_flag) def zero_post_forward(module, inputs, outputs): - forward_flag = 1 if config['zero_level'] == 2 else 0 + #forward_flag = 1 if config['zero_level'] == 2 else 0 + forward_flag = 1 if module._zero_level == 2 else 0 if module.all_param_no_grad: forward_flag = 0 exit = True @@ -31,7 +32,8 @@ def zero_post_forward(module, inputs, outputs): module._ref_count += 1 def zero_pre_backward(module, grad_outputs): - backward_flag = 2 if config['zero_level'] == 2 else 0 + #backward_flag = 2 if config['zero_level'] == 2 else 0 + backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) @@ -43,7 +45,8 @@ def zero_pre_backward(module, grad_outputs): module._backward_block_ctx.enter(backward_flag, True) def zero_post_backward(module, grad_inputs, grad_outputs): - backward_flag = 2 if config['zero_level'] == 2 else 0 + #backward_flag = 2 if config['zero_level'] == 2 else 0 + backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": if module._is_first_layer: module.backward_release(backward_flag) diff --git a/bmtrain/init.py b/bmtrain/init.py index 1fa0712d..f3c1faa7 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -14,7 +14,6 @@ def init_distributed( init_method : str = "env://", seed : int = 0, - zero_level: int = 3, pipe_size: int = -1, num_micro_batches: int = None, ): @@ -24,7 +23,6 @@ def init_distributed( Args: seed (int): The random seed. - zero_level (int): The ZeRO optimization level. 2 for stage-2, 3 for stage-3. **init_distributed** reads the following environment variables: @@ -74,7 +72,6 @@ def init_distributed( config["load_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() - config["zero_level"] = zero_level config["topology"] = topology(config) config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] cpus_this_worker = None From 849382864b487231bb904434432992070f7bf1c2 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 25 Aug 2023 15:43:24 +0800 Subject: [PATCH 100/122] use dataptr as storage id --- bmtrain/hook_func.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 09082805..7a0465f6 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -4,6 +4,7 @@ from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations from collections import deque,OrderedDict from contextlib import contextmanager +from .utils import round_up class Offload_Dict: @@ -13,7 +14,7 @@ def __init__(self): def add(self, tensor): tensor = tensor.contiguous() tensor_id = id(tensor) - data_ptr = id(tensor.storage()) + data_ptr = tensor.storage().data_ptr() if data_ptr not in self._offload_dict: self._offload_dict[data_ptr] = {} self._offload_dict[data_ptr]["stor"] = tensor.storage() From 8919f18bb44f27edbce05e494e2a42a4e04b1e85 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 25 Aug 2023 16:29:47 +0800 Subject: [PATCH 101/122] fix prev confilct --- bmtrain/block_layer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 51e76dd5..03aa30c9 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -66,11 +66,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ -<<<<<<< HEAD def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offload=False, offload_level=0, zero_level=3): -======= - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3): ->>>>>>> 7d62a181c248c332f1ef5223a02c63c8de6e27ca super().__init__() self._module = inner_module self._inputs = None From 604ddfee44578148794561c7dd50a122c7110b9a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 16:53:19 +0800 Subject: [PATCH 102/122] fix overlap --- bmtrain/nn/parallel_linear_func.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index 9e76db9e..df4fb0b5 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -31,8 +31,10 @@ def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=Fal ctx.split_input = split_input ctx.gather_input = gather_input ctx.reduce_output_type = reduce_output_type + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) out = F.linear(all_input, weight, bias) + if gather_output: all_output_list = all_gather(out, config['tp_comm']) all_output_list = all_output_list.chunk(config['tp_size'], dim=0) @@ -44,12 +46,15 @@ def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=Fal if reduce_output_type == ReduceType.ALL_REDUCE: nccl.allReduce(out.storage(), out.storage(), "sum", config['tp_comm']) return out + elif reduce_output_type == ReduceType.REDUCE_SCATTER: shape = list(out.shape) shape[0] = shape[0] // config['tp_size'] reduce_out = torch.empty(shape, dtype=out.dtype, device=out.device) nccl.reduceScatter(out.storage(), reduce_out.storage(), "sum", config['tp_comm']) return reduce_out + else: + assert False, "no support reduce type{}".format(reduce_output_type) @staticmethod def backward(ctx, grad_output): @@ -72,20 +77,23 @@ def backward(ctx, grad_output): all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) if input.requires_grad: - #gather can async with grad_out.matmul(weight) - #TODO: gather on load_stream grad_all_input = grad_output.matmul(weight) grad_input = torch.empty_like(input) - current_stream = torch.cuda.current_stream() - config['tp_comm_stream'].wait_stream(current_stream) if ctx.gather_input: with torch.cuda.stream(config['tp_comm_stream']): + current_stream = torch.cuda.current_stream() + config['tp_comm_stream'].wait_stream(current_stream) + grad_input.record_stream(config['tp_comm_stream']) + grad_all_input.record_stream(config['tp_comm_stream']) nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) else: grad_input = grad_all_input if ctx.split_input: with torch.cuda.stream(config['tp_comm_stream']): + current_stream = torch.cuda.current_stream() + config['tp_comm_stream'].wait_stream(current_stream) + grad_input.record_stream(config['tp_comm_stream']) grad_input = all_gather(grad_input, config['tp_comm']) if weight.requires_grad: From 0aee817ef193e8f00a44b46017f60e1bcdc8fd9f Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 21:25:49 +0800 Subject: [PATCH 103/122] gather once in atten --- bmtrain/nn/parallel_linear_func.py | 3 +-- example/layers/attention.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index df4fb0b5..f7c4573a 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -77,11 +77,11 @@ def backward(ctx, grad_output): all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) if input.requires_grad: + current_stream = torch.cuda.current_stream() grad_all_input = grad_output.matmul(weight) grad_input = torch.empty_like(input) if ctx.gather_input: with torch.cuda.stream(config['tp_comm_stream']): - current_stream = torch.cuda.current_stream() config['tp_comm_stream'].wait_stream(current_stream) grad_input.record_stream(config['tp_comm_stream']) grad_all_input.record_stream(config['tp_comm_stream']) @@ -91,7 +91,6 @@ def backward(ctx, grad_output): if ctx.split_input: with torch.cuda.stream(config['tp_comm_stream']): - current_stream = torch.cuda.current_stream() config['tp_comm_stream'].wait_stream(current_stream) grad_input.record_stream(config['tp_comm_stream']) grad_input = all_gather(grad_input, config['tp_comm']) diff --git a/example/layers/attention.py b/example/layers/attention.py index a49edabb..61eeb9f2 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -19,9 +19,9 @@ def __init__(self, super().__init__() if config['tp_size'] > 1: - self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) else: self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) @@ -44,9 +44,12 @@ def forward(self, batch_size, seq_q, dim_model = hidden_q.size() seq_kv = hidden_kv.size(1) + if config['tp_size'] > 1: + hidden_q = all_gather(hidden_q, comm=config['tp_comm']).flatten(0,1) + h_q : torch.Tensor = self.project_q(hidden_q) - h_k : torch.Tensor = self.project_k(hidden_kv) - h_v : torch.Tensor = self.project_v(hidden_kv) + h_k : torch.Tensor = self.project_k(hidden_q) + h_v : torch.Tensor = self.project_v(hidden_q) if config['tp_size'] > 1: #batch_size will changed in TensorParallel batch_size = h_v.shape[0] @@ -74,7 +77,8 @@ def forward(self, score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) if config['tp_size'] > 1: - mask = all_gather(mask, config['tp_comm']) + with torch.no_grad(): + mask = all_gather(mask, config['tp_comm']).flatten(0,1) score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), From bd0bad0e647c7387836a1f515ae18f6f0d7e6b42 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 22:25:17 +0800 Subject: [PATCH 104/122] fix sub grad_input in parallel linear --- bmtrain/nn/parallel_linear_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index f7c4573a..1f0a362a 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -79,7 +79,7 @@ def backward(ctx, grad_output): if input.requires_grad: current_stream = torch.cuda.current_stream() grad_all_input = grad_output.matmul(weight) - grad_input = torch.empty_like(input) + grad_input = torch.zeros_like(input) if ctx.gather_input: with torch.cuda.stream(config['tp_comm_stream']): config['tp_comm_stream'].wait_stream(current_stream) From 15460b6da3722b3d82aa4ef136e4a650331da0e5 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 26 Aug 2023 11:16:23 +0800 Subject: [PATCH 105/122] fix gather_output --- bmtrain/nn/parallel_linear_func.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index 1f0a362a..55741530 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -16,7 +16,7 @@ def preprocess_input(input, gather_input, split_input): input = input.flatten(0, 1) if split_input: - all_input_list = input.chunk(config['tp_size'], dim=1) + all_input_list = input.chunk(config['tp_size'], dim=-1) input = all_input_list[config['topology'].tp_id] return input @@ -68,7 +68,7 @@ def backward(ctx, grad_output): if gather_output: tp_size = config['tp_size'] tp_id = config['topology'].tp_id - grad_output_list = grad_output.chunk(tp_size, dim=1) + grad_output_list = grad_output.chunk(tp_size, dim=-1) grad_output = grad_output_list[tp_id] grad_input = grad_weight = grad_bias = None From 66a04f3f7957f746e38bc2e84e508e74e6900c0d Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Sat, 26 Aug 2023 13:13:46 +0800 Subject: [PATCH 106/122] better overlap --- bmtrain/hook_func.py | 81 ++++++++++++++++++-------------------------- 1 file changed, 33 insertions(+), 48 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 4151ef94..5f7ae2d4 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -92,6 +92,13 @@ def d2h_memcpy(self): else: fp32_offset += val['size'] val['stor'] = None +def find_pre_module_helper(m): + if len(m) == 0: + return None + if m[0]._mode == "OFFLOAD": + return m[0] + else: + return find_pre_module_helper(m[0]._pre_module) def nearest_offload_module(module): queue = deque([(module, 0)]) @@ -150,13 +157,6 @@ def offload_post_hook(module, input, output): torch._C._autograd._pop_saved_tensors_default_hooks() def zero_pre_forward(module, inputs): - def find_pre_module_helper(m): - if m._mode == "OFFLOAD": - return m - elif m._is_first_layer: - return None - else: - return find_pre_module_helper(m._pre_module[0]) enter = True pipe = False if module._mode == "OFFLOAD": @@ -175,23 +175,6 @@ def find_pre_module_helper(m): torch._C._autograd._push_saved_tensors_default_hooks( pack_hook, unpack_hook ) - elif module._mode != "OFFLOAD" and ((len(module._pre_module) > 0) and module._pre_module[0]._mode == "OFFLOAD"): - pre_module = module._pre_module[0] - if len(pre_module._pre_module) == 0: - pre_offload_module = None - else: - pre_offload_module = find_pre_module_helper(pre_module._pre_module[0]) - if pre_offload_module is not None: - torch.cuda.current_stream().wait_event(pre_offload_module.offload_event) - if pre_module._mode == "OFFLOAD": - with torch.cuda.stream(config["offload_stream"]): - config["offload_stream"].wait_event(pre_module.calc_event) - if not hasattr(pre_module._offload_dict, "fp16_storage"): - pre_module._offload_dict.make_cpu_storage() - pre_module._offload_dict.record_stream(config["offload_stream"]) - pre_module._offload_dict.d2h_memcpy() - if len(module._next_module) > 0: - config["offload_stream"].record_event(pre_module.offload_event) if module._mode == "PIPE": enter = module._micro_idx == 0 @@ -214,39 +197,41 @@ def zero_post_forward(module, inputs, outputs): if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 elif module._mode == "OFFLOAD": + torch.cuda.current_stream().record_event(module.calc_event) + pre_offload_module = find_pre_module_helper(module._pre_module) + if pre_offload_module is not None: + torch.cuda.current_stream().wait_event(pre_offload_module.offload_event) + with torch.cuda.stream(config["offload_stream"]): + config["offload_stream"].wait_event(module.calc_event) + if not hasattr(module._offload_dict, "fp16_storage"): + module._offload_dict.make_cpu_storage() + module._offload_dict.record_stream(config["offload_stream"]) + module._offload_dict.d2h_memcpy() + if len(module._next_module) > 0: + config["offload_stream"].record_event(module.offload_event) if module.offload_level == 2: torch._C._autograd._pop_saved_tensors_default_hooks() - torch.cuda.current_stream().record_event(module.calc_event) if exit: module._forward_block_ctx.exit(forward_flag) module._ref_count += 1 def zero_pre_backward(module, grad_outputs): - def find_pre_module_helper(m): - if m._mode == "OFFLOAD": - return m - else: - if len(m._pre_module) != 0: - return find_pre_module_helper(m._pre_module[0]) - else: - return None backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": - if module._mode != "OFFLOAD": - count = len([m for m in module._pre_module if m._mode=="OFFLOAD"]) - if (len(module._next_module) == 0) or module._next_module[0]._mode == "OFFLOAD": - pre_module = find_pre_module_helper(module) - if pre_module is not None: - pre_module._on_device = True - with torch.cuda.stream(config["offload_stream"]): - if (len(module._next_module) != 0): - torch.cuda.current_stream().wait_event(module._next_module[0].calc_event) - pre_module._offload_dict.h2d_memcpy() - torch.cuda.current_stream().record_event(pre_module.offload_event) - else: - current_stream = torch.cuda.current_stream() - current_stream.wait_event(module.offload_event) - module._offload_dict.record_stream(current_stream) + if module._mode == "OFFLOAD" or (len(module._next_module) == 0): + if len(module._next_module) != 0: + current_stream = torch.cuda.current_stream() + current_stream.wait_event(module.offload_event) + pre_module = find_pre_module_helper(module._pre_module) + if pre_module is not None: + pre_module._on_device = True + with torch.cuda.stream(config["offload_stream"]): + if (len(module._next_module) != 0): + torch.cuda.current_stream().wait_event(module.calc_event) + pre_module._offload_dict.h2d_memcpy() + torch.cuda.current_stream().record_event(pre_module.offload_event) + if (len(module._next_module) != 0): + module._offload_dict.record_stream(current_stream) module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) if not module._is_last_layer: From b44a62e9c5570c34861673a4b7c57a6250aeeb02 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 26 Aug 2023 15:13:27 +0800 Subject: [PATCH 107/122] fix train.py --- example/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/example/train.py b/example/train.py index 44a7d5d2..8aaf65e4 100644 --- a/example/train.py +++ b/example/train.py @@ -9,7 +9,6 @@ def main(): bmt.init_distributed( seed=0, - zero_level=2, tp_size=2, ) @@ -54,7 +53,7 @@ def main(): break if config['tp_size'] > 1: - loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) else: loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) From b208e9f6120cac7f9ac9de1e80edaa1a974a609b Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Sat, 26 Aug 2023 16:08:14 +0800 Subject: [PATCH 108/122] rm unused code --- bmtrain/hook_func.py | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 5f7ae2d4..758a6a87 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -92,6 +92,7 @@ def d2h_memcpy(self): else: fp32_offset += val['size'] val['stor'] = None + def find_pre_module_helper(m): if len(m) == 0: return None @@ -100,30 +101,6 @@ def find_pre_module_helper(m): else: return find_pre_module_helper(m[0]._pre_module) -def nearest_offload_module(module): - queue = deque([(module, 0)]) - nearest_modules = [] - nearest_depth = float('inf') - - while queue: - curr_module, curr_depth = queue.popleft() - - if curr_depth > nearest_depth: - break - - for m in curr_module._pre_module: - if m._mode == "OFFLOAD" and not m._on_device: - if curr_depth < nearest_depth: - nearest_modules = [m] - nearest_depth = curr_depth - elif curr_depth == nearest_depth: - nearest_modules.append(m) - else: - queue.append((m, curr_depth + 1)) - - return nearest_modules - - def offload_wrapper(offload_dict): def pack_hook(tensor): if isinstance(tensor, torch.nn.Parameter): From de32538e5858c5b525664458190f910c191f2ab3 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 10:58:29 +0800 Subject: [PATCH 109/122] fix tp feature --- bmtrain/nn/row_parallel_linear.py | 3 ++- example/layers/attention.py | 20 ++++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py index 71a4297d..d246a2c9 100644 --- a/bmtrain/nn/row_parallel_linear.py +++ b/bmtrain/nn/row_parallel_linear.py @@ -28,7 +28,8 @@ def forward(self, input): gather_output = False reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER out = OpParallelLinear.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) - out = out + self.bias + if self.bias: + out = out + self.bias return out def extra_repr(self) -> str: diff --git a/example/layers/attention.py b/example/layers/attention.py index 41721528..de50844a 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -45,12 +45,20 @@ def forward(self, batch_size, seq_q, dim_model = hidden_q.size() seq_kv = hidden_kv.size(1) - if config['tp_size'] > 1: - hidden_q = all_gather(hidden_q, comm=config['tp_comm']).flatten(0,1) - - h_q : torch.Tensor = self.project_q(hidden_q) - h_k : torch.Tensor = self.project_k(hidden_q) - h_v : torch.Tensor = self.project_v(hidden_q) + if isinstance(self.project_q, ColumnParallelLinear): + assert hidden_q.data_ptr() == hidden_kv.data_ptr() + hidden_q = bmt.nn.OpParallelLinear.apply( + hidden_q, + torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0), + torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0) if self.project_q.bias is not None else None, + True, False, + False, None + ) + h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) + else: + h_q : torch.Tensor = self.project_q(hidden_q) + h_k : torch.Tensor = self.project_k(hidden_q) + h_v : torch.Tensor = self.project_v(hidden_q) if config['tp_size'] > 1: #batch_size will changed in TensorParallel batch_size = h_v.shape[0] From c64da6f6c1c5c37260c63dd0a31dc6903fe90d07 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 11:00:41 +0800 Subject: [PATCH 110/122] update pre module interface --- bmtrain/hook_func.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 758a6a87..cb9abdba 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -21,6 +21,7 @@ def add(self, tensor): self._offload_dict[data_ptr]["size"] = tensor.storage().size() self._offload_dict[data_ptr]["dtype"] = tensor.storage().dtype self._offload_dict[data_ptr]["tensors"] = {} + self._offload_dict[data_ptr]["tensors"][id(tensor)] = {} self._offload_dict[data_ptr]["tensors"][id(tensor)]["numel"] = tensor.numel() self._offload_dict[data_ptr]["tensors"][id(tensor)]['dtype'] = tensor.dtype @@ -63,7 +64,6 @@ def h2d_memcpy(self): elif id_val['dtype'] == torch.float32: id_val['tensor'].set_(fp32_storage_cuda, id_val['abs_offset'], id_val['shape']) - def record_stream(self, stream): for key, val in self._offload_dict.items(): for id_val in val['tensors'].values(): @@ -96,10 +96,10 @@ def d2h_memcpy(self): def find_pre_module_helper(m): if len(m) == 0: return None - if m[0]._mode == "OFFLOAD": - return m[0] + if m._mode == "OFFLOAD": + return m else: - return find_pre_module_helper(m[0]._pre_module) + return find_pre_module_helper(m.pre_module()) def offload_wrapper(offload_dict): def pack_hook(tensor): @@ -121,7 +121,6 @@ def unpack_hook(packed): return tensor return pack_hook, unpack_hook - def offload_pre_hook(module, input): if hasattr(module, "_offload_hook"): pack_hook, unpack_hook = module._offload_hook @@ -175,7 +174,7 @@ def zero_post_forward(module, inputs, outputs): exit = module._micro_idx == config['micros'] - 1 elif module._mode == "OFFLOAD": torch.cuda.current_stream().record_event(module.calc_event) - pre_offload_module = find_pre_module_helper(module._pre_module) + pre_offload_module = find_pre_module_helper(module.pre_module()) if pre_offload_module is not None: torch.cuda.current_stream().wait_event(pre_offload_module.offload_event) with torch.cuda.stream(config["offload_stream"]): @@ -199,7 +198,7 @@ def zero_pre_backward(module, grad_outputs): if len(module._next_module) != 0: current_stream = torch.cuda.current_stream() current_stream.wait_event(module.offload_event) - pre_module = find_pre_module_helper(module._pre_module) + pre_module = find_pre_module_helper(module.pre_module()) if pre_module is not None: pre_module._on_device = True with torch.cuda.stream(config["offload_stream"]): From 5819ce4eba868174f929439b6a44e791ceceea9c Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 11:19:00 +0800 Subject: [PATCH 111/122] .gitignore back --- .gitignore | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 9c9a3f28..0222862f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class -*nsys-rep + # C extensions *.so @@ -150,4 +150,4 @@ log .vscode !bmtrain/dist -tests/test_log.txt +tests/test_log.txt \ No newline at end of file From 832141a5c42da4adc33ae68ade90d055027e3918 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 11:21:20 +0800 Subject: [PATCH 112/122] example back to origin --- example/layers/attention.py | 73 +++++++++-------------------- example/layers/embedding.py | 8 ++-- example/layers/feedforward.py | 15 ++---- example/layers/transformer.py | 6 +-- example/models/gpt.py | 27 +++-------- example/run.sh | 4 +- example/train.py | 88 ++++++++++++++++------------------- 7 files changed, 81 insertions(+), 140 deletions(-) diff --git a/example/layers/attention.py b/example/layers/attention.py index de50844a..243df3ea 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -1,14 +1,8 @@ from typing import Optional import torch import bmtrain as bmt -from bmtrain.nn import ( - Linear, - ColumnParallelLinear, - RowParallelLinear, -) +from bmtrain.nn import Linear import math -from bmtrain.global_var import config -from bmtrain.distributed import all_gather class Attention(bmt.DistributedModule): def __init__(self, @@ -18,21 +12,14 @@ def __init__(self, ) -> None: super().__init__() - if config['tp_size'] > 1: - self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) - self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) - self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) - self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) - else: - self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) self.softmax = torch.nn.Softmax(dim=-1) self.num_heads = num_heads - self.num_kv_heads = num_heads self.dim_head = dim_head self.dim_model = dim_model @@ -45,50 +32,32 @@ def forward(self, batch_size, seq_q, dim_model = hidden_q.size() seq_kv = hidden_kv.size(1) - if isinstance(self.project_q, ColumnParallelLinear): - assert hidden_q.data_ptr() == hidden_kv.data_ptr() - hidden_q = bmt.nn.OpParallelLinear.apply( - hidden_q, - torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0), - torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0) if self.project_q.bias is not None else None, - True, False, - False, None - ) - h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) - else: - h_q : torch.Tensor = self.project_q(hidden_q) - h_k : torch.Tensor = self.project_k(hidden_q) - h_v : torch.Tensor = self.project_v(hidden_q) - if config['tp_size'] > 1: - #batch_size will changed in TensorParallel - batch_size = h_v.shape[0] - - h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) - h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) - h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) + h_q : torch.Tensor = self.project_q(hidden_q) + h_k : torch.Tensor = self.project_k(hidden_kv) + h_v : torch.Tensor = self.project_v(hidden_kv) + + h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) + h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) + h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) h_q = h_q.permute(0, 2, 1, 3).contiguous() h_k = h_k.permute(0, 2, 1, 3).contiguous() h_v = h_v.permute(0, 2, 1, 3).contiguous() - h_q = h_q.view(-1, seq_q, self.dim_head) - h_k = h_k.view(-1, seq_kv, self.dim_head) - h_v = h_v.view(-1, seq_kv, self.dim_head) + h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) + h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) + h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) score = torch.bmm( h_q, h_k.transpose(1, 2) ) score = score / math.sqrt(self.dim_head) - score = score.view(batch_size, -1, seq_q, seq_kv) + score = score.view(batch_size, self.num_heads, seq_q, seq_kv) if position_bias is not None: - score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) - - if config['tp_size'] > 1: - with torch.no_grad(): - mask = all_gather(mask, config['tp_comm']).flatten(0,1) - + score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) + score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, @@ -101,14 +70,14 @@ def forward(self, torch.scalar_tensor(0, device=score.device, dtype=score.dtype) ) - score = score.view(-1, seq_q, seq_kv) + score = score.view(batch_size * self.num_heads, seq_q, seq_kv) h_out = torch.bmm( score, h_v ) - h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) + h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) h_out = h_out.permute(0, 2, 1, 3).contiguous() - h_out = h_out.view(batch_size, seq_q, -1) + h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) attn_out = self.project_out(h_out) return attn_out diff --git a/example/layers/embedding.py b/example/layers/embedding.py index f62151c4..13c47384 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -77,13 +77,11 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: - out = F.embedding( + return F.embedding( input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) - return out else: - out = F.linear(input, self.weight) - return out + return F.linear(input, self.weight) / math.sqrt(self.embedding_dim) def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}' @@ -99,4 +97,4 @@ def extra_repr(self) -> str: s += ', sparse=True' return s.format(**self.__dict__) - + \ No newline at end of file diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index e88d2495..99d2dc3b 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -1,23 +1,16 @@ import torch import bmtrain as bmt -from bmtrain.nn import ( - Linear, - ColumnParallelLinear, - RowParallelLinear) -from bmtrain.global_var import config +from bmtrain.nn import Linear class Feedforward(bmt.DistributedModule): def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: super().__init__() - if config['tp_size'] > 1: - self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype) - self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype) - else: - self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype) - self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype) + self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) + self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype) self.relu = torch.nn.ReLU() def forward(self, input : torch.Tensor) -> torch.Tensor: + return self.w_out(self.relu(self.w_in(input))) diff --git a/example/layers/transformer.py b/example/layers/transformer.py index 4b867e0a..7cda1bb9 100644 --- a/example/layers/transformer.py +++ b/example/layers/transformer.py @@ -20,10 +20,10 @@ def forward(self, hidden : torch.Tensor, # (batch, seq_len, dim_model) mask : torch.BoolTensor, # (batch, seq_len, dim_model) position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len) - ): - # bmt.inspect.record_tensor(hidden, "hidden") + ): + bmt.inspect.record_tensor(hidden, "hidden") x = self.ln_attn(hidden) - x = self.attn(x, x, mask) + x = self.attn(x, x, mask, position_bias) hidden = hidden + x x = self.ln_ff(hidden) diff --git a/example/models/gpt.py b/example/models/gpt.py index 32d57624..78d77a7d 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -1,38 +1,28 @@ import torch import bmtrain as bmt from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder -from bmtrain.global_var import config class GPT(bmt.DistributedModule): def __init__(self, num_layers : int, vocab_size : int, dim_model : int, dim_head : int, num_heads : int, dim_ff : int, max_distance : int, - bias : bool = True, dtype = None, offload = False, offload_level = 0 + bias : bool = True, dtype = None ) -> None: super().__init__() self.max_distance = max_distance - if config['tp_size'] > 1: - self.word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) - else: - self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) + self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) - if offload: - offload_mask = [True if i%4 == 0 else False for i in range(num_layers)] - ckpt_mask = [not offload_mask[i] for i in range(num_layers)] - offload_level = offload_level - else: - ckpt_mask = [ True for i in range(num_layers) ] - offload_mask = [ False for i in range(num_layers) ] + self.transformers = bmt.TransformerBlockList([ bmt.CheckpointBlock( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype - ),use_checkpoint=ckpt_mask[i],use_offload=offload_mask[i],offload_level=offload_level + ) ) - for i in range(num_layers) + for _ in range(num_layers) ]) self.layernorm = Layernorm(dim_model, dtype=dtype) @@ -52,10 +42,7 @@ def forward(self, out = self.transformers(out, mask_2d, None) out = self.layernorm(out) - if config['tp_size'] > 1: - logits = self.word_emb.projection(out) - else: - logits = self.word_emb(out, projection=True) + logits = self.word_emb(out, projection=True) bmt.inspect.record_tensor(logits, "logits") - return logits + return logits \ No newline at end of file diff --git a/example/run.sh b/example/run.sh index 1beb71bd..542e5252 100644 --- a/example/run.sh +++ b/example/run.sh @@ -1 +1,3 @@ -torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost $1 +export NCCL_P2P_DISABLE=1 +export CUDA_LAUNCH_BLOCKING=1 +torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost train.py diff --git a/example/train.py b/example/train.py index 94dbdd7e..1a744e20 100644 --- a/example/train.py +++ b/example/train.py @@ -3,32 +3,28 @@ from models import GPT import time from bmtrain import optim -from bmtrain.global_var import config from bmtrain import inspect def main(): bmt.init_distributed( seed=0, - tp_size=2, + zero_level=2, ) - offload = True - seq_len = 4096 - offload_level = 0 + model = GPT( - num_layers=24, - vocab_size=80000, - dim_model=1024, - dim_head=64, - num_heads=16, - dim_ff=4096, - max_distance=seq_len, - bias=False, - dtype=torch.half, - offload=offload, - offload_level=offload_level + num_layers=8, + vocab_size=10240, + dim_model=2560, + dim_head=80, + num_heads=32, + dim_ff=8192, + max_distance=1024, + bias=True, + dtype=torch.half ) bmt.init_parameters(model) + # print_inspect(model, "*") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) @@ -37,7 +33,10 @@ def main(): # data # generate dummy data for each rank torch.manual_seed(1234) - batch_size = 4 + + batch_size = 2 + seq_len = 512 + for i in range(bmt.world_size()): sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() @@ -53,11 +52,7 @@ def main(): if i == bmt.rank(): break - if config['tp_size'] > 1: - loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) - else: - loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) - + loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) @@ -69,43 +64,40 @@ def main(): avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - for iteration in range(30): + for iteration in range(1000): # load data st = time.time() - # with bmt.inspect.inspect_tensor() as inspector: - pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) - logits = model( - enc_input, - pos, - pos < enc_length[:, None] - ) - batch, seq_len, vocab_out_size = logits.size() + with inspect.inspect_tensor() as inspector: + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + logits = model( + enc_input, + pos, + pos < enc_length[:, None] + ) + batch, seq_len, vocab_out_size = logits.size() - if config['tp_size'] > 1: - loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets) - else: - loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) global_loss = bmt.sum_loss(loss).item() - optim_manager.zero_grad() + optim_manager.zero_grad() - optim_manager.backward(loss) + optim_manager.backward(loss) # print inspected tensors in the forward & backward pass # print parameters of the model - # if iteration % 100 == 0: - # bmt.print_rank( - # bmt.inspect.format_summary( - # inspector.get_summary() - # ) - # ) - # bmt.print_rank( - # bmt.inspect.format_summary( - # bmt.inspect.inspect_model(model, "*") - # ) - # ) + if iteration % 100 == 0: + bmt.print_rank( + inspect.format_summary( + inspector.get_summary() + ) + ) + bmt.print_rank( + inspect.format_summary( + inspect.inspect_model(model, "*") + ) + ) optim_manager.step() From 8bd6475b48f142b1766ec3ff43b39a9bbe11d27b Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 11:34:06 +0800 Subject: [PATCH 113/122] delete test file --- example/layers/attention.py | 68 ++- example/layers/embedding.py | 8 +- example/layers/feedforward.py | 15 +- example/layers/flash_triton.py | 830 --------------------------------- example/layers/test_linear.py | 152 ------ example/models/gpt.py | 13 +- example/test_attn.py | 49 -- example/test_block.py | 72 --- example/train.py | 15 +- 9 files changed, 85 insertions(+), 1137 deletions(-) delete mode 100644 example/layers/flash_triton.py delete mode 100644 example/layers/test_linear.py delete mode 100644 example/test_attn.py delete mode 100644 example/test_block.py diff --git a/example/layers/attention.py b/example/layers/attention.py index 243df3ea..8fbb7510 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -1,8 +1,14 @@ from typing import Optional import torch import bmtrain as bmt -from bmtrain.nn import Linear +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear, +) import math +from bmtrain.global_var import config +from bmtrain.distributed import all_gather class Attention(bmt.DistributedModule): def __init__(self, @@ -12,11 +18,17 @@ def __init__(self, ) -> None: super().__init__() - self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + if config['tp_size'] > 1: + self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + else: + self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) - self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) self.softmax = torch.nn.Softmax(dim=-1) self.num_heads = num_heads @@ -32,32 +44,48 @@ def forward(self, batch_size, seq_q, dim_model = hidden_q.size() seq_kv = hidden_kv.size(1) - h_q : torch.Tensor = self.project_q(hidden_q) - h_k : torch.Tensor = self.project_k(hidden_kv) - h_v : torch.Tensor = self.project_v(hidden_kv) + assert hidden_q.data_ptr() == hidden_kv.data_ptr() - h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) - h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) - h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) + hidden_q = bmt.nn.OpParallelLinear.apply( + hidden_q, + torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0), + torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0), + True, False, + False, None + ) + + h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) + + if config['tp_size'] > 1: + #batch_size will changed in TensorParallel + batch_size = h_v.shape[0] + + h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) + h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) + h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) h_q = h_q.permute(0, 2, 1, 3).contiguous() h_k = h_k.permute(0, 2, 1, 3).contiguous() h_v = h_v.permute(0, 2, 1, 3).contiguous() - h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) - h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) - h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) + h_q = h_q.view(-1, seq_q, self.dim_head) + h_k = h_k.view(-1, seq_kv, self.dim_head) + h_v = h_v.view(-1, seq_kv, self.dim_head) score = torch.bmm( h_q, h_k.transpose(1, 2) ) score = score / math.sqrt(self.dim_head) - score = score.view(batch_size, self.num_heads, seq_q, seq_kv) + score = score.view(batch_size, -1, seq_q, seq_kv) if position_bias is not None: - score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) - + score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) + + if config['tp_size'] > 1: + with torch.no_grad(): + mask = all_gather(mask, config['tp_comm']).flatten(0,1) + score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, @@ -70,14 +98,14 @@ def forward(self, torch.scalar_tensor(0, device=score.device, dtype=score.dtype) ) - score = score.view(batch_size * self.num_heads, seq_q, seq_kv) + score = score.view(-1, seq_q, seq_kv) h_out = torch.bmm( score, h_v ) - h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) + h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) h_out = h_out.permute(0, 2, 1, 3).contiguous() - h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) + h_out = h_out.view(batch_size, seq_q, -1) attn_out = self.project_out(h_out) return attn_out diff --git a/example/layers/embedding.py b/example/layers/embedding.py index 13c47384..f62151c4 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -77,11 +77,13 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: - return F.embedding( + out = F.embedding( input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) + return out else: - return F.linear(input, self.weight) / math.sqrt(self.embedding_dim) + out = F.linear(input, self.weight) + return out def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}' @@ -97,4 +99,4 @@ def extra_repr(self) -> str: s += ', sparse=True' return s.format(**self.__dict__) - \ No newline at end of file + diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index 99d2dc3b..e88d2495 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -1,16 +1,23 @@ import torch import bmtrain as bmt -from bmtrain.nn import Linear +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear) +from bmtrain.global_var import config class Feedforward(bmt.DistributedModule): def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: super().__init__() - self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) - self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype) + if config['tp_size'] > 1: + self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype) + self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype) + else: + self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype) + self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype) self.relu = torch.nn.ReLU() def forward(self, input : torch.Tensor) -> torch.Tensor: - return self.w_out(self.relu(self.w_in(input))) diff --git a/example/layers/flash_triton.py b/example/layers/flash_triton.py deleted file mode 100644 index 1d687378..00000000 --- a/example/layers/flash_triton.py +++ /dev/null @@ -1,830 +0,0 @@ -""" -*Experimental* implementation of FlashAttention in Triton. - -We use the FlashAttention implementation from Phil Tillet a starting point. -https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py - -Changes: -- Implement both causal and non-causal attention. -- Implement both self-attention and cross-attention. -- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. -- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. -- Support attention bias. -- Speed up the forward pass a bit, and only store the LSE instead of m and l. -- Make the backward for d=128 much faster by reducing register spilling. -- Optionally parallelize the backward pass across seqlen_k, to deal with the case of -small batch size * nheads. - -Caution: -- This is an *experimental* implementation. The forward pass should be quite robust but -I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). -- This implementation has only been tested on A100. -- If you plan to use headdim other than 64 and 128, you should test for race conditions -(due to the Triton compiler), as done in tests/test_flash_attn.py -"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions -for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident -that there are none left for other head dimensions. - -Differences between this Triton version and the CUDA version: -- Triton version doesn't support dropout. -- Triton forward is generally faster than CUDA forward, while Triton backward is -generally slower than CUDA backward. Overall Triton forward + backward is slightly slower -than CUDA forward + backward. -- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). -- Triton version supports attention bias, while CUDA version doesn't. -""" - -import math - -import torch - -import triton -import triton.language as tl - - -# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), -# # This config has a race condition when EVEN_M == False, disabling it for now. -# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), -# ], -# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] -# ) -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } -) -@triton.jit -def _fwd_kernel( - Q, K, V, Bias, Out, - Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - softmax_scale, - stride_qb, stride_qh, stride_qm, - stride_kb, stride_kh, stride_kn, - stride_vb, stride_vh, stride_vn, - stride_bb, stride_bh, stride_bm, - stride_ob, stride_oh, stride_om, - nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, - CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, - BIAS_TYPE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # off_b = tl.program_id(1) - # off_h = tl.program_id(2) - # off_hb = off_b * nheads + off_h - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # Initialize pointers to Q, K, V - # Adding parenthesis around indexing might use int32 math instead of int64 math? - # https://github.com/openai/triton/issues/741 - # I'm seeing a tiny bit of difference (5-7us) - q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) - k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) - v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) - if BIAS_TYPE == 'vector': - b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n - elif BIAS_TYPE == 'matrix': - b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) - # initialize pointer to m and l - t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m - lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) - # load q: it will stay in SRAM throughout - # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call - # tl.load(q_ptrs), we get the wrong output! - if EVEN_M & EVEN_N: - if EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) - else: - q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0) - # loop over k, v and update accumulator - end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) - for start_n in range(0, end_n, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn) - else: - k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0) - else: - k = tl.load(k_ptrs + start_n * stride_kn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, trans_b=True) - # Trying to combine the two masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) - if IS_CAUSAL: - qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) - if BIAS_TYPE != 'none': - if BIAS_TYPE == 'vector': - if EVEN_N: - bias = tl.load(b_ptrs + start_n).to(tl.float32) - else: - bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32) - bias = bias[None, :] - elif BIAS_TYPE == 'matrix': - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs + start_n).to(tl.float32) - else: - bias = tl.load(b_ptrs + start_n, - mask=(offs_m[:, None] < seqlen_q) - & ((start_n + offs_n)[None, :] < seqlen_k), - other=0.0).to(tl.float32) - # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler - # can then fuse the mult and add into an fma instruction. But if we have bias we need to - # to multiply with softmax_scale here. - qk = qk * softmax_scale + bias - m_ij = tl.maximum(tl.max(qk, 1), lse_i) - - m_ij = tl.where(m_ij==float("-inf"),0,m_ij) - p = tl.exp(qk - m_ij[:, None]) - else: - m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) - p = tl.exp(qk * softmax_scale - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # p = tl.where(p==float("-inf"), 0, p) - # l_ij = tl.maximum(tl.sum(p, 1),-1e16) - # scale acc_o - acc_o_scale = tl.exp(m_i - m_ij) - # mask_sum = tl.sum(bias == float("-inf"), axis=1) == BLOCK_M - # acc_o_scale = tl.where(mask_sum, 0, acc_o_scale) - # # -- update output accumulator -- - # BUG: have to store and immediately load - tl.store(t_ptrs, acc_o_scale) - acc_o_scale = tl.load(t_ptrs) - acc_o = acc_o * acc_o_scale[:, None] - # update acc_o - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn) - else: - v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0) - else: - v = tl.load(v_ptrs + start_n * stride_vn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0) - p = p.to(v.dtype) - acc_o += tl.dot(p, v) - - # -- update statistics - m_i = m_ij - l_i_new = tl.exp(lse_i - m_ij) + l_ij - lse_i = m_ij + tl.log(l_i_new) - lse_i = tl.where(lse_i == float("-inf"), 0, lse_i) - o_scale = tl.exp(m_i - lse_i) - # BUG: have to store and immediately load - tl.store(t_ptrs, o_scale) - o_scale = tl.load(t_ptrs) - acc_o = acc_o * o_scale[:, None] - # rematerialize offsets to save registers - start_m = tl.program_id(0) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # write back l and m - lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m - tl.store(lse_ptrs, lse_i) - # initialize pointers to output - offs_d = tl.arange(0, BLOCK_HEADDIM) - out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) - if EVEN_M: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o) - else: - tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) - else: - tl.store(out_ptrs, acc_o, - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) - - -@triton.jit -def _bwd_preprocess_do_o_dot( - Out, DO, Delta, - stride_ob, stride_oh, stride_om, - stride_dob, stride_doh, stride_dom, - nheads, seqlen_q, seqlen_q_rounded, headdim, - BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, -): - start_m = tl.program_id(0) - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # load - o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) - do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) - delta = tl.sum(o * do, axis=1) - # write-back - tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) - - -@triton.jit -def _bwd_store_dk_dv( - dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, - EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, -): - # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.store(dv_ptrs), there's a race condition - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - else: - tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) - tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) - tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) - else: - tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - - -@triton.jit -def _bwd_kernel_one_col_block( - start_n, - Q, K, V, Bias, - DO, DQ, DK, DV, - LSE, D, - softmax_scale, - stride_qm, stride_kn, stride_vn, stride_bm, - stride_dom, stride_dqm, stride_dkn, stride_dvn, - seqlen_q, seqlen_k, headdim, - ATOMIC_ADD: tl.constexpr, - BIAS_TYPE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, -): - # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) - begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M - # initialize row/col offsets - offs_qm = begin_m + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_m = tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) - v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) - do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) - dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) - if BIAS_TYPE == 'vector': - b_ptrs = Bias + offs_n - elif BIAS_TYPE == 'matrix': - b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) - # initialize dv and dk - dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - # There seems to be some problem with Triton pipelining that makes results wrong for - # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop - # may have zero step, and pipelining with the bias matrix could screw it up. - # So we just exit early. - if begin_m >= seqlen_q: - dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) - dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, - EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) - return - # k and v stay in SRAM throughout - # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.load(k_ptrs), we get the wrong output! - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - else: - k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) - v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) - else: - k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0) - v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0) - # loop over rows - num_block_m = tl.cdiv(seqlen_q, BLOCK_M) - for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): - start_m = tl.multiple_of(start_m, BLOCK_M) - offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) - if EVEN_M & EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - else: - q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) - & (offs_d[None, :] < headdim), other=0.0) - # recompute p = softmax(qk, dim=-1).T - qk = tl.dot(q, k, trans_b=True) - # Trying to combine the two masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) - if IS_CAUSAL: - qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) - if BIAS_TYPE != 'none': - tl.debug_barrier() # Race condition otherwise - if BIAS_TYPE == 'vector': - if EVEN_N: - bias = tl.load(b_ptrs).to(tl.float32) - else: - bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) - bias = bias[None, :] - elif BIAS_TYPE == 'matrix': - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs).to(tl.float32) - else: - bias = tl.load(b_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) - & (offs_n[None, :] < seqlen_k), - other=0.0).to(tl.float32) - qk = qk * softmax_scale + bias - # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. - # Also wrong for headdim=64. - if not (EVEN_M & EVEN_HEADDIM): - tl.debug_barrier() - lse_i = tl.load(LSE + offs_m_curr) - if BIAS_TYPE == 'none': - p = tl.exp(qk * softmax_scale - lse_i[:, None]) - else: - p = tl.exp(qk - lse_i[:, None]) - # compute dv - # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call - # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs - # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, - # the output is correct. - if EVEN_M & EVEN_HEADDIM: - do = tl.load(do_ptrs) - else: - # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. - do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) - & (offs_d[None, :] < headdim), other=0.0) - # if EVEN_M: - # if EVEN_HEADDIM: - # do = tl.load(do_ptrs) - # else: - # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - # else: - # if EVEN_HEADDIM: - # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - # else: - # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) - # & (offs_d[None, :] < headdim), other=0.0) - dv += tl.dot(p.to(do.dtype), do, trans_a=True) - # compute dp = dot(v, do) - # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. - # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True - # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False - if not (EVEN_M & EVEN_HEADDIM): - tl.debug_barrier() - dp = tl.dot(do, v, trans_b=True) - # There's a race condition for headdim=48 - if not EVEN_HEADDIM: - tl.debug_barrier() - # compute ds = p * (dp - delta[:, None]) - # Putting the subtraction after the dp matmul (instead of before) is slightly faster - Di = tl.load(D + offs_m_curr) - # Converting ds to q.dtype here reduces register pressure and makes it much faster - # for BLOCK_HEADDIM=128 - ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) - # compute dk = dot(ds.T, q) - dk += tl.dot(ds, q, trans_a=True) - # compute dq - if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix' - tl.debug_barrier() - if not ATOMIC_ADD: - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") - else: - if EVEN_HEADDIM: - dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, - eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, - eviction_policy="evict_last") - else: - dq = tl.load(dq_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - eviction_policy="evict_last") - else: # If we're parallelizing across the seqlen_k dimension - dq = tl.dot(ds, k) - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - tl.atomic_add(dq_ptrs, dq) - else: - if EVEN_HEADDIM: - tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) - else: - tl.atomic_add(dq_ptrs, dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) - # increment pointers - dq_ptrs += BLOCK_M * stride_dqm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_dom - if BIAS_TYPE == 'matrix': - b_ptrs += BLOCK_M * stride_bm - # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) - dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, - EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) - - -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now - # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), - ], - key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'], -) -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } -) -@triton.jit -def _bwd_kernel( - Q, K, V, Bias, - DO, DQ, DK, DV, - LSE, D, - softmax_scale, - stride_qb, stride_qh, stride_qm, - stride_kb, stride_kh, stride_kn, - stride_vb, stride_vh, stride_vn, - stride_bb, stride_bh, stride_bm, - stride_dob, stride_doh, stride_dom, - stride_dqb, stride_dqh, stride_dqm, - stride_dkb, stride_dkh, stride_dkn, - stride_dvb, stride_dvh, stride_dvn, - nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, - CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, - BIAS_TYPE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, -): - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # offset pointers for batch/head - Q += off_b * stride_qb + off_h * stride_qh - K += off_b * stride_kb + off_h * stride_kh - V += off_b * stride_vb + off_h * stride_vh - DO += off_b * stride_dob + off_h * stride_doh - DQ += off_b * stride_dqb + off_h * stride_dqh - DK += off_b * stride_dkb + off_h * stride_dkh - DV += off_b * stride_dvb + off_h * stride_dvh - if BIAS_TYPE != 'none': - Bias += off_b * stride_bb + off_h * stride_bh - # pointer to row-wise quantities in value-like data - D += off_hb * seqlen_q_rounded - LSE += off_hb * seqlen_q_rounded - if not SEQUENCE_PARALLEL: - num_block_n = tl.cdiv(seqlen_k, BLOCK_N) - for start_n in range(0, num_block_n): - _bwd_kernel_one_col_block( - start_n, - Q, K, V, Bias, - DO, DQ, DK, DV, - LSE, D, - softmax_scale, - stride_qm, stride_kn, stride_vn, stride_bm, - stride_dom, stride_dqm, stride_dkn, stride_dvn, - seqlen_q, seqlen_k, headdim, - ATOMIC_ADD=False, - BIAS_TYPE=BIAS_TYPE, - IS_CAUSAL=IS_CAUSAL, - BLOCK_HEADDIM=BLOCK_HEADDIM, - EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - else: - start_n = tl.program_id(0) - _bwd_kernel_one_col_block( - start_n, - Q, K, V, Bias, - DO, DQ, DK, DV, - LSE, D, - softmax_scale, - stride_qm, stride_kn, stride_vn, stride_bm, - stride_dom, stride_dqm, stride_dkn, stride_dvn, - seqlen_q, seqlen_k, headdim, - ATOMIC_ADD=True, - BIAS_TYPE=BIAS_TYPE, - IS_CAUSAL=IS_CAUSAL, - BLOCK_HEADDIM=BLOCK_HEADDIM, - EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - - -def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): - # shape constraints - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - assert k.shape == (batch, seqlen_k, nheads, d) - assert v.shape == (batch, seqlen_k, nheads, d) - assert d <= 128, 'FlashAttention only support head dimensions up to 128' - assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' - assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' - assert q.is_cuda and k.is_cuda and v.is_cuda - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - - has_bias = bias is not None - bias_type = 'none' - if has_bias: - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - assert bias.dim() == 4 - if bias.stride(-1) != 1: - bias = bias.contiguous() - if bias.shape[2:] == (1, seqlen_k): - bias_type = 'vector' - elif bias.shape[2:] == (seqlen_q, seqlen_k): - bias_type = 'matrix' - else: - raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' - ' or (seqlen_q, seqlen_k)') - bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) - bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) - - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - o = torch.empty_like(q) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - BLOCK = 128 - num_warps = 4 if d <= 64 else 8 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _fwd_kernel[grid]( - q, k, v, bias, o, - lse, tmp, - softmax_scale, - q.stride(0), q.stride(2), q.stride(1), - k.stride(0), k.stride(2), k.stride(1), - v.stride(0), v.stride(2), v.stride(1), - *bias_strides, - o.stride(0), o.stride(2), o.stride(1), - nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, - seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, - bias_type, causal, BLOCK_HEADDIM, - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return o, lse, softmax_scale # softmax_scale could have been updated - - -def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None): - # Make sure that the last dimension is contiguous - if do.stride(-1) != 1: - do = do.contiguous() - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - # assert d in {16, 32, 64, 128} - assert d <= 128 - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - assert lse.shape == (batch, nheads, seqlen_q_rounded) - assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 - assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - # dq_accum = torch.zeros_like(q, dtype=torch.float32) - dq_accum = torch.empty_like(q, dtype=torch.float32) - delta = torch.empty_like(lse) - # delta = torch.zeros_like(lse) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _bwd_preprocess_do_o_dot[grid]( - o, do, delta, - o.stride(0), o.stride(2), o.stride(1), - do.stride(0), do.stride(2), do.stride(1), - nheads, seqlen_q, seqlen_q_rounded, d, - BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, - ) - - has_bias = bias is not None - bias_type = 'none' - if has_bias: - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - assert bias.dim() == 4 - assert bias.stride(-1) == 1 - if bias.shape[2:] == (1, seqlen_k): - bias_type = 'vector' - elif bias.shape[2:] == (seqlen_q, seqlen_k): - bias_type = 'matrix' - else: - raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' - ' or (seqlen_q, seqlen_k)') - bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) - bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) - - # BLOCK_M = 128 - # BLOCK_N = 64 - # num_warps = 4 - grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, - batch * nheads) - _bwd_kernel[grid]( - q, k, v, bias, - do, dq_accum, dk, dv, - lse, delta, - softmax_scale, - q.stride(0), q.stride(2), q.stride(1), - k.stride(0), k.stride(2), k.stride(1), - v.stride(0), v.stride(2), v.stride(1), - *bias_strides, - do.stride(0), do.stride(2), do.stride(1), - dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), - dk.stride(0), dk.stride(2), dk.stride(1), - dv.stride(0), dv.stride(2), dv.stride(1), - nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, - seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, - bias_type, causal, BLOCK_HEADDIM, - # SEQUENCE_PARALLEL=False, - # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - # num_warps=num_warps, - # num_stages=1, - ) - dq.copy_(dq_accum) - - -class FlashAttnQKVPackedFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): - """ - qkv: (batch, seqlen, 3, nheads, headdim) - bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). - For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). - ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) - """ - # Make sure that the last dimension is contiguous - if qkv.stride(-1) != 1: - qkv = qkv.contiguous() - o, lse, ctx.softmax_scale = _flash_attn_forward( - qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, - softmax_scale=softmax_scale - ) - ctx.save_for_backward(qkv, o, lse, bias) - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - qkv, o, lse, bias = ctx.saved_tensors - assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet' - # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - with torch.inference_mode(): - dqkv = torch.empty_like(qkv) - _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, - dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], - bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) - return dqkv, None, None, None - - -flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply - - -class FlashAttnKVPackedFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): - """ - q: (batch, seqlen_q, nheads, headdim) - kv: (batch, seqlen_k, 2, nheads, headdim) - bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). - For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). - ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) - """ - # Make sure that the last dimension is contiguous - q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] - o, lse, ctx.softmax_scale = _flash_attn_forward( - q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale - ) - ctx.save_for_backward(q, kv, o, lse, bias) - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - q, kv, o, lse, bias = ctx.saved_tensors - assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet' - # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - with torch.inference_mode(): - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, - dq, dkv[:, :, 0], dkv[:, :, 1], - bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) - return dq, dkv, None, None, None - - -flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply - - -class FlashAttnFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): - """ - q: (batch_size, seqlen_q, nheads, headdim) - k, v: (batch_size, seqlen_k, nheads, headdim) - bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). - For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). - ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) - """ - # Make sure that the last dimension is contiguous - q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] - o, lse, ctx.softmax_scale = _flash_attn_forward( - q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale - ) - ctx.save_for_backward(q, k, v, o, lse, bias) - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, lse, bias = ctx.saved_tensors - assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet' - # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - with torch.inference_mode(): - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, - bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) - return dq, dk, dv, None, None, None - - -flash_attn_func = FlashAttnFunc.apply diff --git a/example/layers/test_linear.py b/example/layers/test_linear.py deleted file mode 100644 index 27568e12..00000000 --- a/example/layers/test_linear.py +++ /dev/null @@ -1,152 +0,0 @@ -import torch -import torch.nn.functional as F -import bmtrain as bmt -from bmtrain.global_var import config -from . import TransformerEncoder - - -gb = 1024.0 * 1024.0 * 1024.0 - -class CustomLinear(torch.autograd.Function): - @staticmethod - def forward(ctx, x, weight, bias=None): - ctx.save_for_backward(x, weight, bias) - return F.linear(x, weight, bias) - - @staticmethod - def backward(ctx, grad_output): - x, weight, bias = ctx.saved_tensors - grad_x = grad_weight = grad_bias = None - if x.requires_grad: - grad_x = grad_output.matmul(weight) - if weight.requires_grad: - dim = grad_output.dim() - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1])) - if bias is not None and bias.requires_grad: - grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) - return grad_x, grad_weight, grad_bias - - -class LinearFunctionForZeroStage3(torch.autograd.Function): - # Note that both forward and backward are @staticmethods - @staticmethod - #@autocast_custom_fwd - # bias is an optional argument - def forward(ctx, input, weight, bias=None): - - ctx.save_for_backward(input, weight, bias) - - if input.dim() == 2 and bias is not None: - # fused op is marginally faster - ret = torch.addmm(bias, input, weight.t()) - else: - output = input.matmul(weight.t()) - if bias is not None: - output += bias - ret = output - - return ret - - # This function has only a single output, so it gets only one gradient - @staticmethod - #@autocast_custom_bwd - def backward(ctx, grad_output): - # This is a pattern that is very convenient - at the top of backward - # unpack saved_tensors and initialize all gradients w.r.t. inputs to - # None. Thanks to the fact that additional trailing Nones are - # ignored, the return statement is simple even when the function has - # optional inputs. - input, weight, bias = ctx.saved_tensors - - grad_input = grad_weight = grad_bias = None - - #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}") - # These needs_input_grad checks are optional and there only to - # improve efficiency. If you want to make your code simpler, you can - # skip them. Returning gradients for inputs that don't require it is - # not an error. - if ctx.needs_input_grad[0]: - #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}") - grad_input = grad_output.matmul(weight) - #print(f"Computed grad input {grad_input.shape}") - if ctx.needs_input_grad[1]: - #print("Computing grad weight") - dim = grad_output.dim() - if dim > 2: - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) - else: - grad_weight = grad_output.t().matmul(input) - #print(f"Computed grad weight grad_weight {grad_weight.shape}") - if bias is not None and ctx.needs_input_grad[2]: - #print("Computing grad bias") - grad_bias = grad_output.sum(0) - #print("Done computing grad bias") - #print("needs bias") - #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}") - return grad_input, grad_weight, grad_bias - - -class Linear(bmt.DistributedModule): - def __init__(self, in_features : int, out_features: int, bias: bool = False, dtype = torch.float16) -> None: - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_) - if bias: - self.bias = bmt.DistributedParameter(torch.empty((1, out_features), dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_) - else: - self.register_parameter('bias', None) - - def forward(self, input): - #return CustomLinear.apply(input, self.weight, self.bias) - return LinearFunctionForZeroStage3.apply(input, self.weight, self.bias) - - def extra_repr(self) -> str: - return 'in_features={}, out_features={}, bias={}'.format( - self.in_features, self.out_features, self.bias is not None - ) - -class Feedforward(bmt.DistributedModule): - def __init__(self, dim_model : int, dim_ff : int, bias : bool = False, dtype = torch.float16) -> None: - super().__init__() - - self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) - self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype) - self.gate = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) - - self.relu = torch.nn.ReLU() - - def forward(self, input : torch.Tensor) -> torch.Tensor: - gate_out = self.relu(self.gate(input)) - return self.w_out(self.w_in(input) * gate_out) - -bmt.init_distributed(zero_level=2) - -linears = [] -for i in range(10): - linears.append(bmt.CheckpointBlock(TransformerEncoder(8192, 20480), use_checkpoint=False)) - -linears = bmt.TransformerBlockList(linears) - -device = torch.device('cuda') -bmt.synchronize() -if config['rank'] == 0: - print('before forward', torch.cuda.memory_allocated(device) / gb) - -x = torch.randn(4096, 8192, dtype=torch.float16, device=device).requires_grad_() -bmt.synchronize() -if config['rank'] == 0: - print('init input', torch.cuda.memory_allocated(device) / gb) - -y = linears(x) -bmt.synchronize() -if config['rank'] == 0: - print('after forward', torch.cuda.memory_allocated(device) / gb) - -y.sum().backward() -bmt.synchronize() -if config['rank'] == 0: - print('after backward', torch.cuda.memory_allocated(device) / gb) diff --git a/example/models/gpt.py b/example/models/gpt.py index 78d77a7d..64474ba8 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -1,6 +1,7 @@ import torch import bmtrain as bmt from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder +from bmtrain.global_var import config class GPT(bmt.DistributedModule): def __init__(self, @@ -13,7 +14,10 @@ def __init__(self, self.max_distance = max_distance - self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) + if config['tp_size'] > 1: + self.word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) + else: + self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) self.transformers = bmt.TransformerBlockList([ @@ -42,7 +46,10 @@ def forward(self, out = self.transformers(out, mask_2d, None) out = self.layernorm(out) - logits = self.word_emb(out, projection=True) + if config['tp_size'] > 1: + logits = self.word_emb.projection(out) + else: + logits = self.word_emb(out, projection=True) bmt.inspect.record_tensor(logits, "logits") - return logits \ No newline at end of file + return logits diff --git a/example/test_attn.py b/example/test_attn.py deleted file mode 100644 index 642f2a8b..00000000 --- a/example/test_attn.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import torch.nn.functional as F -import bmtrain as bmt -from bmtrain.global_var import config -from layers import Attention - - -gb = 1024.0 * 1024.0 * 1024.0 - -bmt.init_distributed(zero_level=3) - -linears = [] -for i in range(10), : - linears.append(bmt.CheckpointBlock(Attention( - dim_model=8192, - dim_head=128, - num_heads=64, - dropout_p=0.0, - use_flash_attn=True, - dtype=torch.half - ), - use_checkpoint=False) - ) - -linears = bmt.TransformerBlockList(linears) - -device = torch.device('cuda') -bmt.synchronize() -if config['rank'] == 0: - print('before forward', torch.cuda.memory_allocated(device) / gb) -batch_size=1 -seq_len=4096 -x = torch.randn(batch_size, seq_len, 8192, dtype=torch.float16, device=device).requires_grad_() -bmt.synchronize() -if config['rank'] == 0: - print('init input', torch.cuda.memory_allocated(device) / gb) -enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() -mask = torch.arange(seq_len).unsqueeze(0) <= torch.arange(seq_len).unsqueeze(1) -mask = mask.unsqueeze(0).unsqueeze(0) -print(mask.shape) -y = linears(x,x,mask) -bmt.synchronize() -if config['rank'] == 0: - print('after forward', torch.cuda.memory_allocated(device) / gb) - -y.sum().backward() -bmt.synchronize() -if config['rank'] == 0: - print('after backward', torch.cuda.memory_allocated(device) / gb) diff --git a/example/test_block.py b/example/test_block.py deleted file mode 100644 index 90f57182..00000000 --- a/example/test_block.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -import torch.nn.functional as F -import bmtrain as bmt -from bmtrain.global_var import config -from layers import TransformerEncoder - - -gb = 1024.0 * 1024.0 * 1024.0 -def reserved(device): - return torch.cuda.memory_reserved(device) / gb -def allocated(device): - return torch.cuda.memory_allocated(device) / gb -def max_allocated(device): - return torch.cuda.max_memory_allocated(device) / gb - -bmt.init_distributed(zero_level=3) - -linears = [] -for i in range(10), : - linears.append(TransformerEncoder( - dim_model=8192, - dim_head=128, - num_heads=64, - dim_ff=20480, - bias=False, - dtype=torch.half - ) - ) - -linears = bmt.TransformerBlockList(linears) -# linears = torch.nn.ModuleList(linears) - -optimizer = bmt.optim.AdamOffloadOptimizer(linears.parameters(), weight_decay=1e-2) -lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) - -optim_manager = bmt.optim.OptimManager(loss_scale=2**20) -optim_manager.add_optimizer(optimizer, lr_scheduler) - -bmt.synchronize() - -device = torch.device('cuda') -bmt.synchronize() -if config['rank'] == 0: - print('before init input', allocated(device), reserved(device)) -batch_size=1 -seq_len=4096 - -for i in range(4): - x = torch.randn(batch_size, seq_len, 8192, dtype=torch.float16, device=device).requires_grad_() - bmt.synchronize() - if config['rank'] == 0: - print('init input', allocated(device), reserved(device)) - enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() - mask = torch.arange(seq_len).unsqueeze(0) <= torch.arange(seq_len).unsqueeze(1) - mask = mask.unsqueeze(0).unsqueeze(0).to(device) -#y = linears(x,mask) - y = x - for encoder in linears: - y = encoder(y, mask) - bmt.synchronize() - if config['rank'] == 0: - print('after forward', allocated(device), reserved(device), max_allocated(device)) - - y.sum().backward() - bmt.synchronize() - if config['rank'] == 0: - print('after backward', allocated(device), reserved(device), max_allocated(device)) - optim_manager.step() - if config['rank'] == 0: - print('after optimizer', allocated(device), reserved(device)) -#torch.cuda.empty_cache() - optim_manager.zero_grad() diff --git a/example/train.py b/example/train.py index 1a744e20..8aaf65e4 100644 --- a/example/train.py +++ b/example/train.py @@ -3,12 +3,13 @@ from models import GPT import time from bmtrain import optim +from bmtrain.global_var import config from bmtrain import inspect def main(): bmt.init_distributed( seed=0, - zero_level=2, + tp_size=2, ) model = GPT( @@ -24,7 +25,6 @@ def main(): ) bmt.init_parameters(model) - # print_inspect(model, "*") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) @@ -52,7 +52,11 @@ def main(): if i == bmt.rank(): break - loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + if config['tp_size'] > 1: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + else: + loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) @@ -77,7 +81,10 @@ def main(): ) batch, seq_len, vocab_out_size = logits.size() - loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + if config['tp_size'] > 1: + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets) + else: + loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) global_loss = bmt.sum_loss(loss).item() From a7270e30ba5f850bba38f45486a1a85bd7cff4c9 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 11:37:43 +0800 Subject: [PATCH 114/122] format --- bmtrain/nn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index b5ceb80d..e22d8c55 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -3,4 +3,4 @@ from .row_parallel_linear import RowParallelLinear from .parallel_embedding import ParallelEmbedding from .parallel_cross_entropy_func import parallel_cross_entropy_func -from .parallel_linear_func import OpParallelLinear +from .parallel_linear_func import OpParallelLinear \ No newline at end of file From 47905b82c1655d87b98427de44d1fbd73c603871 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 11:38:43 +0800 Subject: [PATCH 115/122] version modify --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ad1c8905..2bbb55d8 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ def build_extension(self, ext): ] setup( name='bmtrain', - version='0.2.3.post3', + version='0.2.3.post2', author="Guoyang Zeng", author_email="qbjooo@qq.com", description="A toolkit for training big models", From 568b02a891ed0c86293ace68d0967a115fc5f236 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 11:43:32 +0800 Subject: [PATCH 116/122] reformat code --- bmtrain/block_layer.py | 1 - bmtrain/hook_func.py | 130 +---------------------------------------- bmtrain/offload.py | 117 +++++++++++++++++++++++++++++++++++++ bmtrain/utils.py | 9 ++- 4 files changed, 127 insertions(+), 130 deletions(-) create mode 100644 bmtrain/offload.py diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 72d921a0..8a1b4086 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -557,7 +557,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self._modules[str(i)] = module module._idx = i self.add_module(str(i), module) - print(f"offload layer: {offload}") self._modules[str(0)]._is_first_layer = True self._modules[str(len(modules)-1)]._is_last_layer = True diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index cb9abdba..60b7cc25 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -4,134 +4,8 @@ from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations from collections import deque,OrderedDict from contextlib import contextmanager -from .utils import round_up - -class Offload_Dict: - - def __init__(self): - self._offload_dict = OrderedDict() - - def add(self, tensor): - tensor = tensor.contiguous() - tensor_id = id(tensor) - data_ptr = tensor.storage().data_ptr() - if data_ptr not in self._offload_dict: - self._offload_dict[data_ptr] = {} - self._offload_dict[data_ptr]["stor"] = tensor.storage() - self._offload_dict[data_ptr]["size"] = tensor.storage().size() - self._offload_dict[data_ptr]["dtype"] = tensor.storage().dtype - self._offload_dict[data_ptr]["tensors"] = {} - - self._offload_dict[data_ptr]["tensors"][id(tensor)] = {} - self._offload_dict[data_ptr]["tensors"][id(tensor)]["numel"] = tensor.numel() - self._offload_dict[data_ptr]["tensors"][id(tensor)]['dtype'] = tensor.dtype - self._offload_dict[data_ptr]["tensors"][id(tensor)]['offset'] = tensor.storage_offset() - self._offload_dict[data_ptr]["tensors"][id(tensor)]['tensor'] = tensor - self._offload_dict[data_ptr]["tensors"][id(tensor)]["shape"] = tensor.shape - self._device = "cuda" - return (data_ptr,tensor_id) - - def get_total(self): - fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) - fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) - return fp16_total,fp32_total - - def make_cpu_storage(self): - fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) - fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) - fp16_storage = torch.HalfStorage(fp16_total).pin_memory() - fp32_storage = torch.FloatStorage(fp32_total).pin_memory() - self.fp16_storage = fp16_storage - self.fp32_storage = fp32_storage - self.fp16_total = fp16_total - self.fp32_total = fp32_total - - def get(self, key): - data_ptr, tensor_id = key - return self._offload_dict[data_ptr]['tensors'][tensor_id]["tensor"] - - def pop_all(self): - self._offload_dict.clear() - - def h2d_memcpy(self): - fp16_storage_cuda = self.fp16_storage.cuda(non_blocking=True) - fp32_storage_cuda = self.fp32_storage.cuda(non_blocking=True) - for key,val in self._offload_dict.items(): - for id_val in val['tensors'].values(): - id_val['tensor'] = torch.tensor([], dtype=id_val['dtype'],device=fp16_storage_cuda.device) - if id_val['dtype'] == torch.float16: - id_val['tensor'].set_(fp16_storage_cuda, id_val['abs_offset'], id_val['shape']) - elif id_val['dtype'] == torch.float32: - id_val['tensor'].set_(fp32_storage_cuda, id_val['abs_offset'], id_val['shape']) - - def record_stream(self, stream): - for key, val in self._offload_dict.items(): - for id_val in val['tensors'].values(): - id_val['tensor'].record_stream(stream) - - def d2h_memcpy(self): - fp16_offset = 0 - fp32_offset = 0 - fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) - fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) - assert fp16_total <= self.fp16_total - assert fp32_total <= self.fp32_total - fp16_storage = self.fp16_storage - fp32_storage = self.fp32_storage - for key,val in self._offload_dict.items(): - assert val['dtype'] in [torch.float16, torch.float32] - storage = fp16_storage if val['dtype'] == torch.float16 else fp32_storage - offset = fp16_offset if val['dtype'] == torch.float16 else fp32_offset - for id_val in val['tensors'].values(): - cpu_tensor = torch.tensor([], dtype=id_val['dtype'], device="cpu") \ - .set_(storage, offset+id_val['offset'], id_val['shape']) - id_val["abs_offset"] = offset+id_val['offset'] - id_val['tensor'] = cpu_tensor.copy_(id_val['tensor'], non_blocking=True) - if val['dtype'] == torch.float16: - fp16_offset += val['size'] - else: - fp32_offset += val['size'] - val['stor'] = None - -def find_pre_module_helper(m): - if len(m) == 0: - return None - if m._mode == "OFFLOAD": - return m - else: - return find_pre_module_helper(m.pre_module()) - -def offload_wrapper(offload_dict): - def pack_hook(tensor): - if isinstance(tensor, torch.nn.Parameter): - return (tensor,) - elif tensor.dtype not in [torch.float16]: - return (tensor,) - else: - key = offload_dict.add(tensor) - return (tensor.device, key) - def unpack_hook(packed): - if len(packed) == 2: - device, key = packed - tensor = offload_dict.get(key) - assert tensor.device == device - return tensor - else: - tensor, = packed - return tensor - return pack_hook, unpack_hook - -def offload_pre_hook(module, input): - if hasattr(module, "_offload_hook"): - pack_hook, unpack_hook = module._offload_hook - torch._C._autograd._push_saved_tensors_default_hooks( - pack_hook, unpack_hook - ) - -def offload_post_hook(module, input, output): - if hasattr(module, "_offload_hook"): - torch._C._autograd._pop_saved_tensors_default_hooks() - +from .utils import round_up, find_pre_module_helper +from .offload import Offload_Dict, offload_wrapper, offload_pre_hook, offload_post_hook def zero_pre_forward(module, inputs): enter = True pipe = False diff --git a/bmtrain/offload.py b/bmtrain/offload.py new file mode 100644 index 00000000..8095a66e --- /dev/null +++ b/bmtrain/offload.py @@ -0,0 +1,117 @@ +class Offload_Dict: + + def __init__(self): + self._offload_dict = OrderedDict() + + def add(self, tensor): + tensor = tensor.contiguous() + tensor_id = id(tensor) + data_ptr = tensor.storage().data_ptr() + if data_ptr not in self._offload_dict: + self._offload_dict[data_ptr] = {} + self._offload_dict[data_ptr]["stor"] = tensor.storage() + self._offload_dict[data_ptr]["size"] = tensor.storage().size() + self._offload_dict[data_ptr]["dtype"] = tensor.storage().dtype + self._offload_dict[data_ptr]["tensors"] = {} + + self._offload_dict[data_ptr]["tensors"][id(tensor)] = {} + self._offload_dict[data_ptr]["tensors"][id(tensor)]["numel"] = tensor.numel() + self._offload_dict[data_ptr]["tensors"][id(tensor)]['dtype'] = tensor.dtype + self._offload_dict[data_ptr]["tensors"][id(tensor)]['offset'] = tensor.storage_offset() + self._offload_dict[data_ptr]["tensors"][id(tensor)]['tensor'] = tensor + self._offload_dict[data_ptr]["tensors"][id(tensor)]["shape"] = tensor.shape + self._device = "cuda" + return (data_ptr,tensor_id) + + def get_total(self): + fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + return fp16_total,fp32_total + + def make_cpu_storage(self): + fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + fp16_storage = torch.HalfStorage(fp16_total).pin_memory() + fp32_storage = torch.FloatStorage(fp32_total).pin_memory() + self.fp16_storage = fp16_storage + self.fp32_storage = fp32_storage + self.fp16_total = fp16_total + self.fp32_total = fp32_total + + def get(self, key): + data_ptr, tensor_id = key + return self._offload_dict[data_ptr]['tensors'][tensor_id]["tensor"] + + def pop_all(self): + self._offload_dict.clear() + + def h2d_memcpy(self): + fp16_storage_cuda = self.fp16_storage.cuda(non_blocking=True) + fp32_storage_cuda = self.fp32_storage.cuda(non_blocking=True) + for key,val in self._offload_dict.items(): + for id_val in val['tensors'].values(): + id_val['tensor'] = torch.tensor([], dtype=id_val['dtype'],device=fp16_storage_cuda.device) + if id_val['dtype'] == torch.float16: + id_val['tensor'].set_(fp16_storage_cuda, id_val['abs_offset'], id_val['shape']) + elif id_val['dtype'] == torch.float32: + id_val['tensor'].set_(fp32_storage_cuda, id_val['abs_offset'], id_val['shape']) + + def record_stream(self, stream): + for key, val in self._offload_dict.items(): + for id_val in val['tensors'].values(): + id_val['tensor'].record_stream(stream) + + def d2h_memcpy(self): + fp16_offset = 0 + fp32_offset = 0 + fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + assert fp16_total <= self.fp16_total + assert fp32_total <= self.fp32_total + fp16_storage = self.fp16_storage + fp32_storage = self.fp32_storage + for key,val in self._offload_dict.items(): + assert val['dtype'] in [torch.float16, torch.float32] + storage = fp16_storage if val['dtype'] == torch.float16 else fp32_storage + offset = fp16_offset if val['dtype'] == torch.float16 else fp32_offset + for id_val in val['tensors'].values(): + cpu_tensor = torch.tensor([], dtype=id_val['dtype'], device="cpu") \ + .set_(storage, offset+id_val['offset'], id_val['shape']) + id_val["abs_offset"] = offset+id_val['offset'] + id_val['tensor'] = cpu_tensor.copy_(id_val['tensor'], non_blocking=True) + if val['dtype'] == torch.float16: + fp16_offset += val['size'] + else: + fp32_offset += val['size'] + val['stor'] = None + +def offload_wrapper(offload_dict): + def pack_hook(tensor): + if isinstance(tensor, torch.nn.Parameter): + return (tensor,) + elif tensor.dtype not in [torch.float16]: + return (tensor,) + else: + key = offload_dict.add(tensor) + return (tensor.device, key) + def unpack_hook(packed): + if len(packed) == 2: + device, key = packed + tensor = offload_dict.get(key) + assert tensor.device == device + return tensor + else: + tensor, = packed + return tensor + return pack_hook, unpack_hook + +def offload_pre_hook(module, input): + if hasattr(module, "_offload_hook"): + pack_hook, unpack_hook = module._offload_hook + torch._C._autograd._push_saved_tensors_default_hooks( + pack_hook, unpack_hook + ) + +def offload_post_hook(module, input, output): + if hasattr(module, "_offload_hook"): + torch._C._autograd._pop_saved_tensors_default_hooks() \ No newline at end of file diff --git a/bmtrain/utils.py b/bmtrain/utils.py index 8cb87808..57249e67 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -32,7 +32,14 @@ def load_nccl_pypi(): if file_split[-1] == "so" or (len(file_split)>1 and file_split[-2] == "so"): ctypes.CDLL(os.path.join(path, file_so)) - +def find_pre_module_helper(m): + if len(m) == 0: + return None + if m._mode == "OFFLOAD": + return m + else: + return find_pre_module_helper(m.pre_module()) + def round_up(x, d): return (x + d - 1) // d * d From b249adcd6fcc911a4139423d70a0462dfb624ef9 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 12:56:17 +0800 Subject: [PATCH 117/122] fix pre module --- bmtrain/block_layer.py | 12 ++++++++---- bmtrain/utils.py | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 8a1b4086..29b588f8 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -217,12 +217,16 @@ def set_pre_module(self, pre_module): pre_module._next_module.append(self) def pre_module(self): - assert len(self._pre_module) == self._ref_count, "{} != {}".format(len(self._pre_module), self._ref_count) - return self._pre_module[self._ref_count-1] + if len(self._pre_module) > 0: + return self._pre_module[self._ref_count-1] + else: + return None def next_module(self): - assert len(self._next_module) == self._ref_count, "{} != {}".format(len(self._next_module), self._ref_count) - return self._next_module[self._ref_count-1] + if len(self._next_module) > 0: + return self._next_module[self._ref_count-1] + else: + return None def backward_release(self, flag): if self._ref_count == 1 and self._backward_block_ctx is not None: diff --git a/bmtrain/utils.py b/bmtrain/utils.py index 57249e67..8ca560c9 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -33,8 +33,8 @@ def load_nccl_pypi(): ctypes.CDLL(os.path.join(path, file_so)) def find_pre_module_helper(m): - if len(m) == 0: - return None + if m is None: + return m if m._mode == "OFFLOAD": return m else: From a1b8eee21e93ce80fd2a18df857164e8dfb47fad Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 12:59:45 +0800 Subject: [PATCH 118/122] modify comment --- bmtrain/block_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 29b588f8..cd11aa48 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -78,7 +78,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offl self._storage_params : Dict[str, torch.nn.Parameter] = {} self._storage_info = {} self._ready = False - # sort parameters by nam_next_modulee + # sort parameters by name ordered_parameters = list(self._module.named_parameters()) assert not (use_checkpoint and use_offload), "It does not make sense to use offload and checkpointing at the same time" # calc total number of parameters From 92b863094604a98b1af2751efd43bcad3a2401dd Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 13:10:19 +0800 Subject: [PATCH 119/122] dont expose use offload interface outside --- bmtrain/block_layer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index cd11aa48..47562977 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -66,7 +66,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offload=False, offload_level=0, zero_level=3): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, offload_level=0, zero_level=3): super().__init__() self._module = inner_module self._inputs = None @@ -80,6 +80,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offl self._ready = False # sort parameters by name ordered_parameters = list(self._module.named_parameters()) + use_offload = offload_level in [1,2] assert not (use_checkpoint and use_offload), "It does not make sense to use offload and checkpointing at the same time" # calc total number of parameters for name, param in ordered_parameters: @@ -202,7 +203,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, use_offl self._pre_module = [] #save the pre module of self self._ref_count = 0 #incremental in forward and decreasing in backward self._mode = "BLOCK" #BLOCK or ZERO or PIPE - if use_offload and offload_level != 0: + if use_offload: self._mode = "OFFLOAD" self._on_device = False self.offload_level = offload_level From 1fac581c5bd9632bc16600c3a8f13aa092f3e9d2 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 15:58:09 +0800 Subject: [PATCH 120/122] print tools --- bmtrain/block_layer.py | 4 ++-- bmtrain/utils.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 47562977..bb13d281 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -203,11 +203,11 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, offload_ self._pre_module = [] #save the pre module of self self._ref_count = 0 #incremental in forward and decreasing in backward self._mode = "BLOCK" #BLOCK or ZERO or PIPE + self.offload_level = offload_level if use_offload: self._mode = "OFFLOAD" self._on_device = False - self.offload_level = offload_level - + self.all_input_no_grad = False self.all_param_no_grad = False self._zero_level = zero_level diff --git a/bmtrain/utils.py b/bmtrain/utils.py index 8ca560c9..0d72106b 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -87,6 +87,17 @@ def print_rank(*args, rank=0, **kwargs): if config["rank"] == rank: print(*args, **kwargs) +def print_strategy(model): + print_rank(" "*24+"|"+" Offload Level |" + " ZeRO Level |"+" Activation Recompute |") + for idx,ckpt in enumerate(model): + print_rank(f"CheckpointBlock Layer {idx} |{ckpt.offload_level:^14} | {ckpt._zero_level:^10} | {ckpt.use_checkpoint.__repr__():^20} |") + +def print_inspect(model): + model_inspect = bmt.inspect.inspect_model(model, "*") + print_rank(bmt.inspect.format_summary(model_inspect)) + + + def see_memory(message, detail=False): """ Outputs a message followed by GPU memory status summary on rank 0. From f66c162364daf8706965b1fcaacd53d49b6f7377 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 16:40:03 +0800 Subject: [PATCH 121/122] high priority for offload stream --- bmtrain/block_layer.py | 1 - bmtrain/init.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index bb13d281..d90a0961 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -560,7 +560,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module.calc_event = torch.cuda.Event() module.offload_event = torch.cuda.Event() self._modules[str(i)] = module - module._idx = i self.add_module(str(i), module) self._modules[str(0)]._is_first_layer = True self._modules[str(len(modules)-1)]._is_last_layer = True diff --git a/bmtrain/init.py b/bmtrain/init.py index d5640cc9..e002ae9c 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -72,9 +72,9 @@ def init_distributed( config["rank"] = rank config["world_size"] = world_size config["calc_stream"] = torch.cuda.current_stream() + config["offload_stream"] = torch.cuda.Stream(priority=-1) config["load_stream"] = torch.cuda.Stream(priority=-1) config["tp_comm_stream"] = torch.cuda.Stream(priority=-1) - config["offload_stream"] = torch.cuda.Stream() config["pp_comm_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() From aef48994af46b5780e076545f81244457546e4e0 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 29 Aug 2023 19:17:15 +0800 Subject: [PATCH 122/122] fix import --- bmtrain/hook_func.py | 1 - bmtrain/offload.py | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 60b7cc25..b74c3149 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -2,7 +2,6 @@ from .global_var import config from .checkpointing import CheckpointBlockContext from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations -from collections import deque,OrderedDict from contextlib import contextmanager from .utils import round_up, find_pre_module_helper from .offload import Offload_Dict, offload_wrapper, offload_pre_hook, offload_post_hook diff --git a/bmtrain/offload.py b/bmtrain/offload.py index 8095a66e..d589d663 100644 --- a/bmtrain/offload.py +++ b/bmtrain/offload.py @@ -1,3 +1,6 @@ +import torch +from collections import OrderedDict + class Offload_Dict: def __init__(self): @@ -85,6 +88,7 @@ def d2h_memcpy(self): fp32_offset += val['size'] val['stor'] = None + def offload_wrapper(offload_dict): def pack_hook(tensor): if isinstance(tensor, torch.nn.Parameter):