diff --git a/README.md b/README.md
index c0fb825..4560556 100644
--- a/README.md
+++ b/README.md
@@ -8,47 +8,57 @@ See article [Clarifying exceptions and visualizing tensor operations in deep lea
To help myself and other programmers debug tensor code, I built this library. TensorSensor clarifies exceptions by augmenting messages and visualizing Python code to indicate the shape of tensor variables (see figure to the right for a teaser). It works with [Tensorflow](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/), [JAX](https://github.com/google/jax), and [Numpy](https://numpy.org/), as well as higher-level libraries like [Keras](https://keras.io/) and [fastai](https://www.fast.ai/).
-*TensorSensor is currently at 0.1.2 (May 2021) so I'm happy to receive issues created at this repo or direct email*.
+*TensorSensor is currently at 1.0 (December 2021)*.
## Visualizations
For more, see [examples.ipynb](testing/examples.ipynb).
```python
-import torch
-import tsensor
-W = torch.rand(d,n_neurons)
-b = torch.rand(n_neurons,1)
-X = torch.rand(n,d)
-with tsensor.clarify():
+import numpy as np
+
+n = 200 # number of instances
+d = 764 # number of instance features
+n_neurons = 100 # how many neurons in this layer?
+
+W = np.random.rand(d,n_neurons)
+b = np.random.rand(n_neurons,1)
+X = np.random.rand(n,d)
+with tsensor.clarify() as c:
Y = W @ X.T + b
```
Displays this in a jupyter notebook or separate window:
-
+
Instead of the following default exception message:
```
-RuntimeError: size mismatch, m1: [764 x 100], m2: [764 x 200] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
+ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)
```
TensorSensor augments the message with more information about which operator caused the problem and includes the shape of the operands:
```
-Cause: @ on tensor operand W w/shape [764, 100] and operand X.T w/shape [764, 200]
+Cause: @ on tensor operand W w/shape (764, 100) and operand X.T w/shape (764, 200)
```
You can also get the full computation graph for an expression that includes all of these sub result shapes.
```python
-tsensor.astviz("b = W@b + (h+3).dot(h) + torch.abs(torch.tensor(34))", sys._getframe())
+W = torch.rand(size=(2000,2000), dtype=torch.float64)
+b = torch.rand(size=(2000,1), dtype=torch.float64)
+h = torch.zeros(size=(1_000_000,), dtype=int)
+x = torch.rand(size=(2000,1))
+z = torch.rand(size=(2000,1), dtype=torch.complex64)
+
+tsensor.astviz("b = W@b + (h+3).dot(h) + z", sys._getframe())
```
yields the following abstract syntax tree with shapes:
-
+
## Install
@@ -70,7 +80,7 @@ $ pip list | grep -i numpy
numpy 1.19.5
numpydoc 1.1.0
$ pip list | grep -i torch
-torch 1.9.0
+torch 1.10.0
torchvision 0.10.0
$ pip list | grep -i jax
jax 0.2.20
diff --git a/images/ast.svg b/images/ast.svg
index 652bfcb..ae14c78 100644
--- a/images/ast.svg
+++ b/images/ast.svg
@@ -1,424 +1,274 @@
-
-
diff --git a/images/mm.svg b/images/mm.svg
new file mode 100644
index 0000000..3a942d4
--- /dev/null
+++ b/images/mm.svg
@@ -0,0 +1,759 @@
+
+
+
+
+
+
+
+
+ 2021-12-11T13:11:04.999774
+ image/svg+xml
+
+
+ Matplotlib v3.3.4, https://matplotlib.org/
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/testing/examples.ipynb b/testing/examples.ipynb
index 251fbc0..5668bda 100644
--- a/testing/examples.ipynb
+++ b/testing/examples.ipynb
@@ -10752,7 +10752,7 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
diff --git a/testing/playground.ipynb b/testing/playground.ipynb
index f3731d1..50e44d0 100644
--- a/testing/playground.ipynb
+++ b/testing/playground.ipynb
@@ -2,18 +2,462 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
- "evalue": "1D tensors expected, got 2D, 2D tensors at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorEvenMoreMath.cpp:83\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]",
+ "evalue": "mat1 and mat2 shapes cannot be multiplied (10x20 and 10x500)\nCause: @ on tensor operand W w/shape [10, 20] and operand X.T w/shape [10, 500]",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtsensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclarify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meye\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m@\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;31m# W[33, 33] = 3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mRuntimeError\u001b[0m: 1D tensors expected, got 2D, 2D tensors at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorEvenMoreMath.cpp:83\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]"
+ "\u001b[0;32m/var/folders/93/9kzk2ccm8xj8k70059b28jk80000gp/T/ipykernel_55352/3024069779.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mX\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m500\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtsensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexplain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msavefig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'/tmp/mm.pdf'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mY\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (10x20 and 10x500)\nCause: @ on tensor operand W w/shape [10, 20] and operand X.T w/shape [10, 500]"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import tsensor\n",
+ "W = torch.rand(10,20)\n",
+ "b = torch.rand(10,1)\n",
+ "X = torch.rand(500,10)\n",
+ "with tsensor.explain(savefig='/tmp/mm.svg'):\n",
+ " Y = W @ X.T + b"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 2021-12-11T13:06:18.685523\n",
+ " image/svg+xml\n",
+ " \n",
+ " \n",
+ " Matplotlib v3.3.4, https://matplotlib.org/\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "ename": "RuntimeError",
+ "evalue": "1D tensors expected, but got 2D and 2D tensors\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m/var/folders/93/9kzk2ccm8xj8k70059b28jk80000gp/T/ipykernel_55352/2660314622.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtsensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclarify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meye\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m@\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;31m# W[33, 33] = 3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: 1D tensors expected, but got 2D and 2D tensors\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]"
]
}
],
@@ -658,7 +1102,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -672,7 +1116,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.6"
+ "version": "3.8.8"
}
},
"nbformat": 4,