From 43b9e1ee5e59fc87dab1f505bd2253b3744655ae Mon Sep 17 00:00:00 2001 From: Md Fahim Faysal Khan Date: Tue, 15 Oct 2024 18:00:21 -0700 Subject: [PATCH] fix assertion bug for SWA API in TE-JAX (#1242) fixed assertion bug for SWA Signed-off-by: Md Fahim Faysal Khan Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- transformer_engine/jax/cpp_extensions/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index c70c03280e..54a5327f08 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1061,7 +1061,7 @@ def partition(config, mesh, arg_infos, result_infos): # Call base implementation for non-context parallel mesh to avoid unecessary work. is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 assert ( - is_context_parallel and config.window_size[0] == -1 + not is_context_parallel or config.window_size[0] == -1 ), "Sliding window attention is not supported when context parallelism is enabled" if not is_context_parallel: return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) @@ -1158,7 +1158,7 @@ def partition(config, mesh, arg_infos, result_infos): # Call base implementation for non-context parallel mesh to avoid unecessary work. is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 assert ( - is_context_parallel and config.window_size[0] == -1 + not is_context_parallel or config.window_size[0] == -1 ), "Sliding window attention is not supported when context parallelism is enabled" if not is_context_parallel: return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)