diff --git a/doc/releases/changelog-0.40.0.md b/doc/releases/changelog-0.40.0.md
index b381ea18ed0..e3769326619 100644
--- a/doc/releases/changelog-0.40.0.md
+++ b/doc/releases/changelog-0.40.0.md
@@ -625,6 +625,9 @@ same information.
Bug fixes 🐛
+* Adds validation so the device vjp is only used when the device actually supports it.
+ [(#6755)](https://github.com/PennyLaneAI/pennylane/pull/6755/)
+
* `qml.counts` returns all outcomes when the `all_outcomes` argument is `True` and mid-circuit measurements are present.
[(#6732)](https://github.com/PennyLaneAI/pennylane/pull/6732)
diff --git a/pennylane/workflow/resolution.py b/pennylane/workflow/resolution.py
index 502a9810ed6..77d85aae9b3 100644
--- a/pennylane/workflow/resolution.py
+++ b/pennylane/workflow/resolution.py
@@ -86,9 +86,7 @@ def _use_tensorflow_autograph():
return not tf.executing_eagerly()
-def _resolve_interface(
- interface: Union[str, Interface, None], tapes: QuantumScriptBatch
-) -> Interface:
+def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBatch) -> Interface:
"""Helper function to resolve an interface based on a set of tapes.
Args:
@@ -249,22 +247,24 @@ def _resolve_execution_config(
):
updated_values["grad_on_execution"] = False
- if execution_config.use_device_jacobian_product and isinstance(
- device, qml.devices.LegacyDeviceFacade
- ):
- raise qml.QuantumFunctionError(
- "device provided jacobian products are not compatible with the old device interface."
- )
-
if (
"lightning" in device.name
- and (transform_program and qml.metric_tensor in transform_program)
+ and transform_program
+ and qml.metric_tensor in transform_program
and execution_config.gradient_method == "best"
):
execution_config = replace(execution_config, gradient_method=qml.gradients.param_shift)
-
execution_config = _resolve_diff_method(execution_config, device, tape=tapes[0])
+ if execution_config.use_device_jacobian_product and not device.supports_vjp(
+ execution_config, tapes[0]
+ ):
+ raise qml.QuantumFunctionError(
+ f"device_vjp=True is not supported for device {device},"
+ f" diff_method {execution_config.gradient_method},"
+ " and the provided circuit."
+ )
+
if execution_config.gradient_method is qml.gradients.param_shift_cv:
updated_values["gradient_keyword_arguments"]["dev"] = device
diff --git a/tests/test_qnode.py b/tests/test_qnode.py
index 22f16a45ff9..5ac318333a3 100644
--- a/tests/test_qnode.py
+++ b/tests/test_qnode.py
@@ -873,6 +873,7 @@ def circuit(x, y):
assert np.allclose(res, expected, atol=tol, rtol=0)
+ # pylint: disable=too-many-positional-arguments
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.mixed"])
@pytest.mark.parametrize("first_par", np.linspace(0.15, np.pi - 0.3, 3))
@pytest.mark.parametrize("sec_par", np.linspace(0.15, np.pi - 0.3, 3))
@@ -1170,6 +1171,27 @@ def decomposition(self) -> list:
res = qml.execute([tape], dev)
assert qml.math.get_interface(res) == "numpy"
+ def test_error_device_vjp_unsuppoprted(self):
+ """Test that an error is raised in the device_vjp is unsupported."""
+
+ class DummyDev(qml.devices.Device):
+
+ def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()):
+ return 0
+
+ def supports_derivatives(self, execution_config=None, circuit=None):
+ return execution_config and execution_config.gradient_method == "vjp_grad"
+
+ def supports_vjp(self, execution_config=None, circuit=None) -> bool:
+ return execution_config and execution_config.gradient_method == "vjp_grad"
+
+ @qml.qnode(DummyDev(), diff_method="parameter-shift", device_vjp=True)
+ def circuit():
+ return qml.expval(qml.Z(0))
+
+ with pytest.raises(qml.QuantumFunctionError, match="device_vjp=True is not supported"):
+ circuit()
+
class TestShots:
"""Unit tests for specifying shots per call."""
diff --git a/tests/workflow/test_resolve_execution_config.py b/tests/workflow/test_resolve_execution_config.py
index 87a67c8292e..96e7cb9f595 100644
--- a/tests/workflow/test_resolve_execution_config.py
+++ b/tests/workflow/test_resolve_execution_config.py
@@ -116,3 +116,30 @@ def test_jax_jit_interface():
expected_mcm_config = MCMConfig(mcm_method="deferred", postselect_mode="fill-shots")
assert resolved_config.mcm_config == expected_mcm_config
+
+
+# pylint: disable=unused-argument
+def test_no_device_vjp_if_not_supported():
+ """Test that an error is raised for device_vjp=True if the device does not support it."""
+
+ class DummyDev(qml.devices.Device):
+
+ def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()):
+ return 0
+
+ def supports_derivatives(self, execution_config=None, circuit=None):
+ return execution_config and execution_config.gradient_method == "vjp_grad"
+
+ def supports_vjp(self, execution_config=None, circuit=None) -> bool:
+ return execution_config and execution_config.gradient_method == "vjp_grad"
+
+ config_vjp_grad = ExecutionConfig(use_device_jacobian_product=True, gradient_method="vjp_grad")
+ tape = qml.tape.QuantumScript()
+ # no error
+ _ = _resolve_execution_config(config_vjp_grad, DummyDev(), (tape,))
+
+ config_parameter_shift = ExecutionConfig(
+ use_device_jacobian_product=True, gradient_method="parameter-shift"
+ )
+ with pytest.raises(qml.QuantumFunctionError, match="device_vjp=True is not supported"):
+ _resolve_execution_config(config_parameter_shift, DummyDev(), (tape,))