Open
Description
It seems that the loss is not converging or that we OOM depending on the XLA_DISABLE_FUNCTIONALIZATION
flag and ZeRO-1.
System info
aws-neuronx-runtime-discovery==2.9
libneuronxla==2.0.2335
neuronx-cc==2.14.213.0+013d129b
neuronx-distributed==0.8.0
torch==2.1.2
torch-neuronx==2.1.2.2.2.0
torch-xla==2.1.3
torchvision==0.16.2
I ran the same training job with 4 settings: XLA_DISABLE_FUNCTIONALIZATION = 0 | 1
and ZeRO-1 enabled / disabled:
XLA_DISABLE_FUNCTIONALIZATION=0
and ZeRO-1
In this case the loss is diverging.

Note: Since I am using Optimum Neuron, I am not sure if this is my integration of the ZeroRedundancyOptimizer or if it is an actual bug on your end and / or torch_xla
.
XLA_DISABLE_FUNCTIONALIZATION=1
and ZeRO-1
In this case the loss diverges to inf
.

XLA_DISABLE_FUNCTIONALIZATION=0
and regular optimizer
In this case we OOM.
XLA_DISABLE_FUNCTIONALIZATION=1
and regular optimizer
The loss converges.
