Skip to content

Commit

Permalink
Daily rc sync to master (#6771)
Browse files Browse the repository at this point in the history
Co-authored-by: PietropaoloFrisoni <[email protected]>
Co-authored-by: Christina Lee <[email protected]>
Co-authored-by: GitHub Actions Bot <>
  • Loading branch information
3 people authored Jan 7, 2025
1 parent f2662b6 commit 8bdf289
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 12 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-0.40.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@ same information.

<h3>Bug fixes 🐛</h3>

* Adds validation so the device vjp is only used when the device actually supports it.
[(#6755)](https://github.com/PennyLaneAI/pennylane/pull/6755/)

* `qml.counts` returns all outcomes when the `all_outcomes` argument is `True` and mid-circuit measurements are present.
[(#6732)](https://github.com/PennyLaneAI/pennylane/pull/6732)

Expand Down
24 changes: 12 additions & 12 deletions pennylane/workflow/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def _use_tensorflow_autograph():
return not tf.executing_eagerly()


def _resolve_interface(
interface: Union[str, Interface, None], tapes: QuantumScriptBatch
) -> Interface:
def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBatch) -> Interface:
"""Helper function to resolve an interface based on a set of tapes.
Args:
Expand Down Expand Up @@ -249,22 +247,24 @@ def _resolve_execution_config(
):
updated_values["grad_on_execution"] = False

if execution_config.use_device_jacobian_product and isinstance(
device, qml.devices.LegacyDeviceFacade
):
raise qml.QuantumFunctionError(
"device provided jacobian products are not compatible with the old device interface."
)

if (
"lightning" in device.name
and (transform_program and qml.metric_tensor in transform_program)
and transform_program
and qml.metric_tensor in transform_program
and execution_config.gradient_method == "best"
):
execution_config = replace(execution_config, gradient_method=qml.gradients.param_shift)

execution_config = _resolve_diff_method(execution_config, device, tape=tapes[0])

if execution_config.use_device_jacobian_product and not device.supports_vjp(
execution_config, tapes[0]
):
raise qml.QuantumFunctionError(
f"device_vjp=True is not supported for device {device},"
f" diff_method {execution_config.gradient_method},"
" and the provided circuit."
)

if execution_config.gradient_method is qml.gradients.param_shift_cv:
updated_values["gradient_keyword_arguments"]["dev"] = device

Expand Down
22 changes: 22 additions & 0 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ def circuit(x, y):

assert np.allclose(res, expected, atol=tol, rtol=0)

# pylint: disable=too-many-positional-arguments
@pytest.mark.parametrize("dev_name", ["default.qubit", "default.mixed"])
@pytest.mark.parametrize("first_par", np.linspace(0.15, np.pi - 0.3, 3))
@pytest.mark.parametrize("sec_par", np.linspace(0.15, np.pi - 0.3, 3))
Expand Down Expand Up @@ -1170,6 +1171,27 @@ def decomposition(self) -> list:
res = qml.execute([tape], dev)
assert qml.math.get_interface(res) == "numpy"

def test_error_device_vjp_unsuppoprted(self):
"""Test that an error is raised in the device_vjp is unsupported."""

class DummyDev(qml.devices.Device):

def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()):
return 0

def supports_derivatives(self, execution_config=None, circuit=None):
return execution_config and execution_config.gradient_method == "vjp_grad"

def supports_vjp(self, execution_config=None, circuit=None) -> bool:
return execution_config and execution_config.gradient_method == "vjp_grad"

@qml.qnode(DummyDev(), diff_method="parameter-shift", device_vjp=True)
def circuit():
return qml.expval(qml.Z(0))

with pytest.raises(qml.QuantumFunctionError, match="device_vjp=True is not supported"):
circuit()


class TestShots:
"""Unit tests for specifying shots per call."""
Expand Down
27 changes: 27 additions & 0 deletions tests/workflow/test_resolve_execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,30 @@ def test_jax_jit_interface():
expected_mcm_config = MCMConfig(mcm_method="deferred", postselect_mode="fill-shots")

assert resolved_config.mcm_config == expected_mcm_config


# pylint: disable=unused-argument
def test_no_device_vjp_if_not_supported():
"""Test that an error is raised for device_vjp=True if the device does not support it."""

class DummyDev(qml.devices.Device):

def execute(self, circuits, execution_config=qml.devices.ExecutionConfig()):
return 0

def supports_derivatives(self, execution_config=None, circuit=None):
return execution_config and execution_config.gradient_method == "vjp_grad"

def supports_vjp(self, execution_config=None, circuit=None) -> bool:
return execution_config and execution_config.gradient_method == "vjp_grad"

config_vjp_grad = ExecutionConfig(use_device_jacobian_product=True, gradient_method="vjp_grad")
tape = qml.tape.QuantumScript()
# no error
_ = _resolve_execution_config(config_vjp_grad, DummyDev(), (tape,))

config_parameter_shift = ExecutionConfig(
use_device_jacobian_product=True, gradient_method="parameter-shift"
)
with pytest.raises(qml.QuantumFunctionError, match="device_vjp=True is not supported"):
_resolve_execution_config(config_parameter_shift, DummyDev(), (tape,))

0 comments on commit 8bdf289

Please sign in to comment.