diff --git a/README.md b/README.md index c0fb825..4560556 100644 --- a/README.md +++ b/README.md @@ -8,47 +8,57 @@ See article [Clarifying exceptions and visualizing tensor operations in deep lea To help myself and other programmers debug tensor code, I built this library. TensorSensor clarifies exceptions by augmenting messages and visualizing Python code to indicate the shape of tensor variables (see figure to the right for a teaser). It works with [Tensorflow](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/), [JAX](https://github.com/google/jax), and [Numpy](https://numpy.org/), as well as higher-level libraries like [Keras](https://keras.io/) and [fastai](https://www.fast.ai/). -*TensorSensor is currently at 0.1.2 (May 2021) so I'm happy to receive issues created at this repo or direct email*. +*TensorSensor is currently at 1.0 (December 2021)*. ## Visualizations For more, see [examples.ipynb](testing/examples.ipynb). ```python -import torch -import tsensor -W = torch.rand(d,n_neurons) -b = torch.rand(n_neurons,1) -X = torch.rand(n,d) -with tsensor.clarify(): +import numpy as np + +n = 200 # number of instances +d = 764 # number of instance features +n_neurons = 100 # how many neurons in this layer? + +W = np.random.rand(d,n_neurons) +b = np.random.rand(n_neurons,1) +X = np.random.rand(n,d) +with tsensor.clarify() as c: Y = W @ X.T + b ``` Displays this in a jupyter notebook or separate window: - + Instead of the following default exception message: ``` -RuntimeError: size mismatch, m1: [764 x 100], m2: [764 x 200] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41 +ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100) ``` TensorSensor augments the message with more information about which operator caused the problem and includes the shape of the operands: ``` -Cause: @ on tensor operand W w/shape [764, 100] and operand X.T w/shape [764, 200] +Cause: @ on tensor operand W w/shape (764, 100) and operand X.T w/shape (764, 200) ``` You can also get the full computation graph for an expression that includes all of these sub result shapes. ```python -tsensor.astviz("b = W@b + (h+3).dot(h) + torch.abs(torch.tensor(34))", sys._getframe()) +W = torch.rand(size=(2000,2000), dtype=torch.float64) +b = torch.rand(size=(2000,1), dtype=torch.float64) +h = torch.zeros(size=(1_000_000,), dtype=int) +x = torch.rand(size=(2000,1)) +z = torch.rand(size=(2000,1), dtype=torch.complex64) + +tsensor.astviz("b = W@b + (h+3).dot(h) + z", sys._getframe()) ``` yields the following abstract syntax tree with shapes: - + ## Install @@ -70,7 +80,7 @@ $ pip list | grep -i numpy numpy 1.19.5 numpydoc 1.1.0 $ pip list | grep -i torch -torch 1.9.0 +torch 1.10.0 torchvision 0.10.0 $ pip list | grep -i jax jax 0.2.20 diff --git a/images/ast.svg b/images/ast.svg index 652bfcb..ae14c78 100644 --- a/images/ast.svg +++ b/images/ast.svg @@ -1,424 +1,274 @@ - - - + + G - - + + -leaf140529612873584 - -b +leaf140396598700880 + +b - + -leaf140529612872336 - -= +leaf140396332381616 + += - - + + -leaf140529612870608 - -W +leaf140395790355424 + +W - - + + -leaf140529612870800 - -@ +leaf140395790354080 + +@ - - + + -leaf140529612871664 - -b +leaf140395790355952 + +b - - + + -leaf140529612873440 - -+ +leaf140395790353216 + ++ - - + + -leaf140529612871472 - -( +leaf140395790355040 + +( - - + + -leaf140529612872720 - -h +leaf140395790354752 + +h - - + + -leaf140529612870704 - -+ +leaf140395790354320 + ++ - - + + -leaf140529612869792 - -3 +leaf140395790352688 + +3 - - + + -leaf140529612872480 - -) +leaf140395790353360 + +) - - + + -leaf140529612869696 - -. +leaf140396596003168 + +. - - + + -leaf140529615243536 - -dot +leaf140396596004224 + +dot - - + + -leaf140529615243728 - -( +leaf140396596005760 + +( - - + + -leaf140529615243440 - -h +leaf140396599103696 + +h - - + + -leaf140529615244112 - -) +leaf140396599104896 + +) - - + + -leaf140529615244544 - -+ +leaf140396599104944 + ++ - - + + -leaf140529615244016 - -torch +leaf140396599104608 + +z - - + + -leaf140529615244304 - -. - - - +node140396598701936 + +@ +2kx1 +<float64> + + + +node140396598701936->leaf140395790355424 + + + + + +node140396598701936->leaf140395790355952 + + + + -leaf140529615244352 - -abs - - - +node140396599104128 + ++ +1m +<int64> + + + +node140396599104128->leaf140395790354752 + + + + + +node140396599104128->leaf140395790352688 + + + + -leaf140529615244160 - -( - - - +node140396599104800 + +. + + + +node140396599104800->leaf140396596004224 + + + + + +node140396599104800->node140396599104128 + + + + -leaf140529615243680 - -torch - - - +node140396599105136 + +dot() + + + +node140396599105136->leaf140396599103696 + + + + + +node140396599105136->node140396599104800 + + + + -leaf140529613484096 - -. - - - +node140396599104416 + ++ +2kx1 +<float64> + + + +node140396599104416->node140396598701936 + + + + + +node140396599104416->node140396599105136 + + + + -leaf140529613484144 - -tensor +node140396599104320 + ++ +2kx1 +<complex128> - - - -leaf140529613484192 - -( - - - - -leaf140529613484240 - -34 - - - - -leaf140529613484288 - -) - - - - -leaf140529613484336 - -) - - - - -node140529612873152 - -@ -2kx1 - - + + +node140396599104320->leaf140396599104608 + + + + -node140529612873152->leaf140529612870608 - - +node140396599104320->node140396599104416 + + - - -node140529612873152->leaf140529612871664 - - - - - -node140529613484432 - -+ -1m - - + + +node140396599104368 + += +2kx1 +<complex128> + + -node140529613484432->leaf140529612872720 - - +node140396599104368->leaf140396598700880 + + - + -node140529613484432->leaf140529612869792 - - - - - -node140529613484528 - -() -1m - - - -node140529613484528->node140529613484432 - - - - - -node140529613484672 - -. - - - -node140529613484672->leaf140529615243536 - - - - - -node140529613484672->node140529613484528 - - - - - -node140529613484720 - -dot() - - - -node140529613484720->leaf140529615243440 - - - - - -node140529613484720->node140529613484672 - - - - - -node140529613484816 - -+ -2kx1 - - - -node140529613484816->node140529612873152 - - - - - -node140529613484816->node140529613484720 - - - - - -node140529613484912 - -. - - - -node140529613484912->leaf140529615244016 - - - - - -node140529613484912->leaf140529615244352 - - - - - -node140529613485056 - -. - - - -node140529613485056->leaf140529615243680 - - - - - -node140529613485056->leaf140529613484144 - - - - - -node140529613485152 - -tensor() - - - -node140529613485152->leaf140529613484240 - - - - - -node140529613485152->node140529613485056 - - - - - -node140529613485200 - -abs() - - - -node140529613485200->node140529613484912 - - - - - -node140529613485200->node140529613485152 - - - - - -node140529613485248 - -+ -2kx1 - - - -node140529613485248->node140529613484816 - - - - - -node140529613485248->node140529613485200 - - - - - -node140529613485296 - -= -2kx1 - - - -node140529613485296->leaf140529612873584 - - - - - -node140529613485296->node140529613485248 - - +node140396599104368->node140396599104320 + + diff --git a/images/mm.svg b/images/mm.svg new file mode 100644 index 0000000..3a942d4 --- /dev/null +++ b/images/mm.svg @@ -0,0 +1,759 @@ + + + + + + + + + 2021-12-11T13:11:04.999774 + image/svg+xml + + + Matplotlib v3.3.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/testing/examples.ipynb b/testing/examples.ipynb index 251fbc0..5668bda 100644 --- a/testing/examples.ipynb +++ b/testing/examples.ipynb @@ -10752,7 +10752,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ diff --git a/testing/playground.ipynb b/testing/playground.ipynb index f3731d1..50e44d0 100644 --- a/testing/playground.ipynb +++ b/testing/playground.ipynb @@ -2,18 +2,462 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "metadata": {}, "outputs": [ { "ename": "RuntimeError", - "evalue": "1D tensors expected, got 2D, 2D tensors at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorEvenMoreMath.cpp:83\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]", + "evalue": "mat1 and mat2 shapes cannot be multiplied (10x20 and 10x500)\nCause: @ on tensor operand W w/shape [10, 20] and operand X.T w/shape [10, 500]", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtsensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclarify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meye\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m@\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;31m# W[33, 33] = 3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: 1D tensors expected, got 2D, 2D tensors at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorEvenMoreMath.cpp:83\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]" + "\u001b[0;32m/var/folders/93/9kzk2ccm8xj8k70059b28jk80000gp/T/ipykernel_55352/3024069779.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mX\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m500\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtsensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexplain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msavefig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'/tmp/mm.pdf'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mY\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (10x20 and 10x500)\nCause: @ on tensor operand W w/shape [10, 20] and operand X.T w/shape [10, 500]" + ] + } + ], + "source": [ + "import torch\n", + "import tsensor\n", + "W = torch.rand(10,20)\n", + "b = torch.rand(10,1)\n", + "X = torch.rand(500,10)\n", + "with tsensor.explain(savefig='/tmp/mm.svg'):\n", + " Y = W @ X.T + b" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " 2021-12-11T13:06:18.685523\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "ename": "RuntimeError", + "evalue": "1D tensors expected, but got 2D and 2D tensors\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/folders/93/9kzk2ccm8xj8k70059b28jk80000gp/T/ipykernel_55352/2660314622.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtsensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclarify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meye\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m@\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;31m# W[33, 33] = 3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: 1D tensors expected, but got 2D and 2D tensors\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]" ] } ], @@ -658,7 +1102,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -672,7 +1116,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.6" + "version": "3.8.8" } }, "nbformat": 4,