Skip to content

Commit

Permalink
testing and changelog
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro committed Jan 8, 2025
1 parent cbe1f20 commit cef0e31
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 22 deletions.
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

<h3>Improvements 🛠</h3>

* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback`
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions
on `default.qubit` can now be jitted end-to-end, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

<h3>Breaking changes 💔</h3>

<h3>Deprecations 👋</h3>
Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class ExecutionConfig:

convert_to_numpy: bool = True
"""Whether or not to convert parameters to numpy before execution.
If ``False`` and using the jax-jit, no pure callback will occur and the device
execution itself will be jitted.
"""
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,6 @@ def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None,
_, key = jax_random_split(prng_key)
samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)

powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1]
powers_of_two = 1 << jnp.arange(num_wires, dtype=int)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
return (states_sampled_base_ten > 0).astype(jnp.int64)
return (states_sampled_base_ten > 0).astype(int)
6 changes: 1 addition & 5 deletions tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,7 @@ def func(x, y, z):
results1 = func1(*params)

jaxpr = str(jax.make_jaxpr(func)(*params))
if diff_method == "best":
assert "pure_callback" in jaxpr
pytest.xfail("QNode with diff_method='best' cannot be compiled with jax.jit.")
else:
assert "pure_callback" not in jaxpr
assert "pure_callback" not in jaxpr

func2 = jax.jit(func)
results2 = func2(*params)
Expand Down
26 changes: 26 additions & 0 deletions tests/devices/default_qubit/test_default_qubit_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,32 @@ def circuit(x):

assert dev.tracker.totals["execute_and_derivative_batches"] == 1

@pytest.mark.parametrize("interface", ("jax", "jax-jit"))
def test_not_convert_to_numpy_with_jax(self, interface):
"""Test that we will not convert to numpy when working with jax."""

dev = qml.device("default.qubit")
config = qml.devices.ExecutionConfig(
gradient_method=qml.gradients.param_shift, interface=interface
)
processed = dev.setup_execution_config(config)
assert not processed.convert_to_numpy

def test_convert_to_numpy_with_adjoint(self):
"""Test that we will convert to numpy with adjoint."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit")
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy

@pytest.mark.parametrize("interface", ("autograd", "torch", "tf"))
def test_convert_to_numpy_non_jax(self, interface):
"""Test that other interfaces are still converted to numpy."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface)
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy


# pylint: disable=too-few-public-methods
class TestPreprocessing:
Expand Down
4 changes: 2 additions & 2 deletions tests/workflow/interfaces/test_jacobian_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_device_jacobians_repr(self):
r" use_device_jacobian_product=None,"
r" gradient_method='adjoint', gradient_keyword_arguments={},"
r" device_options={}, interface=<Interface.NUMPY: 'numpy'>, derivative_order=1,"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>"
)

assert repr(jpc) == expected
Expand All @@ -155,7 +155,7 @@ def test_device_jacobian_products_repr(self):
r" use_device_jacobian_product=None,"
r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={},"
r" interface=<Interface.NUMPY: 'numpy'>, derivative_order=1,"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>"
)

assert repr(jpc) == expected
Expand Down
23 changes: 11 additions & 12 deletions tests/workflow/test_setup_transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def test_prune_dynamic_transform_warning_raised():

def test_interface_data_not_supported():
"""Test that convert_to_numpy_parameters transform is correctly added."""
config = ExecutionConfig()
config.interface = "autograd"
config.gradient_method = "adjoint"
config = ExecutionConfig(interface="autograd", gradient_method="adjoint")
device = qml.device("default.qubit")

user_transform_program = TransformProgram()
Expand All @@ -154,39 +152,40 @@ def test_interface_data_not_supported():

def test_interface_data_supported():
"""Test that convert_to_numpy_parameters transform is not added for these cases."""
config = ExecutionConfig()
config = ExecutionConfig(interface="autograd", gradient_method=None)

config.interface = "autograd"
config.gradient_method = None
device = qml.device("default.mixed", wires=1)

user_transform_program = TransformProgram()
_, inner_tp = _setup_transform_program(user_transform_program, device, config)

assert qml.transforms.convert_to_numpy_parameters not in inner_tp

config = ExecutionConfig()
config = ExecutionConfig(interface="autograd", gradient_method="backprop")

config.interface = "autograd"
config.gradient_method = "backprop"
device = qml.device("default.qubit")

user_transform_program = TransformProgram()
_, inner_tp = _setup_transform_program(user_transform_program, device, config)

assert qml.transforms.convert_to_numpy_parameters not in inner_tp

config = ExecutionConfig()
config = ExecutionConfig(interface=None, gradient_method="backprop")

config.interface = None
config.gradient_method = "backprop"
device = qml.device("default.qubit")

user_transform_program = TransformProgram()
_, inner_tp = _setup_transform_program(user_transform_program, device, config)

assert qml.transforms.convert_to_numpy_parameters not in inner_tp

config = ExecutionConfig(
convert_to_numpy=False, interface="jax", gradient_method=qml.gradients.param_shift
)

_, inner_tp = _setup_transform_program(TransformProgram(), device, config)
assert qml.transforms.convert_to_numpy_parameters not in inner_tp


def test_cache_handling():
"""Test that caching is handled correctly."""
Expand Down

0 comments on commit cef0e31

Please sign in to comment.