Skip to content

Commit

Permalink
Added type hints to QNode public functions (#6084)
Browse files Browse the repository at this point in the history
**Description of the Change:**
- Added `Literal` type `SUPPORTED_INTERFACE_USER_INPUTS` to
`execution.py` and use that to populate the `INTERFACE_MAP` through the
`get_args` method.
- Added `SUPPORTED_DIFF_METHODS` `Literal` type in the `QNode` file.
- Added type hints for the rest of the arguments.

[[sc-66109](https://app.shortcut.com/xanaduai/story/66109)]

---------

Co-authored-by: Mudit Pandey <[email protected]>
Co-authored-by: Christina Lee <[email protected]>
  • Loading branch information
3 people authored Aug 13, 2024
1 parent 9af4c37 commit afd8433
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 53 deletions.
55 changes: 37 additions & 18 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import warnings
from collections.abc import Callable, MutableMapping, Sequence
from functools import partial
from typing import Optional, Union
from typing import Literal, Optional, Union, get_args

from cachetools import Cache, LRUCache

Expand Down Expand Up @@ -61,23 +61,42 @@
"tensorflow",
}

INTERFACE_MAP = {
None: "Numpy",
"auto": "auto",
"autograd": "autograd",
"numpy": "autograd",
"scipy": "numpy",
"jax": "jax",
"jax-jit": "jax",
"jax-python": "jax",
"JAX": "jax",
"torch": "torch",
"pytorch": "torch",
"tf": "tf",
"tensorflow": "tf",
"tensorflow-autograph": "tf",
"tf-autograph": "tf",
}
SupportedInterfaceUserInput = Literal[
None,
"auto",
"autograd",
"numpy",
"scipy",
"jax",
"jax-jit",
"jax-python",
"JAX",
"torch",
"pytorch",
"tf",
"tensorflow",
"tensorflow-autograph",
"tf-autograph",
]

_mapping_output = (
"Numpy",
"auto",
"autograd",
"autograd",
"numpy",
"jax",
"jax",
"jax",
"jax",
"torch",
"torch",
"tf",
"tf",
"tf",
"tf",
)
INTERFACE_MAP = dict(zip(get_args(SupportedInterfaceUserInput), _mapping_output))
"""dict[str, str]: maps an allowed interface specification to its canonical name."""

#: list[str]: allowed interface strings
Expand Down
84 changes: 54 additions & 30 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,34 @@
import inspect
import logging
import warnings
from collections.abc import Sequence
from typing import Optional, Union
from collections.abc import Callable, MutableMapping, Sequence
from typing import Any, Literal, Optional, Union, get_args

import pennylane as qml
from pennylane import Device
from pennylane.debugging import pldb_device_manager
from pennylane.logging import debug_logger
from pennylane.measurements import CountsMP, MidMeasureMP, Shots
from pennylane.tape import QuantumScript, QuantumTape
from pennylane.transforms.core import TransformContainer, TransformDispatcher, TransformProgram

from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES
from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES, SupportedInterfaceUserInput

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

SupportedDiffMethods = Literal[
None,
"best",
"device",
"backprop",
"adjoint",
"parameter-shift",
"hadamard",
"finite-diff",
"spsa",
]


def _convert_to_interface(res, interface):
"""
Expand Down Expand Up @@ -119,7 +132,7 @@ class QNode:
device (~.Device): a PennyLane-compatible device
interface (str): The interface that will be used for classical backpropagation.
This affects the types of objects that can be passed to/returned from the QNode. See
``qml.workflow.SUPPORTED_INTERFACES`` for a list of all accepted strings.
``qml.workflow.SUPPORTED_INTERFACE_USER_INPUT`` for a list of all accepted strings.
* ``"autograd"``: Allows autograd to backpropagate
through the QNode. The QNode accepts default Python types
Expand Down Expand Up @@ -451,19 +464,19 @@ def circuit_unpacking(x):

def __init__(
self,
func,
func: Callable,
device: Union[Device, "qml.devices.Device"],
interface="auto",
diff_method="best",
expansion_strategy=None,
max_expansion=None,
grad_on_execution="best",
cache="auto",
cachesize=10000,
max_diff=1,
device_vjp=False,
postselect_mode=None,
mcm_method=None,
interface: SupportedInterfaceUserInput = "auto",
diff_method: Union[TransformDispatcher, SupportedDiffMethods] = "best",
expansion_strategy: Literal[None, "device", "gradient"] = None,
max_expansion: Optional[int] = None,
grad_on_execution: Literal[True, False, "best"] = "best",
cache: Union[MutableMapping, Literal["auto", True, False]] = "auto",
cachesize: int = 10000,
max_diff: int = 1,
device_vjp: Union[None, bool] = False,
postselect_mode: Literal[None, "hw-like", "fill-shots"] = None,
mcm_method: Literal[None, "deferred", "one-shot", "tree-traversal"] = None,
**gradient_kwargs,
):
# Moving it here since the old default value is checked on debugging
Expand Down Expand Up @@ -580,11 +593,11 @@ def __init__(
self.gradient_kwargs = {}
self._tape_cached = False

self._transform_program = qml.transforms.core.TransformProgram()
self._transform_program = TransformProgram()
self._update_gradient_fn()
functools.update_wrapper(self, func)

def __copy__(self):
def __copy__(self) -> "QNode":
copied_qnode = QNode.__new__(QNode)
for attr, value in vars(self).items():
if attr not in {"execute_kwargs", "_transform_program", "gradient_kwargs"}:
Expand All @@ -597,7 +610,7 @@ def __copy__(self):
copied_qnode.gradient_kwargs = dict(self.gradient_kwargs)
return copied_qnode

def __repr__(self):
def __repr__(self) -> str:
"""String representation."""
if isinstance(self.device, qml.devices.Device):
return f"<QNode: device='{self.device}', interface='{self.interface}', diff_method='{self.diff_method}'>"
Expand All @@ -611,13 +624,14 @@ def __repr__(self):
)

@property
def interface(self):
def interface(self) -> str:
"""The interface used by the QNode"""
return self._interface

@interface.setter
def interface(self, value):
def interface(self, value: SupportedInterfaceUserInput):
if value not in SUPPORTED_INTERFACES:

raise qml.QuantumFunctionError(
f"Unknown interface {value}. Interface must be one of {SUPPORTED_INTERFACES}."
)
Expand All @@ -626,19 +640,19 @@ def interface(self, value):
self._update_gradient_fn(shots=self.device.shots)

@property
def transform_program(self):
def transform_program(self) -> TransformProgram:
"""The transform program used by the QNode."""
return self._transform_program

@debug_logger
def add_transform(self, transform_container):
def add_transform(self, transform_container: TransformContainer):
"""Add a transform (container) to the transform program.
.. warning:: This is a developer facing feature and is called when a transform is applied on a QNode.
"""
self._transform_program.push_back(transform_container=transform_container)

def _update_gradient_fn(self, shots=None, tape=None):
def _update_gradient_fn(self, shots=None, tape: Optional["qml.tape.QuantumTape"] = None):
if self.diff_method is None:
self._interface = None
self.gradient_fn = None
Expand Down Expand Up @@ -681,7 +695,10 @@ def _update_original_device(self):
@staticmethod
@debug_logger
def get_gradient_fn(
device, interface, diff_method="best", tape: Optional["qml.tape.QuantumTape"] = None
device: Union[Device, "qml.devices.Device"],
interface,
diff_method: Union[TransformDispatcher, SupportedDiffMethods] = "best",
tape: Optional["qml.tape.QuantumTape"] = None,
):
"""Determine the best differentiation method, interface, and device
for a requested device, interface, and diff method.
Expand Down Expand Up @@ -739,8 +756,7 @@ def get_gradient_fn(
if isinstance(diff_method, str):
raise qml.QuantumFunctionError(
f"Differentiation method {diff_method} not recognized. Allowed "
"options are ('best', 'parameter-shift', 'backprop', 'finite-diff', "
"'device', 'adjoint', 'spsa', 'hadamard')."
f"options are {tuple(get_args(SupportedDiffMethods))}."
)

if isinstance(diff_method, qml.transforms.core.TransformDispatcher):
Expand All @@ -752,7 +768,15 @@ def get_gradient_fn(

@staticmethod
@debug_logger
def get_best_method(device, interface, tape=None):
def get_best_method(
device: Union[Device, "qml.devices.Device"],
interface,
tape: Optional["qml.tape.QuantumTape"] = None,
) -> tuple[
Union[TransformDispatcher, Literal["device", "backprop", "parameter-shift", "finite-diff"]],
dict[str, Any],
Union[Device, "qml.devices.Device"],
]:
"""Returns the 'best' differentiation method
for a particular device and interface combination.
Expand Down Expand Up @@ -799,7 +823,7 @@ def get_best_method(device, interface, tape=None):

@staticmethod
@debug_logger
def best_method_str(device, interface):
def best_method_str(device: Union[Device, "qml.devices.Device"], interface) -> str:
"""Similar to :meth:`~.get_best_method`, except return the
'best' differentiation method in human-readable format.
Expand Down Expand Up @@ -838,7 +862,7 @@ def best_method_str(device, interface):

@staticmethod
@debug_logger
def _validate_backprop_method(device, interface, tape=None):
def _validate_backprop_method(device, interface, tape: Optional["qml.tape.QuantumTape"] = None):
if isinstance(device, qml.devices.Device):
raise ValueError(
"QNode._validate_backprop_method only applies to the qml.Device interface."
Expand Down
8 changes: 3 additions & 5 deletions tests/devices/test_default_qubit_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2234,12 +2234,10 @@ def circuit(x, w=None):
qml.RZ(x, wires=w)
return qml.expval(qml.PauliX(w))

with pytest.raises(Exception) as e:
with pytest.raises(
qml.QuantumFunctionError, match="Differentiation method autograd not recognized"
):
assert qml.qnode(dev, diff_method="autograd", interface=interface)(circuit)
assert str(e.value) == (
"Differentiation method autograd not recognized. Allowed options are ('best', "
"'parameter-shift', 'backprop', 'finite-diff', 'device', 'adjoint', 'spsa', 'hadamard')."
)


@pytest.mark.torch
Expand Down

0 comments on commit afd8433

Please sign in to comment.