Skip to content

Commit

Permalink
support multiple input-output in transformerblocklist
Browse files Browse the repository at this point in the history
  • Loading branch information
Achazwl committed May 4, 2023
1 parent d531727 commit 7084623
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 29 deletions.
82 changes: 53 additions & 29 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,29 +680,30 @@ def __repr__(self):

class OpTransformerBlockList(torch.autograd.Function):
@staticmethod
def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_state, *args):
def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, num_hidden, *args):
tensors = []
others = []
for arg in args:
for arg in args[num_hidden:]:
if torch.is_tensor(arg):
tensors.append(arg)
others.append(None)
else:
tensors.append(None)
others.append(arg)
hidden_states = args[:num_hidden]

ctx.nontensor_inputs = others
ctx.self = self
ctx.save_list = copy.deepcopy(save_list)
ctx.num_save_needed = save_list[-1][1]+1
ctx.layers_dict=[{} for _ in range(len(self))]
ctx.layers_dict = [{} for _ in range(len(self))]
layer_inputs = []
layer_inspector = []
cuda_rng_state = []
for i in range(len(self)):
with torch.no_grad():
if save_list[i][0] == i:
layer_inputs.append(hidden_state.detach())
layer_inputs += [hidden_state.detach() for hidden_state in hidden_states]
cuda_rng_state.append( torch.cuda.get_rng_state() )
if config['zero_level']==2:
flag = 1
Expand All @@ -713,29 +714,38 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s
block_ctx.enter()
# call inner module directly
with ScopedTensorInspectorContext() as inspector:
hidden_state = self._modules[str(i)]._module._call_impl(hidden_state, *args)
hidden_states = self._modules[str(i)]._module._call_impl(*hidden_states, *args[num_hidden:])
if not isinstance(hidden_states, tuple):
hidden_states = (hidden_states,)
block_ctx.exit()
for it in inspector.hidden_states:
debug.append("_inspect_hidden_states", it)
layer_inspector.append(inspector.hidden_states)

ctx.layer_inspector = layer_inspector
ctx.cuda_rng_state = cuda_rng_state
ctx.num_hidden = num_hidden

ctx.save_for_backward(*layer_inputs, *tensors)

if self.return_hidden_states:
middle_hiddens = layer_inputs
for mid in middle_hiddens:
mid.requires_grad_()
middle_hiddens = torch.stack(middle_hiddens, dim=0)
middle_hiddens = [
torch.stack(middle_hiddens[i::num_hidden], dim=0)
for i in range(num_hidden)
]
else:
middle_hiddens = None
return tuple([hidden_state, middle_hiddens] + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens])
middle_hiddens = [None] * num_hidden
return tuple(list(hidden_states) + middle_hiddens + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens])


@staticmethod
def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle: List[torch.Tensor], *grad_inspectors):
def backward(ctx, *grads):
grad_hidden_states = grads[:ctx.num_hidden]
grad_middles = grads[ctx.num_hidden:2*ctx.num_hidden]
grad_inspectors = grads[2*ctx.num_hidden:]
def exit_prev(prev_ctx, prev_grad):
if prev_ctx is not None:
if prev_grad:
Expand All @@ -755,8 +765,8 @@ def exit_prev(prev_ctx, prev_grad):
all_inputs = []
input_requires_grad = []

