diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index dc4a79e20da..ffff841ad73 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -26,7 +26,7 @@ import warnings from collections.abc import Callable, MutableMapping, Sequence from functools import partial -from typing import Optional, Union +from typing import Literal, Optional, Union, get_args from cachetools import Cache, LRUCache @@ -61,23 +61,42 @@ "tensorflow", } -INTERFACE_MAP = { - None: "Numpy", - "auto": "auto", - "autograd": "autograd", - "numpy": "autograd", - "scipy": "numpy", - "jax": "jax", - "jax-jit": "jax", - "jax-python": "jax", - "JAX": "jax", - "torch": "torch", - "pytorch": "torch", - "tf": "tf", - "tensorflow": "tf", - "tensorflow-autograph": "tf", - "tf-autograph": "tf", -} +SupportedInterfaceUserInput = Literal[ + None, + "auto", + "autograd", + "numpy", + "scipy", + "jax", + "jax-jit", + "jax-python", + "JAX", + "torch", + "pytorch", + "tf", + "tensorflow", + "tensorflow-autograph", + "tf-autograph", +] + +_mapping_output = ( + "Numpy", + "auto", + "autograd", + "autograd", + "numpy", + "jax", + "jax", + "jax", + "jax", + "torch", + "torch", + "tf", + "tf", + "tf", + "tf", +) +INTERFACE_MAP = dict(zip(get_args(SupportedInterfaceUserInput), _mapping_output)) """dict[str, str]: maps an allowed interface specification to its canonical name.""" #: list[str]: allowed interface strings diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 91432dafc54..92b1818cff4 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -20,8 +20,8 @@ import inspect import logging import warnings -from collections.abc import Sequence -from typing import Optional, Union +from collections.abc import Callable, MutableMapping, Sequence +from typing import Any, Literal, Optional, Union, get_args import pennylane as qml from pennylane import Device @@ -29,12 +29,25 @@ from pennylane.logging import debug_logger from pennylane.measurements import CountsMP, MidMeasureMP, Shots from pennylane.tape import QuantumScript, QuantumTape +from pennylane.transforms.core import TransformContainer, TransformDispatcher, TransformProgram -from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES +from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES, SupportedInterfaceUserInput logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) +SupportedDiffMethods = Literal[ + None, + "best", + "device", + "backprop", + "adjoint", + "parameter-shift", + "hadamard", + "finite-diff", + "spsa", +] + def _convert_to_interface(res, interface): """ @@ -119,7 +132,7 @@ class QNode: device (~.Device): a PennyLane-compatible device interface (str): The interface that will be used for classical backpropagation. This affects the types of objects that can be passed to/returned from the QNode. See - ``qml.workflow.SUPPORTED_INTERFACES`` for a list of all accepted strings. + ``qml.workflow.SUPPORTED_INTERFACE_USER_INPUT`` for a list of all accepted strings. * ``"autograd"``: Allows autograd to backpropagate through the QNode. The QNode accepts default Python types @@ -451,19 +464,19 @@ def circuit_unpacking(x): def __init__( self, - func, + func: Callable, device: Union[Device, "qml.devices.Device"], - interface="auto", - diff_method="best", - expansion_strategy=None, - max_expansion=None, - grad_on_execution="best", - cache="auto", - cachesize=10000, - max_diff=1, - device_vjp=False, - postselect_mode=None, - mcm_method=None, + interface: SupportedInterfaceUserInput = "auto", + diff_method: Union[TransformDispatcher, SupportedDiffMethods] = "best", + expansion_strategy: Literal[None, "device", "gradient"] = None, + max_expansion: Optional[int] = None, + grad_on_execution: Literal[True, False, "best"] = "best", + cache: Union[MutableMapping, Literal["auto", True, False]] = "auto", + cachesize: int = 10000, + max_diff: int = 1, + device_vjp: Union[None, bool] = False, + postselect_mode: Literal[None, "hw-like", "fill-shots"] = None, + mcm_method: Literal[None, "deferred", "one-shot", "tree-traversal"] = None, **gradient_kwargs, ): # Moving it here since the old default value is checked on debugging @@ -580,11 +593,11 @@ def __init__( self.gradient_kwargs = {} self._tape_cached = False - self._transform_program = qml.transforms.core.TransformProgram() + self._transform_program = TransformProgram() self._update_gradient_fn() functools.update_wrapper(self, func) - def __copy__(self): + def __copy__(self) -> "QNode": copied_qnode = QNode.__new__(QNode) for attr, value in vars(self).items(): if attr not in {"execute_kwargs", "_transform_program", "gradient_kwargs"}: @@ -597,7 +610,7 @@ def __copy__(self): copied_qnode.gradient_kwargs = dict(self.gradient_kwargs) return copied_qnode - def __repr__(self): + def __repr__(self) -> str: """String representation.""" if isinstance(self.device, qml.devices.Device): return f"" @@ -611,13 +624,14 @@ def __repr__(self): ) @property - def interface(self): + def interface(self) -> str: """The interface used by the QNode""" return self._interface @interface.setter - def interface(self, value): + def interface(self, value: SupportedInterfaceUserInput): if value not in SUPPORTED_INTERFACES: + raise qml.QuantumFunctionError( f"Unknown interface {value}. Interface must be one of {SUPPORTED_INTERFACES}." ) @@ -626,19 +640,19 @@ def interface(self, value): self._update_gradient_fn(shots=self.device.shots) @property - def transform_program(self): + def transform_program(self) -> TransformProgram: """The transform program used by the QNode.""" return self._transform_program @debug_logger - def add_transform(self, transform_container): + def add_transform(self, transform_container: TransformContainer): """Add a transform (container) to the transform program. .. warning:: This is a developer facing feature and is called when a transform is applied on a QNode. """ self._transform_program.push_back(transform_container=transform_container) - def _update_gradient_fn(self, shots=None, tape=None): + def _update_gradient_fn(self, shots=None, tape: Optional["qml.tape.QuantumTape"] = None): if self.diff_method is None: self._interface = None self.gradient_fn = None @@ -681,7 +695,10 @@ def _update_original_device(self): @staticmethod @debug_logger def get_gradient_fn( - device, interface, diff_method="best", tape: Optional["qml.tape.QuantumTape"] = None + device: Union[Device, "qml.devices.Device"], + interface, + diff_method: Union[TransformDispatcher, SupportedDiffMethods] = "best", + tape: Optional["qml.tape.QuantumTape"] = None, ): """Determine the best differentiation method, interface, and device for a requested device, interface, and diff method. @@ -739,8 +756,7 @@ def get_gradient_fn( if isinstance(diff_method, str): raise qml.QuantumFunctionError( f"Differentiation method {diff_method} not recognized. Allowed " - "options are ('best', 'parameter-shift', 'backprop', 'finite-diff', " - "'device', 'adjoint', 'spsa', 'hadamard')." + f"options are {tuple(get_args(SupportedDiffMethods))}." ) if isinstance(diff_method, qml.transforms.core.TransformDispatcher): @@ -752,7 +768,15 @@ def get_gradient_fn( @staticmethod @debug_logger - def get_best_method(device, interface, tape=None): + def get_best_method( + device: Union[Device, "qml.devices.Device"], + interface, + tape: Optional["qml.tape.QuantumTape"] = None, + ) -> tuple[ + Union[TransformDispatcher, Literal["device", "backprop", "parameter-shift", "finite-diff"]], + dict[str, Any], + Union[Device, "qml.devices.Device"], + ]: """Returns the 'best' differentiation method for a particular device and interface combination. @@ -799,7 +823,7 @@ def get_best_method(device, interface, tape=None): @staticmethod @debug_logger - def best_method_str(device, interface): + def best_method_str(device: Union[Device, "qml.devices.Device"], interface) -> str: """Similar to :meth:`~.get_best_method`, except return the 'best' differentiation method in human-readable format. @@ -838,7 +862,7 @@ def best_method_str(device, interface): @staticmethod @debug_logger - def _validate_backprop_method(device, interface, tape=None): + def _validate_backprop_method(device, interface, tape: Optional["qml.tape.QuantumTape"] = None): if isinstance(device, qml.devices.Device): raise ValueError( "QNode._validate_backprop_method only applies to the qml.Device interface." diff --git a/tests/devices/test_default_qubit_torch.py b/tests/devices/test_default_qubit_torch.py index 3a132070038..b423da6eaea 100644 --- a/tests/devices/test_default_qubit_torch.py +++ b/tests/devices/test_default_qubit_torch.py @@ -2234,12 +2234,10 @@ def circuit(x, w=None): qml.RZ(x, wires=w) return qml.expval(qml.PauliX(w)) - with pytest.raises(Exception) as e: + with pytest.raises( + qml.QuantumFunctionError, match="Differentiation method autograd not recognized" + ): assert qml.qnode(dev, diff_method="autograd", interface=interface)(circuit) - assert str(e.value) == ( - "Differentiation method autograd not recognized. Allowed options are ('best', " - "'parameter-shift', 'backprop', 'finite-diff', 'device', 'adjoint', 'spsa', 'hadamard')." - ) @pytest.mark.torch