Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gradient computation in multiple forward passes #4

Closed
wants to merge 7 commits into from
Closed

Support gradient computation in multiple forward passes #4

wants to merge 7 commits into from

Conversation

jusjusjus
Copy link

@jusjusjus jusjusjus commented Feb 28, 2020

I rewrote parts of your module to allow the computation of gradients from multiple forward passes. The use case is best summarized in the test, which I added:

def test_grad1_for_multiple_passes():
    torch.manual_seed(42)
    model = Net()
    loss_fn = nn.CrossEntropyLoss()

    def get_data(batch_size):
        return (torch.rand(batch_size, 1, 28, 28),
                torch.LongTensor(batch_size).random_(0, 10))

    n1 = 4
    n2 = 10

    autograd_hacks.add_hooks(model)

    data, targets = get_data(n1)
    output = model(data)
    loss_fn(output, targets).backward(retain_graph=True)
    grads = [{n: p.grad.clone() for n, p in model.named_parameters()}]
    model.zero_grad()

    data, targets = get_data(n2)
    output = model(data)
    loss_fn(output, targets).backward(retain_graph=True)
    grads.append({n: p.grad for n, p in model.named_parameters()})

    autograd_hacks.compute_grad1(model)

    autograd_hacks.disable_hooks()

    for n, p in model.named_parameters():
        for i, grad in enumerate(grads):
            assert grad[n].shape == p.grad1[i].shape[1:]
            assert torch.allclose(grad[n], p.grad1[i].mean(dim=0))

The pull request also addresses Issue #3.
This, of course, requires us to have another index on property grad1 which now becomes a list. See if you like this major change in your API.

In the future, I'd also like to release this module on PyPI for maintenance reasons.

We can now compute gradients for multiple passes through the network b/c
activations and backprops are stored.  Multiple passes are saved into a
list `model.param.weight.grad1[...]`.
@jusjusjus jusjusjus closed this Mar 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant