Skip to content

Commit

Permalink
Stop using global mesh for custom_partitioning. (#1112)
Browse files Browse the repository at this point in the history
Signed-off-by: Frederic Bastien <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
nouiz and phu0ngng committed Aug 19, 2024
1 parent 350a4ff commit ee541e8
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 23 deletions.
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = ActLuFp8Primitive.impl(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)

return local_x, global_updated_amax

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def sharded_impl(
)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,8 @@ def sharded_impl(dz, x, mu, rsigma, gamma):
local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh)
return local_dx, global_dgamma, global_dbeta

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down Expand Up @@ -935,7 +935,7 @@ def partition(epsilon, mesh, arg_infos, result_infos):

def sharded_impl(dz, x, rsigma, gamma):
local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
return local_dx, global_dgamma

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def sharded_impl(x, gamma, beta, amax, scale, scale_inv):
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)

return local_x, local_mu, local_rsigma, global_updated_amax

Expand Down Expand Up @@ -1481,7 +1481,7 @@ def sharded_impl(x, gamma, amax, scale, scale_inv):
local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)

return local_x, local_rsigma, global_updated_amax

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def sharded_impl(x, amax, scale, scale_inv):
local_cx, local_updated_amax = CastFP8Primitive.impl(
x, amax, scale, scale_inv, out_dtype=out_dtype
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)

return local_cx, global_updated_amax

Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/jax/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def sharded_impl(x, amax, scale, scale_inv):
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)

return local_cx, local_cxt, global_updated_amax

Expand Down Expand Up @@ -646,8 +646,8 @@ def sharded_impl(dz, amax, scale, scale_inv):
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down Expand Up @@ -981,8 +981,8 @@ def sharded_impl(dz, x, amax, scale, scale_inv):
act_enum=act_enum,
)
)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def sharded_impl(dz, x, amax, scale, scale_inv):
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_updated_amax

return mesh, sharded_impl, out_shardings, arg_shardings
Expand Down
17 changes: 8 additions & 9 deletions transformer_engine/jax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
W_JOINED_AXES = "nvte_w_joined"


def _get_mesh_info(resource: str):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
return mesh.shape[resource], resource

Expand Down Expand Up @@ -132,12 +131,12 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec))


def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh):
"""
A wrapper function to invoke lax.p* operations, like psum.
"""
if mesh_resource is not None:
_, resource = _get_mesh_info(mesh_resource)
_, resource = _get_mesh_info(mesh_resource, mesh)
return ops(x, resource)
return x

Expand Down Expand Up @@ -201,22 +200,22 @@ def global_mesh_resource() -> MeshResource:
return _GLOBAL_MESH_RESOURCE


def all_reduce_sum_along_dp_fsdp(x: jnp.array):
def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
"""
All-Reduce (Sum) along DP and FSDP mesh axes.
"""
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource)
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)


def all_reduce_max_along_all_axes_except_PP(x: jnp.array):
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
"""
All-Reduce (Max) along all mesh axes.
"""
all_axes = get_all_mesh_axes()
for axis in all_axes:
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis)
x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x


Expand Down

0 comments on commit ee541e8

Please sign in to comment.