import torch
+from torch import nn, Tensor
+
+class Model(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.embed = nn.Embedding(10, 4)
+ self.project = nn.Linear(4, 4)
+ self.unembed = nn.Linear(4, 10)
+
+ def forward(self, tokens: Tensor) -> Tensor:
+ logits = self.unembed(self.project(self.embed(tokens)))
+ return nn.functional.cross_entropy(logits, tokens)
+
+torch.manual_seed(100)
+module = Model()
+inputs = torch.randint(0, 10, (3,))
+
Using tensor_tracker
:
import tensor_tracker
+
+with tensor_tracker.track(module) as tracker:
+ module(inputs).backward()
+
+print(list(tracker))
+# => [Stash(name="embed", type=nn.Embedding, grad=False, value=tensor(...)),
+# ...]
+
+display(tracker.to_frame())
+
[Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=False, value=tensor([[ 0.4698, 1.2426, 0.5403, -1.1454], + [-0.8425, -0.6475, -0.2189, -1.1326], + [ 0.1268, 1.3564, 0.5632, -0.1039]])), Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.6237, -0.1652, 0.3782, -0.8841], + [-0.9278, -0.2848, -0.8688, -0.4719], + [-0.3449, 0.3643, 0.3935, -0.6302]])), Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.2458, 1.0003, -0.8231, -0.1405, -0.2964, 0.5837, 0.2889, 0.2059, + -0.6114, -0.5916], + [-0.6345, 1.0882, -0.4304, -0.2196, -0.0426, 0.9428, 0.2051, 0.5897, + -0.2217, -0.9132], + [-0.0822, 0.9985, -0.7097, -0.3139, -0.4805, 0.6878, 0.2560, 0.3254, + -0.4447, -0.3332]])), Stash(name='', type=<class '__main__.Model'>, grad=False, value=tensor(2.5663)), Stash(name='', type=<class '__main__.Model'>, grad=True, value=(tensor(1.),)), Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[ 0.0237, 0.0824, -0.3200, 0.0263, 0.0225, 0.0543, 0.0404, 0.0372, + 0.0164, 0.0168], + [ 0.0139, 0.0779, 0.0171, 0.0211, 0.0251, 0.0673, 0.0322, -0.2860, + 0.0210, 0.0105], + [-0.3066, 0.0787, 0.0143, 0.0212, 0.0179, 0.0577, 0.0374, 0.0401, + 0.0186, 0.0208]]),)), Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[-0.1755, 0.1306, 0.0443, -0.1823], + [ 0.1202, -0.0728, 0.0066, -0.0839], + [-0.1863, 0.0470, -0.1055, -0.0353]]),)), Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=True, value=(tensor([[-0.0108, 0.1086, -0.1304, -0.0370], + [ 0.0534, -0.0029, 0.0078, -0.0074], + [-0.0829, 0.0152, -0.1170, -0.0625]]),))] ++
+ | name | +type | +grad | +std | +
---|---|---|---|---|
0 | +embed | +torch.nn.modules.sparse.Embedding | +False | +0.853265 | +
1 | +project | +torch.nn.modules.linear.Linear | +False | +0.494231 | +
2 | +unembed | +torch.nn.modules.linear.Linear | +False | +0.581503 | +
3 | ++ | __main__.Model | +False | +NaN | +
4 | ++ | __main__.Model | +True | +NaN | +
5 | +unembed | +torch.nn.modules.linear.Linear | +True | +0.105266 | +
6 | +project | +torch.nn.modules.linear.Linear | +True | +0.112392 | +
7 | +embed | +torch.nn.modules.sparse.Embedding | +True | +0.068816 | +