diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index c70c03280e..54a5327f08 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1061,7 +1061,7 @@ def partition(config, mesh, arg_infos, result_infos): # Call base implementation for non-context parallel mesh to avoid unecessary work. is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 assert ( - is_context_parallel and config.window_size[0] == -1 + not is_context_parallel or config.window_size[0] == -1 ), "Sliding window attention is not supported when context parallelism is enabled" if not is_context_parallel: return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) @@ -1158,7 +1158,7 @@ def partition(config, mesh, arg_infos, result_infos): # Call base implementation for non-context parallel mesh to avoid unecessary work. is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 assert ( - is_context_parallel and config.window_size[0] == -1 + not is_context_parallel or config.window_size[0] == -1 ), "Sliding window attention is not supported when context parallelism is enabled" if not is_context_parallel: return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)