Skip to content

Commit

Permalink
torch.compile documentation updates for 2024.2 (openvinotoolkit#24803)
Browse files Browse the repository at this point in the history
### Details:
 - Updated 'How to Use' and 'Options' sections
 - Added quantization and torchserve sections

---------

Co-authored-by: Mustafa Cavus <[email protected]>
Co-authored-by: suryasidd <[email protected]>
Co-authored-by: Roman Kazantsev <[email protected]>
Co-authored-by: Karol Blaszczak <[email protected]>
  • Loading branch information
5 people authored Jun 5, 2024
1 parent aa04a9b commit 9e66235
Showing 1 changed file with 77 additions and 5 deletions.
82 changes: 77 additions & 5 deletions docs/articles_en/openvino-workflow/torch-compile.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,33 @@ By default, Torch code runs in eager-mode, but with the use of ``torch.compile``
How to Use
####################

To use ``torch.compile``, you need to add an import statement and define the ``openvino`` backend.
To use ``torch.compile``, you need to define the ``openvino`` backend in your PyTorch application.
This way Torch FX subgraphs will be directly converted to OpenVINO representation without
any additional PyTorch-based tracing/scripting.
This approach works only for the **package distributed via pip**, as it is now configured with
`torch_dynamo_backends entrypoint <https://pytorch.org/docs/stable/torch.compiler_custom_backends.html#registering-custom-backends>`__.

.. code-block:: python
...
model = torch.compile(model, backend='openvino')
...
For OpenVINO installed via channels other than pip, such as conda, and versions older than
2024.1, an additional import statement is needed:

.. code-block:: sh
.. code-block:: python
import openvino.torch
...
model = torch.compile(model, backend='openvino')
...
Execution diagram:
.. image:: ../assets/images/torch_compile_backend_openvino.svg
:alt: torch.compile execution diagram
:width: 992px
:height: 720px
:scale: 60%
Expand All @@ -51,13 +64,26 @@ enable model caching, set the cache directory etc. You can use a dictionary of t
By default, the OpenVINO backend for ``torch.compile`` runs PyTorch applications
on CPU. If you set this variable to ``GPU.0``, for example, the application will
use the integrated graphics processor instead.
* ``aot_autograd`` - enables aot_autograd graph capture. The aot_autograd graph capture
is needed to enable dynamic shapes or to finetune a model. For models with dynamic
shapes, it is recommended to set this option to ``True``. By default, aot_autograd
is set to ``False``.
* ``model_caching`` - enables saving the optimized model files to a hard drive,
after the first application run. This makes them available for the following
application executions, reducing the first-inference latency. By default, this
variable is set to ``False``. Set it to ``True`` to enable caching.
* ``cache_dir`` - enables defining a custom directory for the model files (if
``model_caching`` is set to ``True``). By default, the OpenVINO IR is saved
in the cache sub-directory, created in the application's root directory.
* ``decompositions`` - enables defining additional operator decompositions. By
default, this is an empty list. For example, to add a decomposition for
an operator ``my_op``, add ``'decompositions': [torch.ops.aten.my_op.default]``
to the options.
* ``disabled_ops`` - enables specifying operators that can be disabled from
openvino execution and make it fall back to native PyTorch runtime. For
example, to disable an operator ``my_op`` from OpenVINO execution, add
``'disabled_ops': [torch.ops.aten.my_op.default]`` to the options. By
default, this is an empty list.
* ``config`` - enables passing any OpenVINO configuration option as a dictionary
to this variable. For details on the various options, refer to the
:ref:`OpenVINO Advanced Features <openvino-advanced-features>`.
Expand All @@ -79,8 +105,10 @@ You can also set OpenVINO specific configuration options by adding them as a dic
Windows support
+++++++++++++++++++++

Currently, PyTorch does not support ``torch.compile`` feature on Windows officially. However, it can be accessed by running
the below instructions:
PyTorch supports ``torch.compile`` officially on Windows from version 2.3.0 onwards.

For PyTorch versions below 2.3.0, the ``torch.compile`` feature is not supported on Windows
officially. However, it can be accessed by running the following instructions:

1. Install the PyTorch nightly wheel file - `2.1.0.dev20230713 <https://download.pytorch.org/whl/nightly/cpu/torch-2.1.0.dev20230713%2Bcpu-cp38-cp38-win_amd64.whl>`__ ,
2. Update the file at ``<python_env_root>/Lib/site-packages/torch/_dynamo/eval_frames.py``
Expand All @@ -104,6 +132,50 @@ the below instructions:
if sys.version_info >= (3, 11):
`raise RuntimeError("Python 3.11+ not yet supported for torch.compile")
Support for PyTorch 2 export quantization (Preview)
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

PyTorch 2 export quantization is supported by OpenVINO backend in ``torch.compile``. To be able
to access this feature, follow the steps provided in
`PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_x86_inductor.html>`__
and update the provided sample as explained below.

1. If you are using PyTorch version 2.3.0 or later, disable constant folding in quantization to
be able to benefit from the optimization in the OpenVINO backend. This can be done by passing
``fold_quantize=False`` parameter into the ``convert_pt2e`` function. To do so, change this
line:

.. code-block:: python
converted_model = convert_pt2e(prepared_model)
to the following:

.. code-block:: python
converted_model = convert_pt2e(prepared_model, fold_quantize=False)
2. Set ``torch.compile`` backend as OpenVINO and execute the model.

Update this line below:

.. code-block:: python
optimized_model = torch.compile(converted_model)
As below:

.. code-block:: python
optimized_model = torch.compile(converted_model, backend="openvino")
TorchServe Integration
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

TorchServe is a performant, flexible, and easy to use tool for serving PyTorch models in production. For more information on the details of TorchServe,
you can refer to `TorchServe github repository. <https://github.com/pytorch/serve>`__. With OpenVINO ``torch.compile`` integration into TorchServe you can serve
PyTorch models in production and accelerate them with OpenVINO on various Intel hardware. Detailed instructions on how to use OpenVINO with TorchServe are
available in `TorchServe examples. <https://github.com/pytorch/serve/tree/master/examples/pt2/torch_compile_openvino>`__

Support for Automatic1111 Stable Diffusion WebUI
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Expand Down

0 comments on commit 9e66235

Please sign in to comment.