7x7 nnx.Conv
using float32
parameter dtype overflows(?) to nan
when sharded
#24848
Labels
bug
Something isn't working
Description
A sufficiently-large (7x7)
nnx.Conv
withfloat32
parameter dtype, when sharded across multiple devices, generatesnan
, seemingly due to overflow.The
nan
is avoided by making any one of the following changes:param_dtype='float64'
insidewith jax.experimental.enable_x64()
float32 is sufficient for 7x7 convolution when not sharded, suggesting that it ought to work when sharded as well.
The issue can be worked around but all of the workarounds are unsatisfactory in some way, either by reducing the size of the convolution, or requiring double the memory usage, or restricting training to a single device.
Minimal example. This code works because it's moved to float64:
This version fails due to use of sharding plus float32 and a 7x7 convolution:
This version works due to using a smaller convolution, still with float32:
This version works due to running on a single device, but with float32 and a 7x7 convolution:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: