From 992034e58cd0494d2e83fa024b4d9f774e10e623 Mon Sep 17 00:00:00 2001 From: rees-c Date: Tue, 7 Dec 2021 14:17:51 -0600 Subject: [PATCH] - added `allow_unused` arg to `gradient_update_parameters` for models with task-specific parameters --- torchmeta/utils/gradient_based.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/torchmeta/utils/gradient_based.py b/torchmeta/utils/gradient_based.py index fe855cb..9e41dd8 100644 --- a/torchmeta/utils/gradient_based.py +++ b/torchmeta/utils/gradient_based.py @@ -8,7 +8,8 @@ def gradient_update_parameters(model, loss, params=None, step_size=0.5, - first_order=False): + first_order=False, + allow_unused=False): """Update of the meta-parameters with one step of gradient descent on the loss function. @@ -33,6 +34,10 @@ def gradient_update_parameters(model, first_order : bool (default: `False`) If `True`, then the first order approximation of MAML is used. + allow_unused : bool (default: `False`) + If `True`, set `allow_unused` to `True` when computing gradients. This + is useful, e.g., when your model has task-specific parameters. + Returns ------- updated_params : `collections.OrderedDict` instance @@ -48,16 +53,23 @@ def gradient_update_parameters(model, grads = torch.autograd.grad(loss, params.values(), - create_graph=not first_order) + create_graph=not first_order, + allow_unused=allow_unused) updated_params = OrderedDict() if isinstance(step_size, (dict, OrderedDict)): for (name, param), grad in zip(params.items(), grads): - updated_params[name] = param - step_size[name] * grad + if grad is not None: + updated_params[name] = param - step_size[name] * grad + else: + updated_params[name] = param else: for (name, param), grad in zip(params.items(), grads): - updated_params[name] = param - step_size * grad + if grad is not None: + updated_params[name] = param - step_size * grad + else: + updated_params[name] = param return updated_params