diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 0f9e46664b9..9cd3f6c51ff 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -6,6 +6,11 @@
Improvements ðŸ›
+* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback`
+ is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions
+ on `default.qubit` can now be jitted end-to-end, even with parameter shift.
+ [(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)
+
Breaking changes 💔
Deprecations 👋
diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py
index 9e18ba09249..d55f9eec623 100644
--- a/pennylane/devices/default_mixed.py
+++ b/pennylane/devices/default_mixed.py
@@ -1008,8 +1008,8 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"best",
}
updated_values["grad_on_execution"] = False
- if not execution_config.gradient_method in {"best", "backprop", None}:
- execution_config.interface = None
+ if execution_config.gradient_method not in {"best", "backprop", None}:
+ updated_values["interface"] = None
# Add device options
updated_values["device_options"] = dict(execution_config.device_options) # copy
diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py
index fcdb25d2783..f1ac48d6461 100644
--- a/pennylane/devices/default_qubit.py
+++ b/pennylane/devices/default_qubit.py
@@ -581,6 +581,11 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"""
updated_values = {}
+ updated_values["convert_to_numpy"] = (
+ execution_config.interface.value not in {"jax", "jax-jit"}
+ or execution_config.gradient_method == "adjoint"
+ )
+
for option in execution_config.device_options:
if option not in self._device_options:
raise qml.DeviceError(f"device option {option} not present on {self}")
@@ -616,7 +621,6 @@ def execute(
execution_config: ExecutionConfig = DefaultExecutionConfig,
) -> Union[Result, ResultBatch]:
self.reset_prng_key()
-
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
self._state_cache = {} if execution_config.use_device_jacobian_product else None
interface = (
diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py
index fc953a2dc22..81d63ee79a7 100644
--- a/pennylane/devices/execution_config.py
+++ b/pennylane/devices/execution_config.py
@@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Optional, Union
-from pennylane.math import get_canonical_interface_name
+from pennylane.math import Interface, get_canonical_interface_name
from pennylane.transforms.core import TransformDispatcher
@@ -87,7 +87,7 @@ class ExecutionConfig:
device_options: Optional[dict] = None
"""Various options for the device executing a quantum circuit"""
- interface: Optional[str] = None
+ interface: Interface = Interface.NUMPY
"""The machine learning framework to use"""
derivative_order: int = 1
@@ -96,6 +96,13 @@ class ExecutionConfig:
mcm_config: MCMConfig = field(default_factory=MCMConfig)
"""Configuration options for handling mid-circuit measurements"""
+ convert_to_numpy: bool = True
+ """Whether or not to convert parameters to numpy before execution.
+
+ If ``False`` and using the jax-jit, no pure callback will occur and the device
+ execution itself will be jitted.
+ """
+
def __post_init__(self):
"""
Validate the configured execution options.
@@ -124,7 +131,7 @@ def __post_init__(self):
)
if isinstance(self.mcm_config, dict):
- self.mcm_config = MCMConfig(**self.mcm_config)
+ self.mcm_config = MCMConfig(**self.mcm_config) # pylint: disable=not-a-mapping
elif not isinstance(self.mcm_config, MCMConfig):
raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'")
diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py
index d8036a6fb99..527e3296a5c 100644
--- a/pennylane/devices/qubit/sampling.py
+++ b/pennylane/devices/qubit/sampling.py
@@ -330,7 +330,7 @@ def _process_single_shot(samples):
prng_key=prng_key,
)
except ValueError as e:
- if str(e) != "probabilities contain NaN":
+ if "probabilities contain nan" not in str(e).lower():
raise e
samples = qml.math.full((shots.total_shots, len(wires)), 0)
@@ -580,6 +580,6 @@ def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None,
_, key = jax_random_split(prng_key)
samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)
- powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1]
+ powers_of_two = 1 << jnp.arange(num_wires, dtype=int)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
- return (states_sampled_base_ten > 0).astype(jnp.int64)
+ return (states_sampled_base_ten > 0).astype(int)
diff --git a/pennylane/workflow/_setup_transform_program.py b/pennylane/workflow/_setup_transform_program.py
index 5866c8f7bf4..67cfab372b9 100644
--- a/pennylane/workflow/_setup_transform_program.py
+++ b/pennylane/workflow/_setup_transform_program.py
@@ -117,7 +117,8 @@ def _setup_transform_program(
# changing this set of conditions causes a bunch of tests to break.
interface_data_supported = (
- resolved_execution_config.interface is Interface.NUMPY
+ (not resolved_execution_config.convert_to_numpy)
+ or resolved_execution_config.interface is Interface.NUMPY
or resolved_execution_config.gradient_method == "backprop"
or (
getattr(device, "short_name", "") == "default.mixed"
diff --git a/pennylane/workflow/run.py b/pennylane/workflow/run.py
index 129f3de6039..b3fcf123e64 100644
--- a/pennylane/workflow/run.py
+++ b/pennylane/workflow/run.py
@@ -208,7 +208,7 @@ def _get_ml_boundary_execute(
elif interface == Interface.TORCH:
from .interfaces.torch import execute as ml_boundary
- elif interface == Interface.JAX_JIT:
+ elif interface == Interface.JAX_JIT and resolved_execution_config.convert_to_numpy:
from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary
else: # interface is jax
diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
index 139b4053ed0..42faaa10809 100644
--- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py
+++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
@@ -389,8 +389,7 @@ def func(x, y, z):
results1 = func1(*params)
jaxpr = str(jax.make_jaxpr(func)(*params))
- assert "pure_callback" in jaxpr
- pytest.xfail("QNode cannot be compiled with jax.jit.")
+ assert "pure_callback" not in jaxpr
func2 = jax.jit(func)
results2 = func2(*params)
diff --git a/tests/devices/default_qubit/test_default_qubit_preprocessing.py b/tests/devices/default_qubit/test_default_qubit_preprocessing.py
index 40f45a9383e..59a9098f7ce 100644
--- a/tests/devices/default_qubit/test_default_qubit_preprocessing.py
+++ b/tests/devices/default_qubit/test_default_qubit_preprocessing.py
@@ -141,6 +141,32 @@ def circuit(x):
assert dev.tracker.totals["execute_and_derivative_batches"] == 1
+ @pytest.mark.parametrize("interface", ("jax", "jax-jit"))
+ def test_not_convert_to_numpy_with_jax(self, interface):
+ """Test that we will not convert to numpy when working with jax."""
+
+ dev = qml.device("default.qubit")
+ config = qml.devices.ExecutionConfig(
+ gradient_method=qml.gradients.param_shift, interface=interface
+ )
+ processed = dev.setup_execution_config(config)
+ assert not processed.convert_to_numpy
+
+ def test_convert_to_numpy_with_adjoint(self):
+ """Test that we will convert to numpy with adjoint."""
+ config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit")
+ dev = qml.device("default.qubit")
+ processed = dev.setup_execution_config(config)
+ assert processed.convert_to_numpy
+
+ @pytest.mark.parametrize("interface", ("autograd", "torch", "tf"))
+ def test_convert_to_numpy_non_jax(self, interface):
+ """Test that other interfaces are still converted to numpy."""
+ config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface)
+ dev = qml.device("default.qubit")
+ processed = dev.setup_execution_config(config)
+ assert processed.convert_to_numpy
+
# pylint: disable=too-few-public-methods
class TestPreprocessing:
diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py
index bfd3e244fcb..95fa71d706c 100644
--- a/tests/gradients/core/test_pulse_gradient.py
+++ b/tests/gradients/core/test_pulse_gradient.py
@@ -1489,7 +1489,6 @@ def circuit(params):
assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0)
jax.clear_caches()
- @pytest.mark.xfail
@pytest.mark.parametrize("num_split_times", [1, 2])
@pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"])
def test_simple_qnode_jit(self, num_split_times, time_interface):
diff --git a/tests/param_shift_dev.py b/tests/param_shift_dev.py
index 12fd11eea16..7c7442161e2 100644
--- a/tests/param_shift_dev.py
+++ b/tests/param_shift_dev.py
@@ -34,6 +34,7 @@ def preprocess(self, execution_config=qml.devices.DefaultExecutionConfig):
execution_config, use_device_jacobian_product=True
)
program, config = super().preprocess(execution_config)
+ config = dataclasses.replace(config, convert_to_numpy=True)
program.add_transform(qml.transform(qml.gradients.param_shift.expand_transform))
return program, config
diff --git a/tests/workflow/interfaces/test_jacobian_products.py b/tests/workflow/interfaces/test_jacobian_products.py
index 1991ba2d523..4d90f2e5012 100644
--- a/tests/workflow/interfaces/test_jacobian_products.py
+++ b/tests/workflow/interfaces/test_jacobian_products.py
@@ -136,7 +136,7 @@ def test_device_jacobians_repr(self):
r" use_device_jacobian_product=None,"
r" gradient_method='adjoint', gradient_keyword_arguments={},"
r" device_options={}, interface=, derivative_order=1,"
- r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>"
+ r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>"
)
assert repr(jpc) == expected
@@ -155,7 +155,7 @@ def test_device_jacobian_products_repr(self):
r" use_device_jacobian_product=None,"
r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={},"
r" interface=, derivative_order=1,"
- r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>"
+ r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>"
)
assert repr(jpc) == expected
diff --git a/tests/workflow/test_setup_transform_program.py b/tests/workflow/test_setup_transform_program.py
index c81ed75ae2f..79207491014 100644
--- a/tests/workflow/test_setup_transform_program.py
+++ b/tests/workflow/test_setup_transform_program.py
@@ -140,9 +140,7 @@ def test_prune_dynamic_transform_warning_raised():
def test_interface_data_not_supported():
"""Test that convert_to_numpy_parameters transform is correctly added."""
- config = ExecutionConfig()
- config.interface = "autograd"
- config.gradient_method = "adjoint"
+ config = ExecutionConfig(interface="autograd", gradient_method="adjoint")
device = qml.device("default.qubit")
user_transform_program = TransformProgram()
@@ -154,10 +152,8 @@ def test_interface_data_not_supported():
def test_interface_data_supported():
"""Test that convert_to_numpy_parameters transform is not added for these cases."""
- config = ExecutionConfig()
+ config = ExecutionConfig(interface="autograd", gradient_method=None)
- config.interface = "autograd"
- config.gradient_method = None
device = qml.device("default.mixed", wires=1)
user_transform_program = TransformProgram()
@@ -165,10 +161,8 @@ def test_interface_data_supported():
assert qml.transforms.convert_to_numpy_parameters not in inner_tp
- config = ExecutionConfig()
+ config = ExecutionConfig(interface="autograd", gradient_method="backprop")
- config.interface = "autograd"
- config.gradient_method = "backprop"
device = qml.device("default.qubit")
user_transform_program = TransformProgram()
@@ -176,10 +170,8 @@ def test_interface_data_supported():
assert qml.transforms.convert_to_numpy_parameters not in inner_tp
- config = ExecutionConfig()
+ config = ExecutionConfig(interface=None, gradient_method="backprop")
- config.interface = None
- config.gradient_method = "backprop"
device = qml.device("default.qubit")
user_transform_program = TransformProgram()
@@ -187,6 +179,13 @@ def test_interface_data_supported():
assert qml.transforms.convert_to_numpy_parameters not in inner_tp
+ config = ExecutionConfig(
+ convert_to_numpy=False, interface="jax", gradient_method=qml.gradients.param_shift
+ )
+
+ _, inner_tp = _setup_transform_program(TransformProgram(), device, config)
+ assert qml.transforms.convert_to_numpy_parameters not in inner_tp
+
def test_cache_handling():
"""Test that caching is handled correctly."""