diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 56359646b1..47483c67ea 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -467,7 +467,7 @@ def sharded_impl(x, amax, scale, scale_inv): local_x, local_amax = ActLuFp8Primitive.impl( x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_x, global_updated_amax diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 76ccec363b..0cbf847dcd 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1011,7 +1011,7 @@ def sharded_impl( ) global_dbias = local_dbias if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) return local_dq, local_dk, local_dv, global_dbias return mesh, sharded_impl, out_shardings, arg_shardings diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index caf9272b02..e85f28a06a 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -533,8 +533,8 @@ def sharded_impl(dz, x, mu, rsigma, gamma): local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl( dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) - global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) - global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta) + global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) + global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh) return local_dx, global_dgamma, global_dbeta return mesh, sharded_impl, out_shardings, arg_shardings @@ -935,7 +935,7 @@ def partition(epsilon, mesh, arg_infos, result_infos): def sharded_impl(dz, x, rsigma, gamma): local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon) - global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) + global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) return local_dx, global_dgamma return mesh, sharded_impl, out_shardings, arg_shardings @@ -1228,7 +1228,7 @@ def sharded_impl(x, gamma, beta, amax, scale, scale_inv): zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_x, local_mu, local_rsigma, global_updated_amax @@ -1481,7 +1481,7 @@ def sharded_impl(x, gamma, amax, scale, scale_inv): local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl( x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_x, local_rsigma, global_updated_amax diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 2c529e71c8..48bf4d969a 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -157,7 +157,7 @@ def sharded_impl(x, amax, scale, scale_inv): local_cx, local_updated_amax = CastFP8Primitive.impl( x, amax, scale, scale_inv, out_dtype=out_dtype ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh) return local_cx, global_updated_amax diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index e503792dc0..963d7f09e8 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -390,7 +390,7 @@ def sharded_impl(x, amax, scale, scale_inv): static_axis_boundary=static_axis_boundary, transpose_axis_boundary=transpose_axis_boundary, ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh) return local_cx, local_cxt, global_updated_amax @@ -646,8 +646,8 @@ def sharded_impl(dz, amax, scale, scale_inv): static_axis_boundary=static_axis_boundary, transpose_axis_boundary=transpose_axis_boundary, ) - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_out, local_t_out, global_dbias, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings @@ -981,8 +981,8 @@ def sharded_impl(dz, x, amax, scale, scale_inv): act_enum=act_enum, ) ) - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_out, local_t_out, global_dbias, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings @@ -1225,7 +1225,7 @@ def sharded_impl(dz, x, amax, scale, scale_inv): static_axis_boundary=static_axis_boundary, act_enum=act_enum, ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_out, local_t_out, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index c0b60fe61e..586e1a70c9 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -30,8 +30,7 @@ W_JOINED_AXES = "nvte_w_joined" -def _get_mesh_info(resource: str): - mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh +def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}." return mesh.shape[resource], resource @@ -132,12 +131,12 @@ def get_padded_spec(spec, ndim): return spec + (None,) * (ndim - len(spec)) -def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str): +def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh): """ A wrapper function to invoke lax.p* operations, like psum. """ if mesh_resource is not None: - _, resource = _get_mesh_info(mesh_resource) + _, resource = _get_mesh_info(mesh_resource, mesh) return ops(x, resource) return x @@ -201,22 +200,22 @@ def global_mesh_resource() -> MeshResource: return _GLOBAL_MESH_RESOURCE -def all_reduce_sum_along_dp_fsdp(x: jnp.array): +def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh): """ All-Reduce (Sum) along DP and FSDP mesh axes. """ - x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource) - return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource) + x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh) + return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) -def all_reduce_max_along_all_axes_except_PP(x: jnp.array): +def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh): """ All-Reduce (Max) along all mesh axes. """ all_axes = get_all_mesh_axes() for axis in all_axes: if axis != global_mesh_resource().pp_resource: - x = lax_paral_op(x, jax.lax.pmax, axis) + x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x