diff --git a/README.md b/README.md
index 6d5a9f4..720a97b 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# Tensor tracker
-[API documentation](https://graphcore-research.github.io/pytorch-tensor-tracker/) | [Example](https://graphcore-research.github.io/pytorch-tensor-tracker/usage.html)
+[API documentation](https://graphcore-research.github.io/pytorch-tensor-tracker/) | [Example](doc/Usage.ipynb)
Flexibly track outputs and grad-outputs of `torch.nn.Module`.
@@ -37,7 +37,7 @@ display(tracker.to_frame())
-See the [documentation](https://graphcore-research.github.io/pytorch-tensor-tracker/) for more info, or for a more practical example, see our demo of [visualising transformer activations & gradients using UMAP](doc/Example.ipynb).
+See the [documentation](https://graphcore-research.github.io/pytorch-tensor-tracker/) for more info, or for a more practical example, see our demo of [visualising transformer activations & gradients using UMAP](doc/Example.ipynb). To use on IPU with PopTorch, please see [Usage (PopTorch)](doc/UsagePopTorch.ipynb).
## License
diff --git a/doc/UsagePopTorch.ipynb b/doc/UsagePopTorch.ipynb
new file mode 100644
index 0000000..1d23e19
--- /dev/null
+++ b/doc/UsagePopTorch.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Copyright (c) 2023 Graphcore Ltd. All rights reserved.\n",
+ "\n",
+ "# Usage example (PopTorch)\n",
+ "\n",
+ "Create a toy model to track:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch import nn, Tensor\n",
+ "\n",
+ "class Model(nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.embed = nn.Embedding(10, 4)\n",
+ " self.project = nn.Linear(4, 4)\n",
+ " self.unembed = nn.Linear(4, 10)\n",
+ "\n",
+ " def forward(self, tokens: Tensor) -> Tensor:\n",
+ " logits = self.unembed(self.project(self.embed(tokens)))\n",
+ " return nn.functional.cross_entropy(logits, tokens)\n",
+ "\n",
+ "torch.manual_seed(100)\n",
+ "module = Model()\n",
+ "inputs = torch.randint(0, 10, (3,))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**PopTorch:**\n",
+ "\n",
+ "A few modifications to work with PopTorch:\n",
+ " - Any tracking should be contained within `forward()`.\n",
+ " - We shouldn't call `tensor.cpu()`, as this is implicit on returned tensors.\n",
+ " - We don't have access to the backward pass."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[13:48:29.771] [poptorch:cpp] [warning] [DISPATCHER] Type coerced from Long to Int for tensor id 138\n",
+ "Graph compilation: 100%|██████████| 100/100 [00:04<00:00]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[Stash(name='embed', type=, grad=False, value=tensor([[ 0.4520, -0.1066, 1.1028, -1.1578],\n",
+ " [-0.4866, -0.1484, -1.6819, 0.7740],\n",
+ " [-1.0324, 0.2063, -0.7983, 0.4695]])),\n",
+ " Stash(name='project', type=, grad=False, value=tensor([[ 1.2474, 0.4518, 0.2115, -0.6991],\n",
+ " [-0.3698, -0.1035, -0.2358, -0.3482],\n",
+ " [ 0.2165, 0.2673, -0.1278, -0.1348]])),\n",
+ " Stash(name='unembed', type=, grad=False, value=tensor([[-0.2676, 0.0945, 0.4727, 0.0716, -0.1146, 0.2311, 0.4380, -0.1172,\n",
+ " 0.6078, -0.0632],\n",
+ " [ 0.2343, -0.0936, 0.1143, -0.0777, 0.0148, -0.0783, 0.2015, 0.1975,\n",
+ " 0.2441, -0.3956],\n",
+ " [ 0.1521, -0.0814, 0.2678, 0.0481, 0.1128, -0.0149, 0.3953, 0.2135,\n",
+ " 0.3824, -0.2818]]))]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from typing import Dict\n",
+ "import poptorch\n",
+ "import tensor_tracker\n",
+ "\n",
+ "class TrackingModel(Model):\n",
+ " def forward(self, inputs: Tensor) -> Dict[str, Tensor]:\n",
+ " with tensor_tracker.track(self, stash_value=lambda t: t) as tracker:\n",
+ " loss = super().forward(inputs)\n",
+ " return loss, [t.__dict__ for t in tracker]\n",
+ "\n",
+ "loss, tracked = poptorch.inferenceModel(TrackingModel())(inputs)\n",
+ "tracked = [tensor_tracker.Stash(**d) for d in tracked]\n",
+ "display(tracked)\n",
+ "# => [Stash(name=\"embed\", type=nn.Embedding, grad=False, value=tensor(...)),\n",
+ "# ...]"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}