Skip to content

Commit 110258b

Browse files
committed
llama4 distributed: scale_grads with foreach
ghstack-source-id: 8072d3f Pull Request resolved: #2624
1 parent 173a6fe commit 110258b

File tree

4 files changed

+92
-3
lines changed

4 files changed

+92
-3
lines changed

recipes/configs/llama4/scout_17B_16E_full.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ compile: False
7777
# model: True
7878
# loss: True
7979
# optimizer_step: False
80+
# scale_grads: True
8081

8182
# Reduced precision
8283
dtype: bf16

recipes/full_finetune_distributed.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,19 @@ def setup(self, cfg: DictConfig) -> None:
313313
self._compile_model = compile_bool
314314
self._compile_loss = compile_bool
315315
self._compile_optimizer_step = compile_bool
316+
self._compile_scale_grads = compile_bool
316317
if isinstance(compile, DictConfig):
317318
self._compile_model = compile.get("model", True)
318319
self._compile_loss = compile.get("loss", True)
319320
self._compile_optimizer_step = compile.get("optimizer_step", False)
321+
self._compile_scale_grads = compile.get("scale_grads", True)
322+
323+
# This indirection is needed to apply torch.compile to scale_grads step.
324+
self._grad_scaler = training.scale_grads_
325+
if self._compile_scale_grads:
326+
self._grad_scaler = torch.compile(
327+
self._grad_scaler, backend=self._compile_backend
328+
)
320329

321330
self._model = self._setup_model(
322331
cfg_model=cfg.model,
@@ -932,8 +941,12 @@ def train(self) -> None:
932941
torch.distributed.all_reduce(num_tokens)
933942
# This will ensure that the logged loss matches what we're optimizing
934943
torch.distributed.all_reduce(running_loss)
944+
935945
# Manually scale the gradients from unnormalized loss by total # of tokens
936-
training.scale_grads(self._model, self.dp_degree / num_tokens)
946+
self._grad_scaler(
947+
self._model.parameters(), self.dp_degree / num_tokens
948+
)
949+
937950
if self._clip_grad_norm is not None:
938951
grad_norm = torch.nn.utils.clip_grad_norm_(
939952
self._model.parameters(),

torchtune/training/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
shard_model,
2626
validate_no_params_on_meta_device,
2727
)
28-
from torchtune.training._grad_scaler import scale_grads
28+
from torchtune.training._grad_scaler import scale_grads, scale_grads_
2929
from torchtune.training._model_util import disable_dropout
3030
from torchtune.training._profiler import (
3131
DEFAULT_PROFILE_DIR,
@@ -139,6 +139,7 @@
139139
"OffloadActivations",
140140
"FormattedCheckpointFiles",
141141
"scale_grads",
142+
"scale_grads_",
142143
"get_distributed_backend",
143144
"disable_dropout",
144145
"DATALOADER_KEY",

torchtune/training/_grad_scaler.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from collections import defaultdict
8+
from typing import Optional
9+
710
import torch
8-
from torch import nn
11+
from torch import nn, Tensor
12+
from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors
13+
from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support
14+
from torchtune.utils._logging import deprecated
915

1016

17+
@deprecated(msg="Please use `scale_grads_` instead.")
1118
def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
1219
"""
1320
Utility to scale the gradients of a model.
@@ -29,3 +36,70 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
2936
scaler = scaler.to(device)
3037
if p.grad is not None:
3138
p.grad *= scaler
39+
40+
41+
@_no_grad
42+
def scale_grads_(
43+
parameters: _tensor_or_tensors,
44+
scaler: torch.Tensor,
45+
foreach: Optional[bool] = None,
46+
) -> None:
47+
r"""Scale gradients of iterable parameters.
48+
49+
This function is equivalent to :func:`torch.mul_` applied to each parameter.
50+
Gradients are modified in-place, multiplying by specified scaler.
51+
52+
Args:
53+
parameters (_tensor_or_tensors): an iterable of Tensors or a
54+
single Tensor that will have gradients scaled
55+
scaler (torch.Tensor): multiplier to scale gradients
56+
foreach (Optional[bool]): use the faster foreach-based implementation.
57+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
58+
fall back to the slow implementation for other device types.
59+
Default: ``None``
60+
Returns:
61+
None
62+
"""
63+
if isinstance(parameters, torch.Tensor):
64+
parameters = [parameters]
65+
else:
66+
parameters = list(parameters)
67+
_scale_grad_(parameters, scaler, foreach)
68+
69+
70+
def _group_tensors_by_device(
71+
tensors: list[torch.Tensor],
72+
) -> dict[torch.device, list[Tensor]]:
73+
ret = defaultdict(list)
74+
for i, tensor in enumerate(tensors):
75+
ret[tensor.device].append(tensor)
76+
77+
return ret
78+
79+
80+
@_no_grad
81+
def _scale_grad_(
82+
parameters: _tensor_or_tensors,
83+
scaler: torch.Tensor,
84+
foreach: Optional[bool] = None,
85+
) -> None:
86+
if isinstance(parameters, torch.Tensor):
87+
parameters = [parameters]
88+
grads = [p.grad for p in parameters if p.grad is not None]
89+
if len(grads) == 0:
90+
return
91+
grouped_grads = _group_tensors_by_device(grads)
92+
93+
for device, device_grads in grouped_grads.items():
94+
if (foreach is None and _has_foreach_support(device_grads, device)) or (
95+
foreach and _device_has_foreach_support(device)
96+
):
97+
torch._foreach_mul_(device_grads, scaler.to(device))
98+
elif foreach:
99+
raise RuntimeError(
100+
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
101+
)
102+
else:
103+
scaler_device = scaler.to(device)
104+
for g in device_grads:
105+
g.mul_(scaler_device)

0 commit comments

Comments
 (0)