diff --git a/artifacts/idea_raw.jpg b/artifacts/idea_raw.jpg new file mode 100644 index 0000000..7a27b27 Binary files /dev/null and b/artifacts/idea_raw.jpg differ diff --git a/docs/buildpdf b/docs/buildpdf new file mode 100755 index 0000000..ba19b77 --- /dev/null +++ b/docs/buildpdf @@ -0,0 +1,2 @@ +#!/bin/bash +jb build --path-output _build/pdf src/ --builder pdfhtml diff --git a/docs/src/_toc.yml b/docs/src/_toc.yml index d6bdf87..0b37d6d 100644 --- a/docs/src/_toc.yml +++ b/docs/src/_toc.yml @@ -17,8 +17,6 @@ parts: - caption: Optimization chapters: - file: part3/compilation - sections: - - file: part3/torch_export - caption: Other chapters: - file: bibliography diff --git a/docs/src/part2/adapting.ipynb b/docs/src/part2/adapting.ipynb index 8f2ee14..98bbacb 100644 --- a/docs/src/part2/adapting.ipynb +++ b/docs/src/part2/adapting.ipynb @@ -4,8 +4,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "(part2:adapting)=\n", "# Adapting the Encoder\n", "\n", + "\n", "```{contents}\n", "```\n", "\n", diff --git a/docs/src/part3/compilation.ipynb b/docs/src/part3/compilation.ipynb new file mode 100644 index 0000000..b5600e4 --- /dev/null +++ b/docs/src/part3/compilation.ipynb @@ -0,0 +1,1790 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compilation\n", + "\n", + "```{contents}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One of PyTorch main strengths is its flexibility and ease of use. The user can write a model with almost no restrictions as long as each operator is differentiable and PyTorch will take care of the rest. On each forward pass, it will evaluate the operators on-the-fly and dynamically construct the computation graph, which is then used to compute the gradients during the backward pass. This is called Eager execution mode and it is the default behavior of PyTorch.\n", + "\n", + "This mode comes in handy when the computation graph is not static, for example when the model has if-statements or loops that depend on the input data or when the input has dynamic shapes (imagine training a model with multiple resolutions). However, this flexibility comes at a cost, because we can't optimize a model if we don't know what operations, which shapes, types or even order of operations will be executed until runtime. This is where compilers come in." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compilers 101\n", + "\n", + "A compiler is a program that translates instructions written in one representation (source) into another representation (target). Nowadays, compilers usually are separated in a frontend and a backend. The frontend is responsible for parsing the source code and generating an intermediate representation (IR) that is independent of the source language. The backend is responsible for translating the IR into the target language. This separation allows for reusability of the frontend with different backends, and vice versa as we can see in {numref}`Figure {number} `.\n", + "\n", + ":::{figure-md} retargetable\n", + "\"Retargetable\n", + "\n", + "Frontends produce an intermediate representation (IR) common to all backends. Backends take the IR and generate code for a specific target. {cite}`aosabook`\n", + ":::\n", + "\n", + "So far, we've talked about compilation as a process that happens before the program is executed, also known as ahead-of-time (AOT) compilation. However, there are other ways to execute a program. Some languages, like Python, are *interpreted*, where programs are executed line by line by the runtime interpreter. Furthermore, some runtimes might use just-in-time (JIT) compilation, where parts of the program are compiled while it is being executed. This allows for optimizations that can only be done at runtime, like specializing code for specific inputs." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ML Compilers\n", + "\n", + "Machine learning compilers take a model written in some framework (e.g. PyTorch), translate it into a program that can be executed in some runtime (e.g. TensorRT, CoreML, PyTorch TorchInductor) which then ends up optimized for some specialized hardware (e.g. GPUs, TPUs, Apple Silicon). \n", + " \n", + "PyTorch has had a few different compiler solutions over the years, the most popular being TorchScript. This, however, has changed since PyTorch 2, as the new compiler stack has been introduced. The main component of this new stack is TorchDynamo, a new compiler frontend with better properties and more Python support than TorchScript. \n", + "\n", + "Along with TorchDynamo, PyTorch 2 has introduced two new APIs, `torch.export` and `torch.compile`, that leverage this technology. On one hand, `torch.export`'s goal is to act as an ahead-of-time frontend which captures the full semantics of the program into an IR independent of Python, while `torch.compile` is meant to be used as a full JIT compiler that can leverage other backends (TorchInductor, TensorRT, ONNX) to optimize parts of the model at runtime and fallback to native Python if necessary. \n", + "\n", + "For edge devices specifically, we are most interested in the `torch.export` API, as it allows us to dispose of the expensive overhead of the Python Runtime and allows us to take advantage of native optimized frameworks for our target hardware, like CoreML for Apple devices or TensorRT (C++) for NVIDIA GPUs. \n", + "\n", + "For convenience, we'll use the popular abbreviation of PyTorch 2 Export, `PT2E`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## PT2E 101" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The main idea of `torch.export` is that it translates an Eager Mode PyTorch model into a graph-based intermediate representation called *Export IR*. This allows compiler backends to take this IR and further transform and optimize it for a target device. A general overview of the process is shown in the figure [below](torchexport).\n", + "\n", + ":::{figure-md} torchexport\n", + "\"torch.export\"\n", + "\n", + "PyTorch 2 Export\n", + ":::" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This IR needs to fulfill a couple of properties for it to be useful to compilers. For example:\n", + "1. Operators have to be general enough for backends to notice patterns and optimize them: Many runtimes have specialized kernels for common operators like convolutions or even more complex ones like a `conv2 + relu` (operator fusion, see examples [here](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#fusion-types)). If the IR reduces all operators to sums, products and views, noticing these patterns becomes too hard.\n", + "2. The number of operators has to be small enough for the backend to implement all of them. \n", + "3. Operators have to be functional, that is, without side effects. For example: If two functions read and modify the same parameters, the order of execution matters and the compiler has to be careful when parallelizing them.\n", + "\n", + "Notice that properties 1 and 2 are in conflict with each other. The more operators we have, the more expressive the IR is, but the harder it is to implement all of them. This is a trade-off that the PyTorch team has to balance. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TODO:\n", + "- [ ] Introduce ATEN (dialects), fx.Graph and link to Export IR, functionalization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For now, let's get some practical intuition with an example." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Hands on with PT2E" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's use a simple network to see how `torch.export` works." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "import torch\n", + "import pprint\n", + "from part3_artifacts.simple_net import SimpleNet\n", + "import torch.fx.graph_module\n", + "from myst_nb import glue" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0;31mInit signature:\u001b[0m \u001b[0mSimpleNet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mSource:\u001b[0m \n", + "\u001b[0;32mclass\u001b[0m \u001b[0mSimpleNet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\"\u001b[0m\n", + "\u001b[0;34m Just a simple network\u001b[0m\n", + "\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mConv2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mConv2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4704\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFile:\u001b[0m ~/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py\n", + "\u001b[0;31mType:\u001b[0m type\n", + "\u001b[0;31mSubclasses:\u001b[0m " + ] + } + ], + "source": [ + "SimpleNet??" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To export a model we must first define a sample input. This is used to `trace` the model and generate the Export IR. \n", + "\n", + "```{note}\n", + "`Tracing` refers to the process of recording the operations executed by a model when given a specific input along with their metadata. \n", + "\n", + "The way tracing works efficiently is by using `torch._subclasses.fake_tensor.FakeTensor`. FakeTensors are a special type of tensor that only store metadata such as `dtype`, `shape` and `device` and overload all operators to simulate the computation without actually looking at the values. \n", + "\n", + "For example, doing matrix multiplications of FakeTensors of shapes `(N, M)` and `(M, K)` will return a FakeTensor of shape `(N, K)` in constant time instead of the normal cubic complexity of multiplication.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For our case, the model will be deployed on a camera with a fixed resolution, so we can just define a statically shaped tensor of `batch_size` 1. If you want to support dynamically shaped inputs, refer to the [documentation](https://pytorch.org/docs/main/export.html#expressing-dynamism).\n", + "\n", + "Once we have the input, we can call the `torch.export.export` function.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(1, 3, 32, 32) \n", + "ep: torch.export.ExportedProgram = torch.export.export(SimpleNet().eval(), (x,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And that's it, we have exported our model. The new object is a `torch.export.ExportedProgram` which contains the model and parameters in the Export IR. Let's inspect it one by one." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first and most important attribute is the `graph_module` which stores the computational graph of the model. We can print it using the `print_readable` method:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class GraphModule(torch.nn.Module):\n", + " def forward(self, p_conv1_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[6, 3, 5, 5]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv1_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[6]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv2_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[6, 3, 5, 5]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv2_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[6]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_fc_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[10, 4704]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_fc_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[10]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", x: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 3, 32, 32]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\"):\n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:16 in forward, code: z = self.conv1(x)\u001b[0m\n", + " conv2d: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.conv2d.default(x, p_conv1_weight, p_conv1_bias); \u001b[2mp_conv1_weight = p_conv1_bias = None\u001b[0m\n", + " \n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:17 in forward, code: z = F.relu(z)\u001b[0m\n", + " relu: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.relu.default(conv2d); \u001b[2mconv2d = None\u001b[0m\n", + " \n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:18 in forward, code: y = self.conv2(x)\u001b[0m\n", + " conv2d_1: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.conv2d.default(x, p_conv2_weight, p_conv2_bias); \u001b[2mx = p_conv2_weight = p_conv2_bias = None\u001b[0m\n", + " \n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:19 in forward, code: y = F.relu(y)\u001b[0m\n", + " relu_1: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.relu.default(conv2d_1); \u001b[2mconv2d_1 = None\u001b[0m\n", + " \n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:20 in forward, code: o = z + y\u001b[0m\n", + " add: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.add.Tensor(relu, relu_1); \u001b[2mrelu = relu_1 = None\u001b[0m\n", + " \n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:21 in forward, code: o = torch.flatten(o, 1)\u001b[0m\n", + " view: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 4704]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.view.default(add, [\u001b[34m1\u001b[0m, \u001b[34m4704\u001b[0m]); \u001b[2madd = None\u001b[0m\n", + " \n", + " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:22 in forward, code: o = self.fc(o)\u001b[0m\n", + " linear: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 10]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.linear.default(view, p_fc_weight, p_fc_bias); \u001b[2mview = p_fc_weight = p_fc_bias = None\u001b[0m\n", + " return (linear,)\n", + " \n" + ] + } + ], + "source": [ + "graph_module: torch.fx.GraphModule = ep.graph_module\n", + "print(graph_module.print_readable(print_output=False, colored=True, include_device=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we can see all *nodes* (`conv2d`, `relu`, `conv2d_1`, etc.), their shapes, dtypes, devices and the aten operators that are being used (`torch.ops.aten.conv2d.default`), with their accompanying file, line and code. We can also see that the graph inputs expects not only the model inputs but also its parameters (buffers and constants too)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "graph():\n", + " %p_conv1_weight : [num_users=1] = placeholder[target=p_conv1_weight]\n", + " %p_conv1_bias : [num_users=1] = placeholder[target=p_conv1_bias]\n", + " %p_conv2_weight : [num_users=1] = placeholder[target=p_conv2_weight]\n", + " %p_conv2_bias : [num_users=1] = placeholder[target=p_conv2_bias]\n", + " %p_fc_weight : [num_users=1] = placeholder[target=p_fc_weight]\n", + " %p_fc_bias : [num_users=1] = placeholder[target=p_fc_bias]\n", + " %x : [num_users=2] = placeholder[target=x]\n", + " %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv1_weight, %p_conv1_bias), kwargs = {})\n", + " %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d,), kwargs = {})\n", + " %conv2d_1 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv2_weight, %p_conv2_bias), kwargs = {})\n", + " %relu_1 : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d_1,), kwargs = {})\n", + " %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %relu_1), kwargs = {})\n", + " %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%add, [1, 4704]), kwargs = {})\n", + " %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%view, %p_fc_weight, %p_fc_bias), kwargs = {})\n", + " return (linear,)" + ] + }, + "metadata": { + "scrapbook": { + "mime_prefix": "", + "name": "graphmodule_graph" + } + }, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "[p_conv1_weight,\n", + " p_conv1_bias,\n", + " p_conv2_weight,\n", + " p_conv2_bias,\n", + " p_fc_weight,\n", + " p_fc_bias,\n", + " x,\n", + " conv2d,\n", + " relu,\n", + " conv2d_1,\n", + " relu_1,\n", + " add,\n", + " view,\n", + " linear,\n", + " output]" + ] + }, + "metadata": { + "scrapbook": { + "mime_prefix": "", + "name": "graphmodule_graph_nodes" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "'call_function'" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_op" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_target" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "(conv2d_1,)" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_args" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": " File \"/home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py\", line 19, in forward\n y = F.relu(y)\n" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_stack_trace_2" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "'relu_1'" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_name" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "{'stack_trace': ' File \"/home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py\", line 19, in forward\\n y = F.relu(y)\\n',\n 'nn_module_stack': {'L__self__': ('',\n 'part3_artifacts.simple_net.SimpleNet')},\n 'source_fn_stack': [('relu_1',\n torch.Tensor>)],\n 'original_aten': ,\n 'from_node': [('y_1',\n torch.Tensor>)],\n 'seq_nr': 50,\n 'torch_fn': ('relu_2', 'function.relu'),\n 'val': FakeTensor(..., size=(1, 6, 28, 28)),\n 'tensor_meta': TensorMetadata(shape=torch.Size([1, 6, 28, 28]), dtype=torch.float32, requires_grad=True, stride=(4704, 784, 28, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_meta" + } + }, + "output_type": "display_data" + }, + { + "data": { + "application/papermill.record/text/plain": "[add]" + }, + "metadata": { + "scrapbook": { + "mime_prefix": "application/papermill.record/", + "name": "relu_1_users" + } + }, + "output_type": "display_data" + } + ], + "source": [ + "def graph_formatter(graph, pp, cycle):\n", + " pp.text(str(graph))\n", + "\n", + "# def graph_nodes_formatter(nodes, pp, cycle):\n", + "# pp.\n", + "# for node in nodes:\n", + "# pp.text(str(node))\n", + "\n", + "from IPython import get_ipython\n", + "import torch.fx.graph as fx_graph\n", + "plain = get_ipython().display_formatter.formatters['text/plain']\n", + "plain.for_type(torch.fx.Graph, graph_formatter)\n", + "# plain.for_type(fx_graph._node_list, graph_nodes_formatter)\n", + "glue(\"graphmodule_graph\", graph_module.graph)\n", + "glue(\"graphmodule_graph_nodes\", list(graph_module.graph.nodes))\n", + "\n", + "class StackTrace(object):\n", + " def __init__(self, stack_trace):\n", + " self.stack_trace = stack_trace\n", + "\n", + "def stack_trace_formatter(stack_trace, pp, cycle):\n", + " pp.text(stack_trace.stack_trace)\n", + "\n", + "plain.for_type(StackTrace, stack_trace_formatter)\n", + "\n", + "relu_1 = next(filter(lambda n: n.name == \"relu_1\", graph_module.graph.nodes))\n", + "glue(\"relu_1_op\", relu_1.op, display=False)\n", + "glue(\"relu_1_target\", relu_1.target, display=False)\n", + "glue(\"relu_1_args\", relu_1.args, display=False)\n", + "glue(\"relu_1_stack_trace_2\", StackTrace(relu_1.stack_trace), display=False)\n", + "glue(\"relu_1_name\", relu_1.name, display=False)\n", + "glue(\"relu_1_meta\", relu_1.meta, display=False)\n", + "glue(\"relu_1_users\", list(relu_1.users), display=False)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "::::{note}\n", + "\n", + "A `torch.fx.GraphModule` is just a wrapper around its `fx.Graph`, and you can access it through `graph_module.graph`. This is useful for two reasons:\n", + "- Most of the compiler steps will work with `fx.Graph` directly, so it's good to get acquainted with its attributes in case you need to debug an error.\n", + "- You *might* need to manipulate the graph directly to ensure compatibility ([example](https://leimao.github.io/blog/PyTorch-Eager-Mode-Quantization-TensorRT-Acceleration/)).\n", + "\n", + "\n", + "To start, if we want to print the underlying graph, we can do it like this:\n", + "\n", + "```python\n", + "print(str(graph_module.graph))\n", + "```\n", + "\n", + "```{glue} graphmodule_graph\n", + "```\n", + "\n", + "This is similar enough to the `graph_module`'s output, so let's move on. Each \"variable\" in the graph is a `Node` object, and we can access them like this:\n", + "\n", + "```python\n", + "print(list(graph_module.graph.nodes))\n", + "```\n", + "\n", + "```{glue} graphmodule_graph_nodes\n", + "```\n", + "\n", + "Specifically, if we're interested in a particular node, like the `relu_1` node, we can filter it by name:\n", + "\n", + "```python\n", + "relu_1 = next(filter(lambda n: n.name == \"relu_1\", graph_module.graph.nodes))\n", + "```\n", + "\n", + "Some of its most important attributes are the `name`, `op`, `args`, `stack_trace`, `target` and `users`. Let's print them and see what they store.\n", + "\n", + "The `name` is just the unique name of the node:\n", + "\n", + "```python\n", + "print(relu_1.name)\n", + "```\n", + "\n", + "```{glue} relu_1_name \n", + "```\n", + "\n", + "The `op` is the operator that the node represents. It refers to the high-level function that specifies the type of node. It is accompanied by a `target` and together they define the behavior of the node.\n", + "For example `Node(op=placeholder, target=p_p_conv1_weight)` means that the node is a placeholder for the weight of the first convolutional layer. Inputs, weights, etc are tagged as `placeholder` nodes.\n", + "\n", + "On the other hand, `call_function` nodes represent a function call to their `target`. For example, `Node(op=call_function, target=torch.ops.aten.relu.default)` means that the node is a call to the `relu` function, as we can see next:\n", + "\n", + "```python\n", + "print(relu_1.op)\n", + "```\n", + "\n", + "```{glue} relu_1_op \n", + "```\n", + "\n", + "```python\n", + "print(relu_1.target)\n", + "```\n", + "\n", + "```{glue} relu_1_target \n", + "```\n", + "\n", + "As we can see, *operator* is almost used interchangeably with *function* in this context.\n", + "\n", + "The `args` are the arguments of the node's function. In our case, since `relu_1` takes as input the output of `conv2d_1`, we should see a reference to that node.\n", + "\n", + "```python\n", + "print(relu_1.args)\n", + "```\n", + "\n", + "```{glue} relu_1_args \n", + "```\n", + "\n", + "Similarly, the `users` are the nodes that take the output of `relu_1` as input. Both of these attributes are useful to traverse the graph and understand the dependencies between nodes.\n", + "\n", + "```python\n", + "print(relu_1.users)\n", + "```\n", + "\n", + "```{glue} relu_1_users \n", + "```\n", + "\n", + "Finally, the `stack_trace` is the piece of code that generated the node. This is also useful for debugging and it helps with localizing the source code that should be rewritten in case of an error.\n", + "```python\n", + "print(relu_1.stack_trace)\n", + "```\n", + "\n", + "```{glue} relu_1_stack_trace_2\n", + "```\n", + "\n", + "For more information refer to the [documentation](https://pytorch.org/docs/main/export.ir_spec.html).\n", + "\n", + "::::" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Back to the `ExportedProgram`, the second most important attribute is its `graph_signature`. This object contains information about the inputs (actual inputs, parameters, constant tensors, etc) and outputs of the model. This is particularly useful if you want to check whether a tensor is being folded as a constant.\n", + "\n", + "We can print it like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ExportGraphSignature(input_specs=[InputSpec(kind=,\n", + " arg=TensorArgument(name='p_conv1_weight'),\n", + " target='conv1.weight',\n", + " persistent=None),\n", + " InputSpec(kind=,\n", + " arg=TensorArgument(name='p_conv1_bias'),\n", + " target='conv1.bias',\n", + " persistent=None),\n", + " InputSpec(kind=,\n", + " arg=TensorArgument(name='p_conv2_weight'),\n", + " target='conv2.weight',\n", + " persistent=None),\n", + " InputSpec(kind=,\n", + " arg=TensorArgument(name='p_conv2_bias'),\n", + " target='conv2.bias',\n", + " persistent=None),\n", + " InputSpec(kind=,\n", + " arg=TensorArgument(name='p_fc_weight'),\n", + " target='fc.weight',\n", + " persistent=None),\n", + " InputSpec(kind=,\n", + " arg=TensorArgument(name='p_fc_bias'),\n", + " target='fc.bias',\n", + " persistent=None),\n", + " InputSpec(kind=,\n", + " arg=TensorArgument(name='x'),\n", + " target=None,\n", + " persistent=None)],\n", + " output_specs=[OutputSpec(kind=,\n", + " arg=TensorArgument(name='linear'),\n", + " target=None)])\n" + ] + } + ], + "source": [ + "pprint.pp(ep._graph_signature) # you can also just use print" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you want to access the parameters and buffers directly, you can reference the `state_dict` attribute." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc.weight', 'fc.bias'])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ep._state_dict.keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Constants are tensors that during the forward pass are found to not change (think of a tensor that contains the shape of the input). It is a bit less common to find them, but somestimes ensuring they are constant can help the compiler to parse the model correctly. Our simple network doesn't have any constants, but you can access them like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{}\n" + ] + } + ], + "source": [ + "print(ep.constants)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we can save our exported program using the `torch.export.save` function." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "torch.export.save(ep, \"simple_net.pt2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TensorRT\n", + "\n", + "- [ ] Introduction to TensorRT" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compiling the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The working script to export and compile our model with the TensorRT backend is `scripts.export_tensorrt`.\n", + "\n", + "The easiest way to specify a compilation target, is by adding a config file at `scripts/config/export_tensorrt`. For example, if we want to compile our model's, we can use the config file located at `scripts/config/export_tensorrt/dinov2.yaml` as follows:\n", + "\n", + "```sh\n", + "python -m scripts.export_tensorrt --config-name dinov2\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This config file specifies information such like:\n", + "- `image`: The sample image's file path, height and width. \n", + " - *Set to target camera dimensions*.\n", + "- `amp_dtype`: `fp16` or `bf16` for `torch.amp.autocast` usage, `fp32` to disable. \n", + " - *Set to `fp32` and use `trt.enabled_precisions`.*\n", + "- `trt`: The kwargs to override `torch_tensorrt.dynamo.compile`. \n", + " - *Set `enabled_precisions` to `fp32`, `fp16` and if new GPU (Ampere or newer) to `bf16`*.\n", + " - *Set `require_full_compilation=False` if necessary. If possible rewrite the code to remove unsupported nodes because making partial compilation work is harder and error prone.*\n", + " - *Set `use_fast_partitioner=False` if partitioner bugs appear, doesn't usually solve anything but sometimes helps with error diagnosis.*\n", + " - *Set `enable_experimental_decompositions=False` if unsupported nodes appear, doesn't solve much but sometimes helps with error diagnosis.*\n", + "- `model`: The path to the model's config file, it's checkpoints and argument overrides.\n", + " - *Try to specialize the model as much as possible. For example, for `timm` ViTs, disable dynamic image sizes/padding and fix the image_size to your camera's dimensions.*\n", + "\n", + "As an example, here's the config file for our model." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0mimage\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mheight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m512\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m512\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"artifacts/idea_raw.jpg\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0mamp_dtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"fp32\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0mtrt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0menabled_precisions\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m\"fp32\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m\"fp16\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m\"bf16\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"projects/dino_dinov2/configs/models/dino_dinov2.py\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"artifacts/model_final.pth\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mopts\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m\"model.backbone.net.img_size=[512, 512]\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m\"model.backbone.net.dynamic_img_size=False\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m\"model.backbone.net.dynamic_img_pad=False\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" + ] + } + ], + "source": [ + "%pycat ../../../scripts/config/export_tensorrt/dinov2.yaml" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Although this script is a useful entrypoint, the challenge when compiling a model lies in making the models' source code compatible with both TorchDynamo and the backend of choice (TensorRT in this case). This is a bit harder to explain because during the debugging procedure, you'll attempt many possible fixes that are informed by insights of the codebase's state at that time, many of which will be deemed unsuccessful or unnecessary. For example, you might find a way to solve a bug which will itself be fixed by another more important bug. Furthermore, one bug might appear/disappear with newer versions of the libararies. \n", + "\n", + "Because of this, I'll cover two apparently similar but very different case studies and share some of the relevant insights and tricks in the following two sections:\n", + "1. DinoV2 + ViTDet + DINO: Successful compilation, minimal final rewrites.\n", + "2. ViT + ViTDet + Cascade Mask RCNN: Almost successful, many final rewrites." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CS1: Compiling DinoV2+ViTDet+DINO" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's start from where we left off at {ref}`part2:adapting`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/dgcnz/development/amsterdam/edge\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n", + " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n" + ] + } + ], + "source": [ + "import sys; from pathlib import Path\n", + "\n", + "__DIRS = list(Path().cwd().resolve().parents) + [Path().cwd().resolve()]\n", + "WDIR = next(p for p in __DIRS if (p / \".project-root\").exists())\n", + "sys.path.append(str(WDIR))\n", + "%cd {WDIR}" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "hide-input", + "remove-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dgcnz/development/amsterdam/edge/detrex/detrex/layers/dcn_v3.py:24: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n", + " def forward(\n", + "/home/dgcnz/development/amsterdam/edge/detrex/detrex/layers/dcn_v3.py:53: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.\n", + " def backward(ctx, grad_output):\n", + "/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n", + " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n" + ] + } + ], + "source": [ + "import torch\n", + "from omegaconf import OmegaConf\n", + "from detrex.modeling.backbone import TimmBackbone\n", + "from detectron2.config import LazyConfig, instantiate, LazyCall\n", + "import detectron2\n", + "import torch_tensorrt\n", + "from src.utils import TracingAdapter, load_input_fixed\n", + "import detrex\n", + "import warnings\n", + "\n", + "import logging\n", + "logging.basicConfig(level=logging.ERROR)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "tags": [ + "remove-output" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m[10/30 15:39:19 timm backbone]: \u001b[0mbackbone out_indices: (11,)\n", + "\u001b[32m[10/30 15:39:19 timm backbone]: \u001b[0mbackbone out_channels: [768]\n", + "\u001b[32m[10/30 15:39:19 timm backbone]: \u001b[0mbackbone out_strides: [16]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:py.warnings:/home/dgcnz/development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino.py:107: UserWarning: device argument is deprecated and has no effect.\n", + " warn(\"device argument is deprecated and has no effect.\")\n", + "\n" + ] + } + ], + "source": [ + "cfg = LazyConfig.load(\"projects/dino_dinov2/configs/models/dino_dinov2.py\")\n", + "cfg.model.backbone.net = LazyCall(TimmBackbone)(\n", + " model_name=\"vit_base_patch14_dinov2.lvd142m\",\n", + " features_only=True,\n", + " out_indices=(-1,),\n", + " patch_size=16,\n", + ")\n", + "model = instantiate(OmegaConf.to_object(cfg.model))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before trying anything we must make three small changes from the original code at {ref}`part2:adapting`:\n", + "1. For convenience, we'll deactivate the custom CUDA multi scale deformable attention kernel and opt for the python implementation. Although you could technically register a custom operator with PT2E compatibility, it's not worth the effort because of the constant tensor specialization issue we'll face later and the fact that TensorRT can optimize the python implementation well enough.\n", + "2. Instead of using a random input, use a sample image and resize it to the appropriate dimensions. This might seem like an innocuous change, but if there is some data-dependent computation (for example, some filtering based on the values of the features), then the `torch.export` will fail but it will show uninformative error logs and guide you erroneously to fix bugs that are not relevant to the real inputs.\n", + "3. Export the model with the appropriate device (`cuda`) and forward type (`eval`, `torch.no_grad`). This is important because sometimes, some operators might decide to use one implementation based on the device of the tensor and some operators are only supported without autograd." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "detrex.layers.multi_scale_deform_attn._ENABLE_CUDA_MSDA = False\n", + "img, inputs = load_input_fixed(height=518, width=518, device=\"cuda\")\n", + "model = model.eval().cuda()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try to export the model." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W1030 15:39:20.764723 23824 site-packages/torch/fx/experimental/symbolic_shapes.py:6047] [1/0] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False\n", + "E1030 15:39:20.765498 23824 site-packages/torch/fx/experimental/recording.py:298] [1/0] failed while running evaluate_expr(*(u0, None), **{'fx_node': False})\n", + "ERROR:root:Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: none)\n", + "\n", + "Caused by: torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device), # development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino_transformer.py:373 in get_reference_points (_dynamo/utils.py:2260 in run_node)\n", + "For more information, run with TORCH_LOGS=\"dynamic\"\n", + "For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\n", + "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n", + "For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n", + "\n", + "User Stack (most recent call last):\n", + " (snipped, see stack below for prefix)\n", + " File \"/home/dgcnz/development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino.py\", line 284, in forward\n", + " ) = self.transformer(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1747, in _call_impl\n", + " return forward_call(*args, **kwargs)\n", + " File \"/home/dgcnz/development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino_transformer.py\", line 439, in forward\n", + " reference_points = self.get_reference_points(\n", + " File \"/home/dgcnz/development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino_transformer.py\", line 373, in get_reference_points\n", + " torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),\n", + "\n", + "For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n", + "For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example\n", + "\n", + "from user code:\n", + " File \"/home/dgcnz/development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino.py\", line 284, in forward\n", + " ) = self.transformer(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1747, in _call_impl\n", + " return forward_call(*args, **kwargs)\n", + " File \"/home/dgcnz/development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino_transformer.py\", line 439, in forward\n", + " reference_points = self.get_reference_points(\n", + " File \"/home/dgcnz/development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino_transformer.py\", line 373, in get_reference_points\n", + " torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),\n", + "\n", + "Set TORCH_LOGS=\"+dynamo\" and TORCHDYNAMO_VERBOSE=1 for more information\n", + "\n" + ] + } + ], + "source": [ + "try:\n", + " with torch.no_grad():\n", + " ep = torch.export.export(model, inputs)\n", + "except Exception as e:\n", + " logging.error(e)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Oh no, we've stumbled into a data-dependent expression error. These errors occur because PT2E currently doesn't support data-dependent expressions out of the box (check [docs](https://pytorch.org/docs/main/export.html#data-shape-dependent-control-flow)). Luckily, this case specifically doesn't really contain a data-dependent expression, it's only a compiler bug." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By looking at the code, we can find out that `linspace` is creating a tensor with a shape that depends on `spatial_shapes` values. If such values are unknown at compile time, then all the following computation will be considered data-dependent computation.\n", + "\n", + "\n", + "```python\n", + "class DINOTransformer(nn.Module):\n", + " ...\n", + " def forward(\n", + " self,\n", + " multi_level_feats: list[torch.Tensor],\n", + " ...,\n", + " **kwargs,\n", + " ):\n", + " ...\n", + " spatial_shapes: List[Tuple[int, int]] = []\n", + "\n", + " ...\n", + " for lvl, (feat, ...) in enumerate(zip(multi_level_feats, ...)):\n", + " spatial_shapes.append(feat.shape[2:])\n", + " ...\n", + "\n", + " ...\n", + " spatial_shapes = torch.tensor(\n", + " spatial_shapes, dtype=torch.long, device=feat_flatten.device\n", + " )\n", + " ...\n", + " reference_points = self.get_reference_points(spatial_shapes, ...)\n", + "\n", + "\n", + " @staticmethod\n", + " def get_reference_points(spatial_shapes, ...):\n", + " ...\n", + " for lvl, (H, W) in enumerate(spatial_shapes):\n", + " ref_y, ref_x = torch.meshgrid(\n", + " torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),\n", + " torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),\n", + " )\n", + " ...\n", + " ...\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, with a bit of debugging, we can find out that `spatial_shapes` is actually a constant (given a fixed image resolution), because the feature pyramid (`multi_level_feats`)'s shapes are known at compile time. So, what's happening here?\n", + "\n", + "The problem is that constant tensors are still not well supported, as documented in the conversations I had with the PyTorch maintainers (check issue [pytorch/pytorch/136642](https://github.com/pytorch/pytorch/issues/136642)). To summarize the error, PT2E is only folding constant tensors if they are small enough. \n", + "\n", + "There are two ways to solve this issue:\n", + "1. Increasing constant tensor limit with `torch._subclasses.fake_tensor.CONSTANT_NUMEL_LIMIT`.\n", + "2. Rewriting code to never handle `spatial_shapes` as a list of tuples instead of as a tensor. This is because, lists and integers are specialized by default and are well supported.\n", + "\n", + "Although we'll go with the second option, it's always better to first try the first one because it's less intrusive and sometimes is enough." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### [❌] Increasing constant tensor limit" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:py.warnings:/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch/functional.py:539: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3612.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", + "\n", + "ERROR:root:Dynamic slicing on data-dependent value is not supported\n", + "\n", + "from user code:\n", + " File \"/home/dgcnz/development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino.py\", line 284, in forward\n", + " ) = self.transformer(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1747, in _call_impl\n", + " return forward_call(*args, **kwargs)\n", + " File \"/home/dgcnz/development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino_transformer.py\", line 456, in forward\n", + " output_memory, output_proposals = self.gen_encoder_output_proposals(\n", + " File \"/home/dgcnz/development/amsterdam/edge/projects/dino_dinov2/modeling/exportable/dino_transformer.py\", line 312, in gen_encoder_output_proposals\n", + " mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(\n", + "\n", + "Set TORCH_LOGS=\"+dynamo\" and TORCHDYNAMO_VERBOSE=1 for more information\n", + "\n" + ] + } + ], + "source": [ + "try:\n", + " torch._subclasses.fake_tensor.CONSTANT_NUMEL_LIMIT = 1000\n", + " with torch.no_grad():\n", + " ep = torch.export.export(model, inputs)\n", + " torch._subclasses.fake_tensor.CONSTANT_NUMEL_LIMIT = 1 # reset to default\n", + "except Exception as e:\n", + " logging.error(e)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This error is similar to the previous one, and it's fixable by adding some code to `PyTorch`'s codebase as I mentioned in [this comment](https://github.com/pytorch/pytorch/issues/136642#issuecomment-2441177631) from the aforementioned issue. This will be fixed in the future by the PyTorch team with a different approach, so it's not worth the effort to take this route." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### [✅] Rewriting code for non-tensor constants" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR:root:It looks like one of the outputs with type `` is not supported or pytree-flattenable. \n", + "Exported graphs outputs can only contain the following supported types: [, , , , , , , , , , , , , , , , , , ]. \n", + "If you are using a custom class object, please register a pytree_flatten/unflatten function using `torch.utils._pytree.register_pytree_node` or `torch.export.register_dataclass`.\n" + ] + } + ], + "source": [ + "try:\n", + " model.transformer.specialize_with_list = True\n", + " with torch.no_grad():\n", + " ep = torch.export.export(model, inputs)\n", + "except Exception as e:\n", + " logging.error(e)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Nice, new error, that means we're making progress.\n", + "\n", + "This new error is due to PT2E not knowing how to handle the output of our model, which is a `detectron2.structures.instances.Instances` object. The way this is solved is by specifying a way to *flatten* the object, that is, to convert it to a standard container (list, dict, etc) of known *flattenable* objects. For example, the `Boxes` class, can be flattened to a tuple of tensors. \n", + "\n", + "There are 2 ways to do this:\n", + "1. PT2E's suggested method: Register a `pytree` node with `flatten_fn` and `unflatten_fn`.\n", + "2. Manually do the flattening in the model's `forward` method.\n", + "\n", + "We'll use the second solution, because `torch_tensorrt` is not totally compatible with the first one. However, we'll introduce both, as the first one is more general and could be useful with other backends." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### [❌] Handling model I/O with PyTree Node Registrations" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "def unflatten_detectron2_boxes(values, _):\n", + " boxes = object.__new__(detectron2.structures.boxes.Boxes)\n", + " boxes.tensor = values[0]\n", + " return boxes\n", + "\n", + "\n", + "def unflatten_detectron2_instances(values, _):\n", + " instances = object.__new__(detectron2.structures.instances.Instances)\n", + " instances._image_size = values[0]\n", + " instances._fields = values[1]\n", + " return instances\n", + "\n", + "\n", + "def flatten_detectron2_instances(x):\n", + " return ([x._image_size, x._fields], None)\n", + "\n", + "\n", + "def flatten_detectron2_boxes(x):\n", + " return ([x.tensor], None)\n", + "\n", + "\n", + "torch.utils._pytree.register_pytree_node(\n", + " detectron2.structures.boxes.Boxes,\n", + " flatten_fn=flatten_detectron2_boxes,\n", + " unflatten_fn=unflatten_detectron2_boxes,\n", + " serialized_type_name=\"detectron2.structures.boxes.Boxes\",\n", + ")\n", + "\n", + "torch.utils._pytree.register_pytree_node(\n", + " detectron2.structures.instances.Instances,\n", + " flatten_fn=flatten_detectron2_instances,\n", + " unflatten_fn=unflatten_detectron2_instances,\n", + " serialized_type_name=\"detectron2.structures.instances.Instances\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " ep = torch.export.export(model, inputs)\n", + "except Exception as e:\n", + " logging.error(e)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Oh nice, it worked, let's try our luck with tensorrt?" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR:root:Invalid input type encountered in the dynamo_compile input parsing. Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}\n" + ] + } + ], + "source": [ + "try:\n", + " trt_gm = torch_tensorrt.dynamo.compile(ep, inputs)\n", + "except Exception as e:\n", + " logging.error(e)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This happens because `torch_tensorrt` expects the model inputs and outputs to be flattened containers (list, dict, tuple) of tensors, and our `height`, `width` integers are not supported. Furthermore, our model outputs `detectron2.structures.instances.Instances`, which poses another problem. Although this is possible, it will involve creating a model wrapper that hardcodes the input/output flattening and specialization. We'll introduce a more general option next." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### [✅] Handling model I/O with `TracingAdapter`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "I've added PT2E support to `detectron2.export.flatten.TracingAdapter` which does all the flattening for you and also optionally folds the non-tensor inputs as model constants, which applies to our case (height, width are constants)." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "adapter = TracingAdapter(\n", + " model, inputs=inputs, allow_non_tensor=False, specialize_non_tensor=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "remove-output" + ] + }, + "outputs": [], + "source": [ + "try:\n", + " compilation_successful = True\n", + " with torch.no_grad():\n", + " ep = torch.export.export(adapter, adapter.flattened_inputs)\n", + " trt_gm = torch_tensorrt.dynamo.compile(ep, adapter.flattened_inputs)\n", + "except Exception as e:\n", + " logging.error(e)\n", + " compilation_successful = False" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "compilation_successful" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Nice, that worked." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Dealing with image sizes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this subsection we'll add one last change that is necessary to deploying our model: Specializing the input image sizes. \n", + "\n", + "Since, we won't be using dynamic shapes, we can fix the image size to the camera's resolution, and disable dynamic images and padding. Although disabling vit's dynamic inputs is not strictly necessary in the current version of PT2E, it was a source of errors in the previous ones and it's good practice to do it.\n", + "\n", + "As such, let's change our image resolution to (512, 512) and put everything together." + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "img, inputs = load_input_fixed(height=512, width=512, device=\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": { + "tags": [ + "remove-output" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m[10/30 16:19:43 timm backbone]: \u001b[0mbackbone out_indices: (11,)\n", + "\u001b[32m[10/30 16:19:43 timm backbone]: \u001b[0mbackbone out_channels: [768]\n", + "\u001b[32m[10/30 16:19:43 timm backbone]: \u001b[0mbackbone out_strides: [16]\n" + ] + } + ], + "source": [ + "cfg = LazyConfig.load(\"projects/dino_dinov2/configs/models/dino_dinov2.py\")\n", + "cfg.model.backbone.net = LazyCall(TimmBackbone)(\n", + " model_name=\"vit_base_patch14_dinov2.lvd142m\",\n", + " features_only=True,\n", + " out_indices=(-1,),\n", + " patch_size=16,\n", + " img_size=[512, 512],\n", + " dynamic_img_size=False,\n", + " dynamic_img_pad=False,\n", + ")\n", + "cfg.model.transformer.specialize_with_list=True\n", + "model = instantiate(OmegaConf.to_object(cfg.model)).eval().cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [], + "source": [ + "adapter = TracingAdapter(\n", + " model, inputs=inputs, allow_non_tensor=False, specialize_non_tensor=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": { + "tags": [ + "remove-output" + ] + }, + "outputs": [], + "source": [ + "try:\n", + " compilation_successful = True\n", + " with torch.no_grad():\n", + " ep = torch.export.export(adapter, adapter.flattened_inputs)\n", + " trt_gm = torch_tensorrt.dynamo.compile(ep, adapter.flattened_inputs)\n", + "except Exception as e:\n", + " logging.error(e)\n", + " compilation_successful = False" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "compilation_successful" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CS2: Compiling ViT+ViTDet+CascadeMaskRCNN" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this section we'll cover an unsuccessful case study. As an official `detectron2` model, I expected it to be easier to compile, but it turns out it is just not possible without a lot of semantically meaningful rewrites.\n", + "\n", + "Anyway, let's take what we learned from the previous case study and apply it here." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "cfg = LazyConfig.load(\"detrex/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_b_100ep.py\")\n", + "model = instantiate(OmegaConf.to_object(cfg.model)).eval().cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "img, inputs = load_input_fixed(height=1024, width=1024, device=\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR:root:Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands\n", + "\n", + "from user code:\n", + " File \"/home/dgcnz/development/amsterdam/edge/detrex/detectron2/detectron2/export/flatten.py\", line 348, in forward\n", + " outputs = self.inference_func(self.model, *inputs_orig_format)\n", + " File \"/home/dgcnz/development/amsterdam/edge/detrex/detectron2/detectron2/export/flatten.py\", line 265, in \n", + " inference_func = lambda model, *inputs: model(*inputs) # noqa\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1747, in _call_impl\n", + " return forward_call(*args, **kwargs)\n", + " File \"/home/dgcnz/development/amsterdam/edge/detrex/detectron2/detectron2/modeling/meta_arch/rcnn.py\", line 150, in forward\n", + " return self.inference(batched_inputs)\n", + " File \"/home/dgcnz/development/amsterdam/edge/detrex/detectron2/detectron2/modeling/meta_arch/rcnn.py\", line 208, in inference\n", + " proposals, _ = self.proposal_generator(images, features, None)\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch/nn/modules/module.py\", line 1747, in _call_impl\n", + " return forward_call(*args, **kwargs)\n", + " File \"/home/dgcnz/development/amsterdam/edge/detrex/detectron2/detectron2/modeling/proposal_generator/rpn.py\", line 477, in forward\n", + " proposals = self.predict_proposals(\n", + " File \"/home/dgcnz/development/amsterdam/edge/detrex/detectron2/detectron2/modeling/proposal_generator/rpn.py\", line 503, in predict_proposals\n", + " return find_top_rpn_proposals(\n", + " File \"/home/dgcnz/development/amsterdam/edge/detrex/detectron2/detectron2/modeling/proposal_generator/proposal_utils.py\", line 116, in find_top_rpn_proposals\n", + " if not valid_mask.all():\n", + "\n", + "Set TORCH_LOGS=\"+dynamo\" and TORCHDYNAMO_VERBOSE=1 for more information\n", + "\n" + ] + } + ], + "source": [ + "try:\n", + " torch._subclasses.fake_tensor.CONSTANT_NUMEL_LIMIT = 1000\n", + " compilation_successful = True\n", + " adapter = TracingAdapter(\n", + " model, inputs=inputs, allow_non_tensor=False, specialize_non_tensor=True\n", + " )\n", + " with torch.no_grad():\n", + " ep = torch.export.export(adapter, adapter.flattened_inputs)\n", + " trt_gm = torch_tensorrt.dynamo.compile(ep, adapter.flattened_inputs)\n", + "except Exception as e:\n", + " logging.error(e)\n", + " compilation_successful = False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we stumble upon a data-dependent expression error on the postprocessing step of the region proposal network. For context, this model first generates ~1000 region proposals that then are fed to the rest of the model to generate the final predictions. In between these steps, there are filtering algorithms to reduce the number of proposals to a more manageable number.\n", + "\n", + "We can look at the relevant code to understand the error:\n", + "\n", + "```python\n", + " # 1: filter non-finite boxes\n", + " valid_mask = torch.isfinite(boxes.tensor).all(dim=1) & torch.isfinite(scores_per_img)\n", + " if not valid_mask.all():\n", + " ...\n", + " boxes = boxes[valid_mask]\n", + " scores_per_img = scores_per_img[valid_mask]\n", + " lvl = lvl[valid_mask]\n", + "\n", + " ...\n", + "\n", + " # 2: filter empty boxes\n", + " keep = boxes.nonempty(threshold=min_box_size)\n", + " if _is_tracing() or keep.sum().item() != len(boxes):\n", + " boxes, scores_per_img, lvl = boxes[keep], scores_per_img[keep], lvl[keep]\n", + "\n", + " # 3: filter based on non-maximum-suppression\n", + " keep = batched_nms(boxes.tensor, scores_per_img, lvl, nms_thresh)\n", + " ...\n", + " boxes = boxes[keep]\n", + "```\n", + "\n", + "The issue here is that all of these steps are intrinsically data-dependent: There's no way to know the final number of boxes at compile time. As such, we're left with no choice than to skip these steps and hope that the model will still be useful. \n", + "\n", + "There are two other places with similar data-dependent expressions, so we'll skip them too.\n", + "\n", + "For this, I've added the following flags:\n", + "- `detectron2.modeling.proposal_generator.proposal_utils.SKIP_NMS`.\n", + "- `detectron2.modeling.roi_heads.fast_rcnn.SKIP_FILTER_CONFIDENCE` \n", + "- `detectron2.modeling.roi_heads.fast_rcnn.SKIP_NMS`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR:root:'int' object has no attribute 'size'\n", + "Traceback (most recent call last):\n", + " File \"/tmp/ipykernel_86302/546930148.py\", line 12, in \n", + " trt_gm = torch_tensorrt.dynamo.compile(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py\", line 318, in compile\n", + " trt_gm = compile_module(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py\", line 366, in compile_module\n", + " gm, settings.debug, settings.torch_executed_ops\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/common.py\", line 195, in get_graph_converter_support\n", + " if op_support.is_node_supported(module_dict, node):\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/_global_partitioner.py\", line 152, in is_node_supported\n", + " (node in CONVERTERS or node.op == \"get_attr\")\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py\", line 504, in __contains__\n", + " self.__getitem__(key)\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py\", line 457, in __getitem__\n", + " or not node_has_dynamic_shapes(node)\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py\", line 104, in node_has_dynamic_shapes\n", + " return _has_dynamic_shapes(node=node)\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py\", line 167, in _has_dynamic_shapes\n", + " if arg_positions_to_check is None and _is_subnode_dynamic(node):\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py\", line 159, in _is_subnode_dynamic\n", + " shape = subnode.meta[\"val\"].size()\n", + "AttributeError: 'int' object has no attribute 'size'\n" + ] + } + ], + "source": [ + "try:\n", + " torch._subclasses.fake_tensor.CONSTANT_NUMEL_LIMIT = 1000\n", + " detectron2.modeling.proposal_generator.proposal_utils.SKIP_NMS = True\n", + " detectron2.modeling.roi_heads.fast_rcnn.SKIP_NMS = True\n", + " detectron2.modeling.roi_heads.fast_rcnn.SKIP_FILTER_CONFIDENCE = True\n", + " compilation_successful = True\n", + " adapter = TracingAdapter(\n", + " model, inputs=inputs, allow_non_tensor=False, specialize_non_tensor=True\n", + " )\n", + " with torch.no_grad():\n", + " ep = torch.export.export(adapter, adapter.flattened_inputs)\n", + " trt_gm = torch_tensorrt.dynamo.compile(\n", + " ep,\n", + " adapter.flattened_inputs,\n", + " )\n", + "except Exception as e:\n", + " logging.error(e, exc_info=True)\n", + " compilation_successful = False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With a bit of debugging we can discover that the node that is causing this error is a node with `target=torch.ops.aten.sym_size.int`. This node is supported by TensorRT but some bug is preventing it to be converted correctly.\n", + "\n", + "To see how far we can go, we can bypass this by telling `torch_tensorrt` to not convert nodes with target `torch.ops.aten.sym_size.int`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "utils.cpp:2468: CHECK(output_shape.size() == rep_vector.size()) failed. \n", + "ERROR:torch_tensorrt [TensorRT Conversion Context]:Error Code: 9: Skipping tactic 0x0000000000000000 due to exception No Myelin Error exists\n", + "ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[model.backbone.net.blocks.0.norm1/native_layer_norm_weight + model.backbone.net.blocks.0.norm1/native_layer_norm_expand_weight_expand_broadcast...[SHUFFLE]-[aten_ops.permute.default]-[model.backbone.net/permute_223]]}.)\n", + "ERROR:root:\n", + "Traceback (most recent call last):\n", + " File \"/tmp/ipykernel_86302/3725039079.py\", line 12, in \n", + " trt_gm = torch_tensorrt.dynamo.compile(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py\", line 318, in compile\n", + " trt_gm = compile_module(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py\", line 534, in compile_module\n", + " submodule,\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py\", line 91, in convert_module\n", + " interpreter_result = interpret_module_to_result(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py\", line 70, in interpret_module_to_result\n", + " interpreter_result = interpreter.run()\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py\", line 639, in run\n", + " assert serialized_engine\n", + "AssertionError\n" + ] + } + ], + "source": [ + "try:\n", + " torch._subclasses.fake_tensor.CONSTANT_NUMEL_LIMIT = 1000\n", + " detectron2.modeling.proposal_generator.proposal_utils.SKIP_NMS = True\n", + " detectron2.modeling.roi_heads.fast_rcnn.SKIP_NMS = True\n", + " detectron2.modeling.roi_heads.fast_rcnn.SKIP_FILTER_CONFIDENCE = True\n", + " compilation_successful = True\n", + " adapter = TracingAdapter(\n", + " model, inputs=inputs, allow_non_tensor=False, specialize_non_tensor=True\n", + " )\n", + " with torch.no_grad():\n", + " ep = torch.export.export(adapter, adapter.flattened_inputs)\n", + " trt_gm = torch_tensorrt.dynamo.compile(\n", + " ep,\n", + " adapter.flattened_inputs,\n", + " torch_executed_ops={\"torch.ops.aten.sym_size.int\"}\n", + " )\n", + "except Exception as e:\n", + " logging.error(e, exc_info=True)\n", + " compilation_successful = False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This new error is tricky, but we can pinpoint its location by looking at the name of the node: `ForeignNode[model.backbone.net.blocks.0.norm1/native_layer_norm_weight...]`. After cross-referencing the operators we see in the node with the source code, we find out that the culprit is the window attention module. We can disable it and use only global attention to bypass this error." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "cfg = LazyConfig.load(\"detrex/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_b_100ep.py\")\n", + "cfg.model.backbone.net.window_block_indexes = []\n", + "model = instantiate(OmegaConf.to_object(cfg.model)).eval().cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR:root:Cannot convert symbols to int\n", + "Traceback (most recent call last):\n", + " File \"/tmp/ipykernel_86302/3725039079.py\", line 12, in \n", + " trt_gm = torch_tensorrt.dynamo.compile(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py\", line 318, in compile\n", + " trt_gm = compile_module(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py\", line 506, in compile_module\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/common.py\", line 91, in construct_submodule_inputs\n", + " get_input(input_shape, input_meta.dtype, name=input.name)\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/common.py\", line 61, in get_input\n", + " return construct_dynamic_input(\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/common.py\", line 32, in construct_dynamic_input\n", + " min_max_opt = extract_var_range_info(dim)\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/torch_tensorrt/dynamo/utils.py\", line 345, in extract_var_range_info\n", + " min_val, max_val, opt_val = int(var_range.lower), int(var_range.upper), int(var_val)\n", + " File \"/home/dgcnz/.conda/envs/cu124/lib/python3.10/site-packages/sympy/core/expr.py\", line 307, in __int__\n", + " raise TypeError(\"Cannot convert symbols to int\")\n", + "TypeError: Cannot convert symbols to int\n" + ] + } + ], + "source": [ + "try:\n", + " torch._subclasses.fake_tensor.CONSTANT_NUMEL_LIMIT = 1000\n", + " detectron2.modeling.proposal_generator.proposal_utils.SKIP_NMS = True\n", + " detectron2.modeling.roi_heads.fast_rcnn.SKIP_NMS = True\n", + " detectron2.modeling.roi_heads.fast_rcnn.SKIP_FILTER_CONFIDENCE = True\n", + " compilation_successful = True\n", + " adapter = TracingAdapter(\n", + " model, inputs=inputs, allow_non_tensor=False, specialize_non_tensor=True\n", + " )\n", + " with torch.no_grad():\n", + " ep = torch.export.export(adapter, adapter.flattened_inputs)\n", + " trt_gm = torch_tensorrt.dynamo.compile(\n", + " ep,\n", + " adapter.flattened_inputs,\n", + " torch_executed_ops={\"torch.ops.aten.sym_size.int\"}\n", + " )\n", + "except Exception as e:\n", + " logging.error(e, exc_info=True)\n", + " compilation_successful = False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "No luck.\n", + "\n", + "This is where we stop. This framework-specific bugs are hard to debug and fix as they often are bugs in the compiler itself. In my experience with the previous case study, these bugs fixed themselves by rewriting the model in order to avoid graph partitioning alltogether. We can obtain the unsupported nodes by feeding `debug=True` to `torch_tensorrt.dynamo.compile`.\n", + "\n", + "For this model, the unsupported nodes after the non-maximum-suppresion rewrites are:\n", + "- `torch.ops.aten.nonzero.default`\n", + "- `torch.ops.aten.index.Tensor`\n", + "- `torch.ops.torchvision.roi_align.default`\n", + "- `torch.ops.aten.index_put.default`\n", + "\n", + "However, we've already rewritten essential parts of the model and my guess is that if we continued with more rewrites, the resulting model would not be usable. For example, the weights of window attention do not have the same the same shape as that of the global attention, so the pre-trained model likely already needs finetuning." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cu124", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/src/part3/compilation.md b/docs/src/part3/compilation.md deleted file mode 100644 index 85792fa..0000000 --- a/docs/src/part3/compilation.md +++ /dev/null @@ -1,31 +0,0 @@ -# Compilation - -```{contents} -``` - -One of PyTorch main strengths is its flexibility and ease of use. The user can write a model with almost no restrictions as long as each operator is differentiable and PyTorch will take care of the rest. On each forward pass, it will evaluate the operators on-the-fly and dynamically construct the computation graph, which is then used to compute the gradients during the backward pass. This is called Eager execution mode and it is the default behavior of PyTorch. - -This mode comes in handy when the computation graph is not static, for example when the model has if-statements or loops that depend on the input data or when the input has dynamic shapes (imagine training a model with multiple resolutions). However, this flexibility comes at a cost, because we can't optimize a model if we don't know what operations, which shapes, types or even order of operations will be executed until runtime. This is where compilers come in. - -## Compilers 101 - -A compiler is a program that translates instructions written in one representation (source) into another representation (target). Nowadays, compilers usually are separated in a frontend and a backend. The frontend is responsible for parsing the source code and generating an intermediate representation (IR) that is independent of the source language. The backend is responsible for translating the IR into the target language. This separation allows for reusability of the frontend with different backends, and vice versa as we can see in {numref}`Figure {number} `. - -:::{figure-md} retargetable -Retargetable Compilers - -Frontends produce an intermediate representation (IR) common to all backends. Backends take the IR and generate code for a specific target. {cite}`aosabook` -::: - -So far, we've talked about compilation as a process that happens before the program is executed, also known as ahead-of-time (AOT) compilation. However, there are other ways to execute a program. Some languages, like Python, are *interpreted*, where programs are executed line by line by the runtime interpreter. Furthermore, some runtimes might use just-in-time (JIT) compilation, where parts of the program are compiled while it is being executed. This allows for optimizations that can only be done at runtime, like specializing code for specific inputs. - - -## Machine Learning Compilers - -Machine learning compilers take a model written in some framework (e.g. PyTorch), translate it into a program that can be executed in some runtime (e.g. TensorRT, CoreML, PyTorch TorchInductor) which then ends up optimized for some specialized hardware (e.g. GPUs, TPUs, Apple Silicon). - -PyTorch has had a few different compiler solutions over the years, the most popular being TorchScript. This, however, has changed since PyTorch 2, as the new compiler stack has been introduced. The main component of this new stack is TorchDynamo, a new compiler frontend with better properties and more Python support than TorchScript. - -Along with TorchDynamo, PyTorch 2 has introduced two new APIs, `torch.export` and `torch.compile`, that leverage this technology. On one hand, `torch.export`'s goal is to act as an ahead-of-time frontend which captures the full semantics of the program into an IR independent of Python, while `torch.compile` is meant to be used as a full JIT compiler that can leverage other backends (TorchInductor, TensorRT, ONNX) to optimize parts of the model at runtime and fallback to native Python if necessary. - -For edge devices specifically, we are most interested in the `torch.export` API, as it allows us to dispose of the expensive overhead of the Python Runtime and allows us to take advantage of native optimized frameworks for our target hardware, like CoreML for Apple devices or TensorRT (C++) for NVIDIA GPUs. diff --git a/docs/src/part3/torch_export.ipynb b/docs/src/part3/torch_export.ipynb deleted file mode 100644 index 98e684d..0000000 --- a/docs/src/part3/torch_export.ipynb +++ /dev/null @@ -1,661 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# PyTorch 2 Export" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```{contents}\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## torch.export 101\n", - "\n", - "The main idea of `torch.export` is that it translates an Eager Mode PyTorch model into a graph-based intermediate representation called *Export IR*. This allows compiler backends to take this IR and further transform and optimize it for a target device. A general overview of the process is shown in the figure [below](torchexport).\n", - "\n", - ":::{figure-md} torchexport\n", - "\"torch.export\"\n", - "\n", - "PyTorch 2 Export\n", - ":::" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This IR needs to fulfill a couple of properties for it to be useful to compilers. For example:\n", - "1. Operators have to be general enough for backends to notice patterns and optimize them: Many runtimes have specialized kernels for common operators like convolutions or even more complex ones like a `conv2 + relu` (operator fusion, see examples [here](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#fusion-types)). If the IR reduces all operators to sums, products and views, noticing these patterns becomes too hard.\n", - "2. The number of operators has to be small enough for the backend to implement all of them. \n", - "3. Operators have to be functional, that is, without side effects. For example: If two functions read and modify the same parameters, the order of execution matters and the compiler has to be careful when parallelizing them.\n", - "\n", - "Notice that properties 1 and 2 are in conflict with each other. The more operators we have, the more expressive the IR is, but the harder it is to implement all of them. This is a trade-off that the PyTorch team has to balance. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "TODO:\n", - "- [ ] Introduce ATEN (dialects), fx.Graph and link to Export IR, functionalization" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For now, let's get some practical intuition with an example." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Hands on with torch.export" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's use a simple network to see how `torch.export` works." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "tags": [ - "hide-input" - ] - }, - "outputs": [], - "source": [ - "import torch\n", - "import pprint\n", - "from part3_artifacts.simple_net import SimpleNet\n", - "import torch.fx.graph_module\n", - "from myst_nb import glue" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [ - "hide-input" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[0;31mInit signature:\u001b[0m \u001b[0mSimpleNet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mSource:\u001b[0m \n", - "\u001b[0;32mclass\u001b[0m \u001b[0mSimpleNet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\"\u001b[0m\n", - "\u001b[0;34m Just a simple network\u001b[0m\n", - "\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mConv2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mConv2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4704\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mFile:\u001b[0m ~/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py\n", - "\u001b[0;31mType:\u001b[0m type\n", - "\u001b[0;31mSubclasses:\u001b[0m " - ] - } - ], - "source": [ - "SimpleNet??" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To export a model we must first define a sample input. This is used to `trace` the model and generate the Export IR. \n", - "\n", - "```{note}\n", - "`Tracing` refers to the process of recording the operations executed by a model when given a specific input along with their metadata. \n", - "\n", - "The way tracing works efficiently is by using `torch._subclasses.fake_tensor.FakeTensor`. FakeTensors are a special type of tensor that only store metadata such as `dtype`, `shape` and `device` and overload all operators to simulate the computation without actually looking at the values. \n", - "\n", - "For example, doing matrix multiplications of FakeTensors of shapes `(N, M)` and `(M, K)` will return a FakeTensor of shape `(N, K)` in constant time instead of the normal cubic complexity of multiplication.\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For our case, the model will be deployed on a camera with a fixed resolution, so we can just define a statically shaped tensor of `batch_size` 1. If you want to support dynamically shaped inputs, refer to the [documentation](https://pytorch.org/docs/main/export.html#expressing-dynamism).\n", - "\n", - "Once we have the input, we can call the `torch.export.export` function.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.randn(1, 3, 32, 32) \n", - "ep: torch.export.ExportedProgram = torch.export.export(SimpleNet().eval(), (x,))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And that's it, we have exported our model. The new object is a `torch.export.ExportedProgram` which contains the model and parameters in the Export IR. Let's inspect it one by one." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The first and most important attribute is the `graph_module` which stores the computational graph of the model. We can print it using the `print_readable` method:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "class GraphModule(torch.nn.Module):\n", - " def forward(self, p_conv1_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[6, 3, 5, 5]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv1_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[6]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv2_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[6, 3, 5, 5]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_conv2_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[6]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_fc_weight: \"\u001b[31mf32\u001b[0m\u001b[34m[10, 4704]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", p_fc_bias: \"\u001b[31mf32\u001b[0m\u001b[34m[10]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\", x: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 3, 32, 32]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\"):\n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:16 in forward, code: z = self.conv1(x)\u001b[0m\n", - " conv2d: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.conv2d.default(x, p_conv1_weight, p_conv1_bias); \u001b[2mp_conv1_weight = p_conv1_bias = None\u001b[0m\n", - " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:17 in forward, code: z = F.relu(z)\u001b[0m\n", - " relu: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.relu.default(conv2d); \u001b[2mconv2d = None\u001b[0m\n", - " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:18 in forward, code: y = self.conv2(x)\u001b[0m\n", - " conv2d_1: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.conv2d.default(x, p_conv2_weight, p_conv2_bias); \u001b[2mx = p_conv2_weight = p_conv2_bias = None\u001b[0m\n", - " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:19 in forward, code: y = F.relu(y)\u001b[0m\n", - " relu_1: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.relu.default(conv2d_1); \u001b[2mconv2d_1 = None\u001b[0m\n", - " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:20 in forward, code: o = z + y\u001b[0m\n", - " add: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 6, 28, 28]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.add.Tensor(relu, relu_1); \u001b[2mrelu = relu_1 = None\u001b[0m\n", - " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:21 in forward, code: o = torch.flatten(o, 1)\u001b[0m\n", - " view: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 4704]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.view.default(add, \u001b[34m[1, 4704]\u001b[0m); \u001b[2madd = None\u001b[0m\n", - " \n", - " \u001b[2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:22 in forward, code: o = self.fc(o)\u001b[0m\n", - " linear: \"\u001b[31mf32\u001b[0m\u001b[34m[1, 10]\u001b[0m\u001b[2m\u001b[34m\u001b[0m\u001b[2m\u001b[32mcpu\u001b[0m\" = torch.ops.aten.linear.default(view, p_fc_weight, p_fc_bias); \u001b[2mview = p_fc_weight = p_fc_bias = None\u001b[0m\n", - " return (linear,)\n", - " \n" - ] - } - ], - "source": [ - "graph_module: torch.fx.GraphModule = ep.graph_module\n", - "print(graph_module.print_readable(print_output=False, colored=True, include_device=True))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here we can see all *nodes* (`conv2d`, `relu`, `conv2d_1`, etc.), their shapes, dtypes, devices and the aten operators that are being used (`torch.ops.aten.conv2d.default`), with their accompanying file, line and code. We can also see that the graph inputs expects not only the model inputs but also its parameters (buffers and constants too)." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "tags": [ - "remove-cell" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "graph():\n", - " %p_conv1_weight : [num_users=1] = placeholder[target=p_conv1_weight]\n", - " %p_conv1_bias : [num_users=1] = placeholder[target=p_conv1_bias]\n", - " %p_conv2_weight : [num_users=1] = placeholder[target=p_conv2_weight]\n", - " %p_conv2_bias : [num_users=1] = placeholder[target=p_conv2_bias]\n", - " %p_fc_weight : [num_users=1] = placeholder[target=p_fc_weight]\n", - " %p_fc_bias : [num_users=1] = placeholder[target=p_fc_bias]\n", - " %x : [num_users=2] = placeholder[target=x]\n", - " %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv1_weight, %p_conv1_bias), kwargs = {})\n", - " %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d,), kwargs = {})\n", - " %conv2d_1 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv2_weight, %p_conv2_bias), kwargs = {})\n", - " %relu_1 : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d_1,), kwargs = {})\n", - " %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %relu_1), kwargs = {})\n", - " %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%add, [1, 4704]), kwargs = {})\n", - " %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%view, %p_fc_weight, %p_fc_bias), kwargs = {})\n", - " return (linear,)" - ] - }, - "metadata": { - "scrapbook": { - "mime_prefix": "", - "name": "graphmodule_graph" - } - }, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "[p_conv1_weight,\n", - " p_conv1_bias,\n", - " p_conv2_weight,\n", - " p_conv2_bias,\n", - " p_fc_weight,\n", - " p_fc_bias,\n", - " x,\n", - " conv2d,\n", - " relu,\n", - " conv2d_1,\n", - " relu_1,\n", - " add,\n", - " view,\n", - " linear,\n", - " output]" - ] - }, - "metadata": { - "scrapbook": { - "mime_prefix": "", - "name": "graphmodule_graph_nodes" - } - }, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record/text/plain": "'call_function'" - }, - "metadata": { - "scrapbook": { - "mime_prefix": "application/papermill.record/", - "name": "relu_1_op" - } - }, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record/text/plain": "" - }, - "metadata": { - "scrapbook": { - "mime_prefix": "application/papermill.record/", - "name": "relu_1_target" - } - }, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record/text/plain": "(conv2d_1,)" - }, - "metadata": { - "scrapbook": { - "mime_prefix": "application/papermill.record/", - "name": "relu_1_args" - } - }, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record/text/plain": " File \"/home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py\", line 19, in forward\n y = F.relu(y)\n" - }, - "metadata": { - "scrapbook": { - "mime_prefix": "application/papermill.record/", - "name": "relu_1_stack_trace_2" - } - }, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record/text/plain": "'relu_1'" - }, - "metadata": { - "scrapbook": { - "mime_prefix": "application/papermill.record/", - "name": "relu_1_name" - } - }, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record/text/plain": "{'stack_trace': ' File \"/home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py\", line 19, in forward\\n y = F.relu(y)\\n',\n 'nn_module_stack': {'L__self__': ('',\n 'part3_artifacts.simple_net.SimpleNet')},\n 'source_fn_stack': [('relu_1',\n torch.Tensor>)],\n 'original_aten': ,\n 'from_node': [('y_1',\n torch.Tensor>)],\n 'seq_nr': 50,\n 'torch_fn': ('relu_2', 'function.relu'),\n 'val': FakeTensor(..., size=(1, 6, 28, 28)),\n 'tensor_meta': TensorMetadata(shape=torch.Size([1, 6, 28, 28]), dtype=torch.float32, requires_grad=True, stride=(4704, 784, 28, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}" - }, - "metadata": { - "scrapbook": { - "mime_prefix": "application/papermill.record/", - "name": "relu_1_meta" - } - }, - "output_type": "display_data" - }, - { - "data": { - "application/papermill.record/text/plain": "[add]" - }, - "metadata": { - "scrapbook": { - "mime_prefix": "application/papermill.record/", - "name": "relu_1_users" - } - }, - "output_type": "display_data" - } - ], - "source": [ - "def graph_formatter(graph, pp, cycle):\n", - " pp.text(str(graph))\n", - "\n", - "# def graph_nodes_formatter(nodes, pp, cycle):\n", - "# pp.\n", - "# for node in nodes:\n", - "# pp.text(str(node))\n", - "\n", - "from IPython import get_ipython\n", - "import torch.fx.graph as fx_graph\n", - "plain = get_ipython().display_formatter.formatters['text/plain']\n", - "plain.for_type(torch.fx.Graph, graph_formatter)\n", - "# plain.for_type(fx_graph._node_list, graph_nodes_formatter)\n", - "glue(\"graphmodule_graph\", graph_module.graph)\n", - "glue(\"graphmodule_graph_nodes\", list(graph_module.graph.nodes))\n", - "\n", - "class StackTrace(object):\n", - " def __init__(self, stack_trace):\n", - " self.stack_trace = stack_trace\n", - "\n", - "def stack_trace_formatter(stack_trace, pp, cycle):\n", - " pp.text(stack_trace.stack_trace)\n", - "\n", - "plain.for_type(StackTrace, stack_trace_formatter)\n", - "\n", - "relu_1 = next(filter(lambda n: n.name == \"relu_1\", graph_module.graph.nodes))\n", - "glue(\"relu_1_op\", relu_1.op, display=False)\n", - "glue(\"relu_1_target\", relu_1.target, display=False)\n", - "glue(\"relu_1_args\", relu_1.args, display=False)\n", - "glue(\"relu_1_stack_trace_2\", StackTrace(relu_1.stack_trace), display=False)\n", - "glue(\"relu_1_name\", relu_1.name, display=False)\n", - "glue(\"relu_1_meta\", relu_1.meta, display=False)\n", - "glue(\"relu_1_users\", list(relu_1.users), display=False)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "::::{note}\n", - "\n", - "A `torch.fx.GraphModule` is just a wrapper around its `fx.Graph`, and you can access it through `graph_module.graph`. This is useful for two reasons:\n", - "- Most of the compiler steps will work with `fx.Graph` directly, so it's good to get acquainted with its attributes in case you need to debug an error.\n", - "- You *might* need to manipulate the graph directly to ensure compatibility ([example](https://leimao.github.io/blog/PyTorch-Eager-Mode-Quantization-TensorRT-Acceleration/)).\n", - "\n", - "\n", - "To start, if we want to print the underlying graph, we can do it like this:\n", - "\n", - "```python\n", - "print(str(graph_module.graph))\n", - "```\n", - "\n", - "```{glue} graphmodule_graph\n", - "```\n", - "\n", - "This is similar enough to the `graph_module`'s output, so let's move on. Each \"variable\" in the graph is a `Node` object, and we can access them like this:\n", - "\n", - "```python\n", - "print(list(graph_module.graph.nodes))\n", - "```\n", - "\n", - "```{glue} graphmodule_graph_nodes\n", - "```\n", - "\n", - "Specifically, if we're interested in a particular node, like the `relu_1` node, we can filter it by name:\n", - "\n", - "```python\n", - "relu_1 = next(filter(lambda n: n.name == \"relu_1\", graph_module.graph.nodes))\n", - "```\n", - "\n", - "Some of its most important attributes are the `name`, `op`, `args`, `stack_trace`, `target` and `users`. Let's print them and see what they store.\n", - "\n", - "The `name` is just the unique name of the node:\n", - "\n", - "```python\n", - "print(relu_1.name)\n", - "```\n", - "\n", - "```{glue} relu_1_name \n", - "```\n", - "\n", - "The `op` is the operator that the node represents. It refers to the high-level function that specifies the type of node. It is accompanied by a `target` and together they define the behavior of the node.\n", - "For example `Node(op=placeholder, target=p_p_conv1_weight)` means that the node is a placeholder for the weight of the first convolutional layer. Inputs, weights, etc are tagged as `placeholder` nodes.\n", - "\n", - "On the other hand, `call_function` nodes represent a function call to their `target`. For example, `Node(op=call_function, target=torch.ops.aten.relu.default)` means that the node is a call to the `relu` function, as we can see next:\n", - "\n", - "```python\n", - "print(relu_1.op)\n", - "```\n", - "\n", - "```{glue} relu_1_op \n", - "```\n", - "\n", - "```python\n", - "print(relu_1.target)\n", - "```\n", - "\n", - "```{glue} relu_1_target \n", - "```\n", - "\n", - "As we can see, *operator* is almost used interchangeably with *function* in this context.\n", - "\n", - "The `args` are the arguments of the node's function. In our case, since `relu_1` takes as input the output of `conv2d_1`, we should see a reference to that node.\n", - "\n", - "```python\n", - "print(relu_1.args)\n", - "```\n", - "\n", - "```{glue} relu_1_args \n", - "```\n", - "\n", - "Similarly, the `users` are the nodes that take the output of `relu_1` as input. Both of these attributes are useful to traverse the graph and understand the dependencies between nodes.\n", - "\n", - "```python\n", - "print(relu_1.users)\n", - "```\n", - "\n", - "```{glue} relu_1_users \n", - "```\n", - "\n", - "Finally, the `stack_trace` is the piece of code that generated the node. This is also useful for debugging and it helps with localizing the source code that should be rewritten in case of an error.\n", - "```python\n", - "print(relu_1.stack_trace)\n", - "```\n", - "\n", - "```{glue} relu_1_stack_trace_2\n", - "```\n", - "\n", - "For more information refer to the [documentation](https://pytorch.org/docs/main/export.ir_spec.html).\n", - "\n", - "::::" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Back to the `ExportedProgram`, the second most important attribute is its `graph_signature`. This object contains information about the inputs (actual inputs, parameters, constant tensors, etc) and outputs of the model. This is particularly useful if you want to check whether a tensor is being folded as a constant.\n", - "\n", - "We can print it like this:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ExportGraphSignature(input_specs=[InputSpec(kind=,\n", - " arg=TensorArgument(name='p_conv1_weight'),\n", - " target='conv1.weight',\n", - " persistent=None),\n", - " InputSpec(kind=,\n", - " arg=TensorArgument(name='p_conv1_bias'),\n", - " target='conv1.bias',\n", - " persistent=None),\n", - " InputSpec(kind=,\n", - " arg=TensorArgument(name='p_conv2_weight'),\n", - " target='conv2.weight',\n", - " persistent=None),\n", - " InputSpec(kind=,\n", - " arg=TensorArgument(name='p_conv2_bias'),\n", - " target='conv2.bias',\n", - " persistent=None),\n", - " InputSpec(kind=,\n", - " arg=TensorArgument(name='p_fc_weight'),\n", - " target='fc.weight',\n", - " persistent=None),\n", - " InputSpec(kind=,\n", - " arg=TensorArgument(name='p_fc_bias'),\n", - " target='fc.bias',\n", - " persistent=None),\n", - " InputSpec(kind=,\n", - " arg=TensorArgument(name='x'),\n", - " target=None,\n", - " persistent=None)],\n", - " output_specs=[OutputSpec(kind=,\n", - " arg=TensorArgument(name='linear'),\n", - " target=None)])\n" - ] - } - ], - "source": [ - "pprint.pp(ep._graph_signature)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If you want to access the parameters and buffers directly, you can reference the `state_dict` attribute." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc.weight', 'fc.bias'])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ep._state_dict.keys()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Constants are tensors that during the forward pass are found to not change (think of a tensor that contains the shape of the input). It is a bit less common to find them, but somestimes ensuring they are constant can help the compiler to parse the model correctly. Our simple network doesn't have any constants, but you can access them like this:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{}\n" - ] - } - ], - "source": [ - "print(ep.constants)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, we can save our exported program using the `torch.export.save` function." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "torch.export.save(ep, \"simple_net.pt2\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cu124", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}