From 0ee5ccda2eacf96fbc9c7ef1f7c084a71e0df7a6 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 19 Sep 2024 08:55:48 +0800 Subject: [PATCH] [PyTorch] Relax the contiguous check for flash attention (#1176) * relax contiguous check for flash attention Signed-off-by: Xin Yao * force contiguous for cp Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao --- transformer_engine/pytorch/attention.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 1e33819e9f..192f430ae1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4881,20 +4881,19 @@ def forward( ) else: query_layer, key_layer, value_layer = [ - x.transpose(0, 1).contiguous() - for x in (query_layer, key_layer, value_layer) + x.transpose(0, 1) for x in (query_layer, key_layer, value_layer) ] - elif qkv_format in ["bshd", "thd"]: + if context_parallel: query_layer, key_layer, value_layer = [ x.contiguous() for x in (query_layer, key_layer, value_layer) ] else: if qkv_format == "sbhd": query_layer._data, key_layer._data, value_layer._data = [ - x.transpose(0, 1).contiguous() + x.transpose(0, 1) for x in (query_layer._data, key_layer._data, value_layer._data) ] - elif qkv_format in ["bshd", "thd"]: + if context_parallel: query_layer._data, key_layer._data, value_layer._data = [ x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] @@ -5092,11 +5091,7 @@ def forward( output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d() output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) else: - output = ( - output.view(batch_size, max_seqlen_q // cp_size, -1) - .transpose(0, 1) - .contiguous() - ) + output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) elif qkv_format == "bshd": # (bs)hd -> bs(hd) output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) @@ -5104,7 +5099,7 @@ def forward( # thd -> t(hd) output = output.reshape(output.shape[0], -1) - return output + return output.contiguous() def _combine_tensors(