diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 45445d4b0..4e81b70b6 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -531,7 +531,7 @@ def flash_attention(self, query_layer, key_layer, value_layer): ) # Combined k/v into [b * sk, 2, np, hn]. - kv = torch.concat([key_layer, value_layer], dim=1) + kv = torch.cat([key_layer, value_layer], dim=1) output = self.flash_kv_fn( query_layer, @@ -553,7 +553,7 @@ def flash_attention(self, query_layer, key_layer, value_layer): ) # Combined q/k/v into [b * s, 3, np, hn]. - qkv = torch.concat([query_layer, key_layer, value_layer], dim=1) + qkv = torch.cat([query_layer, key_layer, value_layer], dim=1) output = self.flash_qkv_fn( qkv,