diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index aa781c4d..5269673b 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -680,29 +680,30 @@ def __repr__(self): class OpTransformerBlockList(torch.autograd.Function): @staticmethod - def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_state, *args): + def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, num_hidden, *args): tensors = [] others = [] - for arg in args: + for arg in args[num_hidden:]: if torch.is_tensor(arg): tensors.append(arg) others.append(None) else: tensors.append(None) others.append(arg) + hidden_states = args[:num_hidden] ctx.nontensor_inputs = others ctx.self = self ctx.save_list = copy.deepcopy(save_list) ctx.num_save_needed = save_list[-1][1]+1 - ctx.layers_dict=[{} for _ in range(len(self))] + ctx.layers_dict = [{} for _ in range(len(self))] layer_inputs = [] layer_inspector = [] cuda_rng_state = [] for i in range(len(self)): with torch.no_grad(): if save_list[i][0] == i: - layer_inputs.append(hidden_state.detach()) + layer_inputs += [hidden_state.detach() for hidden_state in hidden_states] cuda_rng_state.append( torch.cuda.get_rng_state() ) if config['zero_level']==2: flag = 1 @@ -713,7 +714,9 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s block_ctx.enter() # call inner module directly with ScopedTensorInspectorContext() as inspector: - hidden_state = self._modules[str(i)]._module._call_impl(hidden_state, *args) + hidden_states = self._modules[str(i)]._module._call_impl(*hidden_states, *args[num_hidden:]) + if not isinstance(hidden_states, tuple): + hidden_states = (hidden_states,) block_ctx.exit() for it in inspector.hidden_states: debug.append("_inspect_hidden_states", it) @@ -721,6 +724,7 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s ctx.layer_inspector = layer_inspector ctx.cuda_rng_state = cuda_rng_state + ctx.num_hidden = num_hidden ctx.save_for_backward(*layer_inputs, *tensors) @@ -728,14 +732,20 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s middle_hiddens = layer_inputs for mid in middle_hiddens: mid.requires_grad_() - middle_hiddens = torch.stack(middle_hiddens, dim=0) + middle_hiddens = [ + torch.stack(middle_hiddens[i::num_hidden], dim=0) + for i in range(num_hidden) + ] else: - middle_hiddens = None - return tuple([hidden_state, middle_hiddens] + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens]) + middle_hiddens = [None] * num_hidden + return tuple(list(hidden_states) + middle_hiddens + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens]) @staticmethod - def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle: List[torch.Tensor], *grad_inspectors): + def backward(ctx, *grads): + grad_hidden_states = grads[:ctx.num_hidden] + grad_middles = grads[ctx.num_hidden:2*ctx.num_hidden] + grad_inspectors = grads[2*ctx.num_hidden:] def exit_prev(prev_ctx, prev_grad): if prev_ctx is not None: if prev_grad: @@ -755,8 +765,8 @@ def exit_prev(prev_ctx, prev_grad): all_inputs = [] input_requires_grad = [] - layer_inputs = ctx.saved_tensors[:ctx.num_save_needed] - save_args = ctx.saved_tensors[ctx.num_save_needed:] + layer_inputs = ctx.saved_tensors[:ctx.num_save_needed * ctx.num_hidden] + save_args = ctx.saved_tensors[ctx.num_save_needed * ctx.num_hidden:] for tensor, other in zip(save_args, ctx.nontensor_inputs): if tensor is None: all_inputs.append(other) @@ -786,14 +796,23 @@ def exit_prev(prev_ctx, prev_grad): block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)], ctx.layers_dict[j], flag) block_ctx.enter() exit_prev(prev_ctx, prev_grad) - output = ctx.self._modules[str(j)]._module._call_impl(layer_inputs[ctx.save_list[j][1]], *all_inputs) + outputs = ctx.self._modules[str(j)]._module._call_impl( + layer_inputs[ctx.save_list[j][1]*ctx.num_hidden: ctx.save_list[j][1]*ctx.num_hidden+ctx.num_hidden], + *all_inputs + ) + if not isinstance(outputs, tuple): + outputs = (outputs,) prev_ctx = block_ctx prev_grad = False - layer_inputs[ctx.save_list[j+1][1]].copy_(output) + for k, output in enumerate(outputs): + layer_inputs[ctx.save_list[j+1][1]*ctx.num_hidden + k].copy_(output) ctx.save_list[j+1][0] = j+1 torch.cuda.set_rng_state(ctx.cuda_rng_state[i]) - ipt = layer_inputs[ctx.save_list[i][1]].detach().requires_grad_() + ipts = [ + layer_inputs[ctx.save_list[i][1]*ctx.num_hidden + k].detach().requires_grad_() + for k in range(ctx.num_hidden) + ] if config['zero_level'] == 2: flag = 2 else: @@ -805,7 +824,9 @@ def exit_prev(prev_ctx, prev_grad): prev_grad = True with ScopedTensorInspectorContext() as inspector: - output = ctx.self._modules[str(i)]._module._call_impl(ipt, *all_inputs) + outputs = ctx.self._modules[str(i)]._module._call_impl(*ipts, *all_inputs) + if not isinstance(outputs, tuple): + outputs = (outputs,) assert len(ctx.layer_inspector[i]) == len(inspector.hidden_states), "Backward step changed" for j, it in enumerate(inspector.hidden_states): @@ -818,18 +839,20 @@ def exit_prev(prev_ctx, prev_grad): ctx.layer_inspector[i][j]["requires_grad"] = it["requires_grad"] if len(inspector.hidden_states) > 0: torch.autograd.backward( - [output] + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], - (grad_hidden_state,) + grad_inspectors[-len(inspector.hidden_states):], + list(outputs) + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], + grad_hidden_states + grad_inspectors[-len(inspector.hidden_states):], ) grad_inspectors = grad_inspectors[:-len(inspector.hidden_states)] else: torch.autograd.backward( - [output], - (grad_hidden_state,), + outputs, + grad_hidden_states, ) - grad_hidden_state = ipt.grad - if grad_middle is not None: - grad_hidden_state = grad_hidden_state + grad_middle[i] + grad_hidden_states = [ipt.grad for ipt in ipts] + for k in range(ctx.num_hidden): + if grad_middles[k] is not None: + grad_hidden_states[k] = grad_hidden_states[k] + grad_middles[k][i] + grad_hidden_states = tuple(grad_hidden_states) exit_prev(prev_ctx, prev_grad) @@ -839,7 +862,7 @@ def exit_prev(prev_ctx, prev_grad): grads.append(inp.grad) else: grads.append(None) - return (None, None, None, grad_hidden_state) + tuple(grads) + return (None, None, None, None) + tuple(grad_hidden_states) + tuple(grads) class TransformerBlockList(torch.nn.Module): r""" @@ -862,7 +885,7 @@ class TransformerBlockList(torch.nn.Module): """ _modules: Dict[str, CheckpointBlock] - def __init__(self, modules: Iterable[CheckpointBlock], sqrt=False) -> None: + def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) -> None: super().__init__() self._modules = {} @@ -872,6 +895,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], sqrt=False) -> None: self._modules[str(i)] = module self.add_module(str(i), module) + self.num_hidden = num_hidden + if sqrt: length = len(self) num_save_needed = 0 @@ -901,12 +926,11 @@ def __iter__(self) -> Iterator[CheckpointBlock]: def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: return self._modules[str(index)] - def forward(self, hidden_state, *args, return_hidden_states = False): + def forward(self, *args, return_hidden_states = False): self.return_hidden_states = return_hidden_states placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args) - last_hidden, middle_hiddens = outputs[:2] + outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, self.num_hidden, *args) if return_hidden_states: - return last_hidden, middle_hiddens + return tuple(outputs[:2*self.num_hidden]) else: - return last_hidden + return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] \ No newline at end of file diff --git a/tests/test_all.py b/tests/test_all.py index 957403b7..b614d3eb 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -15,6 +15,7 @@ ("dropout", 1), ("loss_func", 1), + ("multi_return", 2), ("middle_hidden", 4), ("other_hidden", 4), ("inspector_hidden", 2), diff --git a/tests/test_multi_return.py b/tests/test_multi_return.py new file mode 100644 index 00000000..8c408a1f --- /dev/null +++ b/tests/test_multi_return.py @@ -0,0 +1,126 @@ +from utils import * + +import bmtrain as bmt +import torch +import random +from bmtrain import config +from bmtrain.block_layer import CheckpointBlock, TransformerBlockList +from bmtrain.pipe_layer import PipelineTransformerBlockList +import torch.nn.functional as F + +class MultiInputReturn(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c, d, e): + return a*2, b+d, c*4+e*5 + +class Model_ZERO(torch.nn.Module): + def __init__(self, ms) -> None: + super().__init__() + self.ms = TransformerBlockList([ + CheckpointBlock(m) + for m in ms + ], num_hidden=3) + + def forward(self, x): + y = self.ms(*x) + return y + +class Model_PIPE(torch.nn.Module): + def __init__(self, ms) -> None: + super().__init__() + self.ms = PipelineTransformerBlockList([ + CheckpointBlock(m) + for m in ms + ], num_hidden=3) + + def forward(self, x): + y = self.ms(*x) + return y + +class Model_BLOCK(torch.nn.Module): + def __init__(self, ms) -> None: + super().__init__() + self.ms = torch.nn.ModuleList([ + CheckpointBlock(m) + for m in ms + ]) + + def forward(self, x): + y = x[:3] + other = x[3:] + for m in self.ms: + y = m(*y, *other) + return y + +class Model_NORMAL(torch.nn.Module): + def __init__(self, ms) -> None: + super().__init__() + self.ms = torch.nn.ModuleList(ms) + + def forward(self, x): + y = x[:3] + other = x[3:] + for m in self.ms: + y = m(*y, *other) + return y + +def manual_seed(seed=33): + torch.manual_seed(seed) + random.seed(seed) + try: + import numpy as np + np.random.seed(seed) + except ModuleNotFoundError: + pass + +def run(name, cls, num_layer=4, dim=4096): + manual_seed() + + ms = [MultiInputReturn() for i in range(num_layer)] + + inps = ( + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + ) + last_weights = ( + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + ) + + for inp in inps: + inp.requires_grad_(True) + m = cls(ms) + + ret = "" + logits = m(inps) + loss = (logits[0]*last_weights[0] + logits[1]*last_weights[1] + logits[2]*last_weights[2]).sum() + loss.backward() + return list(logits) + [ + inp.grad + for inp in inps + ] + +def test_main(): + ret = {} + ret["normal"] = run("normal", Model_NORMAL) + ret["block"] = run("block", Model_BLOCK) + ret["zero"] = run("zero", Model_ZERO) + # ret["pipe"] = run("pipe", Model_PIPE) # TODO pipeline not support multiple input-output yet + for k, r in ret.items(): + bmt.print_rank(f"============={k}============") + bmt.print_rank(r) + for r in ret.values(): + for r2 in ret.values(): + for i in range(len(r)): + assert_lt((r[i]-r2[i]).abs().max(), 1e-5) + +if __name__ == "__main__": + bmt.init_distributed(pipe_size=2) + + test_main()