Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop using global mesh for custom_partitioning. #1112

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = ActLuFp8Primitive.impl(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)

return local_x, global_updated_amax

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def sharded_impl(
)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,8 @@ def sharded_impl(dz, x, mu, rsigma, gamma):
local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh)
return local_dx, global_dgamma, global_dbeta

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down Expand Up @@ -935,7 +935,7 @@ def partition(epsilon, mesh, arg_infos, result_infos):

def sharded_impl(dz, x, rsigma, gamma):
local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
return local_dx, global_dgamma

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def sharded_impl(x, gamma, beta, amax, scale, scale_inv):
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)

return local_x, local_mu, local_rsigma, global_updated_amax

Expand Down Expand Up @@ -1481,7 +1481,7 @@ def sharded_impl(x, gamma, amax, scale, scale_inv):
local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)

return local_x, local_rsigma, global_updated_amax

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def sharded_impl(x, amax, scale, scale_inv):
local_cx, local_updated_amax = CastFP8Primitive.impl(
x, amax, scale, scale_inv, out_dtype=out_dtype
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)

return local_cx, global_updated_amax

Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/jax/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def sharded_impl(x, amax, scale, scale_inv):
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)

return local_cx, local_cxt, global_updated_amax

Expand Down Expand Up @@ -646,8 +646,8 @@ def sharded_impl(dz, amax, scale, scale_inv):
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down Expand Up @@ -981,8 +981,8 @@ def sharded_impl(dz, x, amax, scale, scale_inv):
act_enum=act_enum,
)
)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def sharded_impl(dz, x, amax, scale, scale_inv):
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_updated_amax

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down
17 changes: 8 additions & 9 deletions transformer_engine/jax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
W_JOINED_AXES = "nvte_w_joined"


def _get_mesh_info(resource: str):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
return mesh.shape[resource], resource

Expand Down Expand Up @@ -132,12 +131,12 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec))


def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh):
"""
A wrapper function to invoke lax.p* operations, like psum.
"""
if mesh_resource is not None:
_, resource = _get_mesh_info(mesh_resource)
_, resource = _get_mesh_info(mesh_resource, mesh)
return ops(x, resource)
return x

Expand Down Expand Up @@ -201,22 +200,22 @@ def global_mesh_resource() -> MeshResource:
return _GLOBAL_MESH_RESOURCE


def all_reduce_sum_along_dp_fsdp(x: jnp.array):
def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
"""
All-Reduce (Sum) along DP and FSDP mesh axes.
"""
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource)
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)


def all_reduce_max_along_all_axes_except_PP(x: jnp.array):
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
"""
All-Reduce (Max) along all mesh axes.
"""
all_axes = get_all_mesh_axes()
for axis in all_axes:
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis)
x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x


Expand Down
Loading