Skip to content

Commit

Permalink
Tweak documentation and to_frame() column naming (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
DouglasOrr authored Oct 2, 2023
1 parent a0f9283 commit 8393f26
Show file tree
Hide file tree
Showing 9 changed files with 420 additions and 196 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ __pycache__

/build
/dist
/doc
/doc/tensor_tracker
/local

*.egg-info/
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ print(list(tracker))
display(tracker.to_frame())
```

See our example of [visualising transformer activations & gradients using UMAP](tests/Example.ipynb).
<img src="doc/usage_to_frame.png" alt="tensor tracker to_frame output" style="width:30em;"/>

See our [example of visualising transformer activations & gradients using UMAP](doc/Example.ipynb).

## License

Expand Down
33 changes: 17 additions & 16 deletions dev
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def cli(*args: Any, **kwargs: Any) -> Callable[[T], T]:

# Commands

PYTHON_ROOTS = ["tensor_tracker", "tests", "dev", "setup.py"]
PYTHON_ROOTS = ["tensor_tracker", "tests", "doc", "dev", "setup.py"]


@cli("-s", "--no-capture", action="store_false", dest="capture")
Expand All @@ -69,7 +69,7 @@ def tests(capture: bool, filter: Optional[str]) -> None:
def lint() -> None:
"""run static analysis"""
run(["python3", "-m", "flake8", *PYTHON_ROOTS])
run(["python3", "-m", "mypy", *PYTHON_ROOTS])
run(["python3", "-m", "mypy", *(r for r in PYTHON_ROOTS if r != "doc")])


@cli("--check", action="store_true")
Expand All @@ -83,7 +83,7 @@ def format(check: bool) -> None:
def copyright() -> None:
"""check for Graphcore copyright headers on relevant files"""
command = (
"find " + " ".join(PYTHON_ROOTS) + " -type f -not -name *.pyc"
"find " + " ".join(PYTHON_ROOTS) + " -type f -not -name *.pyc -not -name *.png"
" | xargs grep -L 'Copyright (c) 202. Graphcore Ltd[.] All rights reserved[.]'"
)
print(f"$ {command}", file=sys.stderr)
Expand Down Expand Up @@ -130,19 +130,20 @@ def doc() -> None:
"tensor_tracker",
]
)
run(
[
"jupyter",
"nbconvert",
"--to",
"html",
"tests/Example.ipynb",
"--output-dir",
"doc/tensor_tracker",
"--output",
"example.html",
]
)
for notebook in ["Example", "Usage"]:
run(
[
"jupyter",
"nbconvert",
"--to",
"html",
f"doc/{notebook}.ipynb",
"--output-dir",
"doc/tensor_tracker",
"--output",
f"{notebook.lower()}.html",
]
)


@cli("--skip", nargs="*", default=[], help="commands to skip")
Expand Down
169 changes: 169 additions & 0 deletions doc/Example.ipynb

Large diffs are not rendered by default.

216 changes: 216 additions & 0 deletions doc/Usage.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) 2023 Graphcore Ltd. All rights reserved.\n",
"\n",
"# Usage example\n",
"\n",
"General setup:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, Tensor\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.embed = nn.Embedding(10, 4)\n",
" self.project = nn.Linear(4, 4)\n",
" self.unembed = nn.Linear(4, 10)\n",
"\n",
" def forward(self, tokens: Tensor) -> Tensor:\n",
" logits = self.unembed(self.project(self.embed(tokens)))\n",
" return nn.functional.cross_entropy(logits, tokens)\n",
"\n",
"torch.manual_seed(100)\n",
"module = Model()\n",
"inputs = torch.randint(0, 10, (3,))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using `tensor_tracker`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=False, value=tensor([[ 0.4698, 1.2426, 0.5403, -1.1454],\n",
" [-0.8425, -0.6475, -0.2189, -1.1326],\n",
" [ 0.1268, 1.3564, 0.5632, -0.1039]])), Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.6237, -0.1652, 0.3782, -0.8841],\n",
" [-0.9278, -0.2848, -0.8688, -0.4719],\n",
" [-0.3449, 0.3643, 0.3935, -0.6302]])), Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.2458, 1.0003, -0.8231, -0.1405, -0.2964, 0.5837, 0.2889, 0.2059,\n",
" -0.6114, -0.5916],\n",
" [-0.6345, 1.0882, -0.4304, -0.2196, -0.0426, 0.9428, 0.2051, 0.5897,\n",
" -0.2217, -0.9132],\n",
" [-0.0822, 0.9985, -0.7097, -0.3139, -0.4805, 0.6878, 0.2560, 0.3254,\n",
" -0.4447, -0.3332]])), Stash(name='', type=<class '__main__.Model'>, grad=False, value=tensor(2.5663)), Stash(name='', type=<class '__main__.Model'>, grad=True, value=(tensor(1.),)), Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[ 0.0237, 0.0824, -0.3200, 0.0263, 0.0225, 0.0543, 0.0404, 0.0372,\n",
" 0.0164, 0.0168],\n",
" [ 0.0139, 0.0779, 0.0171, 0.0211, 0.0251, 0.0673, 0.0322, -0.2860,\n",
" 0.0210, 0.0105],\n",
" [-0.3066, 0.0787, 0.0143, 0.0212, 0.0179, 0.0577, 0.0374, 0.0401,\n",
" 0.0186, 0.0208]]),)), Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[-0.1755, 0.1306, 0.0443, -0.1823],\n",
" [ 0.1202, -0.0728, 0.0066, -0.0839],\n",
" [-0.1863, 0.0470, -0.1055, -0.0353]]),)), Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=True, value=(tensor([[-0.0108, 0.1086, -0.1304, -0.0370],\n",
" [ 0.0534, -0.0029, 0.0078, -0.0074],\n",
" [-0.0829, 0.0152, -0.1170, -0.0625]]),))]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>type</th>\n",
" <th>grad</th>\n",
" <th>std</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>embed</td>\n",
" <td>torch.nn.modules.sparse.Embedding</td>\n",
" <td>False</td>\n",
" <td>0.853265</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>project</td>\n",
" <td>torch.nn.modules.linear.Linear</td>\n",
" <td>False</td>\n",
" <td>0.494231</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>unembed</td>\n",
" <td>torch.nn.modules.linear.Linear</td>\n",
" <td>False</td>\n",
" <td>0.581503</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td></td>\n",
" <td>__main__.Model</td>\n",
" <td>False</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td></td>\n",
" <td>__main__.Model</td>\n",
" <td>True</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>unembed</td>\n",
" <td>torch.nn.modules.linear.Linear</td>\n",
" <td>True</td>\n",
" <td>0.105266</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>project</td>\n",
" <td>torch.nn.modules.linear.Linear</td>\n",
" <td>True</td>\n",
" <td>0.112392</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>embed</td>\n",
" <td>torch.nn.modules.sparse.Embedding</td>\n",
" <td>True</td>\n",
" <td>0.068816</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name type grad std\n",
"0 embed torch.nn.modules.sparse.Embedding False 0.853265\n",
"1 project torch.nn.modules.linear.Linear False 0.494231\n",
"2 unembed torch.nn.modules.linear.Linear False 0.581503\n",
"3 __main__.Model False NaN\n",
"4 __main__.Model True NaN\n",
"5 unembed torch.nn.modules.linear.Linear True 0.105266\n",
"6 project torch.nn.modules.linear.Linear True 0.112392\n",
"7 embed torch.nn.modules.sparse.Embedding True 0.068816"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import tensor_tracker\n",
"\n",
"with tensor_tracker.track(module) as tracker:\n",
" module(inputs).backward()\n",
"\n",
"print(list(tracker))\n",
"# => [Stash(name=\"embed\", type=nn.Embedding, grad=False, value=tensor(...)),\n",
"# ...]\n",
"\n",
"display(tracker.to_frame())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file added doc/usage_to_frame.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 13 additions & 8 deletions tensor_tracker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
copies of fwd/bwd tensors at (sub)module outputs. (Beware, this can consume
a lot of memory.)
Usage:
Usage ([notebook](usage.html)):
```
with tensor_tracker.track(model) as tracker:
Expand All @@ -34,8 +34,8 @@
- Manually register/unregister hooks:
`tracker = Tracker(); tracker.register(...); tracker.unregister()`
See also: example of
[visualising transformer activations & gradients using UMAP](example.html).
See also: [example of
visualising transformer activations & gradients using UMAP](example.html).
"""

import dataclasses
Expand Down Expand Up @@ -214,16 +214,21 @@ def __len__(self) -> int:
return len(self.stashes)

def to_frame(
self, stat: Callable[[Tensor], Tensor] = torch.std
self,
stat: Callable[[Tensor], Tensor] = torch.std,
stat_name: Optional[str] = None,
) -> "pandas.DataFrame": # type:ignore[name-defined] # NOQA: F821
import pandas

column_name = (
getattr(stat, "__name__", "value") if stat_name is None else stat_name
)

def to_item(stash: Stash) -> Dict[str, Any]:
d = stash.__dict__.copy()
first_value = stash.first_value
d["value"] = (
stat(first_value).item() if isinstance(first_value, Tensor) else None
)
d.pop("value")
v = stash.first_value
d[column_name] = stat(v).item() if isinstance(v, Tensor) else None
d["type"] = f"{stash.type.__module__}.{stash.type.__name__}"
return d

Expand Down
Loading

0 comments on commit 8393f26

Please sign in to comment.