Skip to content

Commit

Permalink
Allow device to configure conversion to numpy and use of `pure_callba…
Browse files Browse the repository at this point in the history
…ck` (#6788)

**Context:**

While we have logic for sampling with jax, it does not really integrate
very well into the workflow. While you can technically set
`diff_method=None` right now and jit the execution end-to-end, trying to
jit `diff_method=None` will cause incomprehensible error messages on
non-DQ devices.

We want to *forbid* differentiation `diff_method=None`, but keep a way
to jit a finite shot execution.

**Description of the Change:**

In order to allow jitting finite shot executions, we need a way for the
device to be able to configure whether or not the data is converted to
numpy. To do so, we simply add another property to the
`ExecutionConfig`, `convert_to_numpy`. If `False`, then we will not use
a `pure_callback` to convert the parameters to numpy. If `True`, we use
a `pure_callback` and convert the parameters to numpy.

**Benefits:**

Speed ups due to being able to jit the entire execution.


![image](https://github.com/user-attachments/assets/738076c6-7bb5-4c38-a8cc-97e138325dbc)


**Possible Drawbacks:**

`ExecutionConfig` gets an addtional property, making it more
complicated.

**Related GitHub Issues:**

Fixes #6054 Fixes #3259  Blocks #6770

---------

Co-authored-by: Mudit Pandey <[email protected]>
albi3ro and mudit2812 authored Jan 15, 2025
1 parent 75c1671 commit 0bed5e8
Showing 21 changed files with 122 additions and 51 deletions.
11 changes: 8 additions & 3 deletions doc/development/plugins.rst
Original file line number Diff line number Diff line change
@@ -472,12 +472,15 @@ pieces of functionality:
Note that these properties are only applicable to devices that provided derivatives or VJPs. If your device
does not provide derivatives, you can safely ignore these properties.

The workflow options are ``use_device_gradient``, ``use_device_jacobian_product``, and ``grad_on_execution``.
The workflow options are ``use_device_gradient``, ``use_device_jacobian_product``, ``grad_on_execution``,
and ``convert_to_numpy``.
``use_device_gradient=True`` indicates that workflow should request derivatives from the device.
``grad_on_execution=True`` indicates a preference to use ``execute_and_compute_derivatives`` instead
of ``execute`` followed by ``compute_derivatives``. Finally, ``use_device_jacobian_product`` indicates
of ``execute`` followed by ``compute_derivatives``. ``use_device_jacobian_product`` indicates
a request to call ``compute_vjp`` instead of ``compute_derivatives``. Note that if ``use_device_jacobian_product``
is ``True``, this takes precedence over calculating the full jacobian.
is ``True``, this takes precedence over calculating the full jacobian. If the device can accept ML framework parameters, like
jax, ``convert_to_numpy=False`` should be specified. Then the parameters will not be converted, and special
interface-specific processing (like executing inside a ``jax.pure_callback`` when using ``jax.jit``) will be needed.

>>> config = qml.devices.ExecutionConfig(gradient_method="adjoint")
>>> processed_config = qml.device('default.qubit').setup_execution_config(config)
@@ -487,6 +490,8 @@ True
True
>>> processed_config.grad_on_execution
True
>>> processed_config.convert_to_numpy
True

Execution
---------
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
@@ -6,6 +6,12 @@

<h3>Improvements 🛠</h3>

* Finite shot and parameter-shift executions on `default.qubit` can now
be natively jitted end-to-end, leading to performance improvements.
Devices can now configure whether or not ML framework data is sent to them
via an `ExecutionConfig.convert_to_numpy` parameter.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

* The coefficients of observables now have improved differentiability.
[(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598)

4 changes: 2 additions & 2 deletions pennylane/devices/default_mixed.py
Original file line number Diff line number Diff line change
@@ -1008,8 +1008,8 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"best",
}
updated_values["grad_on_execution"] = False
if not execution_config.gradient_method in {"best", "backprop", None}:
execution_config.interface = None
if execution_config.gradient_method not in {"best", "backprop", None}:
updated_values["interface"] = None

# Add device options
updated_values["device_options"] = dict(execution_config.device_options) # copy
18 changes: 17 additions & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
@@ -161,6 +161,14 @@ def _conditional_broastcast_expand(tape):
return (tape,), null_postprocessing


@qml.transform
def no_counts(tape):
"""Throws an error on counts measurements."""
if any(isinstance(mp, qml.measurements.CountsMP) for mp in tape.measurements):
raise NotImplementedError("The JAX-JIT interface doesn't support qml.counts.")
return (tape,), null_postprocessing


@qml.transform
def adjoint_state_measurements(
tape: QuantumScript, device_vjp=False
@@ -535,6 +543,8 @@ def preprocess(
config = self._setup_execution_config(execution_config)
transform_program = TransformProgram()

if config.interface == qml.math.Interface.JAX_JIT:
transform_program.add_transform(no_counts)
transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(
mid_circuit_measurements, device=self, mcm_config=config.mcm_config
@@ -581,6 +591,13 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"""
updated_values = {}

jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT}
updated_values["convert_to_numpy"] = (
execution_config.interface not in jax_interfaces
or execution_config.gradient_method == "adjoint"
# need numpy to use caching, and need caching higher order derivatives
or execution_config.derivative_order > 1
)
for option in execution_config.device_options:
if option not in self._device_options:
raise qml.DeviceError(f"device option {option} not present on {self}")
@@ -616,7 +633,6 @@ def execute(
execution_config: ExecutionConfig = DefaultExecutionConfig,
) -> Union[Result, ResultBatch]:
self.reset_prng_key()

max_workers = execution_config.device_options.get("max_workers", self._max_workers)
self._state_cache = {} if execution_config.use_device_jacobian_product else None
interface = (
13 changes: 10 additions & 3 deletions pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Optional, Union

from pennylane.math import get_canonical_interface_name
from pennylane.math import Interface, get_canonical_interface_name
from pennylane.transforms.core import TransformDispatcher


@@ -87,7 +87,7 @@ class ExecutionConfig:
device_options: Optional[dict] = None
"""Various options for the device executing a quantum circuit"""

interface: Optional[str] = None
interface: Interface = Interface.NUMPY
"""The machine learning framework to use"""

derivative_order: int = 1
@@ -96,6 +96,13 @@ class ExecutionConfig:
mcm_config: MCMConfig = field(default_factory=MCMConfig)
"""Configuration options for handling mid-circuit measurements"""

convert_to_numpy: bool = True
"""Whether or not to convert parameters to numpy before execution.
If ``False`` and using the jax-jit, no pure callback will occur and the device
execution itself will be jitted.
"""

def __post_init__(self):
"""
Validate the configured execution options.
@@ -124,7 +131,7 @@ def __post_init__(self):
)

if isinstance(self.mcm_config, dict):
self.mcm_config = MCMConfig(**self.mcm_config)
self.mcm_config = MCMConfig(**self.mcm_config) # pylint: disable=not-a-mapping

elif not isinstance(self.mcm_config, MCMConfig):
raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'")
4 changes: 2 additions & 2 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
@@ -580,6 +580,6 @@ def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None,
_, key = jax_random_split(prng_key)
samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)

powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1]
powers_of_two = 1 << jnp.arange(num_wires, dtype=int)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
return (states_sampled_base_ten > 0).astype(jnp.int64)
return (states_sampled_base_ten > 0).astype(int)
3 changes: 2 additions & 1 deletion pennylane/workflow/_setup_transform_program.py
Original file line number Diff line number Diff line change
@@ -117,7 +117,8 @@ def _setup_transform_program(

# changing this set of conditions causes a bunch of tests to break.
interface_data_supported = (
resolved_execution_config.interface is Interface.NUMPY
(not resolved_execution_config.convert_to_numpy)
or resolved_execution_config.interface is Interface.NUMPY
or resolved_execution_config.gradient_method == "backprop"
or (
getattr(device, "short_name", "") == "default.mixed"
2 changes: 1 addition & 1 deletion pennylane/workflow/interfaces/jax.py
Original file line number Diff line number Diff line change
@@ -186,7 +186,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch:
return result
if isinstance(result, (list, tuple)):
return tuple(_to_jax(r) for r in result)
return jnp.array(result)
return result if qml.math.get_interface(result) == "jax" else jnp.array(result)


def _execute_wrapper(params, tapes, execute_fn, jpc) -> ResultBatch:
2 changes: 1 addition & 1 deletion pennylane/workflow/interfaces/jax_jit.py
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch:
"""
if isinstance(result, dict):
return {key: jnp.array(value) for key, value in result.items()}
return {key: _to_jax(value) for key, value in result.items()}
if isinstance(result, (list, tuple)):
return tuple(_to_jax(r) for r in result)
return jnp.array(result)
2 changes: 1 addition & 1 deletion pennylane/workflow/run.py
Original file line number Diff line number Diff line change
@@ -204,7 +204,7 @@ def _get_ml_boundary_execute(
elif interface == Interface.TORCH:
from .interfaces.torch import execute as ml_boundary

elif interface == Interface.JAX_JIT:
elif interface == Interface.JAX_JIT and resolved_execution_config.convert_to_numpy:
from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary

else: # interface is jax
6 changes: 1 addition & 5 deletions tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
@@ -389,11 +389,7 @@ def func(x, y, z):
results1 = func1(*params)

jaxpr = str(jax.make_jaxpr(func)(*params))
if diff_method == "best":
assert "pure_callback" in jaxpr
pytest.xfail("QNode with diff_method='best' cannot be compiled with jax.jit.")
else:
assert "pure_callback" not in jaxpr
assert "pure_callback" not in jaxpr

func2 = jax.jit(func)
results2 = func2(*params)
26 changes: 26 additions & 0 deletions tests/devices/default_qubit/test_default_qubit_preprocessing.py
Original file line number Diff line number Diff line change
@@ -141,6 +141,32 @@ def circuit(x):

assert dev.tracker.totals["execute_and_derivative_batches"] == 1

@pytest.mark.parametrize("interface", ("jax", "jax-jit"))
def test_not_convert_to_numpy_with_jax(self, interface):
"""Test that we will not convert to numpy when working with jax."""

dev = qml.device("default.qubit")
config = qml.devices.ExecutionConfig(
gradient_method=qml.gradients.param_shift, interface=interface
)
processed = dev.setup_execution_config(config)
assert not processed.convert_to_numpy

def test_convert_to_numpy_with_adjoint(self):
"""Test that we will convert to numpy with adjoint."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit")
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy

@pytest.mark.parametrize("interface", ("autograd", "torch", "tf"))
def test_convert_to_numpy_non_jax(self, interface):
"""Test that other interfaces are still converted to numpy."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface)
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy


# pylint: disable=too-few-public-methods
class TestPreprocessing:
1 change: 0 additions & 1 deletion tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
@@ -1485,7 +1485,6 @@ def circuit(params):
assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0)
jax.clear_caches()

@pytest.mark.xfail
@pytest.mark.parametrize("num_split_times", [1, 2])
@pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"])
def test_simple_qnode_jit(self, num_split_times, time_interface):
1 change: 1 addition & 0 deletions tests/param_shift_dev.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ def preprocess(self, execution_config=qml.devices.DefaultExecutionConfig):
execution_config, use_device_jacobian_product=True
)
program, config = super().preprocess(execution_config)
config = dataclasses.replace(config, convert_to_numpy=True)
program.add_transform(qml.transform(qml.gradients.param_shift.expand_transform))
return program, config

Original file line number Diff line number Diff line change
@@ -417,11 +417,13 @@ def circuit(state):


@pytest.mark.jax
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)])
def test_jacobians_with_and_without_jit_match(shots, atol, seed):
def test_jacobians_with_and_without_jit_match(seed):
"""Test that the Jacobian of the circuit is the same with and without jit."""
import jax

shots = None
atol = 0.005

dev = qml.device("default.qubit", shots=shots, seed=seed)
dev_no_shots = qml.device("default.qubit", shots=None)

@@ -433,7 +435,7 @@ def circuit(coeffs):
circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift")
circuit_exact = qml.QNode(circuit, dev_no_shots)

params = jax.numpy.array([0.5, 0.5, 0.5, 0.5])
params = jax.numpy.array([0.5, 0.5, 0.5, 0.5], dtype=jax.numpy.float64)
jac_exact_fn = jax.jacobian(circuit_exact)
jac_fd_fn = jax.jacobian(circuit_fd)
jac_fd_fn_jit = jax.jit(jac_fd_fn)
7 changes: 4 additions & 3 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
@@ -613,12 +613,13 @@ def func(x, y):
assert tape.measurements == contents[3:]

@pytest.mark.jax
def test_jit_counts_raises_error(self):
@pytest.mark.parametrize("dev_name", ("default.qubit", "reference.qubit"))
def test_jit_counts_raises_error(self, dev_name):
"""Test that returning counts in a quantum function with trainable parameters while
jitting raises an error."""
import jax

dev = qml.device("default.qubit", wires=2, shots=5)
dev = qml.device(dev_name, wires=2, shots=5)

def circuit1(param):
qml.Hadamard(0)
@@ -632,7 +633,7 @@ def circuit1(param):
with pytest.raises(
NotImplementedError, match="The JAX-JIT interface doesn't support qml.counts."
):
jitted_qnode1(0.123)
_ = jitted_qnode1(0.123)

# Test with qnode decorator syntax
@qml.qnode(dev)
10 changes: 6 additions & 4 deletions tests/workflow/interfaces/execute/test_jax.py
Original file line number Diff line number Diff line change
@@ -693,24 +693,26 @@ def test_max_diff(self, tol):

def cost_fn(x):
ops = [qml.RX(x[0], 0), qml.RY(x[1], 1), qml.CNOT((0, 1))]
tape1 = qml.tape.QuantumScript(ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))])
tape1 = qml.tape.QuantumScript(
ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))], shots=50000
)

ops2 = [qml.RX(x[0], 0), qml.RY(x[0], 1), qml.CNOT((0, 1))]
tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)])
tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)], shots=50000)

result = execute([tape1, tape2], dev, diff_method=param_shift, max_diff=1)
return result[0] + result[1][0]

res = cost_fn(params)
x, y = params
expected = 0.5 * (3 + jnp.cos(x) ** 2 * jnp.cos(2 * y))
assert np.allclose(res, expected, atol=tol, rtol=0)
assert np.allclose(res, expected, atol=2e-2, rtol=0)

res = jax.grad(cost_fn)(params)
expected = jnp.array(
[-jnp.cos(x) * jnp.cos(2 * y) * jnp.sin(x), -jnp.cos(x) ** 2 * jnp.sin(2 * y)]
)
assert np.allclose(res, expected, atol=tol, rtol=0)
assert np.allclose(res, expected, atol=2e-2, rtol=0)

res = jax.jacobian(jax.grad(cost_fn))(params)
expected = jnp.zeros([2, 2])
19 changes: 14 additions & 5 deletions tests/workflow/interfaces/execute/test_jax_jit.py
Original file line number Diff line number Diff line change
@@ -886,14 +886,17 @@ def cost(x, y, device, interface, ek):

class TestJitAllCounts:

@pytest.mark.parametrize(
"device_name", (pytest.param("default.qubit", marks=pytest.mark.xfail), "reference.qubit")
)
@pytest.mark.parametrize("counts_wires", (None, (0, 1)))
def test_jit_allcounts(self, counts_wires):
def test_jit_allcounts(self, device_name, counts_wires):
"""Test jitting with counts with all_outcomes == True."""

tape = qml.tape.QuantumScript(
[qml.RX(0, 0), qml.I(1)], [qml.counts(wires=counts_wires, all_outcomes=True)], shots=50
)
device = qml.device("default.qubit")
device = qml.device(device_name, wires=2)

res = jax.jit(qml.execute, static_argnums=(1, 2))(
(tape,), device, qml.gradients.param_shift
@@ -904,15 +907,22 @@ def test_jit_allcounts(self, counts_wires):
for val in ["01", "10", "11"]:
assert qml.math.allclose(res[val], 0)

def test_jit_allcounts_broadcasting(self):
@pytest.mark.parametrize(
"device_name",
(
pytest.param("default.qubit", marks=pytest.mark.xfail),
pytest.param("reference.qubit", marks=pytest.mark.xfail),
),
)
def test_jit_allcounts_broadcasting(self, device_name):
"""Test jitting with counts with all_outcomes == True."""

tape = qml.tape.QuantumScript(
[qml.RX(np.array([0.0, 0.0]), 0)],
[qml.counts(wires=(0, 1), all_outcomes=True)],
shots=50,
)
device = qml.device("default.qubit")
device = qml.device(device_name, wires=2)

res = jax.jit(qml.execute, static_argnums=(1, 2))(
(tape,), device, qml.gradients.param_shift
@@ -927,7 +937,6 @@ def test_jit_allcounts_broadcasting(self):
assert qml.math.allclose(ri[val], 0)


@pytest.mark.xfail(reason="Need to figure out how to handle this case in a less ambiguous manner")
def test_diff_method_None_jit():
"""Test that jitted execution works when `diff_method=None`."""

3 changes: 2 additions & 1 deletion tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
@@ -3185,7 +3185,8 @@ def test_complex64_return(self, diff_method):
jax.config.update("jax_enable_x64", False)

try:
tol = 2e-2 if diff_method == "finite-diff" else 1e-6
# finite diff with float32 ...
tol = 5e-2 if diff_method == "finite-diff" else 1e-6

@jax.jit
@qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method)
Loading

0 comments on commit 0bed5e8

Please sign in to comment.