layer_inputs = ctx.saved_tensors[:ctx.num_save_needed]
save_args = ctx.saved_tensors[ctx.num_save_needed:]
layer_inputs = ctx.saved_tensors[:ctx.num_save_needed * ctx.num_hidden]
save_args = ctx.saved_tensors[ctx.num_save_needed * ctx.num_hidden:]
for tensor, other in zip(save_args, ctx.nontensor_inputs):
if tensor is None:
all_inputs.append(other)
Expand Down Expand Up @@ -786,14 +796,23 @@ def exit_prev(prev_ctx, prev_grad):
block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)], ctx.layers_dict[j], flag)
block_ctx.enter()
exit_prev(prev_ctx, prev_grad)
output = ctx.self._modules[str(j)]._module._call_impl(layer_inputs[ctx.save_list[j][1]], *all_inputs)
outputs = ctx.self._modules[str(j)]._module._call_impl(
layer_inputs[ctx.save_list[j][1]*ctx.num_hidden: ctx.save_list[j][1]*ctx.num_hidden+ctx.num_hidden],
*all_inputs
)
if not isinstance(outputs, tuple):
outputs = (outputs,)
prev_ctx = block_ctx
prev_grad = False
layer_inputs[ctx.save_list[j+1][1]].copy_(output)
for k, output in enumerate(outputs):
layer_inputs[ctx.save_list[j+1][1]*ctx.num_hidden + k].copy_(output)
ctx.save_list[j+1][0] = j+1

torch.cuda.set_rng_state(ctx.cuda_rng_state[i])
ipt = layer_inputs[ctx.save_list[i][1]].detach().requires_grad_()
ipts = [
layer_inputs[ctx.save_list[i][1]*ctx.num_hidden + k].detach().requires_grad_()
for k in range(ctx.num_hidden)
]
if config['zero_level'] == 2:
flag = 2
else:
Expand All @@ -805,7 +824,9 @@ def exit_prev(prev_ctx, prev_grad):
prev_grad = True

with ScopedTensorInspectorContext() as inspector:
output = ctx.self._modules[str(i)]._module._call_impl(ipt, *all_inputs)
outputs = ctx.self._modules[str(i)]._module._call_impl(*ipts, *all_inputs)
if not isinstance(outputs, tuple):
outputs = (outputs,)

assert len(ctx.layer_inspector[i]) == len(inspector.hidden_states), "Backward step changed"
for j, it in enumerate(inspector.hidden_states):
Expand All @@ -818,18 +839,20 @@ def exit_prev(prev_ctx, prev_grad):
ctx.layer_inspector[i][j]["requires_grad"] = it["requires_grad"]
if len(inspector.hidden_states) > 0:
torch.autograd.backward(
[output] + [hidden_state["tensor"] for hidden_state in inspector.hidden_states],
(grad_hidden_state,) + grad_inspectors[-len(inspector.hidden_states):],
list(outputs) + [hidden_state["tensor"] for hidden_state in inspector.hidden_states],
grad_hidden_states + grad_inspectors[-len(inspector.hidden_states):],
)
grad_inspectors = grad_inspectors[:-len(inspector.hidden_states)]
else:
torch.autograd.backward(
[output],
(grad_hidden_state,),
outputs,
grad_hidden_states,
)
grad_hidden_state = ipt.grad
if grad_middle is not None:
grad_hidden_state = grad_hidden_state + grad_middle[i]
grad_hidden_states = [ipt.grad for ipt in ipts]
for k in range(ctx.num_hidden):
if grad_middles[k] is not None:
grad_hidden_states[k] = grad_hidden_states[k] + grad_middles[k][i]
grad_hidden_states = tuple(grad_hidden_states)

exit_prev(prev_ctx, prev_grad)

Expand All @@ -839,7 +862,7 @@ def exit_prev(prev_ctx, prev_grad):
grads.append(inp.grad)
else:
grads.append(None)
return (None, None, None, grad_hidden_state) + tuple(grads)
return (None, None, None, None) + tuple(grad_hidden_states) + tuple(grads)

class TransformerBlockList(torch.nn.Module):
r"""
Expand All @@ -862,7 +885,7 @@ class TransformerBlockList(torch.nn.Module):
"""
_modules: Dict[str, CheckpointBlock]

def __init__(self, modules: Iterable[CheckpointBlock], sqrt=False) -> None:
def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) -> None:
super().__init__()

self._modules = {}
Expand All @@ -872,6 +895,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], sqrt=False) -> None:
self._modules[str(i)] = module
self.add_module(str(i), module)

