Skip to content

Commit

Permalink
delete test file
Browse files Browse the repository at this point in the history
  • Loading branch information
MayDomine committed Aug 29, 2023
1 parent 832141a commit 8bd6475
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 1,137 deletions.
68 changes: 48 additions & 20 deletions example/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from typing import Optional
import torch
import bmtrain as bmt
from bmtrain.nn import Linear
from bmtrain.nn import (
Linear,
ColumnParallelLinear,
RowParallelLinear,
)
import math
from bmtrain.global_var import config
from bmtrain.distributed import all_gather

class Attention(bmt.DistributedModule):
def __init__(self,
Expand All @@ -12,11 +18,17 @@ def __init__(self,
) -> None:
super().__init__()

self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
if config['tp_size'] > 1:
self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
else:
self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)

self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)

self.softmax = torch.nn.Softmax(dim=-1)
self.num_heads = num_heads
Expand All @@ -32,32 +44,48 @@ def forward(self,
batch_size, seq_q, dim_model = hidden_q.size()
seq_kv = hidden_kv.size(1)

h_q : torch.Tensor = self.project_q(hidden_q)
h_k : torch.Tensor = self.project_k(hidden_kv)
h_v : torch.Tensor = self.project_v(hidden_kv)
assert hidden_q.data_ptr() == hidden_kv.data_ptr()

h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head)
h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head)
h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head)
hidden_q = bmt.nn.OpParallelLinear.apply(
hidden_q,
torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0),
torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0),
True, False,
False, None
)

h_q, h_k, h_v = hidden_q.chunk(3, dim=-1)

if config['tp_size'] > 1:
#batch_size will changed in TensorParallel
batch_size = h_v.shape[0]

h_q = h_q.view(batch_size, seq_q, -1, self.dim_head)
h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head)
h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head)

h_q = h_q.permute(0, 2, 1, 3).contiguous()
h_k = h_k.permute(0, 2, 1, 3).contiguous()
h_v = h_v.permute(0, 2, 1, 3).contiguous()

h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head)
h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head)
h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head)
h_q = h_q.view(-1, seq_q, self.dim_head)
h_k = h_k.view(-1, seq_kv, self.dim_head)
h_v = h_v.view(-1, seq_kv, self.dim_head)

score = torch.bmm(
h_q, h_k.transpose(1, 2)
)
score = score / math.sqrt(self.dim_head)

score = score.view(batch_size, self.num_heads, seq_q, seq_kv)
score = score.view(batch_size, -1, seq_q, seq_kv)

if position_bias is not None:
score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv)

score = score + position_bias.view(batch_size, -1, seq_q, seq_kv)

if config['tp_size'] > 1:
with torch.no_grad():
mask = all_gather(mask, config['tp_comm']).flatten(0,1)

score = torch.where(
mask.view(batch_size, 1, seq_q, seq_kv),
score,
Expand All @@ -70,14 +98,14 @@ def forward(self,
torch.scalar_tensor(0, device=score.device, dtype=score.dtype)
)

score = score.view(batch_size * self.num_heads, seq_q, seq_kv)
score = score.view(-1, seq_q, seq_kv)

h_out = torch.bmm(
score, h_v
)
h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head)
h_out = h_out.view(batch_size, -1, seq_q, self.dim_head)
h_out = h_out.permute(0, 2, 1, 3).contiguous()
h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head)
h_out = h_out.view(batch_size, seq_q, -1)

attn_out = self.project_out(h_out)
return attn_out
Expand Down
8 changes: 5 additions & 3 deletions example/layers/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,

def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor:
if not projection:
return F.embedding(
out = F.embedding(
input, self.weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
return out
else:
return F.linear(input, self.weight) / math.sqrt(self.embedding_dim)
out = F.linear(input, self.weight)
return out

def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}'
Expand All @@ -97,4 +99,4 @@ def extra_repr(self) -> str:
s += ', sparse=True'
return s.format(**self.__dict__)



15 changes: 11 additions & 4 deletions example/layers/feedforward.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import torch
import bmtrain as bmt
from bmtrain.nn import Linear
from bmtrain.nn import (
Linear,
ColumnParallelLinear,
RowParallelLinear)
from bmtrain.global_var import config

class Feedforward(bmt.DistributedModule):
def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None:
super().__init__()

self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype)
self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype)
if config['tp_size'] > 1:
self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype)
self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype)
else:
self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype)
self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype)

self.relu = torch.nn.ReLU()

def forward(self, input : torch.Tensor) -> torch.Tensor:

return self.w_out(self.relu(self.w_in(input)))
Loading

0 comments on commit 8bd6475

Please sign in to comment.