From 721acb263c67ba7aa508257b214716f96858d377 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Fri, 24 Jan 2025 13:18:39 -0500 Subject: [PATCH] Fix `InterfaceEnum` object (#6877) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Context:** `QNode.interface` returns a `str` type, ```python @property def interface(self) -> str: """The interface used by the QNode""" return self._interface.value ``` Therefore, this equality will never be true, ```python from pennylane.math.interface_utils import Interface >>> "jax" == Interface.JAX False ``` Moreover, `str` cannot be compared with `Enum` objects at all. 🤯 **Description of the Change:** - Adjust the `InterfaceEnum` class to have dunder methods to overwrite comparisons. This way it will error out if you try to compare it with a string. - Add test coverage for the private helpers in `qnode.py` Now we have, ```python from pennylane.math.interface_utils import Interface >>> "jax" == Interface.JAX TypeError: Cannot compare Interface with str ``` **Benefits:** Logic is correct. **Possible Drawbacks:** None identified. [sc-83034] --------- Co-authored-by: Yushao Chen (Jerry) --- doc/releases/changelog-dev.md | 3 ++ pennylane/math/interface_utils.py | 11 ++++++- pennylane/workflow/qnode.py | 4 +-- pennylane/workflow/resolution.py | 11 ++++--- tests/math/test_functions.py | 12 ++++++++ tests/test_qnode.py | 50 +++++++++++++++++++++++++++++++ 6 files changed, 82 insertions(+), 9 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 88280392227..f5c4c9620da 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -103,6 +103,9 @@

Internal changes ⚙️

+* Improved the `InterfaceEnum` object to prevent direct comparisons to `str` objects. + [(#6877)](https://github.com/PennyLaneAI/pennylane/pull/6877) + * Added a `QmlPrimitive` class that inherits `jax.core.Primitive` to a new `qml.capture.custom_primitives` module. This class contains a `prim_type` property so that we can differentiate between different sets of PennyLane primitives. Consequently, `QmlPrimitive` is now used to define all PennyLane primitives. diff --git a/pennylane/math/interface_utils.py b/pennylane/math/interface_utils.py index 4fc401b9fd6..df6eface81d 100644 --- a/pennylane/math/interface_utils.py +++ b/pennylane/math/interface_utils.py @@ -46,6 +46,15 @@ def get_like(self): } return mapping[self] + def __eq__(self, interface): + if isinstance(interface, str): + raise TypeError("Cannot compare Interface with str") + return super().__eq__(interface) + + def __hash__(self): + # pylint: disable=useless-super-delegation + return super().__hash__() + InterfaceLike = Union[str, Interface, None] @@ -225,7 +234,7 @@ def get_canonical_interface_name(user_input: InterfaceLike) -> Interface: Interface: canonical interface """ - if user_input in SUPPORTED_INTERFACE_NAMES: + if isinstance(user_input, Interface) and user_input in SUPPORTED_INTERFACE_NAMES: return user_input try: return INTERFACE_MAP[user_input] diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index ea1a9e33900..0929bc5bb8d 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -60,11 +60,11 @@ def _convert_to_interface(result, interface: Interface): def _make_execution_config( circuit: Optional["QNode"], diff_method=None, mcm_config=None ) -> "qml.devices.ExecutionConfig": - circuit_interface = getattr(circuit, "interface", Interface.NUMPY) + circuit_interface = getattr(circuit, "interface", Interface.NUMPY.value) execute_kwargs = getattr(circuit, "execute_kwargs", {}) gradient_kwargs = getattr(circuit, "gradient_kwargs", {}) grad_on_execution = execute_kwargs.get("grad_on_execution") - if circuit_interface == Interface.JAX: + if circuit_interface in {Interface.JAX.value, Interface.JAX_JIT.value}: grad_on_execution = False elif grad_on_execution == "best": grad_on_execution = None diff --git a/pennylane/workflow/resolution.py b/pennylane/workflow/resolution.py index 77d85aae9b3..6f9de5b0e82 100644 --- a/pennylane/workflow/resolution.py +++ b/pennylane/workflow/resolution.py @@ -96,7 +96,6 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat Returns: Interface: resolved interface """ - interface = get_canonical_interface_name(interface) if interface == Interface.AUTO: @@ -104,11 +103,11 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat for tape in tapes: params.extend(tape.get_parameters(trainable_only=False)) interface = get_interface(*params) - if interface != Interface.NUMPY: - try: - interface = get_canonical_interface_name(interface) - except ValueError: - interface = Interface.NUMPY + try: + interface = get_canonical_interface_name(interface) + except ValueError: + # If the interface is not recognized, default to numpy, like networkx + interface = Interface.NUMPY if interface == Interface.TF and _use_tensorflow_autograph(): interface = Interface.TF_AUTOGRAPH if interface == Interface.JAX: diff --git a/tests/math/test_functions.py b/tests/math/test_functions.py index 54990b137e8..bcf1510d19b 100644 --- a/tests/math/test_functions.py +++ b/tests/math/test_functions.py @@ -1024,6 +1024,18 @@ def test_get_interface(t, interface): assert res == interface +# pylint: disable=too-few-public-methods +class TestInterfaceEnum: + """Test the Interface enum class""" + + def test_eq(self): + """Test that an error is raised if comparing to string""" + assert fn.Interface.NUMPY == fn.Interface.NUMPY + with pytest.raises(TypeError, match="Cannot compare Interface with str"): + # pylint: disable=pointless-statement + fn.Interface.NUMPY == "numpy" + + @pytest.mark.parametrize("t", test_data) def test_toarray(t): """Test that the toarray method correctly converts the input diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 5d270d1f9b8..68e91cd406b 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -29,6 +29,7 @@ from pennylane import qnode from pennylane.tape import QuantumScript, QuantumScriptBatch from pennylane.typing import PostprocessingFn +from pennylane.workflow.qnode import _make_execution_config def dummyfunc(): @@ -2073,3 +2074,52 @@ def circuit(x): circuit(qml.numpy.array(0.1)) assert circuit.interface == "auto" + + +class TestPrivateFunctions: + """Tests for private functions in the QNode class.""" + + def test_make_execution_config_with_no_qnode(self): + """Test that the _make_execution_config function correctly creates an execution config.""" + diff_method = "best" + mcm_config = qml.devices.MCMConfig(postselect_mode="fill-shots", mcm_method="deferred") + config = _make_execution_config(None, diff_method, mcm_config) + + expected_config = qml.devices.ExecutionConfig( + interface="numpy", + gradient_keyword_arguments={}, + use_device_jacobian_product=False, + grad_on_execution=None, + gradient_method=diff_method, + mcm_config=mcm_config, + ) + + assert config == expected_config + + @pytest.mark.parametrize("interface", ["autograd", "torch", "tf", "jax", "jax-jit"]) + def test_make_execution_config_with_qnode(self, interface): + """Test that a execution config is made correctly with no QNode.""" + if "jax" in interface: + grad_on_execution = False + else: + grad_on_execution = None + + @qml.qnode(qml.device("default.qubit"), interface=interface) + def circuit(): + qml.H(0) + return qml.probs() + + diff_method = "best" + mcm_config = qml.devices.MCMConfig(postselect_mode="fill-shots", mcm_method="deferred") + config = _make_execution_config(circuit, diff_method, mcm_config) + + expected_config = qml.devices.ExecutionConfig( + interface=interface, + gradient_keyword_arguments={}, + use_device_jacobian_product=False, + grad_on_execution=grad_on_execution, + gradient_method=diff_method, + mcm_config=mcm_config, + ) + + assert config == expected_config