diff --git a/train.py b/train.py index 3b157ad1..8dbe80f5 100644 --- a/train.py +++ b/train.py @@ -130,8 +130,9 @@ def loss_fn(pred, labels): pred.flatten(0, 1).float(), labels.flatten(0, 1) ) - if job_config.training.compile: - loss_fn = torch.compile(loss_fn) + # TODO: compiling loss function causes CUDA errors, turning off for now + # if job_config.training.compile: + # loss_fn = torch.compile(loss_fn) # move sharded model to CPU/GPU and initialize weights via DTensor if job_config.checkpoint.create_seed_checkpoint: