Skip to content

Update cpp custom ops docs with sycl extension and xpu support #3391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 88 additions & 19 deletions advanced_source/cpp_custom_ops.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. _cpp-custom-ops-tutorial:

Custom C++ and CUDA Operators
=============================
Custom C++ and CUDA/SYCL Operators
==================================

**Author:** `Richard Zou <https://github.com/zou3519>`_

Expand All @@ -10,25 +10,30 @@ Custom C++ and CUDA Operators
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
:class-card: card-prerequisites

* How to integrate custom operators written in C++/CUDA with PyTorch
* How to integrate custom operators written in C++/CUDA/SYCL with PyTorch
* How to test custom operators using ``torch.library.opcheck``

.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites

* PyTorch 2.4 or later
* Basic understanding of C++ and CUDA programming
* PyTorch 2.4 or later for C++/CUDA & PyTorch 2.8 or later for SYCL
* Basic understanding of C++ and CUDA/SYCL programming

.. note::

This tutorial will also work on AMD ROCm with no additional modifications.

.. note::

``SYCL`` serves as the backend programming language for Intel GPUs (device label ``xpu``). For configuration details, see:
`Getting Started on Intel GPUs <https://docs.pytorch.org/docs/main/notes/get_start_xpu.html>`_.

PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc).
However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the
blessed path to authoring a custom operator written in C++/CUDA.
blessed path to authoring a custom operator written in C++/CUDA/SYCL.

For our tutorial, we’ll demonstrate how to author a fused multiply-add C++
and CUDA operator that composes with PyTorch subsystems. The semantics of
and CUDA/SYCL operator that composes with PyTorch subsystems. The semantics of
the operation are as follows:

.. code-block:: python
Expand All @@ -42,13 +47,13 @@ You can find the end-to-end working example for this tutorial
Setting up the Build System
---------------------------

If you are developing custom C++/CUDA code, it must be compiled.
If you are developing custom C++/CUDA/SYCL code, it must be compiled.
Note that if you’re interfacing with a Python library that already has bindings
to precompiled C++/CUDA code, you might consider writing a custom Python operator
to precompiled C++/CUDA/SYCL code, you might consider writing a custom Python operator
instead (:ref:`python-custom-ops-tutorial`).

Use `torch.utils.cpp_extension <https://pytorch.org/docs/stable/cpp_extension.html>`_
to compile custom C++/CUDA code for use with PyTorch
to compile custom C++/CUDA/SYCL code for use with PyTorch
C++ extensions may be built either "ahead of time" with setuptools, or "just in time"
via `load_inline <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load_inline>`_;
we’ll focus on the "ahead of time" flavor.
Expand All @@ -73,10 +78,10 @@ Using ``cpp_extension`` is as simple as writing the following ``setup.py``:
options={"bdist_wheel": {"py_limited_api": "cp39"}} # 3.9 is minimum supported Python version
)

If you need to compile CUDA code (for example, ``.cu`` files), then instead use
`torch.utils.cpp_extension.CUDAExtension <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension>`_.
Please see `extension-cpp <https://github.com/pytorch/extension-cpp>`_ for an
example for how this is set up.
If you need to compile **CUDA** or **SYCL** code (for example, ``.cu`` or ``.sycl`` files), use
`torch.utils.cpp_extension.CUDAExtension <https://docs.pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension>`_
or `torch.utils.cpp_extension.SyclExtension <https://docs.pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.SyclExtension>`_
respectively. For CUDA/SYCL examples, see `extension-cpp <https://github.com/pytorch/extension-cpp>`_.

The above example represents what we refer to as a CPython agnostic wheel, meaning we are
building a single wheel that can be run across multiple CPython versions (similar to pure
Expand Down Expand Up @@ -126,7 +131,7 @@ to build a CPython agnostic wheel and will influence the naming of the wheel acc
)

It is necessary to specify ``py_limited_api=True`` as an argument to CppExtension/
CUDAExtension and also as an option to the ``"bdist_wheel"`` command with the minimal
CUDAExtension/SyclExtension and also as an option to the ``"bdist_wheel"`` command with the minimal
supported CPython version (in this case, 3.9). Consequently, the ``setup`` in our
tutorial would build one properly named wheel that could be installed across multiple
CPython versions ``>=3.9``.
Expand Down Expand Up @@ -181,7 +186,7 @@ Operator registration is a two step-process:

- **Defining the operator** - This step ensures that PyTorch is aware of the new operator.
- **Registering backend implementations** - In this step, implementations for various
backends, such as CPU and CUDA, are associated with the operator.
backends, such as CPU and CUDA/SYCL, are associated with the operator.

Defining an operator
^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -249,6 +254,70 @@ in a separate ``TORCH_LIBRARY_IMPL`` block:
m.impl("mymuladd", &mymuladd_cuda);
}

If you also have a SYCL implementation of ``myaddmul``, you can register it
in a separate ``TORCH_LIBRARY_IMPL`` block:

