4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from collections import defaultdict
8
+ from typing import Optional
9
+
7
10
import torch
8
- from torch import nn
11
+ from torch import nn , Tensor
12
+ from torch .nn .utils .clip_grad import _no_grad , _tensor_or_tensors
13
+ from torch .utils ._foreach_utils import _device_has_foreach_support , _has_foreach_support
14
+ from torchtune .utils ._logging import deprecated
9
15
10
16
17
+ @deprecated (msg = "Please use `scale_grads_` instead." )
11
18
def scale_grads (model : nn .Module , scaler : torch .Tensor ) -> None :
12
19
"""
13
20
Utility to scale the gradients of a model.
@@ -29,3 +36,70 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
29
36
scaler = scaler .to (device )
30
37
if p .grad is not None :
31
38
p .grad *= scaler
39
+
40
+
41
+ @_no_grad
42
+ def scale_grads_ (
43
+ parameters : _tensor_or_tensors ,
44
+ scaler : torch .Tensor ,
45
+ foreach : Optional [bool ] = None ,
46
+ ) -> None :
47
+ r"""Scale gradients of iterable parameters.
48
+
49
+ This function is equivalent to :func:`torch.mul_` applied to each parameter.
50
+ Gradients are modified in-place, multiplying by specified scaler.
51
+
52
+ Args:
53
+ parameters (_tensor_or_tensors): an iterable of Tensors or a
54
+ single Tensor that will have gradients scaled
55
+ scaler (torch.Tensor): multiplier to scale gradients
56
+ foreach (Optional[bool]): use the faster foreach-based implementation.
57
+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
58
+ fall back to the slow implementation for other device types.
59
+ Default: ``None``
60
+ Returns:
61
+ None
62
+ """
63
+ if isinstance (parameters , torch .Tensor ):
64
+ parameters = [parameters ]
65
+ else :
66
+ parameters = list (parameters )
67
+ _scale_grad_ (parameters , scaler , foreach )
68
+
69
+
70
+ def _group_tensors_by_device (
71
+ tensors : list [torch .Tensor ],
72
+ ) -> dict [torch .device , list [Tensor ]]:
73
+ ret = defaultdict (list )
74
+ for i , tensor in enumerate (tensors ):
75
+ ret [tensor .device ].append (tensor )
76
+
77
+ return ret
78
+
79
+
80
+ @_no_grad
81
+ def _scale_grad_ (
82
+ parameters : _tensor_or_tensors ,
83
+ scaler : torch .Tensor ,
84
+ foreach : Optional [bool ] = None ,
85
+ ) -> None :
86
+ if isinstance (parameters , torch .Tensor ):
87
+ parameters = [parameters ]
88
+ grads = [p .grad for p in parameters if p .grad is not None ]
89
+ if len (grads ) == 0 :
90
+ return
91
+ grouped_grads = _group_tensors_by_device (grads )
92
+
93
+ for device , device_grads in grouped_grads .items ():
94
+ if (foreach is None and _has_foreach_support (device_grads , device )) or (
95
+ foreach and _device_has_foreach_support (device )
96
+ ):
97
+ torch ._foreach_mul_ (device_grads , scaler .to (device ))
98
+ elif foreach :
99
+ raise RuntimeError (
100
+ f"foreach=True was passed, but can't use the foreach API on { device .type } tensors"
101
+ )
102
+ else :
103
+ scaler_device = scaler .to (device )
104
+ for g in device_grads :
105
+ g .mul_ (scaler_device )
0 commit comments