Skip to content

Commit

Permalink
code visibility
Browse files Browse the repository at this point in the history
  • Loading branch information
themurtazanazir committed Mar 12, 2024
1 parent b78a957 commit 96a7a1e
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions docs/neural_networks/transformer/transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ class MultiHeadAttention(nn.Module):
v_proj_heads = v_proj.view(bsize, -1, self.num_heads,
int(self.d_model/self.num_heads))

##put all heads in batch dim since scaled_dot_product_attn already handles batch
# put all heads in batch dim since scaled_dot_product_attn
# already handles batch

# q_proj_batched: batch*n_heads, T_q, dq
q_proj_batched = q_proj_heads.transpose(1,2).reshape(bsize*self.num_heads,
-1, int(self.d_model/self.num_heads) )
Expand All @@ -271,20 +273,18 @@ class MultiHeadAttention(nn.Module):
-1, int(self.d_model/self.num_heads) )

##attn_out: batch*n_heads, T_q, dq
attn_out = scaled_dot_product_attn(q_proj_batched, k_proj_batched, v_proj_batched,
mask)
attn_out = scaled_dot_product_attn(q_proj_batched, k_proj_batched,
v_proj_batched, mask)
## batch, n_heads, Tq, dq
attn_out = attn_out.view(bsize, self.num_heads, -1, int(self.d_model/self.num_heads))
attn_out = attn_out.view(bsize, self.num_heads,
-1, int(self.d_model/self.num_heads))
##batch, Tq, n_heads, dq
attn_out = attn_out.transpose(1, 2)
##batch, Tq, n_heads, h*dq(d_model)
concat_out = attn_out.reshape(bsize, -1, self.d_model)
# concat_out = torch.cat(attns_outs, dim=-1)
mha_out = self.o_layer(concat_out)
return mha_out



return mha_out
```


Expand Down

0 comments on commit 96a7a1e

Please sign in to comment.