Skip to content

Commit

Permalink
Merge pull request #186 from OpenBMB/vpe
Browse files Browse the repository at this point in the history
Vocab parallel Embedding impl and make example work when tp_size > 1
  • Loading branch information
MayDomine authored Feb 26, 2024
2 parents 0def29f + f915f94 commit 5713d76
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 59 deletions.
4 changes: 2 additions & 2 deletions bmtrain/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .linear import Linear, OpLinear
from .column_parallel_linear import ColumnParallelLinear
from .row_parallel_linear import RowParallelLinear
from .parallel_embedding import Projection, VPProjection
from .parallel_linear_func import OpParallelLinear
from .parallel_embedding import VPEmbedding
from .parallel_linear_func import OpParallelLinear
46 changes: 9 additions & 37 deletions bmtrain/nn/parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,8 @@
from bmtrain.distributed import all_reduce, all_gather
from .parallel_linear_func import OpParallelLinear

class Projection(bmt.DistributedModule):
def __init__(
self,
vocab_size: int,
embedding_size: int,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
):
super().__init__()

self.dim_model = embedding_size
self.weight = bmt.DistributedParameter(
torch.empty(vocab_size, embedding_size, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)

def forward(self, x: torch.Tensor):
"""
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
""" # noqa: E501
logits = F.linear(x, self.weight)
return logits

class VPProjection(bmt.DistributedModule):
class VPEmbedding(bmt.DistributedModule):
def __init__(
self,
vocab_size: int,
Expand All @@ -59,12 +32,11 @@ def __init__(
tp_mode=True,
)

def forward(self, x: torch.Tensor):
"""
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
""" # noqa: E501
return bmt.nn.OpParallelLinear.apply(x, self.weight, None, False, False, False, None, 1)
def forward(self, x: torch.Tensor, projection=False):
if not projection:
weight = all_gather(self.weight, comm=config['tp_comm']).flatten(0,1)
out = F.embedding(x, weight)
return out
else:
x = bmt.distributed.all_gather(x, comm=bmt.config['tp_comm']).view(x.shape[0], -1, x.shape[-1])
return bmt.nn.OpParallelLinear.apply(x, self.weight, None, False, False, False, None, 1)
16 changes: 8 additions & 8 deletions example/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def forward(self,
mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv)
position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv)
) -> torch.Tensor:
batch_size, seq_q, dim_model = hidden_q.size()
seq_kv = hidden_kv.size(1)
batch_size = hidden_q.size()[0]

assert hidden_q.data_ptr() == hidden_kv.data_ptr()

Expand All @@ -54,14 +53,16 @@ def forward(self,
True, False,
False, None
)
hidden_q = hidden_q.view(batch_size, -1, hidden_q.shape[-1])
h_q, h_k, h_v = hidden_q.chunk(3, dim=-1)
#batch_size will changed in TensorParallel
batch_size = h_v.shape[0]
else:
h_q : torch.Tensor = self.project_q(hidden_q)
h_k : torch.Tensor = self.project_k(hidden_kv)
h_v : torch.Tensor = self.project_v(hidden_kv)

seq_q = h_q.size()[1]
seq_kv = h_k.size(1)

h_q = h_q.view(batch_size, seq_q, -1, self.dim_head)
h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head)
h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head)
Expand All @@ -84,10 +85,6 @@ def forward(self,
if position_bias is not None:
score = score + position_bias.view(batch_size, -1, seq_q, seq_kv)

if config['tp_size'] > 1:
with torch.no_grad():
mask = all_gather(mask, config['tp_comm']).flatten(0,1)

score = torch.where(
mask.view(batch_size, 1, seq_q, seq_kv),
score,
Expand All @@ -108,8 +105,11 @@ def forward(self,
h_out = h_out.view(batch_size, -1, seq_q, self.dim_head)
h_out = h_out.permute(0, 2, 1, 3).contiguous()
h_out = h_out.view(batch_size, seq_q, -1)
if config['tp_size'] > 1:
h_out = h_out.view(h_out.shape[0] * bmt.config["tp_size"], -1, h_out.shape[-1])

attn_out = self.project_out(h_out)

return attn_out


Expand Down
2 changes: 1 addition & 1 deletion example/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def forward(self,

x = self.ln_ff(hidden)
x = self.ff(x)
hidden = hidden + x
hidden = hidden + x.view_as(hidden)

return hidden

14 changes: 6 additions & 8 deletions example/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def __init__(self,

self.max_distance = max_distance

if config['tp_size'] > 1:
self.word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype)
if config["tp_size"] > 1:
self.word_emb = bmt.nn.VPEmbedding(vocab_size, dim_model, dtype=dtype)
else:
self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype)
self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype)
Expand Down Expand Up @@ -50,17 +50,15 @@ def forward(self,

mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len)
mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None])

if config["tp_size"] > 1:
input = input.chunk(config["tp_size"], dim=1)[config["tp_rank"]]
pos = pos.chunk(config["tp_size"], dim=1)[config["tp_rank"]]
out = self.pos_emb(pos) + self.word_emb(input)

# for layer in self.transformers:
out = self.transformers(out, mask_2d, None)
out = self.layernorm(out)

if config['tp_size'] > 1:
logits = self.word_emb.projection(out)
else:
logits = self.word_emb(out, projection=True)
logits = self.word_emb(out, projection=True)
bmt.inspect.record_tensor(logits, "logits")

return logits
8 changes: 5 additions & 3 deletions example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def main():

batch_size = 2
seq_len = 512
world_size = bmt.config["world_size"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_size"]
r = bmt.config["rank"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_rank"]

for i in range(bmt.world_size()):
for i in range(world_size):
sent = torch.randint(0, 10240, (batch_size, seq_len + 1))
enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda()
enc_input = sent[:, :-1].long().cuda()
Expand All @@ -49,7 +51,7 @@ def main():
torch.full_like(targets, -100, dtype=torch.long)
)

if i == bmt.rank():
if i == r:
break

if config['tp_size'] > 1:
Expand Down Expand Up @@ -82,7 +84,7 @@ def main():
batch, seq_len, vocab_out_size = logits.size()

if config['tp_size'] > 1:
loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets)
loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
else:
loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))

Expand Down

0 comments on commit 5713d76

Please sign in to comment.