-
Notifications
You must be signed in to change notification settings - Fork 257
/
Copy pathgradient_based.py
75 lines (59 loc) · 2.74 KB
/
gradient_based.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from collections import OrderedDict
from torchmeta.modules import MetaModule
def gradient_update_parameters(model,
loss,
params=None,
step_size=0.5,
first_order=False,
allow_unused=False):
"""Update of the meta-parameters with one step of gradient descent on the
loss function.
Parameters
----------
model : `torchmeta.modules.MetaModule` instance
The model.
loss : `torch.Tensor` instance
The value of the inner-loss. This is the result of the training dataset
through the loss function.
params : `collections.OrderedDict` instance, optional
Dictionary containing the meta-parameters of the model. If `None`, then
the values stored in `model.meta_named_parameters()` are used. This is
useful for running multiple steps of gradient descent as the inner-loop.
step_size : int, `torch.Tensor`, or `collections.OrderedDict` instance (default: 0.5)
The step size in the gradient update. If an `OrderedDict`, then the
keys must match the keys in `params`.
first_order : bool (default: `False`)
If `True`, then the first order approximation of MAML is used.
allow_unused : bool (default: `False`)
If `True`, set `allow_unused` to `True` when computing gradients. This
is useful, e.g., when your model has task-specific parameters.
Returns
-------
updated_params : `collections.OrderedDict` instance
Dictionary containing the updated meta-parameters of the model, with one
gradient update wrt. the inner-loss.
"""
if not isinstance(model, MetaModule):
raise ValueError('The model must be an instance of `torchmeta.modules.'
'MetaModule`, got `{0}`'.format(type(model)))
if params is None:
params = OrderedDict(model.meta_named_parameters())
grads = torch.autograd.grad(loss,
params.values(),
create_graph=not first_order,
allow_unused=allow_unused)
updated_params = OrderedDict()
if isinstance(step_size, (dict, OrderedDict)):
for (name, param), grad in zip(params.items(), grads):
if grad is not None:
updated_params[name] = param - step_size[name] * grad
else:
updated_params[name] = param
else:
for (name, param), grad in zip(params.items(), grads):
if grad is not None:
updated_params[name] = param - step_size * grad
else:
updated_params[name] = param
return updated_params