Skip to content

Commit

Permalink
Merge pull request EleutherAI#996 from mariebiscuit/main
Browse files Browse the repository at this point in the history
Replaced all torch.concat with torch.cat
  • Loading branch information
StellaAthena authored Jul 18, 2023
2 parents 303d7be + edfa39b commit 408e29d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
)

# Combined k/v into [b * sk, 2, np, hn].
kv = torch.concat([key_layer, value_layer], dim=1)
kv = torch.cat([key_layer, value_layer], dim=1)

output = self.flash_kv_fn(
query_layer,
Expand All @@ -553,7 +553,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
)

# Combined q/k/v into [b * s, 3, np, hn].
qkv = torch.concat([query_layer, key_layer, value_layer], dim=1)
qkv = torch.cat([query_layer, key_layer, value_layer], dim=1)

output = self.flash_qkv_fn(
qkv,
Expand Down

0 comments on commit 408e29d

Please sign in to comment.