diff --git a/docs/src/part3/part3_artifacts/simple_net.py b/docs/src/part3/part3_artifacts/simple_net.py index ef356cf..759e446 100644 --- a/docs/src/part3/part3_artifacts/simple_net.py +++ b/docs/src/part3/part3_artifacts/simple_net.py @@ -9,14 +9,16 @@ class SimpleNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) - self.conv2 = nn.Conv2d(6, 9, 5) - self.fc = nn.Linear(5184, 10) + self.conv2 = nn.Conv2d(3, 6, 5) + self.fc = nn.Linear(4704, 10) def forward(self, x: torch.Tensor): - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = torch.flatten(x, 1) - x = self.fc(x) - return x + z = self.conv1(x) + z = F.relu(z) + y = self.conv2(x) + y = F.relu(y) + o = z + y + o = torch.flatten(o, 1) + o = self.fc(o) + return o + \ No newline at end of file diff --git a/docs/src/part3/simple_net.pt2 b/docs/src/part3/simple_net.pt2 deleted file mode 100644 index cc98ab0..0000000 Binary files a/docs/src/part3/simple_net.pt2 and /dev/null differ diff --git a/docs/src/part3/torch_export.ipynb b/docs/src/part3/torch_export.ipynb index d4fdce7..98e684d 100644 --- a/docs/src/part3/torch_export.ipynb +++ b/docs/src/part3/torch_export.ipynb @@ -35,7 +35,7 @@ "metadata": {}, "source": [ "This IR needs to fulfill a couple of properties for it to be useful to compilers. For example:\n", - "1. Operators have to be general enough for backends to notice patterns and optimize them: Many runtimes have specialized kernels for common operators like convolutions or even more complex ones like a `conv2 + batchnorm` (operator fusion). If the IR reduces all operators to sums, products and views, noticing these patterns becomes too hard.\n", + "1. Operators have to be general enough for backends to notice patterns and optimize them: Many runtimes have specialized kernels for common operators like convolutions or even more complex ones like a `conv2 + relu` (operator fusion, see examples [here](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#fusion-types)). If the IR reduces all operators to sums, products and views, noticing these patterns becomes too hard.\n", "2. The number of operators has to be small enough for the backend to implement all of them. \n", "3. Operators have to be functional, that is, without side effects. For example: If two functions read and modify the same parameters, the order of execution matters and the compiler has to be careful when parallelizing them.\n", "\n", @@ -47,7 +47,7 @@ "metadata": {}, "source": [ "TODO:\n", - "- [ ] Introduce ATEN (dialects), fx.Graph and link to Export IR" + "- [ ] Introduce ATEN (dialects), fx.Graph and link to Export IR, functionalization" ] }, { @@ -84,13 +84,18 @@ "import torch\n", "import pprint\n", "from part3_artifacts.simple_net import SimpleNet\n", - "import torch.fx.graph_module" + "import torch.fx.graph_module\n", + "from myst_nb import glue" ] }, { "cell_type": "code", "execution_count": 2, - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [ { "name": "stdout", @@ -105,17 +110,18 @@ "\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mConv2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mConv2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m9\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5184\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mConv2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4704\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFile:\u001b[0m ~/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py\n", "\u001b[0;31mType:\u001b[0m type\n", "\u001b[0;31mSubclasses:\u001b[0m " @@ -130,8 +136,24 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To export a model we must first define a sample input. This is used to trace the model and generate the Export IR. The way this works efficiently is by using `torch._subclasses.fake_tensor.FakeTensor`. FakeTensors are a special type of tensor that only store metadata such as `dtype`, `shape` and `device` and overload all operators to simulate the computation without actually looking at the values. For example, doing matrix multiplications of FakeTensors of shapes `(N, M)` and `(M, K)` will return a FakeTensor of shape `(N, K)` in constant time instead of the normal quadratic complexity.\n", - "\n" + "To export a model we must first define a sample input. This is used to `trace` the model and generate the Export IR. \n", + "\n", + "```{note}\n", + "`Tracing` refers to the process of recording the operations executed by a model when given a specific input along with their metadata. \n", + "\n", + "The way tracing works efficiently is by using `torch._subclasses.fake_tensor.FakeTensor`. FakeTensors are a special type of tensor that only store metadata such as `dtype`, `shape` and `device` and overload all operators to simulate the computation without actually looking at the values. \n", + "\n", + "For example, doing matrix multiplications of FakeTensors of shapes `(N, M)` and `(M, K)` will return a FakeTensor of shape `(N, K)` in constant time instead of the normal cubic complexity of multiplication.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For our case, the model will be deployed on a camera with a fixed resolution, so we can just define a statically shaped tensor of `batch_size` 1. If you want to support dynamically shaped inputs, refer to the [documentation](https://pytorch.org/docs/main/export.html#expressing-dynamism).\n", + "\n", + "Once we have the input, we can call the `torch.export.export` function.\n" ] }, { @@ -168,23 +190,26 @@ "output_type": "stream", "text": [ "class GraphModule(torch.nn.Module):\n", - " def forward(self, p_conv1_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[6, 3, 5, 5]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv1_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[6]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv2_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[9, 6, 5, 5]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv2_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[9]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_fc_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[10, 5184]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_fc_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[10]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", x: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 3, 32, 32]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\"):\n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:16 in forward, code: x = self.conv1(x)\u001b[0m\n", - " conv2d: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.conv2d.default(x, p_conv1_weight, p_conv1_bias); \u001b[2mx = p_conv1_weight = p_conv1_bias = None\u001b[0m\n", + " def forward(self, p_conv1_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[6, 3, 5, 5]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv1_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[6]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv2_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[6, 3, 5, 5]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv2_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[6]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_fc_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[10, 4704]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_fc_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[10]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", x: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 3, 32, 32]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\"):\n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:16 in forward, code: z = self.conv1(x)\u001b[0m\n", + " conv2d: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.conv2d.default(x, p_conv1_weight, p_conv1_bias); \u001b[2mp_conv1_weight = p_conv1_bias = None\u001b[0m\n", " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:17 in forward, code: x = F.relu(x)\u001b[0m\n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:17 in forward, code: z = F.relu(z)\u001b[0m\n", " relu: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.relu.default(conv2d); \u001b[2mconv2d = None\u001b[0m\n", " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:18 in forward, code: x = self.conv2(x)\u001b[0m\n", - " conv2d_1: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 9, 24, 24]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.conv2d.default(relu, p_conv2_weight, p_conv2_bias); \u001b[2mrelu = p_conv2_weight = p_conv2_bias = None\u001b[0m\n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:18 in forward, code: y = self.conv2(x)\u001b[0m\n", + " conv2d_1: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.conv2d.default(x, p_conv2_weight, p_conv2_bias); \u001b[2mx = p_conv2_weight = p_conv2_bias = None\u001b[0m\n", + " \n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:19 in forward, code: y = F.relu(y)\u001b[0m\n", + " relu_1: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.relu.default(conv2d_1); \u001b[2mconv2d_1 = None\u001b[0m\n", " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:19 in forward, code: x = F.relu(x)\u001b[0m\n", - " relu_1: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 9, 24, 24]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.relu.default(conv2d_1); \u001b[2mconv2d_1 = None\u001b[0m\n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:20 in forward, code: o = z + y\u001b[0m\n", + " add: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.add.Tensor(relu, relu_1); \u001b[2mrelu = relu_1 = None\u001b[0m\n", " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:20 in forward, code: x = torch.flatten(x, 1)\u001b[0m\n", - " view: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 5184]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.view.default(relu_1, \u001b[34m[1, 5184]\u001b[0m); \u001b[2mrelu_1 = None\u001b[0m\n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:21 in forward, code: o = torch.flatten(o, 1)\u001b[0m\n", + " view: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 4704]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.view.default(add, \u001b[34m[1, 4704]\u001b[0m); \u001b[2madd = None\u001b[0m\n", " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:21 in forward, code: x = self.fc(x)\u001b[0m\n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:22 in forward, code: o = self.fc(o)\u001b[0m\n", " linear: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 10]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.linear.default(view, p_fc_weight, p_fc_bias); \u001b[2mview = p_fc_weight = p_fc_bias = None\u001b[0m\n", " return (linear,)\n", " \n" @@ -200,108 +225,306 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Here we can see all nodes (`conv2d`, `relu`, `conv2d_1`, etc.), their shapes, dtypes, devices and the aten operators that are being used (`torch.ops.aten.conv2d.default`) with their accompanying file, line and code. We can also see that the graph inputs expects not only the model inputs but also its parameters (buffers and constants too)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A `torch.fx.GraphModule` is just a wrapper around the `fx.Graph`, and you can access it through `graph_module.graph`. This is useful because `fx.Graph` has a lot of methods to manipulate the graph, like `graph_module.graph.nodes` to access all nodes, `graph_module.graph.nodes[0].args` to access the arguments of the first node." + "Here we can see all *nodes* (`conv2d`, `relu`, `conv2d_1`, etc.), their shapes, dtypes, devices and the aten operators that are being used (`torch.ops.aten.conv2d.default`), with their accompanying file, line and code. We can also see that the graph inputs expects not only the model inputs but also its parameters (buffers and constants too)." ] }, { "cell_type": "code", - "execution_count": 12, - "metadata": {}, + "execution_count": 5, + "metadata": { + "tags": [ + "remove-cell" + ] + }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "graph():\n", - " %p_conv1_weight : [num_users=1] = placeholder[target=p_conv1_weight]\n", - " %p_conv1_bias : [num_users=1] = placeholder[target=p_conv1_bias]\n", - " %p_conv2_weight : [num_users=1] = placeholder[target=p_conv2_weight]\n", - " %p_conv2_bias : [num_users=1] = placeholder[target=p_conv2_bias]\n", - " %p_fc_weight : [num_users=1] = placeholder[target=p_fc_weight]\n", - " %p_fc_bias : [num_users=1] = placeholder[target=p_fc_bias]\n", - " %x : [num_users=1] = placeholder[target=x]\n", - " %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv1_weight, %p_conv1_bias), kwargs = {})\n", - " %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d,), kwargs = {})\n", - " %conv2d_1 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%relu, %p_conv2_weight, %p_conv2_bias), kwargs = {})\n", - " %relu_1 : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d_1,), kwargs = {})\n", - " %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%relu_1, [1, 5184]), kwargs = {})\n", - " %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%view, %p_fc_weight, %p_fc_bias), kwargs = {})\n", - " return (linear,)\n" - ] + "data": { + "text/plain": [ + "graph():\n", + " %p_conv1_weight : [num_users=1] = placeholder[target=p_conv1_weight]\n", + " %p_conv1_bias : [num_users=1] = placeholder[target=p_conv1_bias]\n", + " %p_conv2_weight : [num_users=1] = placeholder[target=p_conv2_weight]\n", + " %p_conv2_bias : [num_users=1] = placeholder[target=p_conv2_bias]\n", + " %p_fc_weight : [num_users=1] = placeholder[target=p_fc_weight]\n", + " %p_fc_bias : [num_users=1] = placeholder[target=p_fc_bias]\n", + " %x : [num_users=2] = placeholder[target=x]\n", + " %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv1_weight, %p_conv1_bias), kwargs = {})\n", + " %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d,), kwargs = {})\n", + " %conv2d_1 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv2_weight, %p_conv2_bias), kwargs = {})\n", + " %relu_1 : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d_1,), kwargs = {})\n", + " %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %relu_1), kwargs = {})\n", + " %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%add, [1, 4704]), kwargs = {})\n", + " %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%view, %p_fc_weight, %p_fc_bias), kwargs = {})\n", + " return (linear,)" + ] + }, + "metadata": { + "scrapbook": { + "mime_prefix": "", + "name": "graphmodule_graph" + } + }, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "[p_conv1_weight,\n", + " p_conv1_bias,\n", + " p_conv2_weight,\n", + " p_conv2_bias,\n", + " p_fc_weight,\n", + " p_fc_bias,\n", + " x,\n", + " conv2d,\n", + " relu,\n", + " conv2d_1,\n", + " relu_1,\n", + " add,\n", + " view,\n", + " linear,\n", + " output]" + ] + }, + "metadata": { + "scrapbook": { + "mime_prefix": "", + "name": "graphmodule_graph_nodes" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "'call_function'" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_op" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_target" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "(conv2d_1,)" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_args" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": " File \"/home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py\", line 19, in forward\n y = F.relu(y)\n" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_stack_trace_2" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "'relu_1'" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_name" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "{'stack_trace': ' File \"/home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py\", line 19, in forward\\n y = F.relu(y)\\n',\n 'nn_module_stack': {'L__self__': ('',\n 'part3_artifacts.simple_net.SimpleNet')},\n 'source_fn_stack': [('relu_1',\n torch.Tensor>)],\n 'original_aten': ,\n 'from_node': [('y_1',\n torch.Tensor>)],\n 'seq_nr': 50,\n 'torch_fn': ('relu_2', 'function.relu'),\n 'val': FakeTensor(..., size=(1, 6, 28, 28)),\n 'tensor_meta': TensorMetadata(shape=torch.Size([1, 6, 28, 28]), dtype=torch.float32, requires_grad=True, stride=(4704, 784, 28, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_meta" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "[add]" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_users" + } + }, + "output_type": "display_data" } ], "source": [ - "print(graph_module.graph)" + "def graph_formatter(graph, pp, cycle):\n", + " pp.text(str(graph))\n", + "\n", + "# def graph_nodes_formatter(nodes, pp, cycle):\n", + "# pp.\n", + "# for node in nodes:\n", + "# pp.text(str(node))\n", + "\n", + "from IPython import get_ipython\n", + "import torch.fx.graph as fx_graph\n", + "plain = get_ipython().display_formatter.formatters['text/plain']\n", + "plain.for_type(torch.fx.Graph, graph_formatter)\n", + "# plain.for_type(fx_graph._node_list, graph_nodes_formatter)\n", + "glue(\"graphmodule_graph\", graph_module.graph)\n", + "glue(\"graphmodule_graph_nodes\", list(graph_module.graph.nodes))\n", + "\n", + "class StackTrace(object):\n", + " def __init__(self, stack_trace):\n", + " self.stack_trace = stack_trace\n", + "\n", + "def stack_trace_formatter(stack_trace, pp, cycle):\n", + " pp.text(stack_trace.stack_trace)\n", + "\n", + "plain.for_type(StackTrace, stack_trace_formatter)\n", + "\n", + "relu_1 = next(filter(lambda n: n.name == \"relu_1\", graph_module.graph.nodes))\n", + "glue(\"relu_1_op\", relu_1.op, display=False)\n", + "glue(\"relu_1_target\", relu_1.target, display=False)\n", + "glue(\"relu_1_args\", relu_1.args, display=False)\n", + "glue(\"relu_1_stack_trace_2\", StackTrace(relu_1.stack_trace), display=False)\n", + "glue(\"relu_1_name\", relu_1.name, display=False)\n", + "glue(\"relu_1_meta\", relu_1.meta, display=False)\n", + "glue(\"relu_1_users\", list(relu_1.users), display=False)\n" ] }, { - "cell_type": "code", - "execution_count": 15, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[p_conv1_weight,\n", - " p_conv1_bias,\n", - " p_conv2_weight,\n", - " p_conv2_bias,\n", - " p_fc_weight,\n", - " p_fc_bias,\n", - " x,\n", - " conv2d,\n", - " relu,\n", - " conv2d_1,\n", - " relu_1,\n", - " view,\n", - " linear,\n", - " output]\n" - ] - } - ], "source": [ - "pprint.pp(list(graph_module.graph.nodes))" + "::::{note}\n", + "\n", + "A `torch.fx.GraphModule` is just a wrapper around its `fx.Graph`, and you can access it through `graph_module.graph`. This is useful for two reasons:\n", + "- Most of the compiler steps will work with `fx.Graph` directly, so it's good to get acquainted with its attributes in case you need to debug an error.\n", + "- You *might* need to manipulate the graph directly to ensure compatibility ([example](https://leimao.github.io/blog/PyTorch-Eager-Mode-Quantization-TensorRT-Acceleration/)).\n", + "\n", + "\n", + "To start, if we want to print the underlying graph, we can do it like this:\n", + "\n", + "```python\n", + "print(str(graph_module.graph))\n", + "```\n", + "\n", + "```{glue} graphmodule_graph\n", + "```\n", + "\n", + "This is similar enough to the `graph_module`'s output, so let's move on. Each \"variable\" in the graph is a `Node` object, and we can access them like this:\n", + "\n", + "```python\n", + "print(list(graph_module.graph.nodes))\n", + "```\n", + "\n", + "```{glue} graphmodule_graph_nodes\n", + "```\n", + "\n", + "Specifically, if we're interested in a particular node, like the `relu_1` node, we can filter it by name:\n", + "\n", + "```python\n", + "relu_1 = next(filter(lambda n: n.name == \"relu_1\", graph_module.graph.nodes))\n", + "```\n", + "\n", + "Some of its most important attributes are the `name`, `op`, `args`, `stack_trace`, `target` and `users`. Let's print them and see what they store.\n", + "\n", + "The `name` is just the unique name of the node:\n", + "\n", + "```python\n", + "print(relu_1.name)\n", + "```\n", + "\n", + "```{glue} relu_1_name \n", + "```\n", + "\n", + "The `op` is the operator that the node represents. It refers to the high-level function that specifies the type of node. It is accompanied by a `target` and together they define the behavior of the node.\n", + "For example `Node(op=placeholder, target=p_p_conv1_weight)` means that the node is a placeholder for the weight of the first convolutional layer. Inputs, weights, etc are tagged as `placeholder` nodes.\n", + "\n", + "On the other hand, `call_function` nodes represent a function call to their `target`. For example, `Node(op=call_function, target=torch.ops.aten.relu.default)` means that the node is a call to the `relu` function, as we can see next:\n", + "\n", + "```python\n", + "print(relu_1.op)\n", + "```\n", + "\n", + "```{glue} relu_1_op \n", + "```\n", + "\n", + "```python\n", + "print(relu_1.target)\n", + "```\n", + "\n", + "```{glue} relu_1_target \n", + "```\n", + "\n", + "As we can see, *operator* is almost used interchangeably with *function* in this context.\n", + "\n", + "The `args` are the arguments of the node's function. In our case, since `relu_1` takes as input the output of `conv2d_1`, we should see a reference to that node.\n", + "\n", + "```python\n", + "print(relu_1.args)\n", + "```\n", + "\n", + "```{glue} relu_1_args \n", + "```\n", + "\n", + "Similarly, the `users` are the nodes that take the output of `relu_1` as input. Both of these attributes are useful to traverse the graph and understand the dependencies between nodes.\n", + "\n", + "```python\n", + "print(relu_1.users)\n", + "```\n", + "\n", + "```{glue} relu_1_users \n", + "```\n", + "\n", + "Finally, the `stack_trace` is the piece of code that generated the node. This is also useful for debugging and it helps with localizing the source code that should be rewritten in case of an error.\n", + "```python\n", + "print(relu_1.stack_trace)\n", + "```\n", + "\n", + "```{glue} relu_1_stack_trace_2\n", + "```\n", + "\n", + "For more information refer to the [documentation](https://pytorch.org/docs/main/export.ir_spec.html).\n", + "\n", + "::::" ] }, { - "cell_type": "code", - "execution_count": 31, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\n", - "def forward(self, p_conv1_weight, p_conv1_bias, p_conv2_weight, p_conv2_bias, p_fc_weight, p_fc_bias, x):\n", - " conv2d = torch.ops.aten.conv2d.default(x, p_conv1_weight, p_conv1_bias); x = p_conv1_weight = p_conv1_bias = None\n", - " relu = torch.ops.aten.relu.default(conv2d); conv2d = None\n", - " conv2d_1 = torch.ops.aten.conv2d.default(relu, p_conv2_weight, p_conv2_bias); relu = p_conv2_weight = p_conv2_bias = None\n", - " relu_1 = torch.ops.aten.relu.default(conv2d_1); conv2d_1 = None\n", - " view = torch.ops.aten.view.default(relu_1, [1, 5184]); relu_1 = None\n", - " linear = torch.ops.aten.linear.default(view, p_fc_weight, p_fc_bias); view = p_fc_weight = p_fc_bias = None\n", - " return (linear,)\n", - " \n" - ] - } - ], "source": [ - "print(graph_module.graph.python_code(graph_module.graph._root).src)" + "Back to the `ExportedProgram`, the second most important attribute is its `graph_signature`. This object contains information about the inputs (actual inputs, parameters, constant tensors, etc) and outputs of the model. This is particularly useful if you want to check whether a tensor is being folded as a constant.\n", + "\n", + "We can print it like this:" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -346,9 +569,16 @@ "pprint.pp(ep._graph_signature)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you want to access the parameters and buffers directly, you can reference the `state_dict` attribute." + ] + }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -357,7 +587,7 @@ "dict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc.weight', 'fc.bias'])" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -367,74 +597,44 @@ ] }, { - "cell_type": "code", - "execution_count": 7, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "torch.export.save(ep, \"simple_net.pt2\")" + "Constants are tensors that during the forward pass are found to not change (think of a tensor that contains the shape of the input). It is a bit less common to find them, but somestimes ensuring they are constant can help the compiler to parse the model correctly. Our simple network doesn't have any constants, but you can access them like this:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, - "outputs": [], - "source": [ - "x = [\n", - " torch.rand(1, 3, 150, 100),\n", - " torch.rand(1, 3, 75, 50),\n", - " torch.rand(1, 3, 37, 25),\n", - " torch.rand(1, 3, 19, 13),\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "{}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "{}\n" + ] } ], "source": [ - "ep.constants" + "print(ep.constants)" ] }, { - "cell_type": "code", - "execution_count": 22, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{}" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "ep.constants" + "Finally, we can save our exported program using the `torch.export.save` function." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "torch.export.save(ep, \"simple_net.pt2\")" + ] } ], "metadata": {