Skip to content

Commit

Permalink
Merge branch 'no-interface-boundary' into no-grad-on-diff-method-none
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro authored Jan 8, 2025
2 parents b4036d5 + f6949e1 commit b13371d
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 28 deletions.
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

<h3>Improvements 🛠</h3>

* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback`
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions
on `default.qubit` can now be jitted end-to-end, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

<h3>Breaking changes 💔</h3>

<h3>Deprecations 👋</h3>
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/default_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,8 +1008,8 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"best",
}
updated_values["grad_on_execution"] = False
if not execution_config.gradient_method in {"best", "backprop", None}:
execution_config.interface = None
if execution_config.gradient_method not in {"best", "backprop", None}:
updated_values["interface"] = None

# Add device options
updated_values["device_options"] = dict(execution_config.device_options) # copy
Expand Down
6 changes: 5 additions & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,11 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"""
updated_values = {}

updated_values["convert_to_numpy"] = (
execution_config.interface.value not in {"jax", "jax-jit"}
or execution_config.gradient_method == "adjoint"
)

for option in execution_config.device_options:
if option not in self._device_options:
raise qml.DeviceError(f"device option {option} not present on {self}")
Expand Down Expand Up @@ -616,7 +621,6 @@ def execute(
execution_config: ExecutionConfig = DefaultExecutionConfig,
) -> Union[Result, ResultBatch]:
self.reset_prng_key()

max_workers = execution_config.device_options.get("max_workers", self._max_workers)
self._state_cache = {} if execution_config.use_device_jacobian_product else None
interface = (
Expand Down
13 changes: 10 additions & 3 deletions pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Optional, Union

from pennylane.math import get_canonical_interface_name
from pennylane.math import Interface, get_canonical_interface_name
from pennylane.transforms.core import TransformDispatcher


Expand Down Expand Up @@ -87,7 +87,7 @@ class ExecutionConfig:
device_options: Optional[dict] = None
"""Various options for the device executing a quantum circuit"""

interface: Optional[str] = None
interface: Interface = Interface.NUMPY
"""The machine learning framework to use"""

derivative_order: int = 1
Expand All @@ -96,6 +96,13 @@ class ExecutionConfig:
mcm_config: MCMConfig = field(default_factory=MCMConfig)
"""Configuration options for handling mid-circuit measurements"""

convert_to_numpy: bool = True
"""Whether or not to convert parameters to numpy before execution.
If ``False`` and using the jax-jit, no pure callback will occur and the device
execution itself will be jitted.
"""

def __post_init__(self):
"""
Validate the configured execution options.
Expand Down Expand Up @@ -124,7 +131,7 @@ def __post_init__(self):
)

if isinstance(self.mcm_config, dict):
self.mcm_config = MCMConfig(**self.mcm_config)
self.mcm_config = MCMConfig(**self.mcm_config) # pylint: disable=not-a-mapping

elif not isinstance(self.mcm_config, MCMConfig):
raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'")
Expand Down
6 changes: 3 additions & 3 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _process_single_shot(samples):
prng_key=prng_key,
)
except ValueError as e:
if str(e) != "probabilities contain NaN":
if "probabilities contain nan" not in str(e).lower():
raise e
samples = qml.math.full((shots.total_shots, len(wires)), 0)

Expand Down Expand Up @@ -580,6 +580,6 @@ def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None,
_, key = jax_random_split(prng_key)
samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)

powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1]
powers_of_two = 1 << jnp.arange(num_wires, dtype=int)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
return (states_sampled_base_ten > 0).astype(jnp.int64)
return (states_sampled_base_ten > 0).astype(int)
3 changes: 2 additions & 1 deletion pennylane/workflow/_setup_transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def _setup_transform_program(

# changing this set of conditions causes a bunch of tests to break.
interface_data_supported = (
resolved_execution_config.interface is Interface.NUMPY
(not resolved_execution_config.convert_to_numpy)
or resolved_execution_config.interface is Interface.NUMPY
or resolved_execution_config.gradient_method == "backprop"
or (
getattr(device, "short_name", "") == "default.mixed"
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _get_ml_boundary_execute(
elif interface == Interface.TORCH:
from .interfaces.torch import execute as ml_boundary

elif interface == Interface.JAX_JIT:
elif interface == Interface.JAX_JIT and resolved_execution_config.convert_to_numpy:
from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary

else: # interface is jax
Expand Down
3 changes: 1 addition & 2 deletions tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,7 @@ def func(x, y, z):
results1 = func1(*params)

jaxpr = str(jax.make_jaxpr(func)(*params))
assert "pure_callback" in jaxpr
pytest.xfail("QNode cannot be compiled with jax.jit.")
assert "pure_callback" not in jaxpr

func2 = jax.jit(func)
results2 = func2(*params)
Expand Down
26 changes: 26 additions & 0 deletions tests/devices/default_qubit/test_default_qubit_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,32 @@ def circuit(x):

assert dev.tracker.totals["execute_and_derivative_batches"] == 1

@pytest.mark.parametrize("interface", ("jax", "jax-jit"))
def test_not_convert_to_numpy_with_jax(self, interface):
"""Test that we will not convert to numpy when working with jax."""

dev = qml.device("default.qubit")
config = qml.devices.ExecutionConfig(
gradient_method=qml.gradients.param_shift, interface=interface
)
processed = dev.setup_execution_config(config)
assert not processed.convert_to_numpy

def test_convert_to_numpy_with_adjoint(self):
"""Test that we will convert to numpy with adjoint."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit")
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy

@pytest.mark.parametrize("interface", ("autograd", "torch", "tf"))
def test_convert_to_numpy_non_jax(self, interface):
"""Test that other interfaces are still converted to numpy."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface)
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy


# pylint: disable=too-few-public-methods
class TestPreprocessing:
Expand Down
1 change: 0 additions & 1 deletion tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,7 +1489,6 @@ def circuit(params):
assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0)
jax.clear_caches()

@pytest.mark.xfail
@pytest.mark.parametrize("num_split_times", [1, 2])
@pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"])
def test_simple_qnode_jit(self, num_split_times, time_interface):
Expand Down
1 change: 1 addition & 0 deletions tests/param_shift_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def preprocess(self, execution_config=qml.devices.DefaultExecutionConfig):
execution_config, use_device_jacobian_product=True
)
program, config = super().preprocess(execution_config)
config = dataclasses.replace(config, convert_to_numpy=True)
program.add_transform(qml.transform(qml.gradients.param_shift.expand_transform))
return program, config

Expand Down
4 changes: 2 additions & 2 deletions tests/workflow/interfaces/test_jacobian_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_device_jacobians_repr(self):
r" use_device_jacobian_product=None,"
r" gradient_method='adjoint', gradient_keyword_arguments={},"
r" device_options={}, interface=<Interface.NUMPY: 'numpy'>, derivative_order=1,"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>"
)

assert repr(jpc) == expected
Expand All @@ -155,7 +155,7 @@ def test_device_jacobian_products_repr(self):
r" use_device_jacobian_product=None,"
r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={},"
r" interface=<Interface.NUMPY: 'numpy'>, derivative_order=1,"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>"
)

assert repr(jpc) == expected
Expand Down
23 changes: 11 additions & 12 deletions tests/workflow/test_setup_transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def test_prune_dynamic_transform_warning_raised():

def test_interface_data_not_supported():
"""Test that convert_to_numpy_parameters transform is correctly added."""
config = ExecutionConfig()
config.interface = "autograd"
config.gradient_method = "adjoint"
config = ExecutionConfig(interface="autograd", gradient_method="adjoint")
device = qml.device("default.qubit")

user_transform_program = TransformProgram()
Expand All @@ -154,39 +152,40 @@ def test_interface_data_not_supported():

def test_interface_data_supported():
"""Test that convert_to_numpy_parameters transform is not added for these cases."""
config = ExecutionConfig()
config = ExecutionConfig(interface="autograd", gradient_method=None)

config.interface = "autograd"
config.gradient_method = None
device = qml.device("default.mixed", wires=1)

user_transform_program = TransformProgram()
_, inner_tp = _setup_transform_program(user_transform_program, device, config)

assert qml.transforms.convert_to_numpy_parameters not in inner_tp

config = ExecutionConfig()
config = ExecutionConfig(interface="autograd", gradient_method="backprop")

config.interface = "autograd"
config.gradient_method = "backprop"
device = qml.device("default.qubit")

user_transform_program = TransformProgram()
_, inner_tp = _setup_transform_program(user_transform_program, device, config)

assert qml.transforms.convert_to_numpy_parameters not in inner_tp

config = ExecutionConfig()
config = ExecutionConfig(interface=None, gradient_method="backprop")

config.interface = None
config.gradient_method = "backprop"
device = qml.device("default.qubit")

user_transform_program = TransformProgram()
_, inner_tp = _setup_transform_program(user_transform_program, device, config)

assert qml.transforms.convert_to_numpy_parameters not in inner_tp

config = ExecutionConfig(
convert_to_numpy=False, interface="jax", gradient_method=qml.gradients.param_shift
)

_, inner_tp = _setup_transform_program(TransformProgram(), device, config)
assert qml.transforms.convert_to_numpy_parameters not in inner_tp


def test_cache_handling():
"""Test that caching is handled correctly."""
Expand Down

0 comments on commit b13371d

Please sign in to comment.