self.num_hidden = num_hidden

if sqrt:
length = len(self)
num_save_needed = 0
Expand Down Expand Up @@ -901,12 +926,11 @@ def __iter__(self) -> Iterator[CheckpointBlock]:
def __getitem__(self, index: Union[int, str]) -> CheckpointBlock:
return self._modules[str(index)]

def forward(self, hidden_state, *args, return_hidden_states = False):
def forward(self, *args, return_hidden_states = False):
self.return_hidden_states = return_hidden_states
placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled())
outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args)
last_hidden, middle_hiddens = outputs[:2]
outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, self.num_hidden, *args)
if return_hidden_states:
return last_hidden, middle_hiddens
return tuple(outputs[:2*self.num_hidden])
else:
return last_hidden
return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0]
1 change: 1 addition & 0 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
("dropout", 1),
("loss_func", 1),

("multi_return", 2),
("middle_hidden", 4),
("other_hidden", 4),
("inspector_hidden", 2),
Expand Down
126 changes: 126 additions & 0 deletions tests/test_multi_return.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from utils import *

import bmtrain as bmt
import torch
import random
from bmtrain import config
from bmtrain.block_layer import CheckpointBlock, TransformerBlockList
from bmtrain.pipe_layer import PipelineTransformerBlockList
import torch.nn.functional as F

class MultiInputReturn(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c, d, e):
return a*2, b+d, c*4+e*5

class Model_ZERO(torch.nn.Module):
def __init__(self, ms) -> None:
super().__init__()
self.ms = TransformerBlockList([
CheckpointBlock(m)
for m in ms
], num_hidden=3)

def forward(self, x):
y = self.ms(*x)
return y

class Model_PIPE(torch.nn.Module):
def __init__(self, ms) -> None:
super().__init__()
self.ms = PipelineTransformerBlockList([
CheckpointBlock(m)
for m in ms
], num_hidden=3)

def forward(self, x):
y = self.ms(*x)
return y

class Model_BLOCK(torch.nn.Module):
def __init__(self, ms) -> None:
super().__init__()
self.ms = torch.nn.ModuleList([
CheckpointBlock(m)
for m in ms
])

def forward(self, x):
y = x[:3]
other = x[3:]
for m in self.ms:
y = m(*y, *other)
return y

class Model_NORMAL(torch.nn.Module):
def __init__(self, ms) -> None:
super().__init__()
self.ms = torch.nn.ModuleList(ms)

def forward(self, x):
y = x[:3]
other = x[3:]
for m in self.ms:
y = m(*y, *other)
return y

def manual_seed(seed=33):
torch.manual_seed(seed)
random.seed(seed)
try:
import numpy as np
np.random.seed(seed)
except ModuleNotFoundError:
pass

def run(name, cls, num_layer=4, dim=4096):
manual_seed()

ms = [MultiInputReturn() for i in range(num_layer)]

inps = (
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
)
last_weights = (
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
)

for inp in inps:
inp.requires_grad_(True)
m = cls(ms)

ret = ""
logits = m(inps)
loss = (logits[0]*last_weights[0] + logits[1]*last_weights[1] + logits[2]*last_weights[2]).sum()
loss.backward()
return list(logits) + [
inp.grad
for inp in inps
]

def test_main():
ret = {}
ret["normal"] = run("normal", Model_NORMAL)
ret["block"] = run("block", Model_BLOCK)
ret["zero"] = run("zero", Model_ZERO)
# ret["pipe"] = run("pipe", Model_PIPE) # TODO pipeline not support multiple input-output yet
for k, r in ret.items():
bmt.print_rank(f"============={k}============")
bmt.print_rank(r)
for r in ret.values():
for r2 in ret.values():
for i in range(len(r)):
assert_lt((r[i]-r2[i]).abs().max(), 1e-5)

if __name__ == "__main__":
bmt.init_distributed(pipe_size=2)

test_main()

0 comments on commit 7084623

Please sign in to comment.