diff --git a/README.md b/README.md index dd274f7..8b6786a 100644 --- a/README.md +++ b/README.md @@ -4,24 +4,23 @@ Extract useful quantities from PyTorch autograd ## Per-example gradients -``` +```python autograd_hacks.add_hooks(model) output = model(data) loss_fn(output, targets).backward() autograd_hacks.compute_grad1() # param.grad: gradient averaged over the batch -# param.grad1[i]: gradient with respect to example i +# param.grad1[0][i]: gradient with respect to example i for param in model.parameters(): - assert(torch.allclose(param.grad1.mean(dim=0), param.grad)) + assert(torch.allclose(param.grad1[0].mean(dim=0), param.grad)) ``` - ## Hessians (assuming ReLU activations, oherwise produces Gauss-Newton matrix) -``` +```python autograd_hacks.backprop_hess(model(data), hess_type='CrossEntropy') autograd_hacks.compute_hess(model) print(param.hess) # print Hessian of param diff --git a/autograd_hacks/__init__.py b/autograd_hacks/__init__.py new file mode 100644 index 0000000..4a5180e --- /dev/null +++ b/autograd_hacks/__init__.py @@ -0,0 +1,7 @@ + +from .autograd_hacks import ( + add_hooks, remove_hooks, + disable_hooks, enable_hooks, + compute_grad1, compute_hess, + backprop_hess +) diff --git a/autograd_hacks.py b/autograd_hacks/autograd_hacks.py similarity index 69% rename from autograd_hacks.py rename to autograd_hacks/autograd_hacks.py index 83fb4cc..533f800 100644 --- a/autograd_hacks.py +++ b/autograd_hacks/autograd_hacks.py @@ -27,9 +27,13 @@ import torch.nn as nn import torch.nn.functional as F -_supported_layers = ['Linear', 'Conv2d'] # Supported layer class types -_hooks_disabled: bool = False # work-around for https://github.com/pytorch/pytorch/issues/25723 -_enforce_fresh_backprop: bool = False # global switch to catch double backprop errors on Hessian computation +# Supported layer class types +_supported_layers = ['Linear', 'Conv2d'] +# work-around for https://github.com/pytorch/pytorch/issues/25723 +_hooks_disabled: bool = False +# global switch to catch double backprop errors on Hessian computation +_enforce_fresh_backprop: bool = False +_enforce_fresh_activation: bool = False def add_hooks(model: nn.Module) -> None: @@ -101,17 +105,26 @@ def _layer_type(layer: nn.Module) -> str: def _capture_activations(layer: nn.Module, input: List[torch.Tensor], output: torch.Tensor): """Save activations into layer.activations in forward pass""" - + global _enforce_fresh_activation if _hooks_disabled: return - assert _layer_type(layer) in _supported_layers, "Hook installed on unsupported layer, this shouldn't happen" - setattr(layer, "activations", input[0].detach()) + + if _enforce_fresh_activation: + assert not hasattr(layer, 'activations_list'), """ + previous forward pass detected""" + _enforce_fresh_activation = False + + if not hasattr(layer, 'activations_list'): + layer.activations_list = [] + + assert _layer_type(layer) in _supported_layers, """ + Hook installed on unsupported layer, this shouldn't happen""" + layer.activations_list.append(input[0].detach()) def _capture_backprops(layer: nn.Module, _input, output): """Append backprop to layer.backprops_list in backward pass.""" global _enforce_fresh_backprop - if _hooks_disabled: return @@ -120,7 +133,8 @@ def _capture_backprops(layer: nn.Module, _input, output): _enforce_fresh_backprop = False if not hasattr(layer, 'backprops_list'): - setattr(layer, 'backprops_list', []) + layer.backprops_list = [] + layer.backprops_list.append(output[0].detach()) @@ -131,44 +145,73 @@ def clear_backprops(model: nn.Module) -> None: del layer.backprops_list +def compute_grad1_for_linear(layer, A, B): + weight_grad = torch.einsum('ni,nj->nij', B, A) + append(layer.weight, 'grad1', weight_grad) + if layer.bias is not None: + append(layer.bias, 'grad1', B) + + +def compute_grad1_for_conv2d(layer, A, B): + n = A.shape[0] + A = torch.nn.functional.unfold(A, layer.kernel_size, layer.dilation, + layer.padding, layer.stride) + B = B.reshape(n, -1, A.shape[-1]) + grad1 = torch.einsum('ijk,ilk->ijl', B, A) + shape = [n] + list(layer.weight.shape) + append(layer.weight, 'grad1', grad1.reshape(shape)) + if layer.bias is not None: + append(layer.bias, 'grad1', torch.sum(B, dim=2)) + + +def clear_grad1_for_linear_and_conv2d(layer): + if hasattr(layer.weight, 'grad1'): + del layer.weight.grad1 + if layer.bias is not None and hasattr(layer.bias, 'grad1'): + del layer.bias.grad1 + + +def append(instance, attrib, item): + if not hasattr(instance, attrib): + setattr(instance, attrib, []) + + getattr(instance, attrib).append(item) + + +compute_grad1_for = { + 'Linear': compute_grad1_for_linear, + 'Conv2d': compute_grad1_for_conv2d +} + +clear_grad_for = { + 'Linear': clear_grad1_for_linear_and_conv2d, + 'Conv2d': clear_grad1_for_linear_and_conv2d +} + + def compute_grad1(model: nn.Module, loss_type: str = 'mean') -> None: - """ - Compute per-example gradients and save them under 'param.grad1'. Must be called after loss.backprop() + """Compute per-example gradients and save them under 'param.grad1'. Must be + called after loss.backprop() Args: model: - loss_type: either "mean" or "sum" depending whether backpropped loss was averaged or summed over batch + loss_type: either "mean" or "sum" depending on the backpropped loss """ - - assert loss_type in ('sum', 'mean') + assert loss_type in ('mean', 'sum') for layer in model.modules(): layer_type = _layer_type(layer) if layer_type not in _supported_layers: continue - assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)" + assert hasattr(layer, 'activations_list'), "No activations detected, run forward after add_hooks(model)" assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)" - assert len(layer.backprops_list) == 1, "Multiple backprops detected, make sure to call clear_backprops(model)" + assert len(layer.activations_list) == len(layer.backprops_list) - A = layer.activations - n = A.shape[0] - if loss_type == 'mean': - B = layer.backprops_list[0] * n - else: # loss_type == 'sum': - B = layer.backprops_list[0] + clear_grad_for[layer_type](layer) + for A, B in zip(layer.activations_list, layer.backprops_list): + if loss_type == 'mean': + B *= A.shape[0] - if layer_type == 'Linear': - setattr(layer.weight, 'grad1', torch.einsum('ni,nj->nij', B, A)) - if layer.bias is not None: - setattr(layer.bias, 'grad1', B) - - elif layer_type == 'Conv2d': - A = torch.nn.functional.unfold(A, layer.kernel_size) - B = B.reshape(n, -1, A.shape[-1]) - grad1 = torch.einsum('ijk,ilk->ijl', B, A) - shape = [n] + list(layer.weight.shape) - setattr(layer.weight, 'grad1', grad1.reshape(shape)) - if layer.bias is not None: - setattr(layer.bias, 'grad1', torch.sum(B, dim=2)) + compute_grad1_for[layer_type](layer, A, B) def compute_hess(model: nn.Module,) -> None: @@ -178,11 +221,12 @@ def compute_hess(model: nn.Module,) -> None: layer_type = _layer_type(layer) if layer_type not in _supported_layers: continue - assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)" + assert hasattr(layer, 'activations_list'), "No forward passes detected" + assert len(layer.activations_list) == 1 assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)" if layer_type == 'Linear': - A = layer.activations + A = layer.activations_list[0] B = torch.stack(layer.backprops_list) n = A.shape[0] @@ -190,10 +234,7 @@ def compute_hess(model: nn.Module,) -> None: A = torch.stack([A] * o) Jb = torch.einsum("oni,onj->onij", B, A).reshape(n*o, -1) - H = torch.einsum('ni,nj->ij', Jb, Jb) / n - - setattr(layer.weight, 'hess', H) - + layer.weight.hess = torch.einsum('ni,nj->ij', Jb, Jb) / n if layer.bias is not None: setattr(layer.bias, 'hess', torch.einsum('oni,onj->ij', B, B)/n) @@ -201,7 +242,7 @@ def compute_hess(model: nn.Module,) -> None: Kh, Kw = layer.kernel_size di, do = layer.in_channels, layer.out_channels - A = layer.activations.detach() + A = layer.activations_list[0].detach() A = torch.nn.functional.unfold(A, (Kh, Kw)) # n, di * Kh * Kw, Oh * Ow n = A.shape[0] B = torch.stack([Bt.reshape(n, do, -1) for Bt in layer.backprops_list]) # o, n, do, Oh*Ow @@ -214,9 +255,9 @@ def compute_hess(model: nn.Module,) -> None: Jb_bias = torch.einsum('onij->oni', B) Hi_bias = torch.einsum('oni,onj->nij', Jb_bias, Jb_bias) - setattr(layer.weight, 'hess', Hi.mean(dim=0)) + layer.weight.hess = Hi.mean(dim=0) if layer.bias is not None: - setattr(layer.bias, 'hess', Hi_bias.mean(dim=0)) + layer.bias.hess = Hi_bias.mean(dim=0) def backprop_hess(output: torch.Tensor, hess_type: str) -> None: @@ -245,16 +286,15 @@ def backprop_hess(output: torch.Tensor, hess_type: str) -> None: outer_prod_part = torch.einsum('ij,ik->ijk', batch, batch) hess = diag_part - outer_prod_part assert hess.shape == (n, o, o) - for i in range(n): hess[i, :, :] = symsqrt(hess[i, :, :]) + hess = hess.transpose(0, 1) elif hess_type == 'LeastSquares': hess = [] assert len(output.shape) == 2 batch_size, output_size = output.shape - id_mat = torch.eye(output_size) for out_idx in range(output_size): hess.append(torch.stack([id_mat[out_idx]] * batch_size)) @@ -279,7 +319,4 @@ def symsqrt(a, cond=None, return_rank=False, dtype=torch.float32): u = u[:, above_cutoff] B = u @ torch.diag(psigma_diag) @ u.t() - if return_rank: - return B, len(psigma_diag) - else: - return B + return (B, len(psigma_diag)) if return_rank else B diff --git a/autograd_hacks_test.py b/autograd_hacks/test_autograd_hacks.py similarity index 60% rename from autograd_hacks_test.py rename to autograd_hacks/test_autograd_hacks.py index 55c2312..fd6f57b 100644 --- a/autograd_hacks_test.py +++ b/autograd_hacks/test_autograd_hacks.py @@ -1,12 +1,33 @@ + import torch import torch.nn as nn import torch.nn.functional as F +import pytest -import autograd_hacks +from . import autograd_hacks + + +class StriddenNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5, stride=2, padding=2) + self.conv2 = nn.Conv2d(20, 30, 5, stride=2, padding=2) + self.fc1_input_size = 7 * 7 * 30 + self.fc1 = nn.Linear(self.fc1_input_size, 500) + self.fc2 = nn.Linear(500, 10) + + def forward(self, x): + batch_size = x.shape[0] + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = x.view(batch_size, self.fc1_input_size) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x -# Lenet-5 from https://github.com/pytorch/examples/blob/master/mnist/main.py class Net(nn.Module): + """Lenet-5 from https://github.com/pytorch/examples/blob/master/mnist/main.py""" def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5, 1) @@ -25,8 +46,9 @@ def forward(self, x): return x -# Tiny LeNet-5 for Hessian testing class TinyNet(nn.Module): + """Tiny LeNet-5 for Hessian testing""" + def __init__(self): super(TinyNet, self).__init__() self.conv1 = nn.Conv2d(1, 2, 2, 1) @@ -63,7 +85,8 @@ def hessian(y: torch.Tensor, x: torch.Tensor): return jacobian(jacobian(y, x, create_graph=True), x) -def test_grad1(): +@pytest.mark.parametrize("Net", [Net, TinyNet, StriddenNet]) +def test_grad1(Net): torch.manual_seed(1) model = Net() loss_fn = nn.CrossEntropyLoss() @@ -79,22 +102,55 @@ def test_grad1(): autograd_hacks.disable_hooks() # Compare values against autograd - losses = torch.stack([loss_fn(output[i:i+1], targets[i:i+1]) for i in range(len(data))]) + losses = torch.stack([loss_fn(output[i:i+1], targets[i:i+1]) + for i in range(len(data))]) for layer in model.modules(): if not autograd_hacks.is_supported(layer): continue + for param in layer.parameters(): - assert torch.allclose(param.grad, param.grad1.mean(dim=0)) - assert torch.allclose(jacobian(losses, param), param.grad1) + assert torch.allclose(param.grad, param.grad1[0].mean(dim=0)) + assert torch.allclose(jacobian(losses, param), param.grad1[0]) + + +def test_grad1_for_multiple_passes(): + torch.manual_seed(42) + model = Net() + loss_fn = nn.CrossEntropyLoss() + + def get_data(batch_size): + return (torch.rand(batch_size, 1, 28, 28), + torch.LongTensor(batch_size).random_(0, 10)) + n1 = 4 + n2 = 10 -def test_hess(): - subtest_hess_type('CrossEntropy') - subtest_hess_type('LeastSquares') + autograd_hacks.add_hooks(model) + data, targets = get_data(n1) + output = model(data) + loss_fn(output, targets).backward(retain_graph=True) + grads = [{n: p.grad.clone() for n, p in model.named_parameters()}] + model.zero_grad() -def subtest_hess_type(hess_type): + data, targets = get_data(n2) + output = model(data) + loss_fn(output, targets).backward(retain_graph=True) + grads.append({n: p.grad for n, p in model.named_parameters()}) + + autograd_hacks.compute_grad1(model) + + autograd_hacks.disable_hooks() + + for n, p in model.named_parameters(): + for i, grad in enumerate(grads): + assert grad[n].shape == p.grad1[i].shape[1:] + assert torch.allclose(grad[n], p.grad1[i].mean(dim=0)) + + +@pytest.mark.parametrize("hess_type", ['CrossEntropy', 'LeastSquares']) +def test_hess(hess_type): torch.manual_seed(1) model = TinyNet() @@ -112,13 +168,15 @@ def least_squares_loss(data_, targets_): if hess_type == 'LeastSquares': targets = torch.rand(output.shape) loss_fn = least_squares_loss - else: # hess_type == 'CrossEntropy': + elif hess_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, 10) loss_fn = nn.CrossEntropyLoss() + else: + raise ValueError(f"Unknown hessian type") - autograd_hacks.backprop_hess(output, hess_type=hess_type) + autograd_hacks.backprop_hess(output, hess_type) autograd_hacks.clear_backprops(model) - autograd_hacks.backprop_hess(output, hess_type=hess_type) + autograd_hacks.backprop_hess(output, hess_type) autograd_hacks.compute_hess(model) autograd_hacks.disable_hooks() @@ -126,13 +184,9 @@ def least_squares_loss(data_, targets_): for layer in model.modules(): if not autograd_hacks.is_supported(layer): continue + for param in layer.parameters(): loss = loss_fn(output, targets) hess_autograd = hessian(loss, param) hess = param.hess assert torch.allclose(hess, hess_autograd.reshape(hess.shape)) - - -if __name__ == '__main__': - test_grad1() - test_hess() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9b11cf6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +# Add other versions of requirements as a comma separated list. +# > pytorch>=1.0.0,<1.3.2 +# However, be sure all intermediate versions are supported. +mypy==0.761 +torch==1.3.1 +pytest==5.3.5 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e359686 --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ + +from setuptools import setup + +def requirements(): + reqs = [] + with open('requirements.txt', 'r') as fp: + for req in fp: + # remove endline, white space, and anything after '#' + req = req.rstrip('\n').strip().split('#')[0] + if req is '': + continue + + reqs.append(req) + + return reqs + +setup( + name='autograd_hacks', + version='0.0.2', + packages=['autograd_hacks'], + long_description=open('README.md').read(), + install_requires=requirements() +)