Skip to content

Commit

Permalink
turning off compile on loss function
Browse files Browse the repository at this point in the history
ghstack-source-id: 7d85c79184056bbc12e37d628796065711f2045f
Pull Request resolved: #755
  • Loading branch information
tianyu-l committed Dec 19, 2024
1 parent 6274377 commit f4bb54a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def loss_fn(pred, labels):
pred.flatten(0, 1).float(), labels.flatten(0, 1)
)

if job_config.training.compile:
loss_fn = torch.compile(loss_fn)
# TODO: compiling loss function causes CUDA errors, turning off for now
# if job_config.training.compile:
# loss_fn = torch.compile(loss_fn)

# move sharded model to CPU/GPU and initialize weights via DTensor
if job_config.checkpoint.create_seed_checkpoint:
Expand Down

0 comments on commit f4bb54a

Please sign in to comment.