Flexibly track outputs and grad-outputs of torch.nn.Module
.
Installation:
pip install git+https://github.com/graphcore-research/pytorch-tensor-tracker
Usage:
Use tensor_tracker.track(module)
as a context manager to start capturing tensors from within your module's forward and backward passes:
import tensor_tracker
with tensor_tracker.track(module) as tracker:
module(inputs).backward()
print(tracker) # => Tracker(stashes=8, tracking=0)
Now Tracker
is filled with stashes, containing copies of fwd/bwd tensors at (sub)module outputs. (Note, this can consume a lot of memory.)
It behaves like a list of Stash
objects, with their attached value
, usually a tensor or tuple of tensors. We can also use to_frame()
to get a Pandas table of summary statistics:
print(list(tracker))
# => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)),
# ...]
display(tracker.to_frame())
See the documentation for more info, or for a more practical example, see our demo of visualising transformer activations & gradients using UMAP. To use on IPU with PopTorch, please see Usage (PopTorch).
Copyright (c) 2023 Graphcore Ltd. Licensed under the MIT License (LICENSE).
Our dependencies are (see requirements.txt):
Component | About | License |
---|---|---|
torch | Machine learning framework | BSD 3-Clause |
We also use additional Python dependencies for development/testing (see requirements-dev.txt).