Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device_vjp validation #6755

Merged
merged 11 commits into from
Jan 6, 2025
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@ same information.

<h3>Bug fixes 🐛</h3>

* Adds validation so the device vjp is only used when the device actually supports it.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
[(#6755)](https://github.com/PennyLaneAI/pennylane/pull/6755/)

* `qml.counts` returns all outcomes when the `all_outcomes` argument is `True` and mid-circuit measurements are present.
[(#6732)](https://github.com/PennyLaneAI/pennylane/pull/6732)

Expand Down
29 changes: 15 additions & 14 deletions pennylane/workflow/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
"""This module contains the necessary helper functions for setting up the workflow for execution.

"""
from collections.abc import Callable

albi3ro marked this conversation as resolved.
Show resolved Hide resolved
from dataclasses import replace
from typing import Literal, Optional, Union, get_args
from typing import Callable, Literal, Optional, Union, get_args
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

import pennylane as qml
from pennylane.logging import debug_logger
Expand Down Expand Up @@ -86,9 +86,7 @@ def _use_tensorflow_autograph():
return not tf.executing_eagerly()


def _resolve_interface(
interface: Union[str, Interface, None], tapes: QuantumScriptBatch
) -> Interface:
def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBatch) -> Interface:
"""Helper function to resolve an interface based on a set of tapes.

Args:
Expand Down Expand Up @@ -249,21 +247,24 @@ def _resolve_execution_config(
):
updated_values["grad_on_execution"] = False

if execution_config.use_device_jacobian_product and isinstance(
device, qml.devices.LegacyDeviceFacade
):
raise qml.QuantumFunctionError(
"device provided jacobian products are not compatible with the old device interface."
)

if (
"lightning" in device.name
and (transform_program and qml.metric_tensor in transform_program)
and transform_program
and qml.metric_tensor in transform_program
and execution_config.gradient_method == "best"
):
execution_config = replace(execution_config, gradient_method=qml.gradients.param_shift)
else:
execution_config = _resolve_diff_method(execution_config, device, tape=tapes[0])
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

execution_config = _resolve_diff_method(execution_config, device, tape=tapes[0])
if execution_config.use_device_jacobian_product and not device.supports_vjp(
execution_config, tapes[0]
):
raise qml.QuantumFunctionError(
f"device_vjp=True is not supported for device {device},"
f" diff_method {execution_config.gradient_method},"
" and the provided circuit."
)

if execution_config.gradient_method is qml.gradients.param_shift_cv:
updated_values["gradient_keyword_arguments"]["dev"] = device
Expand Down
19 changes: 19 additions & 0 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ def circuit(x, y):

assert np.allclose(res, expected, atol=tol, rtol=0)

# pylint: disable=too-many-positional-arguments
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.mixed"])
@pytest.mark.parametrize("first_par", np.linspace(0.15, np.pi - 0.3, 3))
@pytest.mark.parametrize("sec_par", np.linspace(0.15, np.pi - 0.3, 3))
Expand Down Expand Up @@ -1170,6 +1171,24 @@ def decomposition(self) -> list:
res = qml.execute([tape], dev)
assert qml.math.get_interface(res) == "numpy"

def test_error_device_vjp_unsuppoprted(self):
"""Test that an error is raised in the device_vjp is unsupported."""

class DummyDev(qml.devices.Device):

def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()):
return 0

def supports_vjp(self, execution_config=None, circuit=None) -> bool:
return execution_config and execution_config.gradient_method == "vjp_grad"

@qml.qnode(DummyDev(), diff_method="parameter-shift", device_vjp=True)
def circuit():
return qml.expval(qml.Z(0))

with pytest.raises(qml.QuantumFunctionError, match="device_vjp=True is not supported"):
circuit()


class TestShots:
"""Unit tests for specifying shots per call."""
Expand Down
23 changes: 23 additions & 0 deletions tests/workflow/test_resolve_execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,26 @@ def test_jax_jit_interface():
expected_mcm_config = MCMConfig(mcm_method="deferred", postselect_mode="fill-shots")

assert resolved_config.mcm_config == expected_mcm_config


def test_no_device_vjp_if_not_supported():
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Test that an error is raised for device_vjp=True if the device does not support it."""

class DummyDev(qml.devices.Device):

def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()):
return 0

def supports_vjp(self, execution_config=None, circuit=None) -> bool:
return execution_config and execution_config.gradient_method == "vjp_grad"

config_vjp_grad = ExecutionConfig(use_device_jacobian_product=True, gradient_method="vjp_grad")
tape = qml.tape.QuantumScript()
# no error
_ = _resolve_execution_config(config_vjp_grad, DummyDev(), (tape,))

config_parameter_shift = ExecutionConfig(
use_device_jacobian_product=True, gradient_method="parameter-shift"
)
with pytest.raises(qml.QuantumFunctionError, match="device_vjp=True is not supported"):
_resolve_execution_config(config_parameter_shift, DummyDev(), (tape,))
Loading