Skip to content

XLA_DISABLE_FUNCTIONALIZATION=0 with ZeRO-1 diverges for Mistral on NxD #26

Open
@michaelbenayoun

Description

@michaelbenayoun

It seems that the loss is not converging or that we OOM depending on the XLA_DISABLE_FUNCTIONALIZATION flag and ZeRO-1.

System info

aws-neuronx-runtime-discovery==2.9
libneuronxla==2.0.2335
neuronx-cc==2.14.213.0+013d129b
neuronx-distributed==0.8.0
torch==2.1.2
torch-neuronx==2.1.2.2.2.0
torch-xla==2.1.3
torchvision==0.16.2

I ran the same training job with 4 settings: XLA_DISABLE_FUNCTIONALIZATION = 0 | 1 and ZeRO-1 enabled / disabled:

XLA_DISABLE_FUNCTIONALIZATION=0 and ZeRO-1

In this case the loss is diverging.

Capture d’écran 2024-07-17 à 15 45 51

Note: Since I am using Optimum Neuron, I am not sure if this is my integration of the ZeroRedundancyOptimizer or if it is an actual bug on your end and / or torch_xla.

XLA_DISABLE_FUNCTIONALIZATION=1 and ZeRO-1

In this case the loss diverges to inf.

Capture d’écran 2024-07-17 à 15 36 27

XLA_DISABLE_FUNCTIONALIZATION=0 and regular optimizer

In this case we OOM.

XLA_DISABLE_FUNCTIONALIZATION=1 and regular optimizer

The loss converges.

Capture d’écran 2024-07-17 à 15 15 19

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions