From 3e1521bdef05235d915ddb8273ea177fcf62d755 Mon Sep 17 00:00:00 2001 From: Mudit Pandey <mudit.pandey@xanadu.ai> Date: Wed, 22 Jan 2025 11:09:07 -0500 Subject: [PATCH] [Capture] Add a `QmlPrimitive` class to differentiate between different types of primitives (#6847) This PR adds a `QmlPrimitive` subclass of `jax.core.Primitive`. This class contains a `prim_type` property set using a new `PrimitiveType` enum. `PrimitiveType`s currently available are "default", "operator", "measurement", "transform", and "higher_order". This can be made more or less fine grained as needed, but should be enough to differentiate between different types of primitives for now. Additionally, this PR: * updates `NonInterpPrimitive` to be a subclass of `QmlPrimitive` * updates all existing PennyLane primitives to be either `QmlPrimitive` or `NonInterpPrimitive`. See [this comment](https://github.com/PennyLaneAI/pennylane/pull/6851#discussion_r1922462699) to see the logic used to determine which `Primitive` subclass is used for each primitive. * updates `PlxprInterpreter.eval` and `CancelInversesInterpreter.eval` to use this `prim_type` property. [sc-82420] --------- Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai> --- doc/releases/changelog-dev.md | 6 ++ pennylane/capture/base_interpreter.py | 13 ++-- pennylane/capture/capture_diff.py | 41 ++++-------- pennylane/capture/capture_measurements.py | 15 ++++- pennylane/capture/capture_operators.py | 7 +- pennylane/capture/custom_primitives.py | 64 +++++++++++++++++++ pennylane/compiler/qjit_api.py | 17 +++-- pennylane/measurements/mid_measure.py | 7 +- pennylane/ops/op_math/adjoint.py | 9 ++- pennylane/ops/op_math/condition.py | 9 ++- pennylane/ops/op_math/controlled.py | 10 ++- .../transforms/core/transform_dispatcher.py | 5 +- .../optimization/cancel_inverses.py | 8 +-- pennylane/workflow/_capture_qnode.py | 4 +- tests/capture/test_custom_primitives.py | 48 ++++++++++++++ tests/capture/test_switches.py | 19 ++++-- 16 files changed, 209 insertions(+), 73 deletions(-) create mode 100644 pennylane/capture/custom_primitives.py create mode 100644 tests/capture/test_custom_primitives.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 6b08ff74363..4eb878d6478 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -90,6 +90,11 @@ <h3>Internal changes ⚙️</h3> +* Added a `QmlPrimitive` class that inherits `jax.core.Primitive` to a new `qml.capture.custom_primitives` module. + This class contains a `prim_type` property so that we can differentiate between different sets of PennyLane primitives. + Consequently, `QmlPrimitive` is now used to define all PennyLane primitives. + [(#6847)](https://github.com/PennyLaneAI/pennylane/pull/6847) + <h3>Documentation 📝</h3> * The docstrings for `qml.unary_mapping`, `qml.binary_mapping`, `qml.christiansen_mapping`, @@ -115,4 +120,5 @@ Diksha Dhawan, Pietropaolo Frisoni, Marcus Gisslén, Christina Lee, +Mudit Pandey, Andrija Paurevic diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 4af11cb6198..2b7314f7c2e 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -25,8 +25,6 @@ from .flatfn import FlatFn from .primitives import ( - AbstractMeasurement, - AbstractOperator, adjoint_transform_prim, cond_prim, ctrl_transform_prim, @@ -311,20 +309,21 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: self._env[constvar] = const for eqn in jaxpr.eqns: + primitive = eqn.primitive + custom_handler = self._primitive_registrations.get(primitive, None) - custom_handler = self._primitive_registrations.get(eqn.primitive, None) if custom_handler: invals = [self.read(invar) for invar in eqn.invars] outvals = custom_handler(self, *invals, **eqn.params) - elif isinstance(eqn.outvars[0].aval, AbstractOperator): + elif getattr(primitive, "prim_type", "") == "operator": outvals = self.interpret_operation_eqn(eqn) - elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): + elif getattr(primitive, "prim_type", "") == "measurement": outvals = self.interpret_measurement_eqn(eqn) else: invals = [self.read(invar) for invar in eqn.invars] - outvals = eqn.primitive.bind(*invals, **eqn.params) + outvals = primitive.bind(*invals, **eqn.params) - if not eqn.primitive.multiple_results: + if not primitive.multiple_results: outvals = [outvals] for outvar, outval in zip(eqn.outvars, outvals, strict=True): self._env[outvar] = outval diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py index 482f692df69..ba7c3846693 100644 --- a/pennylane/capture/capture_diff.py +++ b/pennylane/capture/capture_diff.py @@ -24,34 +24,6 @@ has_jax = False -@lru_cache -def create_non_interpreted_prim(): - """Create a primitive type ``NonInterpPrimitive``, which binds to JAX's JVPTrace - and BatchTrace objects like a standard Python function and otherwise behaves like jax.core.Primitive. - """ - - if not has_jax: # pragma: no cover - return None - - # pylint: disable=too-few-public-methods - class NonInterpPrimitive(jax.core.Primitive): - """A subclass to JAX's Primitive that works like a Python function - when evaluating JVPTracers and BatchTracers.""" - - def bind_with_trace(self, trace, args, params): - """Bind the ``NonInterpPrimitive`` with a trace. - - If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call. - Otherwise, the bind call of JAX's standard Primitive is used.""" - if isinstance( - trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace) - ): - return self.impl(*args, **params) - return super().bind_with_trace(trace, args, params) - - return NonInterpPrimitive - - @lru_cache def _get_grad_prim(): """Create a primitive for gradient computations. @@ -60,8 +32,11 @@ def _get_grad_prim(): if not has_jax: # pragma: no cover return None - grad_prim = create_non_interpreted_prim()("grad") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + grad_prim = NonInterpPrimitive("grad") grad_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init + grad_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments @grad_prim.def_impl @@ -91,8 +66,14 @@ def _get_jacobian_prim(): """Create a primitive for Jacobian computations. This primitive is used when capturing ``qml.jacobian``. """ - jacobian_prim = create_non_interpreted_prim()("jacobian") + if not has_jax: # pragma: no cover + return None + + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + jacobian_prim = NonInterpPrimitive("jacobian") jacobian_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init + jacobian_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments @jacobian_prim.def_impl diff --git a/pennylane/capture/capture_measurements.py b/pennylane/capture/capture_measurements.py index 59bf5490679..e23457bc7b7 100644 --- a/pennylane/capture/capture_measurements.py +++ b/pennylane/capture/capture_measurements.py @@ -128,7 +128,10 @@ def create_measurement_obs_primitive( if not has_jax: return None - primitive = jax.core.Primitive(name + "_obs") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(name + "_obs") + primitive.prim_type = "measurement" @primitive.def_impl def _(obs, **kwargs): @@ -165,7 +168,10 @@ def create_measurement_mcm_primitive( if not has_jax: return None - primitive = jax.core.Primitive(name + "_mcm") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(name + "_mcm") + primitive.prim_type = "measurement" @primitive.def_impl def _(*mcms, single_mcm=True, **kwargs): @@ -200,7 +206,10 @@ def create_measurement_wires_primitive( if not has_jax: return None - primitive = jax.core.Primitive(name + "_wires") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(name + "_wires") + primitive.prim_type = "measurement" @primitive.def_impl def _(*args, has_eigvals=False, **kwargs): diff --git a/pennylane/capture/capture_operators.py b/pennylane/capture/capture_operators.py index 23c98f38944..2124b5b9fe4 100644 --- a/pennylane/capture/capture_operators.py +++ b/pennylane/capture/capture_operators.py @@ -20,8 +20,6 @@ import pennylane as qml -from .capture_diff import create_non_interpreted_prim - has_jax = True try: import jax @@ -103,7 +101,10 @@ def create_operator_primitive( if not has_jax: return None - primitive = create_non_interpreted_prim()(operator_type.__name__) + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(operator_type.__name__) + primitive.prim_type = "operator" @primitive.def_impl def _(*args, **kwargs): diff --git a/pennylane/capture/custom_primitives.py b/pennylane/capture/custom_primitives.py new file mode 100644 index 00000000000..183ae05771b --- /dev/null +++ b/pennylane/capture/custom_primitives.py @@ -0,0 +1,64 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This submodule offers custom primitives for the PennyLane capture module. +""" +from enum import Enum +from typing import Union + +import jax + + +class PrimitiveType(Enum): + """Enum to define valid set of primitive classes""" + + DEFAULT = "default" + OPERATOR = "operator" + MEASUREMENT = "measurement" + HIGHER_ORDER = "higher_order" + TRANSFORM = "transform" + + +# pylint: disable=too-few-public-methods,abstract-method +class QmlPrimitive(jax.core.Primitive): + """A subclass for JAX's Primitive that differentiates between different + classes of primitives.""" + + _prim_type: PrimitiveType = PrimitiveType.DEFAULT + + @property + def prim_type(self): + """Value of Enum representing the primitive type to differentiate between various + sets of PennyLane primitives.""" + return self._prim_type.value + + @prim_type.setter + def prim_type(self, value: Union[str, PrimitiveType]): + """Setter for QmlPrimitive.prim_type.""" + self._prim_type = PrimitiveType(value) + + +# pylint: disable=too-few-public-methods,abstract-method +class NonInterpPrimitive(QmlPrimitive): + """A subclass to JAX's Primitive that works like a Python function + when evaluating JVPTracers and BatchTracers.""" + + def bind_with_trace(self, trace, args, params): + """Bind the ``NonInterpPrimitive`` with a trace. + + If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call. + Otherwise, the bind call of JAX's standard Primitive is used.""" + if isinstance(trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)): + return self.impl(*args, **params) + return super().bind_with_trace(trace, args, params) diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 08d88988b79..797a7437abb 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -17,7 +17,6 @@ from collections.abc import Callable import pennylane as qml -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.capture.flatfn import FlatFn from .compiler import ( @@ -405,10 +404,14 @@ def _decorator(body_fn: Callable) -> Callable: def _get_while_loop_qfunc_prim(): """Get the while_loop primitive for quantum functions.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax - while_loop_prim = create_non_interpreted_prim()("while_loop") + from pennylane.capture.custom_primitives import NonInterpPrimitive + + while_loop_prim = NonInterpPrimitive("while_loop") while_loop_prim.multiple_results = True + while_loop_prim.prim_type = "higher_order" @while_loop_prim.def_impl def _(*args, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice): @@ -626,10 +629,14 @@ def _decorator(body_fn): def _get_for_loop_qfunc_prim(): """Get the loop_for primitive for quantum functions.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax + + from pennylane.capture.custom_primitives import NonInterpPrimitive - for_loop_prim = create_non_interpreted_prim()("for_loop") + for_loop_prim = NonInterpPrimitive("for_loop") for_loop_prim.multiple_results = True + for_loop_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments @for_loop_prim.def_impl diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index 5cdcd8cd708..3c9bdc8f1a8 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -243,9 +243,12 @@ def _create_mid_measure_primitive(): measurement. """ - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax - mid_measure_p = jax.core.Primitive("measure") + from pennylane.capture.custom_primitives import NonInterpPrimitive + + mid_measure_p = NonInterpPrimitive("measure") @mid_measure_p.def_impl def _(wires, reset=False, postselect=None): diff --git a/pennylane/ops/op_math/adjoint.py b/pennylane/ops/op_math/adjoint.py index 400f2fc83c0..5bda04440d4 100644 --- a/pennylane/ops/op_math/adjoint.py +++ b/pennylane/ops/op_math/adjoint.py @@ -18,7 +18,6 @@ from typing import Callable, overload import pennylane as qml -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.compiler import compiler from pennylane.math import conj, moveaxis, transpose from pennylane.operation import Observable, Operation, Operator @@ -190,10 +189,14 @@ def create_adjoint_op(fn, lazy): def _get_adjoint_qfunc_prim(): """See capture/explanations.md : Higher Order primitives for more information on this code.""" # if capture is enabled, jax should be installed - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax + + from pennylane.capture.custom_primitives import NonInterpPrimitive - adjoint_prim = create_non_interpreted_prim()("adjoint_transform") + adjoint_prim = NonInterpPrimitive("adjoint_transform") adjoint_prim.multiple_results = True + adjoint_prim.prim_type = "higher_order" @adjoint_prim.def_impl def _(*args, jaxpr, lazy, n_consts): diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index deace92e73c..a15fdafff1d 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -20,7 +20,6 @@ import pennylane as qml from pennylane import QueuingManager -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.capture.flatfn import FlatFn from pennylane.compiler import compiler from pennylane.measurements import MeasurementValue @@ -681,10 +680,14 @@ def _get_mcm_predicates(conditions: tuple[MeasurementValue]) -> list[Measurement def _get_cond_qfunc_prim(): """Get the cond primitive for quantum functions.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax - cond_prim = create_non_interpreted_prim()("cond") + from pennylane.capture.custom_primitives import NonInterpPrimitive + + cond_prim = NonInterpPrimitive("cond") cond_prim.multiple_results = True + cond_prim.prim_type = "higher_order" @cond_prim.def_impl def _(*all_args, jaxpr_branches, consts_slices, args_slice): diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py index d49209660c6..17e62323223 100644 --- a/pennylane/ops/op_math/controlled.py +++ b/pennylane/ops/op_math/controlled.py @@ -28,7 +28,6 @@ import pennylane as qml from pennylane import math as qmlmath from pennylane import operation -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.compiler import compiler from pennylane.operation import Operator from pennylane.wires import Wires, WiresLike @@ -233,10 +232,15 @@ def wrapper(*args, **kwargs): def _get_ctrl_qfunc_prim(): """See capture/explanations.md : Higher Order primitives for more information on this code.""" # if capture is enabled, jax should be installed - import jax # pylint: disable=import-outside-toplevel - ctrl_prim = create_non_interpreted_prim()("ctrl_transform") + # pylint: disable=import-outside-toplevel + import jax + + from pennylane.capture.custom_primitives import NonInterpPrimitive + + ctrl_prim = NonInterpPrimitive("ctrl_transform") ctrl_prim.multiple_results = True + ctrl_prim.prim_type = "higher_order" @ctrl_prim.def_impl def _(*args, n_control, jaxpr, control_values, work_wires, n_consts): diff --git a/pennylane/transforms/core/transform_dispatcher.py b/pennylane/transforms/core/transform_dispatcher.py index e00bda09c8d..1cefb724cd2 100644 --- a/pennylane/transforms/core/transform_dispatcher.py +++ b/pennylane/transforms/core/transform_dispatcher.py @@ -540,12 +540,13 @@ def final_transform(self): def _create_transform_primitive(name): try: # pylint: disable=import-outside-toplevel - import jax + from pennylane.capture.custom_primitives import NonInterpPrimitive except ImportError: return None - transform_prim = jax.core.Primitive(name + "_transform") + transform_prim = NonInterpPrimitive(name + "_transform") transform_prim.multiple_results = True + transform_prim.prim_type = "transform" @transform_prim.def_impl def _( diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py index 418e68941a6..85dc0320fb7 100644 --- a/pennylane/transforms/optimization/cancel_inverses.py +++ b/pennylane/transforms/optimization/cancel_inverses.py @@ -70,7 +70,7 @@ def _get_plxpr_cancel_inverses(): # pylint: disable=missing-function-docstring, # pylint: disable=import-outside-toplevel from jax import make_jaxpr - from pennylane.capture import AbstractMeasurement, AbstractOperator, PlxprInterpreter + from pennylane.capture import PlxprInterpreter from pennylane.operation import Operator except ImportError: # pragma: no cover return None, None @@ -204,15 +204,15 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: self.interpret_all_previous_ops() invals = [self.read(invar) for invar in eqn.invars] outvals = custom_handler(self, *invals, **eqn.params) - elif len(eqn.outvars) > 0 and isinstance(eqn.outvars[0].aval, AbstractOperator): + elif getattr(eqn.primitive, "prim_type", "") == "operator": outvals = self.interpret_operation_eqn(eqn) - elif len(eqn.outvars) > 0 and isinstance(eqn.outvars[0].aval, AbstractMeasurement): + elif getattr(eqn.primitive, "prim_type", "") == "measurement": self.interpret_all_previous_ops() outvals = self.interpret_measurement_eqn(eqn) else: # Transform primitives don't have custom handlers, so we check for them here # to purge the stored ops in self.previous_ops - if eqn.primitive.name.endswith("_transform"): + if getattr(eqn.primitive, "prim_type", "") == "transform": self.interpret_all_previous_ops() invals = [self.read(invar) for invar in eqn.invars] outvals = eqn.primitive.bind(*invals, **eqn.params) diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index e56b5dd0583..d4b23f2cb6a 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -117,6 +117,7 @@ import pennylane as qml from pennylane.capture import FlatFn +from pennylane.capture.custom_primitives import QmlPrimitive from pennylane.typing import TensorLike @@ -177,8 +178,9 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=( return shapes -qnode_prim = jax.core.Primitive("qnode") +qnode_prim = QmlPrimitive("qnode") qnode_prim.multiple_results = True +qnode_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments, unused-argument diff --git a/tests/capture/test_custom_primitives.py b/tests/capture/test_custom_primitives.py new file mode 100644 index 00000000000..3d35e3e57e4 --- /dev/null +++ b/tests/capture/test_custom_primitives.py @@ -0,0 +1,48 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for PennyLane custom primitives. +""" +# pylint: disable=wrong-import-position +import pytest + +jax = pytest.importorskip("jax") + +from pennylane.capture.custom_primitives import PrimitiveType, QmlPrimitive + +pytestmark = pytest.mark.jax + + +def test_qml_primitive_prim_type_default(): + """Test that the default prim_type of a QmlPrimitive is set correctly.""" + prim = QmlPrimitive("primitive") + assert prim._prim_type == PrimitiveType("default") # pylint: disable=protected-access + assert prim.prim_type == "default" + + +@pytest.mark.parametrize("cast_in_enum", [True, False]) +@pytest.mark.parametrize("prim_type", ["operator", "measurement", "transform", "higher_order"]) +def test_qml_primitive_prim_type_setter(prim_type, cast_in_enum): + """Test that the QmlPrimitive.prim_type setter works correctly""" + prim = QmlPrimitive("primitive") + prim.prim_type = PrimitiveType(prim_type) if cast_in_enum else prim_type + assert prim._prim_type == PrimitiveType(prim_type) # pylint: disable=protected-access + assert prim.prim_type == prim_type + + +def test_qml_primitive_prim_type_setter_invalid(): + """Test that setting an invalid prim_type raises an error""" + prim = QmlPrimitive("primitive") + with pytest.raises(ValueError, match="not a valid PrimitiveType"): + prim.prim_type = "blah" diff --git a/tests/capture/test_switches.py b/tests/capture/test_switches.py index 52f50321740..72a170205e7 100644 --- a/tests/capture/test_switches.py +++ b/tests/capture/test_switches.py @@ -32,10 +32,15 @@ def test_switches_with_jax(): def test_switches_without_jax(): """Test switches and status reporting function.""" - - assert qml.capture.enabled() is False - with pytest.raises(ImportError, match="plxpr requires JAX to be installed."): - qml.capture.enable() - assert qml.capture.enabled() is False - assert qml.capture.disable() is None - assert qml.capture.enabled() is False + # We want to skip the test if jax is installed + try: + # pylint: disable=import-outside-toplevel, unused-import + import jax + except ImportError: + + assert qml.capture.enabled() is False + with pytest.raises(ImportError, match="plxpr requires JAX to be installed."): + qml.capture.enable() + assert qml.capture.enabled() is False + assert qml.capture.disable() is None + assert qml.capture.enabled() is False