Skip to content

Commit

Permalink
Run init_weights under no_grad (#747)
Browse files Browse the repository at this point in the history
The initializations in `init_weights` can create gradients. This is
almost always not intended

The alternative would be to add the decorator to `model.init_weights`
directly. Then the responsibility is moved to the model writer. I can
change that if that's preferred.
  • Loading branch information
carmocca authored Dec 17, 2024
1 parent 4f7f883 commit 5ce8a0c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,15 @@ def loss_fn(pred, labels):
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
m.to_empty(device=init_device)
m.init_weights(buffer_device=buffer_device)
with torch.no_grad():
m.init_weights(buffer_device=buffer_device)
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
model.to_empty(device=init_device)
model.init_weights(buffer_device=buffer_device)
with torch.no_grad():
model.init_weights(buffer_device=buffer_device)
model.train()

model_parts = [model]
Expand Down

0 comments on commit 5ce8a0c

Please sign in to comment.