.. code-block:: cpp

static void muladd_kernel(
int numel, const float* a, const float* b, float c, float* result,
const sycl::nd_item<1>& item) {
int idx = item.get_global_id(0);
if (idx < numel) {
result[idx] = a[idx] * b[idx] + c;
}
}

class MulAddKernelFunctor {
public:
MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result)
: numel(_numel), a(_a), b(_b), c(_c), result(_result) {}

void operator()(const sycl::nd_item<1>& item) const {
muladd_kernel(numel, a, b, c, result, item);
}

private:
int numel;
const float* a;
const float* b;
float c;
float* result;
};

at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");

at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = at::empty_like(a_contig);

const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* res_ptr = result.data_ptr<float>();
int numel = a_contig.numel();

sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
constexpr int threads = 256;
int blocks = (numel + threads - 1) / threads;

queue.submit([&](sycl::handler& cgh) {
cgh.parallel_for<MulAddKernelFunctor>(
sycl::nd_range<1>(blocks * threads, threads),
MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr)
);
});
return result;
}

TORCH_LIBRARY_IMPL(extension_cpp, XPU, m) {
m.impl("mymuladd", &mymuladd_xpu);
}

Adding ``torch.compile`` support for an operator
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -285,7 +354,7 @@ for more details).

Setting up hybrid Python/C++ registration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In this tutorial, we defined a custom operator in C++, added CPU/CUDA
In this tutorial, we defined a custom operator in C++, added CPU/CUDA/SYCL
implementations in C++, and added ``FakeTensor`` kernels and backward formulas
in Python. The order in which these registrations are loaded (or imported)
matters (importing in the wrong order will lead to an error).
Expand Down Expand Up @@ -412,7 +481,7 @@ for more details).
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)

Note that the backward must be a composition of PyTorch-understood operators.
If you wish to use another custom C++ or CUDA kernel in your backwards pass,
If you wish to use another custom C++, CUDA or SYCL kernel in your backwards pass,
it must be wrapped into a custom operator.

If we had our own custom ``mymul`` kernel, we would need to wrap it into a
Expand Down Expand Up @@ -577,6 +646,6 @@ When defining the operator, we must specify that it mutates the out Tensor in th
Conclusion
----------
In this tutorial, we went over the recommended approach to integrating Custom C++
and CUDA operators with PyTorch. The ``TORCH_LIBRARY/torch.library`` APIs are fairly
and CUDA/SYCL operators with PyTorch. The ``TORCH_LIBRARY/torch.library`` APIs are fairly
low-level. For more information about how to use the API, see
`The Custom Operators Manual <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual>`_.
14 changes: 7 additions & 7 deletions advanced_source/custom_ops_landing_page.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. _custom-ops-landing-page:

PyTorch Custom Operators
===========================
========================

PyTorch offers a large library of operators that work on Tensors (e.g. ``torch.add``,
``torch.sum``, etc). However, you may wish to bring a new custom operation to PyTorch
Expand All @@ -21,18 +21,18 @@ You may wish to author a custom operator from Python (as opposed to C++) if:

- you have a Python function you want PyTorch to treat as an opaque callable, especially with
respect to ``torch.compile`` and ``torch.export``.
- you have some Python bindings to C++/CUDA kernels and want those to compose with PyTorch
- you have some Python bindings to C++/CUDA/SYCL kernels and want those to compose with PyTorch
subsystems (like ``torch.compile`` or ``torch.autograd``)
- you are using Python (and not a C++-only environment like AOTInductor).

Integrating custom C++ and/or CUDA code with PyTorch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Integrating custom C++ and/or CUDA/SYCL code with PyTorch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Please see :ref:`cpp-custom-ops-tutorial`.

You may wish to author a custom operator from C++ (as opposed to Python) if:

- you have custom C++ and/or CUDA code.
- you have custom C++ and/or CUDA/SYCL code.
- you plan to use this code with ``AOTInductor`` to do Python-less inference.

The Custom Operators Manual
Expand All @@ -50,12 +50,12 @@ If your operation is expressible as a composition of built-in PyTorch operators
then please write it as a Python function and call it instead of creating a
custom operator. Use the operator registration APIs to create a custom operator if you
are calling into some library that PyTorch doesn't understand (e.g. custom C/C++ code,
a custom CUDA kernel, or Python bindings to C/C++/CUDA extensions).
a custom CUDA kernel, a custom SYCL kernel, or Python bindings to C/C++/CUDA/SYCL extensions).

Why should I create a Custom Operator?
--------------------------------------

It is possible to use a C/C++/CUDA kernel by grabbing a Tensor's data pointer
It is possible to use a C/C++/CUDA/SYCL kernel by grabbing a Tensor's data pointer
and passing it to a pybind'ed kernel. However, this approach doesn't compose with
PyTorch subsystems like autograd, torch.compile, vmap, and more. In order
for an operation to compose with PyTorch subsystems, it must be registered
Expand Down