From 09294307eab5e8544c6410fa291413959d7b3f79 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 3 Jan 2025 17:19:36 -0500 Subject: [PATCH 1/7] add devicE_vjp validation --- doc/releases/changelog-dev.md | 2 ++ pennylane/workflow/resolution.py | 32 ++++++++----------- tests/test_qnode.py | 19 +++++++++++ .../workflow/test_resolve_execution_config.py | 23 +++++++++++++ 4 files changed, 58 insertions(+), 18 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 71ad0fe840f..e760e0ba119 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -625,6 +625,8 @@ same information.

Bug fixes 🐛

+* Adds validation so the device vjp is only used when the device actually supports it. + * `qml.counts` returns all outcomes when the `all_outcomes` argument is `True` and mid-circuit measurements are present. [(#6732)](https://github.com/PennyLaneAI/pennylane/pull/6732) diff --git a/pennylane/workflow/resolution.py b/pennylane/workflow/resolution.py index 502a9810ed6..f269bbe7d68 100644 --- a/pennylane/workflow/resolution.py +++ b/pennylane/workflow/resolution.py @@ -14,7 +14,7 @@ """This module contains the necessary helper functions for setting up the workflow for execution. """ -from collections.abc import Callable + from dataclasses import replace from typing import Literal, Optional, Union, get_args @@ -86,9 +86,7 @@ def _use_tensorflow_autograph(): return not tf.executing_eagerly() -def _resolve_interface( - interface: Union[str, Interface, None], tapes: QuantumScriptBatch -) -> Interface: +def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBatch) -> Interface: """Helper function to resolve an interface based on a set of tapes. Args: @@ -244,26 +242,24 @@ def _resolve_execution_config( updated_values = {} updated_values["gradient_keyword_arguments"] = dict(execution_config.gradient_keyword_arguments) - if execution_config.interface in {Interface.JAX, Interface.JAX_JIT} and not isinstance( - execution_config.gradient_method, Callable - ): - updated_values["grad_on_execution"] = False - - if execution_config.use_device_jacobian_product and isinstance( - device, qml.devices.LegacyDeviceFacade - ): - raise qml.QuantumFunctionError( - "device provided jacobian products are not compatible with the old device interface." - ) - if ( "lightning" in device.name - and (transform_program and qml.metric_tensor in transform_program) + and transform_program + and qml.metric_tensor in transform_program and execution_config.gradient_method == "best" ): execution_config = replace(execution_config, gradient_method=qml.gradients.param_shift) + else: + execution_config = _resolve_diff_method(execution_config, device, tape=tapes[0]) - execution_config = _resolve_diff_method(execution_config, device, tape=tapes[0]) + if execution_config.use_device_jacobian_product and not device.supports_vjp( + execution_config, tapes[0] + ): + raise qml.QuantumFunctionError( + f"device_vjp=True is not supported for device {device}," + f" diff_method {execution_config.gradient_method}," + " and the provided circuit." + ) if execution_config.gradient_method is qml.gradients.param_shift_cv: updated_values["gradient_keyword_arguments"]["dev"] = device diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 22f16a45ff9..cd51a20ab45 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -873,6 +873,7 @@ def circuit(x, y): assert np.allclose(res, expected, atol=tol, rtol=0) + # pylint: disable=too-many-positional-arguments @pytest.mark.parametrize("dev_name", ["default.qubit", "default.mixed"]) @pytest.mark.parametrize("first_par", np.linspace(0.15, np.pi - 0.3, 3)) @pytest.mark.parametrize("sec_par", np.linspace(0.15, np.pi - 0.3, 3)) @@ -1170,6 +1171,24 @@ def decomposition(self) -> list: res = qml.execute([tape], dev) assert qml.math.get_interface(res) == "numpy" + def test_error_device_vjp_unsuppoprted(self): + """Test that an error is raised in the device_vjp is unsupported.""" + + class DummyDev(qml.devices.Device): + + def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()): + return 0 + + def supports_vjp(self, execution_config=None, circuit=None) -> bool: + return execution_config and execution_config.gradient_method == "vjp_grad" + + @qml.qnode(DummyDev(), diff_method="parameter-shift", device_vjp=True) + def circuit(): + return qml.expval(qml.Z(0)) + + with pytest.raises(qml.QuantumFunctionError, match="device_vjp=True is not supported"): + circuit() + class TestShots: """Unit tests for specifying shots per call.""" diff --git a/tests/workflow/test_resolve_execution_config.py b/tests/workflow/test_resolve_execution_config.py index 87a67c8292e..3d18bb5e5a1 100644 --- a/tests/workflow/test_resolve_execution_config.py +++ b/tests/workflow/test_resolve_execution_config.py @@ -116,3 +116,26 @@ def test_jax_jit_interface(): expected_mcm_config = MCMConfig(mcm_method="deferred", postselect_mode="fill-shots") assert resolved_config.mcm_config == expected_mcm_config + + +def test_no_device_vjp_if_not_supported(): + """Test that an error is raised for device_vjp=True if the device does not support it.""" + + class DummyDev(qml.devices.Device): + + def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()): + return 0 + + def supports_vjp(self, execution_config=None, circuit=None) -> bool: + return execution_config and execution_config.gradient_method == "vjp_grad" + + config_vjp_grad = ExecutionConfig(use_device_jacobian_product=True, gradient_method="vjp_grad") + tape = qml.tape.QuantumScript() + # no error + _ = _resolve_execution_config(config_vjp_grad, DummyDev(), (tape,)) + + config_parameter_shift = ExecutionConfig( + use_device_jacobian_product=True, gradient_method="parameter-shift" + ) + with pytest.raises(qml.QuantumFunctionError, match="device_vjp=True is not supported"): + _resolve_execution_config(config_parameter_shift, DummyDev(), (tape,)) From 76ccc18f0ad19f3bfd910e6398d6ae7fbd21be4a Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Fri, 3 Jan 2025 17:20:43 -0500 Subject: [PATCH 2/7] Update doc/releases/changelog-dev.md --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index e760e0ba119..5d0e5ea7690 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -626,6 +626,7 @@ same information.

Bug fixes 🐛

* Adds validation so the device vjp is only used when the device actually supports it. + [(#6755)](https://github.com/PennyLaneAI/pennylane/pull/6755/) * `qml.counts` returns all outcomes when the `all_outcomes` argument is `True` and mid-circuit measurements are present. [(#6732)](https://github.com/PennyLaneAI/pennylane/pull/6732) From ad67fa191b18229064689437d1979d59ccfe8e57 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 3 Jan 2025 17:24:59 -0500 Subject: [PATCH 3/7] merging issues? --- pennylane/workflow/resolution.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pennylane/workflow/resolution.py b/pennylane/workflow/resolution.py index f269bbe7d68..32db341383b 100644 --- a/pennylane/workflow/resolution.py +++ b/pennylane/workflow/resolution.py @@ -16,7 +16,7 @@ """ from dataclasses import replace -from typing import Literal, Optional, Union, get_args +from typing import Callable, Literal, Optional, Union, get_args import pennylane as qml from pennylane.logging import debug_logger @@ -242,6 +242,11 @@ def _resolve_execution_config( updated_values = {} updated_values["gradient_keyword_arguments"] = dict(execution_config.gradient_keyword_arguments) + if execution_config.interface in {Interface.JAX, Interface.JAX_JIT} and not isinstance( + execution_config.gradient_method, Callable + ): + updated_values["grad_on_execution"] = False + if ( "lightning" in device.name and transform_program From ffa6acb5d28dc44da2fe3265fab131726e311651 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Fri, 3 Jan 2025 17:26:37 -0500 Subject: [PATCH 4/7] Apply suggestions from code review --- pennylane/workflow/resolution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pennylane/workflow/resolution.py b/pennylane/workflow/resolution.py index 32db341383b..05261e4e85c 100644 --- a/pennylane/workflow/resolution.py +++ b/pennylane/workflow/resolution.py @@ -14,9 +14,9 @@ """This module contains the necessary helper functions for setting up the workflow for execution. """ - +from collections.abc import Callable from dataclasses import replace -from typing import Callable, Literal, Optional, Union, get_args +from typing import Literal, Optional, Union, get_args import pennylane as qml from pennylane.logging import debug_logger From 9f0faff772951d80464dbeb74c8b9fae96898993 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Fri, 3 Jan 2025 17:27:42 -0500 Subject: [PATCH 5/7] Update pennylane/workflow/resolution.py --- pennylane/workflow/resolution.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pennylane/workflow/resolution.py b/pennylane/workflow/resolution.py index 05261e4e85c..77d85aae9b3 100644 --- a/pennylane/workflow/resolution.py +++ b/pennylane/workflow/resolution.py @@ -254,8 +254,7 @@ def _resolve_execution_config( and execution_config.gradient_method == "best" ): execution_config = replace(execution_config, gradient_method=qml.gradients.param_shift) - else: - execution_config = _resolve_diff_method(execution_config, device, tape=tapes[0]) + execution_config = _resolve_diff_method(execution_config, device, tape=tapes[0]) if execution_config.use_device_jacobian_product and not device.supports_vjp( execution_config, tapes[0] From 46b1fed08cdd357810bcde2088ddc1b5538d5f84 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 3 Jan 2025 17:45:40 -0500 Subject: [PATCH 6/7] oops --- tests/test_qnode.py | 3 +++ tests/workflow/test_resolve_execution_config.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index cd51a20ab45..5ac318333a3 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1179,6 +1179,9 @@ class DummyDev(qml.devices.Device): def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()): return 0 + def supports_derivatives(self, execution_config=None, circuit=None): + return execution_config and execution_config.gradient_method == "vjp_grad" + def supports_vjp(self, execution_config=None, circuit=None) -> bool: return execution_config and execution_config.gradient_method == "vjp_grad" diff --git a/tests/workflow/test_resolve_execution_config.py b/tests/workflow/test_resolve_execution_config.py index 3d18bb5e5a1..4ea8e05556d 100644 --- a/tests/workflow/test_resolve_execution_config.py +++ b/tests/workflow/test_resolve_execution_config.py @@ -126,6 +126,9 @@ class DummyDev(qml.devices.Device): def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()): return 0 + def supports_derivatives(self, execution_config=None, circuit=None): + return execution_config and execution_config.gradient_method == "vjp_grad" + def supports_vjp(self, execution_config=None, circuit=None) -> bool: return execution_config and execution_config.gradient_method == "vjp_grad" From ca8fc66e1f721d8b38882828a108b009772294b3 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Mon, 6 Jan 2025 09:25:17 -0500 Subject: [PATCH 7/7] Update tests/workflow/test_resolve_execution_config.py --- tests/workflow/test_resolve_execution_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/workflow/test_resolve_execution_config.py b/tests/workflow/test_resolve_execution_config.py index 4ea8e05556d..96e7cb9f595 100644 --- a/tests/workflow/test_resolve_execution_config.py +++ b/tests/workflow/test_resolve_execution_config.py @@ -118,6 +118,7 @@ def test_jax_jit_interface(): assert resolved_config.mcm_config == expected_mcm_config +# pylint: disable=unused-argument def test_no_device_vjp_if_not_supported(): """Test that an error is raised for device_vjp=True if the device does not support it."""