From 3e1521bdef05235d915ddb8273ea177fcf62d755 Mon Sep 17 00:00:00 2001
From: Mudit Pandey <mudit.pandey@xanadu.ai>
Date: Wed, 22 Jan 2025 11:09:07 -0500
Subject: [PATCH] [Capture] Add a `QmlPrimitive` class to differentiate between
 different types of primitives (#6847)

This PR adds a `QmlPrimitive` subclass of `jax.core.Primitive`. This
class contains a `prim_type` property set using a new `PrimitiveType`
enum. `PrimitiveType`s currently available are "default", "operator",
"measurement", "transform", and "higher_order". This can be made more or
less fine grained as needed, but should be enough to differentiate
between different types of primitives for now. Additionally, this PR:

* updates `NonInterpPrimitive` to be a subclass of `QmlPrimitive`
* updates all existing PennyLane primitives to be either `QmlPrimitive`
or `NonInterpPrimitive`. See [this
comment](https://github.com/PennyLaneAI/pennylane/pull/6851#discussion_r1922462699)
to see the logic used to determine which `Primitive` subclass is used
for each primitive.
* updates `PlxprInterpreter.eval` and `CancelInversesInterpreter.eval`
to use this `prim_type` property.

[sc-82420]

---------

Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai>
---
 doc/releases/changelog-dev.md                 |  6 ++
 pennylane/capture/base_interpreter.py         | 13 ++--
 pennylane/capture/capture_diff.py             | 41 ++++--------
 pennylane/capture/capture_measurements.py     | 15 ++++-
 pennylane/capture/capture_operators.py        |  7 +-
 pennylane/capture/custom_primitives.py        | 64 +++++++++++++++++++
 pennylane/compiler/qjit_api.py                | 17 +++--
 pennylane/measurements/mid_measure.py         |  7 +-
 pennylane/ops/op_math/adjoint.py              |  9 ++-
 pennylane/ops/op_math/condition.py            |  9 ++-
 pennylane/ops/op_math/controlled.py           | 10 ++-
 .../transforms/core/transform_dispatcher.py   |  5 +-
 .../optimization/cancel_inverses.py           |  8 +--
 pennylane/workflow/_capture_qnode.py          |  4 +-
 tests/capture/test_custom_primitives.py       | 48 ++++++++++++++
 tests/capture/test_switches.py                | 19 ++++--
 16 files changed, 209 insertions(+), 73 deletions(-)
 create mode 100644 pennylane/capture/custom_primitives.py
 create mode 100644 tests/capture/test_custom_primitives.py

diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 6b08ff74363..4eb878d6478 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -90,6 +90,11 @@
 
 <h3>Internal changes ⚙️</h3>
 
+* Added a `QmlPrimitive` class that inherits `jax.core.Primitive` to a new `qml.capture.custom_primitives` module.
+  This class contains a `prim_type` property so that we can differentiate between different sets of PennyLane primitives.
+  Consequently, `QmlPrimitive` is now used to define all PennyLane primitives.
+  [(#6847)](https://github.com/PennyLaneAI/pennylane/pull/6847)
+
 <h3>Documentation 📝</h3>
 
 * The docstrings for `qml.unary_mapping`, `qml.binary_mapping`, `qml.christiansen_mapping`, 
@@ -115,4 +120,5 @@ Diksha Dhawan,
 Pietropaolo Frisoni,
 Marcus Gisslén,
 Christina Lee,
+Mudit Pandey,
 Andrija Paurevic
diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py
index 4af11cb6198..2b7314f7c2e 100644
--- a/pennylane/capture/base_interpreter.py
+++ b/pennylane/capture/base_interpreter.py
@@ -25,8 +25,6 @@
 
 from .flatfn import FlatFn
 from .primitives import (
-    AbstractMeasurement,
-    AbstractOperator,
     adjoint_transform_prim,
     cond_prim,
     ctrl_transform_prim,
@@ -311,20 +309,21 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
             self._env[constvar] = const
 
         for eqn in jaxpr.eqns:
+            primitive = eqn.primitive
+            custom_handler = self._primitive_registrations.get(primitive, None)
 
-            custom_handler = self._primitive_registrations.get(eqn.primitive, None)
             if custom_handler:
                 invals = [self.read(invar) for invar in eqn.invars]
                 outvals = custom_handler(self, *invals, **eqn.params)
-            elif isinstance(eqn.outvars[0].aval, AbstractOperator):
+            elif getattr(primitive, "prim_type", "") == "operator":
                 outvals = self.interpret_operation_eqn(eqn)
-            elif isinstance(eqn.outvars[0].aval, AbstractMeasurement):
+            elif getattr(primitive, "prim_type", "") == "measurement":
                 outvals = self.interpret_measurement_eqn(eqn)
             else:
                 invals = [self.read(invar) for invar in eqn.invars]
-                outvals = eqn.primitive.bind(*invals, **eqn.params)
+                outvals = primitive.bind(*invals, **eqn.params)
 
-            if not eqn.primitive.multiple_results:
+            if not primitive.multiple_results:
                 outvals = [outvals]
             for outvar, outval in zip(eqn.outvars, outvals, strict=True):
                 self._env[outvar] = outval
diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py
index 482f692df69..ba7c3846693 100644
--- a/pennylane/capture/capture_diff.py
+++ b/pennylane/capture/capture_diff.py
@@ -24,34 +24,6 @@
     has_jax = False
 
 
-@lru_cache
-def create_non_interpreted_prim():
-    """Create a primitive type ``NonInterpPrimitive``, which binds to JAX's JVPTrace
-    and BatchTrace objects like a standard Python function and otherwise behaves like jax.core.Primitive.
-    """
-
-    if not has_jax:  # pragma: no cover
-        return None
-
-    # pylint: disable=too-few-public-methods
-    class NonInterpPrimitive(jax.core.Primitive):
-        """A subclass to JAX's Primitive that works like a Python function
-        when evaluating JVPTracers and BatchTracers."""
-
-        def bind_with_trace(self, trace, args, params):
-            """Bind the ``NonInterpPrimitive`` with a trace.
-
-            If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call.
-            Otherwise, the bind call of JAX's standard Primitive is used."""
-            if isinstance(
-                trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)
-            ):
-                return self.impl(*args, **params)
-            return super().bind_with_trace(trace, args, params)
-
-    return NonInterpPrimitive
-
-
 @lru_cache
 def _get_grad_prim():
     """Create a primitive for gradient computations.
@@ -60,8 +32,11 @@ def _get_grad_prim():
     if not has_jax:  # pragma: no cover
         return None
 
-    grad_prim = create_non_interpreted_prim()("grad")
+    from .custom_primitives import NonInterpPrimitive  # pylint: disable=import-outside-toplevel
+
+    grad_prim = NonInterpPrimitive("grad")
     grad_prim.multiple_results = True  # pylint: disable=attribute-defined-outside-init
+    grad_prim.prim_type = "higher_order"
 
     # pylint: disable=too-many-arguments
     @grad_prim.def_impl
@@ -91,8 +66,14 @@ def _get_jacobian_prim():
     """Create a primitive for Jacobian computations.
     This primitive is used when capturing ``qml.jacobian``.
     """
-    jacobian_prim = create_non_interpreted_prim()("jacobian")
+    if not has_jax:  # pragma: no cover
+        return None
+
+    from .custom_primitives import NonInterpPrimitive  # pylint: disable=import-outside-toplevel
+
+    jacobian_prim = NonInterpPrimitive("jacobian")
     jacobian_prim.multiple_results = True  # pylint: disable=attribute-defined-outside-init
+    jacobian_prim.prim_type = "higher_order"
 
     # pylint: disable=too-many-arguments
     @jacobian_prim.def_impl
diff --git a/pennylane/capture/capture_measurements.py b/pennylane/capture/capture_measurements.py
index 59bf5490679..e23457bc7b7 100644
--- a/pennylane/capture/capture_measurements.py
+++ b/pennylane/capture/capture_measurements.py
@@ -128,7 +128,10 @@ def create_measurement_obs_primitive(
     if not has_jax:
         return None
 
-    primitive = jax.core.Primitive(name + "_obs")
+    from .custom_primitives import NonInterpPrimitive  # pylint: disable=import-outside-toplevel
+
+    primitive = NonInterpPrimitive(name + "_obs")
+    primitive.prim_type = "measurement"
 
     @primitive.def_impl
     def _(obs, **kwargs):
@@ -165,7 +168,10 @@ def create_measurement_mcm_primitive(
     if not has_jax:
         return None
 
-    primitive = jax.core.Primitive(name + "_mcm")
+    from .custom_primitives import NonInterpPrimitive  # pylint: disable=import-outside-toplevel
+
+    primitive = NonInterpPrimitive(name + "_mcm")
+    primitive.prim_type = "measurement"
 
     @primitive.def_impl
     def _(*mcms, single_mcm=True, **kwargs):
@@ -200,7 +206,10 @@ def create_measurement_wires_primitive(
     if not has_jax:
         return None
 
-    primitive = jax.core.Primitive(name + "_wires")
+    from .custom_primitives import NonInterpPrimitive  # pylint: disable=import-outside-toplevel
+
+    primitive = NonInterpPrimitive(name + "_wires")
+    primitive.prim_type = "measurement"
 
     @primitive.def_impl
     def _(*args, has_eigvals=False, **kwargs):
diff --git a/pennylane/capture/capture_operators.py b/pennylane/capture/capture_operators.py
index 23c98f38944..2124b5b9fe4 100644
--- a/pennylane/capture/capture_operators.py
+++ b/pennylane/capture/capture_operators.py
@@ -20,8 +20,6 @@
 
 import pennylane as qml
 
-from .capture_diff import create_non_interpreted_prim
-
 has_jax = True
 try:
     import jax
@@ -103,7 +101,10 @@ def create_operator_primitive(
     if not has_jax:
         return None
 
-    primitive = create_non_interpreted_prim()(operator_type.__name__)
+    from .custom_primitives import NonInterpPrimitive  # pylint: disable=import-outside-toplevel
+
+    primitive = NonInterpPrimitive(operator_type.__name__)
+    primitive.prim_type = "operator"
 
     @primitive.def_impl
     def _(*args, **kwargs):
diff --git a/pennylane/capture/custom_primitives.py b/pennylane/capture/custom_primitives.py
new file mode 100644
index 00000000000..183ae05771b
--- /dev/null
+++ b/pennylane/capture/custom_primitives.py
@@ -0,0 +1,64 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+#     http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This submodule offers custom primitives for the PennyLane capture module.
+"""
+from enum import Enum
+from typing import Union
+
+import jax
+
+
+class PrimitiveType(Enum):
+    """Enum to define valid set of primitive classes"""
+
+    DEFAULT = "default"
+    OPERATOR = "operator"
+    MEASUREMENT = "measurement"
+    HIGHER_ORDER = "higher_order"
+    TRANSFORM = "transform"
+
+
+# pylint: disable=too-few-public-methods,abstract-method
+class QmlPrimitive(jax.core.Primitive):
+    """A subclass for JAX's Primitive that differentiates between different
+    classes of primitives."""
+
+    _prim_type: PrimitiveType = PrimitiveType.DEFAULT
+
+    @property
+    def prim_type(self):
+        """Value of Enum representing the primitive type to differentiate between various
+        sets of PennyLane primitives."""
+        return self._prim_type.value
+
+    @prim_type.setter
+    def prim_type(self, value: Union[str, PrimitiveType]):
+        """Setter for QmlPrimitive.prim_type."""
+        self._prim_type = PrimitiveType(value)
+
+
+# pylint: disable=too-few-public-methods,abstract-method
+class NonInterpPrimitive(QmlPrimitive):
+    """A subclass to JAX's Primitive that works like a Python function
+    when evaluating JVPTracers and BatchTracers."""
+
+    def bind_with_trace(self, trace, args, params):
+        """Bind the ``NonInterpPrimitive`` with a trace.
+
+        If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call.
+        Otherwise, the bind call of JAX's standard Primitive is used."""
+        if isinstance(trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)):
+            return self.impl(*args, **params)
+        return super().bind_with_trace(trace, args, params)
diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py
index 08d88988b79..797a7437abb 100644
--- a/pennylane/compiler/qjit_api.py
+++ b/pennylane/compiler/qjit_api.py
@@ -17,7 +17,6 @@
 from collections.abc import Callable
 
 import pennylane as qml
-from pennylane.capture.capture_diff import create_non_interpreted_prim
 from pennylane.capture.flatfn import FlatFn
 
 from .compiler import (
@@ -405,10 +404,14 @@ def _decorator(body_fn: Callable) -> Callable:
 def _get_while_loop_qfunc_prim():
     """Get the while_loop primitive for quantum functions."""
 
-    import jax  # pylint: disable=import-outside-toplevel
+    # pylint: disable=import-outside-toplevel
+    import jax
 
-    while_loop_prim = create_non_interpreted_prim()("while_loop")
+    from pennylane.capture.custom_primitives import NonInterpPrimitive
+
+    while_loop_prim = NonInterpPrimitive("while_loop")
     while_loop_prim.multiple_results = True
+    while_loop_prim.prim_type = "higher_order"
 
     @while_loop_prim.def_impl
     def _(*args, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice):
@@ -626,10 +629,14 @@ def _decorator(body_fn):
 def _get_for_loop_qfunc_prim():
     """Get the loop_for primitive for quantum functions."""
 
-    import jax  # pylint: disable=import-outside-toplevel
+    # pylint: disable=import-outside-toplevel
+    import jax
+
+    from pennylane.capture.custom_primitives import NonInterpPrimitive
 
-    for_loop_prim = create_non_interpreted_prim()("for_loop")
+    for_loop_prim = NonInterpPrimitive("for_loop")
     for_loop_prim.multiple_results = True
+    for_loop_prim.prim_type = "higher_order"
 
     # pylint: disable=too-many-arguments
     @for_loop_prim.def_impl
diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py
index 5cdcd8cd708..3c9bdc8f1a8 100644
--- a/pennylane/measurements/mid_measure.py
+++ b/pennylane/measurements/mid_measure.py
@@ -243,9 +243,12 @@ def _create_mid_measure_primitive():
         measurement.
 
     """
-    import jax  # pylint: disable=import-outside-toplevel
+    # pylint: disable=import-outside-toplevel
+    import jax
 
-    mid_measure_p = jax.core.Primitive("measure")
+    from pennylane.capture.custom_primitives import NonInterpPrimitive
+
+    mid_measure_p = NonInterpPrimitive("measure")
 
     @mid_measure_p.def_impl
     def _(wires, reset=False, postselect=None):
diff --git a/pennylane/ops/op_math/adjoint.py b/pennylane/ops/op_math/adjoint.py
index 400f2fc83c0..5bda04440d4 100644
--- a/pennylane/ops/op_math/adjoint.py
+++ b/pennylane/ops/op_math/adjoint.py
@@ -18,7 +18,6 @@
 from typing import Callable, overload
 
 import pennylane as qml
-from pennylane.capture.capture_diff import create_non_interpreted_prim
 from pennylane.compiler import compiler
 from pennylane.math import conj, moveaxis, transpose
 from pennylane.operation import Observable, Operation, Operator
@@ -190,10 +189,14 @@ def create_adjoint_op(fn, lazy):
 def _get_adjoint_qfunc_prim():
     """See capture/explanations.md : Higher Order primitives for more information on this code."""
     # if capture is enabled, jax should be installed
-    import jax  # pylint: disable=import-outside-toplevel
+    # pylint: disable=import-outside-toplevel
+    import jax
+
+    from pennylane.capture.custom_primitives import NonInterpPrimitive
 
-    adjoint_prim = create_non_interpreted_prim()("adjoint_transform")
+    adjoint_prim = NonInterpPrimitive("adjoint_transform")
     adjoint_prim.multiple_results = True
+    adjoint_prim.prim_type = "higher_order"
 
     @adjoint_prim.def_impl
     def _(*args, jaxpr, lazy, n_consts):
diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py
index deace92e73c..a15fdafff1d 100644
--- a/pennylane/ops/op_math/condition.py
+++ b/pennylane/ops/op_math/condition.py
@@ -20,7 +20,6 @@
 
 import pennylane as qml
 from pennylane import QueuingManager
-from pennylane.capture.capture_diff import create_non_interpreted_prim
 from pennylane.capture.flatfn import FlatFn
 from pennylane.compiler import compiler
 from pennylane.measurements import MeasurementValue
@@ -681,10 +680,14 @@ def _get_mcm_predicates(conditions: tuple[MeasurementValue]) -> list[Measurement
 def _get_cond_qfunc_prim():
     """Get the cond primitive for quantum functions."""
 
-    import jax  # pylint: disable=import-outside-toplevel
+    # pylint: disable=import-outside-toplevel
+    import jax
 
-    cond_prim = create_non_interpreted_prim()("cond")
+    from pennylane.capture.custom_primitives import NonInterpPrimitive
+
+    cond_prim = NonInterpPrimitive("cond")
     cond_prim.multiple_results = True
+    cond_prim.prim_type = "higher_order"
 
     @cond_prim.def_impl
     def _(*all_args, jaxpr_branches, consts_slices, args_slice):
diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py
index d49209660c6..17e62323223 100644
--- a/pennylane/ops/op_math/controlled.py
+++ b/pennylane/ops/op_math/controlled.py
@@ -28,7 +28,6 @@
 import pennylane as qml
 from pennylane import math as qmlmath
 from pennylane import operation
-from pennylane.capture.capture_diff import create_non_interpreted_prim
 from pennylane.compiler import compiler
 from pennylane.operation import Operator
 from pennylane.wires import Wires, WiresLike
@@ -233,10 +232,15 @@ def wrapper(*args, **kwargs):
 def _get_ctrl_qfunc_prim():
     """See capture/explanations.md : Higher Order primitives for more information on this code."""
     # if capture is enabled, jax should be installed
-    import jax  # pylint: disable=import-outside-toplevel
 
-    ctrl_prim = create_non_interpreted_prim()("ctrl_transform")
+    # pylint: disable=import-outside-toplevel
+    import jax
+
+    from pennylane.capture.custom_primitives import NonInterpPrimitive
+
+    ctrl_prim = NonInterpPrimitive("ctrl_transform")
     ctrl_prim.multiple_results = True
+    ctrl_prim.prim_type = "higher_order"
 
     @ctrl_prim.def_impl
     def _(*args, n_control, jaxpr, control_values, work_wires, n_consts):
diff --git a/pennylane/transforms/core/transform_dispatcher.py b/pennylane/transforms/core/transform_dispatcher.py
index e00bda09c8d..1cefb724cd2 100644
--- a/pennylane/transforms/core/transform_dispatcher.py
+++ b/pennylane/transforms/core/transform_dispatcher.py
@@ -540,12 +540,13 @@ def final_transform(self):
 def _create_transform_primitive(name):
     try:
         # pylint: disable=import-outside-toplevel
-        import jax
+        from pennylane.capture.custom_primitives import NonInterpPrimitive
     except ImportError:
         return None
 
-    transform_prim = jax.core.Primitive(name + "_transform")
+    transform_prim = NonInterpPrimitive(name + "_transform")
     transform_prim.multiple_results = True
+    transform_prim.prim_type = "transform"
 
     @transform_prim.def_impl
     def _(
diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py
index 418e68941a6..85dc0320fb7 100644
--- a/pennylane/transforms/optimization/cancel_inverses.py
+++ b/pennylane/transforms/optimization/cancel_inverses.py
@@ -70,7 +70,7 @@ def _get_plxpr_cancel_inverses():  # pylint: disable=missing-function-docstring,
         # pylint: disable=import-outside-toplevel
         from jax import make_jaxpr
 
-        from pennylane.capture import AbstractMeasurement, AbstractOperator, PlxprInterpreter
+        from pennylane.capture import PlxprInterpreter
         from pennylane.operation import Operator
     except ImportError:  # pragma: no cover
         return None, None
@@ -204,15 +204,15 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list:
                     self.interpret_all_previous_ops()
                     invals = [self.read(invar) for invar in eqn.invars]
                     outvals = custom_handler(self, *invals, **eqn.params)
-                elif len(eqn.outvars) > 0 and isinstance(eqn.outvars[0].aval, AbstractOperator):
+                elif getattr(eqn.primitive, "prim_type", "") == "operator":
                     outvals = self.interpret_operation_eqn(eqn)
-                elif len(eqn.outvars) > 0 and isinstance(eqn.outvars[0].aval, AbstractMeasurement):
+                elif getattr(eqn.primitive, "prim_type", "") == "measurement":
                     self.interpret_all_previous_ops()
                     outvals = self.interpret_measurement_eqn(eqn)
                 else:
                     # Transform primitives don't have custom handlers, so we check for them here
                     # to purge the stored ops in self.previous_ops
-                    if eqn.primitive.name.endswith("_transform"):
+                    if getattr(eqn.primitive, "prim_type", "") == "transform":
                         self.interpret_all_previous_ops()
                     invals = [self.read(invar) for invar in eqn.invars]
                     outvals = eqn.primitive.bind(*invals, **eqn.params)
diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py
index e56b5dd0583..d4b23f2cb6a 100644
--- a/pennylane/workflow/_capture_qnode.py
+++ b/pennylane/workflow/_capture_qnode.py
@@ -117,6 +117,7 @@
 
 import pennylane as qml
 from pennylane.capture import FlatFn
+from pennylane.capture.custom_primitives import QmlPrimitive
 from pennylane.typing import TensorLike
 
 
@@ -177,8 +178,9 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=(
     return shapes
 
 
-qnode_prim = jax.core.Primitive("qnode")
+qnode_prim = QmlPrimitive("qnode")
 qnode_prim.multiple_results = True
+qnode_prim.prim_type = "higher_order"
 
 
 # pylint: disable=too-many-arguments, unused-argument
diff --git a/tests/capture/test_custom_primitives.py b/tests/capture/test_custom_primitives.py
new file mode 100644
index 00000000000..3d35e3e57e4
--- /dev/null
+++ b/tests/capture/test_custom_primitives.py
@@ -0,0 +1,48 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+#     http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Unit tests for PennyLane custom primitives.
+"""
+# pylint: disable=wrong-import-position
+import pytest
+
+jax = pytest.importorskip("jax")
+
+from pennylane.capture.custom_primitives import PrimitiveType, QmlPrimitive
+
+pytestmark = pytest.mark.jax
+
+
+def test_qml_primitive_prim_type_default():
+    """Test that the default prim_type of a QmlPrimitive is set correctly."""
+    prim = QmlPrimitive("primitive")
+    assert prim._prim_type == PrimitiveType("default")  # pylint: disable=protected-access
+    assert prim.prim_type == "default"
+
+
+@pytest.mark.parametrize("cast_in_enum", [True, False])
+@pytest.mark.parametrize("prim_type", ["operator", "measurement", "transform", "higher_order"])
+def test_qml_primitive_prim_type_setter(prim_type, cast_in_enum):
+    """Test that the QmlPrimitive.prim_type setter works correctly"""
+    prim = QmlPrimitive("primitive")
+    prim.prim_type = PrimitiveType(prim_type) if cast_in_enum else prim_type
+    assert prim._prim_type == PrimitiveType(prim_type)  # pylint: disable=protected-access
+    assert prim.prim_type == prim_type
+
+
+def test_qml_primitive_prim_type_setter_invalid():
+    """Test that setting an invalid prim_type raises an error"""
+    prim = QmlPrimitive("primitive")
+    with pytest.raises(ValueError, match="not a valid PrimitiveType"):
+        prim.prim_type = "blah"
diff --git a/tests/capture/test_switches.py b/tests/capture/test_switches.py
index 52f50321740..72a170205e7 100644
--- a/tests/capture/test_switches.py
+++ b/tests/capture/test_switches.py
@@ -32,10 +32,15 @@ def test_switches_with_jax():
 
 def test_switches_without_jax():
     """Test switches and status reporting function."""
-
-    assert qml.capture.enabled() is False
-    with pytest.raises(ImportError, match="plxpr requires JAX to be installed."):
-        qml.capture.enable()
-    assert qml.capture.enabled() is False
-    assert qml.capture.disable() is None
-    assert qml.capture.enabled() is False
+    # We want to skip the test if jax is installed
+    try:
+        # pylint: disable=import-outside-toplevel, unused-import
+        import jax
+    except ImportError:
+
+        assert qml.capture.enabled() is False
+        with pytest.raises(ImportError, match="plxpr requires JAX to be installed."):
+            qml.capture.enable()
+        assert qml.capture.enabled() is False
+        assert qml.capture.disable() is None
+        assert qml.capture.enabled() is False