Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gradient computation in multiple forward passes #4

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,23 @@ Extract useful quantities from PyTorch autograd

## Per-example gradients

```
```python
autograd_hacks.add_hooks(model)
output = model(data)
loss_fn(output, targets).backward()
autograd_hacks.compute_grad1()

# param.grad: gradient averaged over the batch
# param.grad1[i]: gradient with respect to example i
# param.grad1[0][i]: gradient with respect to example i

for param in model.parameters():
assert(torch.allclose(param.grad1.mean(dim=0), param.grad))
assert(torch.allclose(param.grad1[0].mean(dim=0), param.grad))
```


## Hessians
(assuming ReLU activations, oherwise produces Gauss-Newton matrix)

```
```python
autograd_hacks.backprop_hess(model(data), hess_type='CrossEntropy')
autograd_hacks.compute_hess(model)
print(param.hess) # print Hessian of param
Expand Down
7 changes: 7 additions & 0 deletions autograd_hacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

from .autograd_hacks import (
add_hooks, remove_hooks,
disable_hooks, enable_hooks,
compute_grad1, compute_hess,
backprop_hess
)
135 changes: 86 additions & 49 deletions autograd_hacks.py → autograd_hacks/autograd_hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
import torch.nn as nn
import torch.nn.functional as F

_supported_layers = ['Linear', 'Conv2d'] # Supported layer class types
_hooks_disabled: bool = False # work-around for https://github.com/pytorch/pytorch/issues/25723
_enforce_fresh_backprop: bool = False # global switch to catch double backprop errors on Hessian computation
# Supported layer class types
_supported_layers = ['Linear', 'Conv2d']
# work-around for https://github.com/pytorch/pytorch/issues/25723
_hooks_disabled: bool = False
# global switch to catch double backprop errors on Hessian computation
_enforce_fresh_backprop: bool = False
_enforce_fresh_activation: bool = False


def add_hooks(model: nn.Module) -> None:
Expand Down Expand Up @@ -101,17 +105,26 @@ def _layer_type(layer: nn.Module) -> str:

def _capture_activations(layer: nn.Module, input: List[torch.Tensor], output: torch.Tensor):
"""Save activations into layer.activations in forward pass"""

global _enforce_fresh_activation
if _hooks_disabled:
return
assert _layer_type(layer) in _supported_layers, "Hook installed on unsupported layer, this shouldn't happen"
setattr(layer, "activations", input[0].detach())

if _enforce_fresh_activation:
assert not hasattr(layer, 'activations_list'), """
previous forward pass detected"""
_enforce_fresh_activation = False

if not hasattr(layer, 'activations_list'):
layer.activations_list = []

assert _layer_type(layer) in _supported_layers, """
Hook installed on unsupported layer, this shouldn't happen"""
layer.activations_list.append(input[0].detach())


def _capture_backprops(layer: nn.Module, _input, output):
"""Append backprop to layer.backprops_list in backward pass."""
global _enforce_fresh_backprop

if _hooks_disabled:
return

Expand All @@ -120,7 +133,8 @@ def _capture_backprops(layer: nn.Module, _input, output):
_enforce_fresh_backprop = False

if not hasattr(layer, 'backprops_list'):
setattr(layer, 'backprops_list', [])
layer.backprops_list = []

layer.backprops_list.append(output[0].detach())


Expand All @@ -131,44 +145,73 @@ def clear_backprops(model: nn.Module) -> None:
del layer.backprops_list


def compute_grad1_for_linear(layer, A, B):
weight_grad = torch.einsum('ni,nj->nij', B, A)
append(layer.weight, 'grad1', weight_grad)
if layer.bias is not None:
append(layer.bias, 'grad1', B)


def compute_grad1_for_conv2d(layer, A, B):
n = A.shape[0]
A = torch.nn.functional.unfold(A, layer.kernel_size, layer.dilation,
layer.padding, layer.stride)
B = B.reshape(n, -1, A.shape[-1])
grad1 = torch.einsum('ijk,ilk->ijl', B, A)
shape = [n] + list(layer.weight.shape)
append(layer.weight, 'grad1', grad1.reshape(shape))
if layer.bias is not None:
append(layer.bias, 'grad1', torch.sum(B, dim=2))


def clear_grad1_for_linear_and_conv2d(layer):
if hasattr(layer.weight, 'grad1'):
del layer.weight.grad1
if layer.bias is not None and hasattr(layer.bias, 'grad1'):
del layer.bias.grad1


def append(instance, attrib, item):
if not hasattr(instance, attrib):
setattr(instance, attrib, [])

getattr(instance, attrib).append(item)


compute_grad1_for = {
'Linear': compute_grad1_for_linear,
'Conv2d': compute_grad1_for_conv2d
}

clear_grad_for = {
'Linear': clear_grad1_for_linear_and_conv2d,
'Conv2d': clear_grad1_for_linear_and_conv2d
}


def compute_grad1(model: nn.Module, loss_type: str = 'mean') -> None:
"""
Compute per-example gradients and save them under 'param.grad1'. Must be called after loss.backprop()
"""Compute per-example gradients and save them under 'param.grad1'. Must be
called after loss.backprop()

Args:
model:
loss_type: either "mean" or "sum" depending whether backpropped loss was averaged or summed over batch
loss_type: either "mean" or "sum" depending on the backpropped loss
"""

assert loss_type in ('sum', 'mean')
assert loss_type in ('mean', 'sum')
for layer in model.modules():
layer_type = _layer_type(layer)
if layer_type not in _supported_layers:
continue
assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)"
assert hasattr(layer, 'activations_list'), "No activations detected, run forward after add_hooks(model)"
assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)"
assert len(layer.backprops_list) == 1, "Multiple backprops detected, make sure to call clear_backprops(model)"
assert len(layer.activations_list) == len(layer.backprops_list)

A = layer.activations
n = A.shape[0]
if loss_type == 'mean':
B = layer.backprops_list[0] * n
else: # loss_type == 'sum':
B = layer.backprops_list[0]
clear_grad_for[layer_type](layer)
for A, B in zip(layer.activations_list, layer.backprops_list):
if loss_type == 'mean':
B *= A.shape[0]

if layer_type == 'Linear':
setattr(layer.weight, 'grad1', torch.einsum('ni,nj->nij', B, A))
if layer.bias is not None:
setattr(layer.bias, 'grad1', B)

elif layer_type == 'Conv2d':
A = torch.nn.functional.unfold(A, layer.kernel_size)
B = B.reshape(n, -1, A.shape[-1])
grad1 = torch.einsum('ijk,ilk->ijl', B, A)
shape = [n] + list(layer.weight.shape)
setattr(layer.weight, 'grad1', grad1.reshape(shape))
if layer.bias is not None:
setattr(layer.bias, 'grad1', torch.sum(B, dim=2))
compute_grad1_for[layer_type](layer, A, B)


def compute_hess(model: nn.Module,) -> None:
Expand All @@ -178,30 +221,28 @@ def compute_hess(model: nn.Module,) -> None:
layer_type = _layer_type(layer)
if layer_type not in _supported_layers:
continue
assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)"
assert hasattr(layer, 'activations_list'), "No forward passes detected"
assert len(layer.activations_list) == 1
assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)"

if layer_type == 'Linear':
A = layer.activations
A = layer.activations_list[0]
B = torch.stack(layer.backprops_list)

n = A.shape[0]
o = B.shape[0]

A = torch.stack([A] * o)
Jb = torch.einsum("oni,onj->onij", B, A).reshape(n*o, -1)
H = torch.einsum('ni,nj->ij', Jb, Jb) / n

setattr(layer.weight, 'hess', H)

layer.weight.hess = torch.einsum('ni,nj->ij', Jb, Jb) / n
if layer.bias is not None:
setattr(layer.bias, 'hess', torch.einsum('oni,onj->ij', B, B)/n)

elif layer_type == 'Conv2d':
Kh, Kw = layer.kernel_size
di, do = layer.in_channels, layer.out_channels

A = layer.activations.detach()
A = layer.activations_list[0].detach()
A = torch.nn.functional.unfold(A, (Kh, Kw)) # n, di * Kh * Kw, Oh * Ow
n = A.shape[0]
B = torch.stack([Bt.reshape(n, do, -1) for Bt in layer.backprops_list]) # o, n, do, Oh*Ow
Expand All @@ -214,9 +255,9 @@ def compute_hess(model: nn.Module,) -> None:
Jb_bias = torch.einsum('onij->oni', B)
Hi_bias = torch.einsum('oni,onj->nij', Jb_bias, Jb_bias)

setattr(layer.weight, 'hess', Hi.mean(dim=0))
layer.weight.hess = Hi.mean(dim=0)
if layer.bias is not None:
setattr(layer.bias, 'hess', Hi_bias.mean(dim=0))
layer.bias.hess = Hi_bias.mean(dim=0)


def backprop_hess(output: torch.Tensor, hess_type: str) -> None:
Expand Down Expand Up @@ -245,16 +286,15 @@ def backprop_hess(output: torch.Tensor, hess_type: str) -> None:
outer_prod_part = torch.einsum('ij,ik->ijk', batch, batch)
hess = diag_part - outer_prod_part
assert hess.shape == (n, o, o)

for i in range(n):
hess[i, :, :] = symsqrt(hess[i, :, :])

hess = hess.transpose(0, 1)

elif hess_type == 'LeastSquares':
hess = []
assert len(output.shape) == 2
batch_size, output_size = output.shape

id_mat = torch.eye(output_size)
for out_idx in range(output_size):
hess.append(torch.stack([id_mat[out_idx]] * batch_size))
Expand All @@ -279,7 +319,4 @@ def symsqrt(a, cond=None, return_rank=False, dtype=torch.float32):
u = u[:, above_cutoff]

B = u @ torch.diag(psigma_diag) @ u.t()
if return_rank:
return B, len(psigma_diag)
else:
return B
return (B, len(psigma_diag)) if return_rank else B
Loading