diff --git a/doc/releases/changelog-0.40.0.md b/doc/releases/changelog-0.40.0.md index b381ea18ed0..e3769326619 100644 --- a/doc/releases/changelog-0.40.0.md +++ b/doc/releases/changelog-0.40.0.md @@ -625,6 +625,9 @@ same information.

Bug fixes 🐛

+* Adds validation so the device vjp is only used when the device actually supports it. + [(#6755)](https://github.com/PennyLaneAI/pennylane/pull/6755/) + * `qml.counts` returns all outcomes when the `all_outcomes` argument is `True` and mid-circuit measurements are present. [(#6732)](https://github.com/PennyLaneAI/pennylane/pull/6732) diff --git a/pennylane/workflow/resolution.py b/pennylane/workflow/resolution.py index 502a9810ed6..77d85aae9b3 100644 --- a/pennylane/workflow/resolution.py +++ b/pennylane/workflow/resolution.py @@ -86,9 +86,7 @@ def _use_tensorflow_autograph(): return not tf.executing_eagerly() -def _resolve_interface( - interface: Union[str, Interface, None], tapes: QuantumScriptBatch -) -> Interface: +def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBatch) -> Interface: """Helper function to resolve an interface based on a set of tapes. Args: @@ -249,22 +247,24 @@ def _resolve_execution_config( ): updated_values["grad_on_execution"] = False - if execution_config.use_device_jacobian_product and isinstance( - device, qml.devices.LegacyDeviceFacade - ): - raise qml.QuantumFunctionError( - "device provided jacobian products are not compatible with the old device interface." - ) - if ( "lightning" in device.name - and (transform_program and qml.metric_tensor in transform_program) + and transform_program + and qml.metric_tensor in transform_program and execution_config.gradient_method == "best" ): execution_config = replace(execution_config, gradient_method=qml.gradients.param_shift) - execution_config = _resolve_diff_method(execution_config, device, tape=tapes[0]) + if execution_config.use_device_jacobian_product and not device.supports_vjp( + execution_config, tapes[0] + ): + raise qml.QuantumFunctionError( + f"device_vjp=True is not supported for device {device}," + f" diff_method {execution_config.gradient_method}," + " and the provided circuit." + ) + if execution_config.gradient_method is qml.gradients.param_shift_cv: updated_values["gradient_keyword_arguments"]["dev"] = device diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 22f16a45ff9..5ac318333a3 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -873,6 +873,7 @@ def circuit(x, y): assert np.allclose(res, expected, atol=tol, rtol=0) + # pylint: disable=too-many-positional-arguments @pytest.mark.parametrize("dev_name", ["default.qubit", "default.mixed"]) @pytest.mark.parametrize("first_par", np.linspace(0.15, np.pi - 0.3, 3)) @pytest.mark.parametrize("sec_par", np.linspace(0.15, np.pi - 0.3, 3)) @@ -1170,6 +1171,27 @@ def decomposition(self) -> list: res = qml.execute([tape], dev) assert qml.math.get_interface(res) == "numpy" + def test_error_device_vjp_unsuppoprted(self): + """Test that an error is raised in the device_vjp is unsupported.""" + + class DummyDev(qml.devices.Device): + + def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()): + return 0 + + def supports_derivatives(self, execution_config=None, circuit=None): + return execution_config and execution_config.gradient_method == "vjp_grad" + + def supports_vjp(self, execution_config=None, circuit=None) -> bool: + return execution_config and execution_config.gradient_method == "vjp_grad" + + @qml.qnode(DummyDev(), diff_method="parameter-shift", device_vjp=True) + def circuit(): + return qml.expval(qml.Z(0)) + + with pytest.raises(qml.QuantumFunctionError, match="device_vjp=True is not supported"): + circuit() + class TestShots: """Unit tests for specifying shots per call.""" diff --git a/tests/workflow/test_resolve_execution_config.py b/tests/workflow/test_resolve_execution_config.py index 87a67c8292e..96e7cb9f595 100644 --- a/tests/workflow/test_resolve_execution_config.py +++ b/tests/workflow/test_resolve_execution_config.py @@ -116,3 +116,30 @@ def test_jax_jit_interface(): expected_mcm_config = MCMConfig(mcm_method="deferred", postselect_mode="fill-shots") assert resolved_config.mcm_config == expected_mcm_config + + +# pylint: disable=unused-argument +def test_no_device_vjp_if_not_supported(): + """Test that an error is raised for device_vjp=True if the device does not support it.""" + + class DummyDev(qml.devices.Device): + + def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()): + return 0 + + def supports_derivatives(self, execution_config=None, circuit=None): + return execution_config and execution_config.gradient_method == "vjp_grad" + + def supports_vjp(self, execution_config=None, circuit=None) -> bool: + return execution_config and execution_config.gradient_method == "vjp_grad" + + config_vjp_grad = ExecutionConfig(use_device_jacobian_product=True, gradient_method="vjp_grad") + tape = qml.tape.QuantumScript() + # no error + _ = _resolve_execution_config(config_vjp_grad, DummyDev(), (tape,)) + + config_parameter_shift = ExecutionConfig( + use_device_jacobian_product=True, gradient_method="parameter-shift" + ) + with pytest.raises(qml.QuantumFunctionError, match="device_vjp=True is not supported"): + _resolve_execution_config(config_parameter_shift, DummyDev(), (tape,))