Skip to content

Commit

Permalink
[Capture] Add a QmlPrimitive class to differentiate between differe…
Browse files Browse the repository at this point in the history
…nt types of primitives (#6847)

This PR adds a `QmlPrimitive` subclass of `jax.core.Primitive`. This
class contains a `prim_type` property set using a new `PrimitiveType`
enum. `PrimitiveType`s currently available are "default", "operator",
"measurement", "transform", and "higher_order". This can be made more or
less fine grained as needed, but should be enough to differentiate
between different types of primitives for now. Additionally, this PR:

* updates `NonInterpPrimitive` to be a subclass of `QmlPrimitive`
* updates all existing PennyLane primitives to be either `QmlPrimitive`
or `NonInterpPrimitive`. See [this
comment](#6851 (comment))
to see the logic used to determine which `Primitive` subclass is used
for each primitive.
* updates `PlxprInterpreter.eval` and `CancelInversesInterpreter.eval`
to use this `prim_type` property.

[sc-82420]

---------

Co-authored-by: Pietropaolo Frisoni <[email protected]>
  • Loading branch information
mudit2812 and PietropaoloFrisoni authored Jan 22, 2025
1 parent fdf34ec commit 3e1521b
Show file tree
Hide file tree
Showing 16 changed files with 209 additions and 73 deletions.
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@

<h3>Internal changes ⚙️</h3>

* Added a `QmlPrimitive` class that inherits `jax.core.Primitive` to a new `qml.capture.custom_primitives` module.
This class contains a `prim_type` property so that we can differentiate between different sets of PennyLane primitives.
Consequently, `QmlPrimitive` is now used to define all PennyLane primitives.
[(#6847)](https://github.com/PennyLaneAI/pennylane/pull/6847)

<h3>Documentation 📝</h3>

* The docstrings for `qml.unary_mapping`, `qml.binary_mapping`, `qml.christiansen_mapping`,
Expand All @@ -115,4 +120,5 @@ Diksha Dhawan,
Pietropaolo Frisoni,
Marcus Gisslén,
Christina Lee,
Mudit Pandey,
Andrija Paurevic
13 changes: 6 additions & 7 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

from .flatfn import FlatFn
from .primitives import (
AbstractMeasurement,
AbstractOperator,
adjoint_transform_prim,
cond_prim,
ctrl_transform_prim,
Expand Down Expand Up @@ -311,20 +309,21 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
self._env[constvar] = const

for eqn in jaxpr.eqns:
primitive = eqn.primitive
custom_handler = self._primitive_registrations.get(primitive, None)

custom_handler = self._primitive_registrations.get(eqn.primitive, None)
if custom_handler:
invals = [self.read(invar) for invar in eqn.invars]
outvals = custom_handler(self, *invals, **eqn.params)
elif isinstance(eqn.outvars[0].aval, AbstractOperator):
elif getattr(primitive, "prim_type", "") == "operator":
outvals = self.interpret_operation_eqn(eqn)
elif isinstance(eqn.outvars[0].aval, AbstractMeasurement):
elif getattr(primitive, "prim_type", "") == "measurement":
outvals = self.interpret_measurement_eqn(eqn)
else:
invals = [self.read(invar) for invar in eqn.invars]
outvals = eqn.primitive.bind(*invals, **eqn.params)
outvals = primitive.bind(*invals, **eqn.params)

if not eqn.primitive.multiple_results:
if not primitive.multiple_results:
outvals = [outvals]
for outvar, outval in zip(eqn.outvars, outvals, strict=True):
self._env[outvar] = outval
Expand Down
41 changes: 11 additions & 30 deletions pennylane/capture/capture_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,34 +24,6 @@
has_jax = False


@lru_cache
def create_non_interpreted_prim():
"""Create a primitive type ``NonInterpPrimitive``, which binds to JAX's JVPTrace
and BatchTrace objects like a standard Python function and otherwise behaves like jax.core.Primitive.
"""

if not has_jax: # pragma: no cover
return None

# pylint: disable=too-few-public-methods
class NonInterpPrimitive(jax.core.Primitive):
"""A subclass to JAX's Primitive that works like a Python function
when evaluating JVPTracers and BatchTracers."""

def bind_with_trace(self, trace, args, params):
"""Bind the ``NonInterpPrimitive`` with a trace.
If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call.
Otherwise, the bind call of JAX's standard Primitive is used."""
if isinstance(
trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)
):
return self.impl(*args, **params)
return super().bind_with_trace(trace, args, params)

return NonInterpPrimitive


@lru_cache
def _get_grad_prim():
"""Create a primitive for gradient computations.
Expand All @@ -60,8 +32,11 @@ def _get_grad_prim():
if not has_jax: # pragma: no cover
return None

grad_prim = create_non_interpreted_prim()("grad")
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel

grad_prim = NonInterpPrimitive("grad")
grad_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init
grad_prim.prim_type = "higher_order"

# pylint: disable=too-many-arguments
@grad_prim.def_impl
Expand Down Expand Up @@ -91,8 +66,14 @@ def _get_jacobian_prim():
"""Create a primitive for Jacobian computations.
This primitive is used when capturing ``qml.jacobian``.
"""
jacobian_prim = create_non_interpreted_prim()("jacobian")
if not has_jax: # pragma: no cover
return None

from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel

jacobian_prim = NonInterpPrimitive("jacobian")
jacobian_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init
jacobian_prim.prim_type = "higher_order"

# pylint: disable=too-many-arguments
@jacobian_prim.def_impl
Expand Down
15 changes: 12 additions & 3 deletions pennylane/capture/capture_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ def create_measurement_obs_primitive(
if not has_jax:
return None

primitive = jax.core.Primitive(name + "_obs")
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel

primitive = NonInterpPrimitive(name + "_obs")
primitive.prim_type = "measurement"

@primitive.def_impl
def _(obs, **kwargs):
Expand Down Expand Up @@ -165,7 +168,10 @@ def create_measurement_mcm_primitive(
if not has_jax:
return None

primitive = jax.core.Primitive(name + "_mcm")
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel

primitive = NonInterpPrimitive(name + "_mcm")
primitive.prim_type = "measurement"

@primitive.def_impl
def _(*mcms, single_mcm=True, **kwargs):
Expand Down Expand Up @@ -200,7 +206,10 @@ def create_measurement_wires_primitive(
if not has_jax:
return None

primitive = jax.core.Primitive(name + "_wires")
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel

primitive = NonInterpPrimitive(name + "_wires")
primitive.prim_type = "measurement"

@primitive.def_impl
def _(*args, has_eigvals=False, **kwargs):
Expand Down
7 changes: 4 additions & 3 deletions pennylane/capture/capture_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import pennylane as qml

from .capture_diff import create_non_interpreted_prim

has_jax = True
try:
import jax
Expand Down Expand Up @@ -103,7 +101,10 @@ def create_operator_primitive(
if not has_jax:
return None

primitive = create_non_interpreted_prim()(operator_type.__name__)
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel

primitive = NonInterpPrimitive(operator_type.__name__)
primitive.prim_type = "operator"

@primitive.def_impl
def _(*args, **kwargs):
Expand Down
64 changes: 64 additions & 0 deletions pennylane/capture/custom_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This submodule offers custom primitives for the PennyLane capture module.
"""
from enum import Enum
from typing import Union

import jax


class PrimitiveType(Enum):
"""Enum to define valid set of primitive classes"""

DEFAULT = "default"
OPERATOR = "operator"
MEASUREMENT = "measurement"
HIGHER_ORDER = "higher_order"
TRANSFORM = "transform"


# pylint: disable=too-few-public-methods,abstract-method
class QmlPrimitive(jax.core.Primitive):
"""A subclass for JAX's Primitive that differentiates between different
classes of primitives."""

_prim_type: PrimitiveType = PrimitiveType.DEFAULT

@property
def prim_type(self):
"""Value of Enum representing the primitive type to differentiate between various
sets of PennyLane primitives."""
return self._prim_type.value

@prim_type.setter
def prim_type(self, value: Union[str, PrimitiveType]):
"""Setter for QmlPrimitive.prim_type."""
self._prim_type = PrimitiveType(value)


# pylint: disable=too-few-public-methods,abstract-method
class NonInterpPrimitive(QmlPrimitive):
"""A subclass to JAX's Primitive that works like a Python function
when evaluating JVPTracers and BatchTracers."""

def bind_with_trace(self, trace, args, params):
"""Bind the ``NonInterpPrimitive`` with a trace.
If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call.
Otherwise, the bind call of JAX's standard Primitive is used."""
if isinstance(trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)):
return self.impl(*args, **params)
return super().bind_with_trace(trace, args, params)
17 changes: 12 additions & 5 deletions pennylane/compiler/qjit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from collections.abc import Callable

import pennylane as qml
from pennylane.capture.capture_diff import create_non_interpreted_prim
from pennylane.capture.flatfn import FlatFn

from .compiler import (
Expand Down Expand Up @@ -405,10 +404,14 @@ def _decorator(body_fn: Callable) -> Callable:
def _get_while_loop_qfunc_prim():
"""Get the while_loop primitive for quantum functions."""

import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax

while_loop_prim = create_non_interpreted_prim()("while_loop")
from pennylane.capture.custom_primitives import NonInterpPrimitive

while_loop_prim = NonInterpPrimitive("while_loop")
while_loop_prim.multiple_results = True
while_loop_prim.prim_type = "higher_order"

@while_loop_prim.def_impl
def _(*args, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice):
Expand Down Expand Up @@ -626,10 +629,14 @@ def _decorator(body_fn):
def _get_for_loop_qfunc_prim():
"""Get the loop_for primitive for quantum functions."""

import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax

from pennylane.capture.custom_primitives import NonInterpPrimitive

for_loop_prim = create_non_interpreted_prim()("for_loop")
for_loop_prim = NonInterpPrimitive("for_loop")
for_loop_prim.multiple_results = True
for_loop_prim.prim_type = "higher_order"

# pylint: disable=too-many-arguments
@for_loop_prim.def_impl
Expand Down
7 changes: 5 additions & 2 deletions pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,12 @@ def _create_mid_measure_primitive():
measurement.
"""
import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax

mid_measure_p = jax.core.Primitive("measure")
from pennylane.capture.custom_primitives import NonInterpPrimitive

mid_measure_p = NonInterpPrimitive("measure")

@mid_measure_p.def_impl
def _(wires, reset=False, postselect=None):
Expand Down
9 changes: 6 additions & 3 deletions pennylane/ops/op_math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Callable, overload

import pennylane as qml
from pennylane.capture.capture_diff import create_non_interpreted_prim
from pennylane.compiler import compiler
from pennylane.math import conj, moveaxis, transpose
from pennylane.operation import Observable, Operation, Operator
Expand Down Expand Up @@ -190,10 +189,14 @@ def create_adjoint_op(fn, lazy):
def _get_adjoint_qfunc_prim():
"""See capture/explanations.md : Higher Order primitives for more information on this code."""
# if capture is enabled, jax should be installed
import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax

from pennylane.capture.custom_primitives import NonInterpPrimitive

adjoint_prim = create_non_interpreted_prim()("adjoint_transform")
adjoint_prim = NonInterpPrimitive("adjoint_transform")
adjoint_prim.multiple_results = True
adjoint_prim.prim_type = "higher_order"

@adjoint_prim.def_impl
def _(*args, jaxpr, lazy, n_consts):
Expand Down
9 changes: 6 additions & 3 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import pennylane as qml
from pennylane import QueuingManager
from pennylane.capture.capture_diff import create_non_interpreted_prim
from pennylane.capture.flatfn import FlatFn
from pennylane.compiler import compiler
from pennylane.measurements import MeasurementValue
Expand Down Expand Up @@ -681,10 +680,14 @@ def _get_mcm_predicates(conditions: tuple[MeasurementValue]) -> list[Measurement
def _get_cond_qfunc_prim():
"""Get the cond primitive for quantum functions."""

import jax # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
import jax

cond_prim = create_non_interpreted_prim()("cond")
from pennylane.capture.custom_primitives import NonInterpPrimitive

cond_prim = NonInterpPrimitive("cond")
cond_prim.multiple_results = True
cond_prim.prim_type = "higher_order"

@cond_prim.def_impl
def _(*all_args, jaxpr_branches, consts_slices, args_slice):
Expand Down
10 changes: 7 additions & 3 deletions pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import pennylane as qml
from pennylane import math as qmlmath
from pennylane import operation
from pennylane.capture.capture_diff import create_non_interpreted_prim
from pennylane.compiler import compiler
from pennylane.operation import Operator
from pennylane.wires import Wires, WiresLike
Expand Down Expand Up @@ -233,10 +232,15 @@ def wrapper(*args, **kwargs):
def _get_ctrl_qfunc_prim():
"""See capture/explanations.md : Higher Order primitives for more information on this code."""
# if capture is enabled, jax should be installed
import jax # pylint: disable=import-outside-toplevel

ctrl_prim = create_non_interpreted_prim()("ctrl_transform")
# pylint: disable=import-outside-toplevel
import jax

from pennylane.capture.custom_primitives import NonInterpPrimitive

ctrl_prim = NonInterpPrimitive("ctrl_transform")
ctrl_prim.multiple_results = True
ctrl_prim.prim_type = "higher_order"

@ctrl_prim.def_impl
def _(*args, n_control, jaxpr, control_values, work_wires, n_consts):
Expand Down
5 changes: 3 additions & 2 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,12 +540,13 @@ def final_transform(self):
def _create_transform_primitive(name):
try:
# pylint: disable=import-outside-toplevel
import jax
from pennylane.capture.custom_primitives import NonInterpPrimitive
except ImportError:
return None

transform_prim = jax.core.Primitive(name + "_transform")
transform_prim = NonInterpPrimitive(name + "_transform")
transform_prim.multiple_results = True
transform_prim.prim_type = "transform"

@transform_prim.def_impl
def _(
Expand Down
Loading

0 comments on commit 3e1521b

Please sign in to comment.