Skip to content

Commit

Permalink
Improve documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
DouglasOrr committed Oct 4, 2023
1 parent 8393f26 commit e2c9264
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 34 deletions.
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
# Tensor tracker

Flexibly track outputs and grad-outputs of `torch.nn.Module`. [API documentation](https://graphcore-research.github.io/pytorch-tensor-tracker/).
[API documentation](https://graphcore-research.github.io/pytorch-tensor-tracker/) | [Example](https://graphcore-research.github.io/pytorch-tensor-tracker/usage.html)

Flexibly track outputs and grad-outputs of `torch.nn.Module`.

**Installation:**

```bash
pip install git+https://github.com/graphcore-research/pytorch-tensor-tracker
```

**Usage:**

Use `tensor_tracker.track(module)` as a context manager to start capturing tensors from within your module's forward and backward passes:

```python
import tensor_tracker

with tensor_tracker.track(module) as tracker:
module(inputs).backward()

print(tracker) # => Tracker(stashes=8, tracking=0)
```

Now `Tracker` is filled with stashes, containing copies of fwd/bwd tensors at (sub)module outputs. (Note, this can consume a lot of memory.)

It behaves like a list of `Stash` objects, with their attached `value`, usually a tensor or tuple of tensors. We can also use `to_frame()` to get a Pandas table of summary statistics:

```python
print(list(tracker))
# => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)),
# ...]
Expand All @@ -21,7 +37,8 @@ display(tracker.to_frame())

<img src="doc/usage_to_frame.png" alt="tensor tracker to_frame output" style="width:30em;"/>

See our [example of visualising transformer activations & gradients using UMAP](doc/Example.ipynb).
See the [documentation](https://graphcore-research.github.io/pytorch-tensor-tracker/) for more info, or for a more practical example, see our demo of [visualising transformer activations & gradients using UMAP](doc/Example.ipynb).


## License

Expand Down
110 changes: 79 additions & 31 deletions doc/Usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"\n",
"# Usage example\n",
"\n",
"General setup:"
"Create a toy model to track:"
]
},
{
Expand Down Expand Up @@ -40,7 +40,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Using `tensor_tracker`:"
"Use `tensor_tracker` to capture forward pass activations and backward pass gradients from our toy model. By default, the tracker saves full tensors, as a list of `tensor_tracker.Stash` objects."
]
},
{
Expand All @@ -52,27 +52,84 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=False, value=tensor([[ 0.4698, 1.2426, 0.5403, -1.1454],\n",
" [-0.8425, -0.6475, -0.2189, -1.1326],\n",
" [ 0.1268, 1.3564, 0.5632, -0.1039]])), Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.6237, -0.1652, 0.3782, -0.8841],\n",
" [-0.9278, -0.2848, -0.8688, -0.4719],\n",
" [-0.3449, 0.3643, 0.3935, -0.6302]])), Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.2458, 1.0003, -0.8231, -0.1405, -0.2964, 0.5837, 0.2889, 0.2059,\n",
" -0.6114, -0.5916],\n",
" [-0.6345, 1.0882, -0.4304, -0.2196, -0.0426, 0.9428, 0.2051, 0.5897,\n",
" -0.2217, -0.9132],\n",
" [-0.0822, 0.9985, -0.7097, -0.3139, -0.4805, 0.6878, 0.2560, 0.3254,\n",
" -0.4447, -0.3332]])), Stash(name='', type=<class '__main__.Model'>, grad=False, value=tensor(2.5663)), Stash(name='', type=<class '__main__.Model'>, grad=True, value=(tensor(1.),)), Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[ 0.0237, 0.0824, -0.3200, 0.0263, 0.0225, 0.0543, 0.0404, 0.0372,\n",
" 0.0164, 0.0168],\n",
" [ 0.0139, 0.0779, 0.0171, 0.0211, 0.0251, 0.0673, 0.0322, -0.2860,\n",
" 0.0210, 0.0105],\n",
" [-0.3066, 0.0787, 0.0143, 0.0212, 0.0179, 0.0577, 0.0374, 0.0401,\n",
" 0.0186, 0.0208]]),)), Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[-0.1755, 0.1306, 0.0443, -0.1823],\n",
" [ 0.1202, -0.0728, 0.0066, -0.0839],\n",
" [-0.1863, 0.0470, -0.1055, -0.0353]]),)), Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=True, value=(tensor([[-0.0108, 0.1086, -0.1304, -0.0370],\n",
" [ 0.0534, -0.0029, 0.0078, -0.0074],\n",
" [-0.0829, 0.0152, -0.1170, -0.0625]]),))]\n"
"Tracker(stashes=8, tracking=0)\n"
]
},
}
],
"source": [
"import tensor_tracker\n",
"\n",
"with tensor_tracker.track(module) as tracker:\n",
" module(inputs).backward()\n",
"\n",
"print(tracker)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that calls are only tracked within the `with` context. Then, the tracker behaves like a list of `Stash` objects, with attached `name`, `value` etc."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=False, value=tensor([[ 0.4698, 1.2426, 0.5403, -1.1454],\n",
" [-0.8425, -0.6475, -0.2189, -1.1326],\n",
" [ 0.1268, 1.3564, 0.5632, -0.1039]])),\n",
" Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.6237, -0.1652, 0.3782, -0.8841],\n",
" [-0.9278, -0.2848, -0.8688, -0.4719],\n",
" [-0.3449, 0.3643, 0.3935, -0.6302]])),\n",
" Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.2458, 1.0003, -0.8231, -0.1405, -0.2964, 0.5837, 0.2889, 0.2059,\n",
" -0.6114, -0.5916],\n",
" [-0.6345, 1.0882, -0.4304, -0.2196, -0.0426, 0.9428, 0.2051, 0.5897,\n",
" -0.2217, -0.9132],\n",
" [-0.0822, 0.9985, -0.7097, -0.3139, -0.4805, 0.6878, 0.2560, 0.3254,\n",
" -0.4447, -0.3332]])),\n",
" Stash(name='', type=<class '__main__.Model'>, grad=False, value=tensor(2.5663)),\n",
" Stash(name='', type=<class '__main__.Model'>, grad=True, value=(tensor(1.),)),\n",
" Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[ 0.0237, 0.0824, -0.3200, 0.0263, 0.0225, 0.0543, 0.0404, 0.0372,\n",
" 0.0164, 0.0168],\n",
" [ 0.0139, 0.0779, 0.0171, 0.0211, 0.0251, 0.0673, 0.0322, -0.2860,\n",
" 0.0210, 0.0105],\n",
" [-0.3066, 0.0787, 0.0143, 0.0212, 0.0179, 0.0577, 0.0374, 0.0401,\n",
" 0.0186, 0.0208]]),)),\n",
" Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[-0.1755, 0.1306, 0.0443, -0.1823],\n",
" [ 0.1202, -0.0728, 0.0066, -0.0839],\n",
" [-0.1863, 0.0470, -0.1055, -0.0353]]),)),\n",
" Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=True, value=(tensor([[-0.0108, 0.1086, -0.1304, -0.0370],\n",
" [ 0.0534, -0.0029, 0.0078, -0.0074],\n",
" [-0.0829, 0.0152, -0.1170, -0.0625]]),))]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(list(tracker))\n",
"# => [Stash(name=\"embed\", type=nn.Embedding, grad=False, value=tensor(...)),\n",
"# ...]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As a higher-level API, `to_frame` computes summary statistics, defaulting to `torch.std`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
Expand Down Expand Up @@ -178,15 +235,6 @@
}
],
"source": [
"import tensor_tracker\n",
"\n",
"with tensor_tracker.track(module) as tracker:\n",
" module(inputs).backward()\n",
"\n",
"print(list(tracker))\n",
"# => [Stash(name=\"embed\", type=nn.Embedding, grad=False, value=tensor(...)),\n",
"# ...]\n",
"\n",
"display(tracker.to_frame())"
]
}
Expand Down
2 changes: 1 addition & 1 deletion tensor_tracker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Utility for tracking activations and gradients at `nn.Module` outputs.
Use `track` to start tracking a module & submodule. Then use the original module
Use `track` to start tracking a module & submodules. Then use the original module
as usual. Your `Tracker` will be filled with a list of `Stash`es, containing
copies of fwd/bwd tensors at (sub)module outputs. (Beware, this can consume
a lot of memory.)
Expand Down

0 comments on commit e2c9264

Please sign in to comment.