Skip to content

Commit

Permalink
fix assertion bug for SWA API in TE-JAX (#1242)
Browse files Browse the repository at this point in the history
fixed assertion bug for SWA

Signed-off-by: Md Fahim Faysal Khan <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
3 people authored Oct 16, 2024
1 parent f6b766b commit 43b9e1e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
is_context_parallel and config.window_size[0] == -1
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
Expand Down Expand Up @@ -1158,7 +1158,7 @@ def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
is_context_parallel and config.window_size[0] == -1
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
Expand Down

0 comments on commit 43b9e1e

Please sign in to comment.