diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 0f9e46664b9..9cd3f6c51ff 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,6 +6,11 @@

Improvements 🛠

+* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback` + is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions + on `default.qubit` can now be jitted end-to-end, even with parameter shift. + [(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) +

Breaking changes 💔

Deprecations 👋

diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index 9e18ba09249..d55f9eec623 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -1008,8 +1008,8 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio "best", } updated_values["grad_on_execution"] = False - if not execution_config.gradient_method in {"best", "backprop", None}: - execution_config.interface = None + if execution_config.gradient_method not in {"best", "backprop", None}: + updated_values["interface"] = None # Add device options updated_values["device_options"] = dict(execution_config.device_options) # copy diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index fcdb25d2783..f1ac48d6461 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -581,6 +581,11 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio """ updated_values = {} + updated_values["convert_to_numpy"] = ( + execution_config.interface.value not in {"jax", "jax-jit"} + or execution_config.gradient_method == "adjoint" + ) + for option in execution_config.device_options: if option not in self._device_options: raise qml.DeviceError(f"device option {option} not present on {self}") @@ -616,7 +621,6 @@ def execute( execution_config: ExecutionConfig = DefaultExecutionConfig, ) -> Union[Result, ResultBatch]: self.reset_prng_key() - max_workers = execution_config.device_options.get("max_workers", self._max_workers) self._state_cache = {} if execution_config.use_device_jacobian_product else None interface = ( diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index fc953a2dc22..81d63ee79a7 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from typing import Optional, Union -from pennylane.math import get_canonical_interface_name +from pennylane.math import Interface, get_canonical_interface_name from pennylane.transforms.core import TransformDispatcher @@ -87,7 +87,7 @@ class ExecutionConfig: device_options: Optional[dict] = None """Various options for the device executing a quantum circuit""" - interface: Optional[str] = None + interface: Interface = Interface.NUMPY """The machine learning framework to use""" derivative_order: int = 1 @@ -96,6 +96,13 @@ class ExecutionConfig: mcm_config: MCMConfig = field(default_factory=MCMConfig) """Configuration options for handling mid-circuit measurements""" + convert_to_numpy: bool = True + """Whether or not to convert parameters to numpy before execution. + + If ``False`` and using the jax-jit, no pure callback will occur and the device + execution itself will be jitted. + """ + def __post_init__(self): """ Validate the configured execution options. @@ -124,7 +131,7 @@ def __post_init__(self): ) if isinstance(self.mcm_config, dict): - self.mcm_config = MCMConfig(**self.mcm_config) + self.mcm_config = MCMConfig(**self.mcm_config) # pylint: disable=not-a-mapping elif not isinstance(self.mcm_config, MCMConfig): raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'") diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py index d8036a6fb99..527e3296a5c 100644 --- a/pennylane/devices/qubit/sampling.py +++ b/pennylane/devices/qubit/sampling.py @@ -330,7 +330,7 @@ def _process_single_shot(samples): prng_key=prng_key, ) except ValueError as e: - if str(e) != "probabilities contain NaN": + if "probabilities contain nan" not in str(e).lower(): raise e samples = qml.math.full((shots.total_shots, len(wires)), 0) @@ -580,6 +580,6 @@ def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None, _, key = jax_random_split(prng_key) samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs) - powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1] + powers_of_two = 1 << jnp.arange(num_wires, dtype=int)[::-1] states_sampled_base_ten = samples[..., None] & powers_of_two - return (states_sampled_base_ten > 0).astype(jnp.int64) + return (states_sampled_base_ten > 0).astype(int) diff --git a/pennylane/workflow/_setup_transform_program.py b/pennylane/workflow/_setup_transform_program.py index 5866c8f7bf4..67cfab372b9 100644 --- a/pennylane/workflow/_setup_transform_program.py +++ b/pennylane/workflow/_setup_transform_program.py @@ -117,7 +117,8 @@ def _setup_transform_program( # changing this set of conditions causes a bunch of tests to break. interface_data_supported = ( - resolved_execution_config.interface is Interface.NUMPY + (not resolved_execution_config.convert_to_numpy) + or resolved_execution_config.interface is Interface.NUMPY or resolved_execution_config.gradient_method == "backprop" or ( getattr(device, "short_name", "") == "default.mixed" diff --git a/pennylane/workflow/run.py b/pennylane/workflow/run.py index 129f3de6039..b3fcf123e64 100644 --- a/pennylane/workflow/run.py +++ b/pennylane/workflow/run.py @@ -208,7 +208,7 @@ def _get_ml_boundary_execute( elif interface == Interface.TORCH: from .interfaces.torch import execute as ml_boundary - elif interface == Interface.JAX_JIT: + elif interface == Interface.JAX_JIT and resolved_execution_config.convert_to_numpy: from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary else: # interface is jax diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index 139b4053ed0..42faaa10809 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -389,8 +389,7 @@ def func(x, y, z): results1 = func1(*params) jaxpr = str(jax.make_jaxpr(func)(*params)) - assert "pure_callback" in jaxpr - pytest.xfail("QNode cannot be compiled with jax.jit.") + assert "pure_callback" not in jaxpr func2 = jax.jit(func) results2 = func2(*params) diff --git a/tests/devices/default_qubit/test_default_qubit_preprocessing.py b/tests/devices/default_qubit/test_default_qubit_preprocessing.py index 40f45a9383e..59a9098f7ce 100644 --- a/tests/devices/default_qubit/test_default_qubit_preprocessing.py +++ b/tests/devices/default_qubit/test_default_qubit_preprocessing.py @@ -141,6 +141,32 @@ def circuit(x): assert dev.tracker.totals["execute_and_derivative_batches"] == 1 + @pytest.mark.parametrize("interface", ("jax", "jax-jit")) + def test_not_convert_to_numpy_with_jax(self, interface): + """Test that we will not convert to numpy when working with jax.""" + + dev = qml.device("default.qubit") + config = qml.devices.ExecutionConfig( + gradient_method=qml.gradients.param_shift, interface=interface + ) + processed = dev.setup_execution_config(config) + assert not processed.convert_to_numpy + + def test_convert_to_numpy_with_adjoint(self): + """Test that we will convert to numpy with adjoint.""" + config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit") + dev = qml.device("default.qubit") + processed = dev.setup_execution_config(config) + assert processed.convert_to_numpy + + @pytest.mark.parametrize("interface", ("autograd", "torch", "tf")) + def test_convert_to_numpy_non_jax(self, interface): + """Test that other interfaces are still converted to numpy.""" + config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface) + dev = qml.device("default.qubit") + processed = dev.setup_execution_config(config) + assert processed.convert_to_numpy + # pylint: disable=too-few-public-methods class TestPreprocessing: diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py index bfd3e244fcb..95fa71d706c 100644 --- a/tests/gradients/core/test_pulse_gradient.py +++ b/tests/gradients/core/test_pulse_gradient.py @@ -1489,7 +1489,6 @@ def circuit(params): assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0) jax.clear_caches() - @pytest.mark.xfail @pytest.mark.parametrize("num_split_times", [1, 2]) @pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"]) def test_simple_qnode_jit(self, num_split_times, time_interface): diff --git a/tests/param_shift_dev.py b/tests/param_shift_dev.py index 12fd11eea16..7c7442161e2 100644 --- a/tests/param_shift_dev.py +++ b/tests/param_shift_dev.py @@ -34,6 +34,7 @@ def preprocess(self, execution_config=qml.devices.DefaultExecutionConfig): execution_config, use_device_jacobian_product=True ) program, config = super().preprocess(execution_config) + config = dataclasses.replace(config, convert_to_numpy=True) program.add_transform(qml.transform(qml.gradients.param_shift.expand_transform)) return program, config diff --git a/tests/workflow/interfaces/test_jacobian_products.py b/tests/workflow/interfaces/test_jacobian_products.py index 1991ba2d523..4d90f2e5012 100644 --- a/tests/workflow/interfaces/test_jacobian_products.py +++ b/tests/workflow/interfaces/test_jacobian_products.py @@ -136,7 +136,7 @@ def test_device_jacobians_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}," r" device_options={}, interface=, derivative_order=1," - r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>" ) assert repr(jpc) == expected @@ -155,7 +155,7 @@ def test_device_jacobian_products_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={}," r" interface=, derivative_order=1," - r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>" ) assert repr(jpc) == expected diff --git a/tests/workflow/test_setup_transform_program.py b/tests/workflow/test_setup_transform_program.py index c81ed75ae2f..79207491014 100644 --- a/tests/workflow/test_setup_transform_program.py +++ b/tests/workflow/test_setup_transform_program.py @@ -140,9 +140,7 @@ def test_prune_dynamic_transform_warning_raised(): def test_interface_data_not_supported(): """Test that convert_to_numpy_parameters transform is correctly added.""" - config = ExecutionConfig() - config.interface = "autograd" - config.gradient_method = "adjoint" + config = ExecutionConfig(interface="autograd", gradient_method="adjoint") device = qml.device("default.qubit") user_transform_program = TransformProgram() @@ -154,10 +152,8 @@ def test_interface_data_not_supported(): def test_interface_data_supported(): """Test that convert_to_numpy_parameters transform is not added for these cases.""" - config = ExecutionConfig() + config = ExecutionConfig(interface="autograd", gradient_method=None) - config.interface = "autograd" - config.gradient_method = None device = qml.device("default.mixed", wires=1) user_transform_program = TransformProgram() @@ -165,10 +161,8 @@ def test_interface_data_supported(): assert qml.transforms.convert_to_numpy_parameters not in inner_tp - config = ExecutionConfig() + config = ExecutionConfig(interface="autograd", gradient_method="backprop") - config.interface = "autograd" - config.gradient_method = "backprop" device = qml.device("default.qubit") user_transform_program = TransformProgram() @@ -176,10 +170,8 @@ def test_interface_data_supported(): assert qml.transforms.convert_to_numpy_parameters not in inner_tp - config = ExecutionConfig() + config = ExecutionConfig(interface=None, gradient_method="backprop") - config.interface = None - config.gradient_method = "backprop" device = qml.device("default.qubit") user_transform_program = TransformProgram() @@ -187,6 +179,13 @@ def test_interface_data_supported(): assert qml.transforms.convert_to_numpy_parameters not in inner_tp + config = ExecutionConfig( + convert_to_numpy=False, interface="jax", gradient_method=qml.gradients.param_shift + ) + + _, inner_tp = _setup_transform_program(TransformProgram(), device, config) + assert qml.transforms.convert_to_numpy_parameters not in inner_tp + def test_cache_handling(): """Test that caching is handled correctly."""