From f0adb760ce70fe24be03fa58a61506402d5e38d1 Mon Sep 17 00:00:00 2001 From: Gabriel Matos Date: Thu, 16 Nov 2023 11:25:19 +0000 Subject: [PATCH] Improvements and fixes to type annotations (#114) * Update jnp.ndarray to jax.Array types * Specify jax>=0.4.1 as package requirement * Add missing `Optional` qualifier on some type annotations * Update `test_gates.py` to reflect change in array type * Refactor types into `typing.py` * Introduce several new type aliases * Fix several type warnings raised by `mypy` and `pylance` * Add `typing_extensions` to package requirements --- docs/conf.py | 1 - qujax/__init__.py | 3 +- qujax/densitytensor.py | 57 ++++++++------ qujax/densitytensor_observable.py | 34 +++++---- qujax/gates.py | 37 ++++----- qujax/statetensor.py | 50 ++++++++----- qujax/statetensor_observable.py | 41 +++++----- qujax/typing.py | 45 +++++++++++ qujax/utils.py | 120 +++++++++++++++++------------- setup.py | 2 +- tests/test_densitytensor.py | 2 +- tests/test_gates.py | 3 +- 12 files changed, 240 insertions(+), 155 deletions(-) create mode 100644 qujax/typing.py diff --git a/docs/conf.py b/docs/conf.py index 9dd0111..2798d50 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -40,7 +40,6 @@ autodoc_typehints = "description" autodoc_type_aliases = { - "jnp.ndarray": "ndarray", "random.PRNGKeyArray": "jax.random.PRNGKeyArray", "UnionCallableOptionalArray": "Union[Callable[[ndarray, Optional[ndarray]], ndarray], " "Callable[[Optional[ndarray]], ndarray]]", diff --git a/qujax/__init__.py b/qujax/__init__.py index 1b565b4..5d75a16 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -28,7 +28,6 @@ from qujax.densitytensor_observable import densitytensor_to_measurement_probabilities from qujax.densitytensor_observable import densitytensor_to_measured_densitytensor -from qujax.utils import UnionCallableOptionalArray from qujax.utils import check_unitary from qujax.utils import check_hermitian from qujax.utils import check_circuit @@ -40,6 +39,8 @@ from qujax.utils import sample_bitstrings from qujax.utils import statetensor_to_densitytensor +import qujax.typing + # pylint: disable=undefined-variable del version del statetensor diff --git a/qujax/densitytensor.py b/qujax/densitytensor.py index 80f7522..ef35102 100644 --- a/qujax/densitytensor.py +++ b/qujax/densitytensor.py @@ -1,25 +1,31 @@ from __future__ import annotations -from typing import Callable, Iterable, Sequence, Tuple, Union +from typing import Iterable, Sequence, Tuple, Union, Optional import jax from jax import numpy as jnp +from jax.typing import ArrayLike from jax.lax import scan from jax._src.dtypes import canonicalize_dtype from jax._src.typing import DTypeLike from qujax.statetensor import ( - UnionCallableOptionalArray, _arrayify_inds, _gate_func_to_unitary, _to_gate_func, apply_gate, ) -from qujax.utils import KrausOp, check_circuit +from qujax.utils import check_circuit +from qujax.typing import ( + MixedCircuitFunction, + KrausOp, + GateFunction, + GateParameterIndices, +) def _kraus_single( - densitytensor: jnp.ndarray, array: jnp.ndarray, qubit_inds: Sequence[int] -) -> jnp.ndarray: + densitytensor: jax.Array, array: jax.Array, qubit_inds: Sequence[int] +) -> jax.Array: r""" Performs single Kraus operation @@ -44,8 +50,8 @@ def _kraus_single( def kraus( - densitytensor: jnp.ndarray, arrays: Iterable[jnp.ndarray], qubit_inds: Sequence[int] -) -> jnp.ndarray: + densitytensor: jax.Array, arrays: Iterable[jax.Array], qubit_inds: Sequence[int] +) -> jax.Array: r""" Performs Kraus operation. @@ -78,8 +84,9 @@ def kraus( def _to_kraus_operator_seq_funcs( - kraus_op: KrausOp, param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]] -) -> Tuple[Sequence[Callable[[jnp.ndarray], jnp.ndarray]], Sequence[jnp.ndarray]]: + kraus_op: KrausOp, + param_inds: Optional[Union[GateParameterIndices, Sequence[GateParameterIndices]]], +) -> Tuple[Sequence[GateFunction], Sequence[jax.Array]]: """ Ensures Kraus operators are a sequence of functions that map (possibly empty) parameters to tensors and that each element of param_inds_seq is a sequence of arrays that correspond to the @@ -96,20 +103,21 @@ def _to_kraus_operator_seq_funcs( and sequence of arrays with parameter indices """ - if param_inds is None: - param_inds = [None for _ in kraus_op] - if isinstance(kraus_op, (list, tuple)): kraus_op_funcs = [_to_gate_func(ko) for ko in kraus_op] - else: + if param_inds is None: + param_inds = [None for _ in kraus_op] + elif isinstance(kraus_op, (str, jax.Array)) or callable(kraus_op): kraus_op_funcs = [_to_gate_func(kraus_op)] param_inds = [param_inds] + else: + raise ValueError(f"Invalid Kraus operator specification: {kraus_op}") return kraus_op_funcs, _arrayify_inds(param_inds) def partial_trace( - densitytensor: jnp.ndarray, indices_to_trace: Sequence[int] -) -> jnp.ndarray: + densitytensor: jax.Array, indices_to_trace: Sequence[int] +) -> jax.Array: """ Traces out (discards) specified qubits, resulting in a densitytensor representing the mixed quantum state on the remaining qubits. @@ -149,9 +157,11 @@ def all_zeros_densitytensor(n_qubits: int, dtype: DTypeLike = complex) -> jax.Ar def get_params_to_densitytensor_func( kraus_ops_seq: Sequence[KrausOp], qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Union[None, Sequence[int], Sequence[Sequence[int]]]], - n_qubits: int = None, -) -> UnionCallableOptionalArray: + param_inds_seq: Sequence[ + Union[GateParameterIndices, Sequence[GateParameterIndices]] + ], + n_qubits: Optional[int] = None, +) -> MixedCircuitFunction: """ Creates a function that maps circuit parameters to a density tensor (a density matrix in tensor form). @@ -195,8 +205,8 @@ def get_params_to_densitytensor_func( param_inds_array_seq = [ko_pi[1] for ko_pi in kraus_ops_seq_callable_and_param_inds] def params_to_densitytensor_func( - params: jnp.ndarray, densitytensor_in: jnp.ndarray = None - ) -> jnp.ndarray: + params: ArrayLike, densitytensor_in: Optional[jax.Array] = None + ) -> jax.Array: """ Applies parameterised circuit (series of gates) to a densitytensor_in (default is |0>^N <0|^N). @@ -216,6 +226,9 @@ def params_to_densitytensor_func( else: densitytensor = densitytensor_in params = jnp.atleast_1d(params) + # Guarantee `params` has the right type for type-checking purposes + if not isinstance(params, jax.Array): + raise ValueError("This should not happen. Please open an issue on GitHub.") for gate_func_single_seq, qubit_inds, param_inds_single_seq in zip( kraus_ops_seq_callable, qubit_inds_seq, param_inds_array_seq ): @@ -232,8 +245,8 @@ def params_to_densitytensor_func( if non_parameterised: def no_params_to_densitytensor_func( - densitytensor_in: jnp.ndarray = None, - ) -> jnp.ndarray: + densitytensor_in: Optional[jax.Array] = None, + ) -> jax.Array: """ Applies circuit (series of gates with no parameters) to a densitytensor_in (default is |0>^N <0|^N). diff --git a/qujax/densitytensor_observable.py b/qujax/densitytensor_observable.py index 9e0ca62..8219db0 100644 --- a/qujax/densitytensor_observable.py +++ b/qujax/densitytensor_observable.py @@ -2,6 +2,8 @@ from typing import Callable, Sequence, Union +import jax +from jax.typing import ArrayLike from jax import numpy as jnp from jax import random @@ -11,8 +13,8 @@ def densitytensor_to_single_expectation( - densitytensor: jnp.ndarray, hermitian: jnp.ndarray, qubit_inds: Sequence[int] -) -> float: + densitytensor: jax.Array, hermitian: jax.Array, qubit_inds: Sequence[int] +) -> jax.Array: """ Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). @@ -35,10 +37,10 @@ def densitytensor_to_single_expectation( def get_densitytensor_to_expectation_func( - hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]], qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray], -) -> Callable[[jnp.ndarray], float]: + coefficients: Union[Sequence[float], jax.Array], +) -> Callable[[jax.Array], float]: """ Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and a list of coefficients and returns a function that converts a densitytensor into an @@ -46,7 +48,7 @@ def get_densitytensor_to_expectation_func( Args: hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian matrix is either represented by a tensor (jnp.ndarray) + Each Hermitian matrix is either represented by a tensor (jax.Array) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. E.g. [['Z', 'Z'], ['X']] qubits_seq_seq: Sequence of sequences of integer qubit indices. @@ -66,10 +68,10 @@ def get_densitytensor_to_expectation_func( def get_densitytensor_to_sampled_expectation_func( - hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]], qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray], -) -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: + coefficients: Union[Sequence[float], jax.Array], +) -> Callable[[jax.Array, random.PRNGKeyArray, int], float]: """ Converts strings (or arrays) representing Hermitian matrices, qubit indices and coefficients into a function that converts a densitytensor into a sampled expected value. @@ -85,7 +87,7 @@ def get_densitytensor_to_sampled_expectation_func( Args: hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). + Each Hermitian is either a tensor (jax.Array) or a string in ('X', 'Y', 'Z'). E.g. [['Z', 'Z'], ['X']] qubits_seq_seq: Sequence of sequences of integer qubit indices. E.g. [[0,1], [2]] @@ -104,7 +106,7 @@ def get_densitytensor_to_sampled_expectation_func( check_hermitian(h, check_z_commutes=True) def densitytensor_to_sampled_expectation_func( - densitytensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int + densitytensor: jax.Array, random_key: random.PRNGKeyArray, n_samps: int ) -> float: """ Maps densitytensor to sampled expected value. @@ -131,8 +133,8 @@ def densitytensor_to_sampled_expectation_func( def densitytensor_to_measurement_probabilities( - densitytensor: jnp.ndarray, qubit_inds: Sequence[int] -) -> jnp.ndarray: + densitytensor: jax.Array, qubit_inds: Sequence[int] +) -> jax.Array: """ Extract array of measurement probabilities given a densitytensor and some qubit indices to measure (in the computational basis). @@ -157,10 +159,10 @@ def densitytensor_to_measurement_probabilities( def densitytensor_to_measured_densitytensor( - densitytensor: jnp.ndarray, + densitytensor: jax.Array, qubit_inds: Sequence[int], - measurement: Union[int, jnp.ndarray], -) -> jnp.ndarray: + measurement: ArrayLike, +) -> jax.Array: """ Returns the post-measurement densitytensor assuming that qubit_inds are measured (in the computational basis) and the given measurement (integer or bitstring) is observed. diff --git a/qujax/gates.py b/qujax/gates.py index ae88fcf..fea9754 100644 --- a/qujax/gates.py +++ b/qujax/gates.py @@ -1,3 +1,4 @@ +import jax from jax import numpy as jnp I = jnp.eye(2) @@ -64,42 +65,42 @@ ) -def Rx(param: float) -> jnp.ndarray: +def Rx(param: float) -> jax.Array: param_pi_2 = param * jnp.pi / 2 return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * X * 1.0j -def Ry(param: float) -> jnp.ndarray: +def Ry(param: float) -> jax.Array: param_pi_2 = param * jnp.pi / 2 return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Y * 1.0j -def Rz(param: float) -> jnp.ndarray: +def Rz(param: float) -> jax.Array: param_pi_2 = param * jnp.pi / 2 return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Z * 1.0j -def CRx(param: float) -> jnp.ndarray: +def CRx(param: float) -> jax.Array: return jnp.block([[I, _0], [_0, Rx(param)]]).reshape((2,) * 4) -def CRy(param: float) -> jnp.ndarray: +def CRy(param: float) -> jax.Array: return jnp.block([[I, _0], [_0, Ry(param)]]).reshape((2,) * 4) -def CRz(param: float) -> jnp.ndarray: +def CRz(param: float) -> jax.Array: return jnp.block([[I, _0], [_0, Rz(param)]]).reshape((2,) * 4) -def U1(param: float) -> jnp.ndarray: +def U1(param: float) -> jax.Array: return U3(0, 0, param) -def U2(param0: float, param1: float) -> jnp.ndarray: +def U2(param0: float, param1: float) -> jax.Array: return U3(0.5, param0, param1) -def U3(param0: float, param1: float, param2: float) -> jnp.ndarray: +def U3(param0: float, param1: float, param2: float) -> jax.Array: return ( jnp.exp((param1 + param2) * jnp.pi * 1.0j / 2) * Rz(param1) @@ -108,19 +109,19 @@ def U3(param0: float, param1: float, param2: float) -> jnp.ndarray: ) -def CU1(param: float) -> jnp.ndarray: +def CU1(param: float) -> jax.Array: return jnp.block([[I, _0], [_0, U1(param)]]).reshape((2,) * 4) -def CU2(param0: float, param1: float) -> jnp.ndarray: +def CU2(param0: float, param1: float) -> jax.Array: return jnp.block([[I, _0], [_0, U2(param0, param1)]]).reshape((2,) * 4) -def CU3(param0: float, param1: float, param2: float) -> jnp.ndarray: +def CU3(param0: float, param1: float, param2: float) -> jax.Array: return jnp.block([[I, _0], [_0, U3(param0, param1, param2)]]).reshape((2,) * 4) -def ISWAP(param: float) -> jnp.ndarray: +def ISWAP(param: float) -> jax.Array: param_pi_2 = param * jnp.pi / 2 c = jnp.cos(param_pi_2) i_s = 1.0j * jnp.sin(param_pi_2) @@ -134,7 +135,7 @@ def ISWAP(param: float) -> jnp.ndarray: ).reshape((2,) * 4) -def PhasedISWAP(param0: float, param1: float) -> jnp.ndarray: +def PhasedISWAP(param0: float, param1: float) -> jax.Array: param1_pi_2 = param1 * jnp.pi / 2 c = jnp.cos(param1_pi_2) i_s = 1.0j * jnp.sin(param1_pi_2) @@ -148,7 +149,7 @@ def PhasedISWAP(param0: float, param1: float) -> jnp.ndarray: ).reshape((2,) * 4) -def XXPhase(param: float) -> jnp.ndarray: +def XXPhase(param: float) -> jax.Array: param_pi_2 = param * jnp.pi / 2 c = jnp.cos(param_pi_2) i_s = 1.0j * jnp.sin(param_pi_2) @@ -162,7 +163,7 @@ def XXPhase(param: float) -> jnp.ndarray: ).reshape((2,) * 4) -def YYPhase(param: float) -> jnp.ndarray: +def YYPhase(param: float) -> jax.Array: param_pi_2 = param * jnp.pi / 2 c = jnp.cos(param_pi_2) i_s = 1.0j * jnp.sin(param_pi_2) @@ -176,7 +177,7 @@ def YYPhase(param: float) -> jnp.ndarray: ).reshape((2,) * 4) -def ZZPhase(param: float) -> jnp.ndarray: +def ZZPhase(param: float) -> jax.Array: param_pi_2 = param * jnp.pi / 2 e_m = jnp.exp(-1.0j * param_pi_2) e_p = jnp.exp(1.0j * param_pi_2) @@ -186,5 +187,5 @@ def ZZPhase(param: float) -> jnp.ndarray: ZZMax = ZZPhase(0.5) -def PhasedX(param0: float, param1: float) -> jnp.ndarray: +def PhasedX(param0: float, param1: float) -> jax.Array: return Rz(param1) @ Rx(param0) @ Rz(-param1) diff --git a/qujax/statetensor.py b/qujax/statetensor.py index 0796b9b..ee95299 100644 --- a/qujax/statetensor.py +++ b/qujax/statetensor.py @@ -1,20 +1,23 @@ from __future__ import annotations from functools import partial -from typing import Callable, Sequence, Union +from typing import Callable, Sequence, Optional import jax from jax import numpy as jnp +from jax.typing import ArrayLike from jax._src.dtypes import canonicalize_dtype from jax._src.typing import DTypeLike from qujax import gates -from qujax.utils import Gate, UnionCallableOptionalArray, _arrayify_inds, check_circuit +from qujax.utils import _arrayify_inds, check_circuit + +from qujax.typing import Gate, PureCircuitFunction, GateFunction, GateParameterIndices def apply_gate( - statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int] -) -> jnp.ndarray: + statetensor: jax.Array, gate_unitary: jax.Array, qubit_inds: Sequence[int] +) -> jax.Array: """ Applies gate to statetensor and returns updated statetensor. Gate is represented by a unitary matrix in tensor form. @@ -36,7 +39,9 @@ def apply_gate( return statetensor -def _to_gate_func(gate: Gate) -> Callable[[jnp.ndarray], jnp.ndarray]: +def _to_gate_func( + gate: Gate, +) -> GateFunction: """ Ensures a gate_seq element is a function that map (possibly empty) parameters to a unitary tensor. @@ -50,7 +55,7 @@ def _to_gate_func(gate: Gate) -> Callable[[jnp.ndarray], jnp.ndarray]: Gate parameter to unitary functions """ - def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: + def _array_to_callable(arr: jax.Array) -> Callable[[], jax.Array]: return lambda: arr if isinstance(gate, str): @@ -69,11 +74,11 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: def _gate_func_to_unitary( - gate_func: Callable[[jnp.ndarray], jnp.ndarray], + gate_func: GateFunction, qubit_inds: Sequence[int], - param_inds: jnp.ndarray, - params: jnp.ndarray, -) -> jnp.ndarray: + param_inds: jax.Array, + params: jax.Array, +) -> jax.Array: """ Extract gate unitary. @@ -114,9 +119,9 @@ def all_zeros_statetensor(n_qubits: int, dtype: DTypeLike = complex) -> jax.Arra def get_params_to_statetensor_func( gate_seq: Sequence[Gate], qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Union[None, Sequence[int]]], - n_qubits: int = None, -) -> UnionCallableOptionalArray: + param_inds_seq: Sequence[GateParameterIndices], + n_qubits: Optional[int] = None, +) -> PureCircuitFunction: """ Creates a function that maps circuit parameters to a statetensor. @@ -151,8 +156,8 @@ def get_params_to_statetensor_func( param_inds_array_seq = _arrayify_inds(param_inds_seq) def params_to_statetensor_func( - params: jnp.ndarray, statetensor_in: jnp.ndarray = None - ) -> jnp.ndarray: + params: ArrayLike, statetensor_in: Optional[jax.Array] = None + ) -> jax.Array: """ Applies parameterised circuit (series of gates) to a statetensor_in (default is |0>^N). @@ -169,7 +174,12 @@ def params_to_statetensor_func( statetensor = all_zeros_statetensor(n_qubits) else: statetensor = statetensor_in + params = jnp.atleast_1d(params) + # Guarantee `params` has the right type for type-checking purposes + if not isinstance(params, jax.Array): + raise ValueError("This should not happen. Please open an issue on GitHub.") + for gate_func, qubit_inds, param_inds in zip( gate_seq_callable, qubit_inds_seq, param_inds_array_seq ): @@ -183,8 +193,8 @@ def params_to_statetensor_func( if non_parameterised: def no_params_to_statetensor_func( - statetensor_in: jnp.ndarray = None, - ) -> jnp.ndarray: + statetensor_in: Optional[jax.Array] = None, + ) -> jax.Array: """ Applies circuit (series of gates with no parameters) to a statetensor_in (default is |0>^N). @@ -208,9 +218,9 @@ def no_params_to_statetensor_func( def get_params_to_unitarytensor_func( gate_seq: Sequence[Gate], qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Union[None, Sequence[int]]], - n_qubits: int = None, -) -> Union[Callable[[], jnp.ndarray], Callable[[jnp.ndarray], jnp.ndarray]]: + param_inds_seq: Sequence[GateParameterIndices], + n_qubits: Optional[int] = None, +) -> PureCircuitFunction: """ Creates a function that maps circuit parameters to a unitarytensor. The unitarytensor is an array with shape (2,) * 2 * n_qubits diff --git a/qujax/statetensor_observable.py b/qujax/statetensor_observable.py index 4bce068..3b7fd6a 100644 --- a/qujax/statetensor_observable.py +++ b/qujax/statetensor_observable.py @@ -2,6 +2,7 @@ from typing import Callable, Sequence, Union +import jax from jax import numpy as jnp from jax import random from jax.lax import fori_loop @@ -11,8 +12,8 @@ def statetensor_to_single_expectation( - statetensor: jnp.ndarray, hermitian: jnp.ndarray, qubit_inds: Sequence[int] -) -> float: + statetensor: jax.Array, hermitian: jax.Array, qubit_inds: Sequence[int] +) -> jax.Array: """ Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). @@ -33,9 +34,7 @@ def statetensor_to_single_expectation( ).real -def get_hermitian_tensor( - hermitian_seq: Sequence[Union[str, jnp.ndarray]] -) -> jnp.ndarray: +def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jax.Array]]) -> jax.Array: """ Convert a sequence of observables represented by Pauli strings or Hermitian matrices in tensor form into single array (in tensor form). @@ -62,11 +61,11 @@ def get_hermitian_tensor( def _get_tensor_to_expectation_func( - hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]], qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray], + coefficients: Union[Sequence[float], jax.Array], contraction_function: Callable, -) -> Callable[[jnp.ndarray], float]: +) -> Callable[[jax.Array], float]: """ Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and a list of coefficients and returns a function that converts a tensor into an expected value. @@ -75,7 +74,7 @@ def _get_tensor_to_expectation_func( Args: hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a + Each Hermitian matrix is either represented by a tensor (jax.Array) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. E.g. [['Z', 'Z'], ['X']] qubits_seq_seq: Sequence of sequences of integer qubit indices. @@ -89,7 +88,7 @@ def _get_tensor_to_expectation_func( hermitian_tensors = [get_hermitian_tensor(h_seq) for h_seq in hermitian_seq_seq] - def tensor_to_expectation_func(tensor: jnp.ndarray) -> float: + def tensor_to_expectation_func(tensor: jax.Array) -> float: """ Maps tensor to expected value. @@ -110,10 +109,10 @@ def tensor_to_expectation_func(tensor: jnp.ndarray) -> float: def get_statetensor_to_expectation_func( - hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]], qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray], -) -> Callable[[jnp.ndarray], float]: + coefficients: Union[Sequence[float], jax.Array], +) -> Callable[[jax.Array], float]: """ Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and a list of coefficients and returns a function that converts a statetensor into an expected @@ -121,7 +120,7 @@ def get_statetensor_to_expectation_func( Args: hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian matrix is either represented by a tensor (jnp.ndarray) + Each Hermitian matrix is either represented by a tensor (jax.Array) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. E.g. [['Z', 'Z'], ['X']] qubits_seq_seq: Sequence of sequences of integer qubit indices. @@ -141,10 +140,10 @@ def get_statetensor_to_expectation_func( def get_statetensor_to_sampled_expectation_func( - hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]], qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray], -) -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: + coefficients: Union[Sequence[float], jax.Array], +) -> Callable[[jax.Array, random.PRNGKeyArray, int], float]: """ Converts strings (or arrays) representing Hermitian matrices, qubit indices and coefficients into a function that converts a statetensor into a sampled expected value. @@ -160,7 +159,7 @@ def get_statetensor_to_sampled_expectation_func( Args: hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). + Each Hermitian is either a tensor (jax.Array) or a string in ('X', 'Y', 'Z'). E.g. [['Z', 'Z'], ['X']] qubits_seq_seq: Sequence of sequences of integer qubit indices. E.g. [[0,1], [2]] @@ -179,7 +178,7 @@ def get_statetensor_to_sampled_expectation_func( check_hermitian(h, check_z_commutes=True) def statetensor_to_sampled_expectation_func( - statetensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int + statetensor: jax.Array, random_key: random.PRNGKeyArray, n_samps: int ) -> float: """ Maps statetensor to sampled expected value. @@ -201,7 +200,7 @@ def statetensor_to_sampled_expectation_func( def sample_probs( - measure_probs: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int + measure_probs: jax.Array, random_key: random.PRNGKeyArray, n_samps: int ): """ Generate an empirical distribution from a probability distribution. @@ -212,7 +211,7 @@ def sample_probs( n_samps: Number of samples contributing to empirical distribution. Returns: - Empirical distribution (jnp.ndarray). + Empirical distribution (jax.Array). """ measure_probs_flat = measure_probs.flatten() sampled_integers = random.choice( diff --git a/qujax/typing.py b/qujax/typing.py new file mode 100644 index 0000000..3d78369 --- /dev/null +++ b/qujax/typing.py @@ -0,0 +1,45 @@ +from typing import Union, Optional, Protocol, Callable, Iterable, Sequence + +# Backwards compatibility with Python <3.10 +from typing_extensions import TypeVarTuple, Unpack + +import jax +from jax.typing import ArrayLike + + +class PureParameterizedCircuit(Protocol): + def __call__( + self, params: ArrayLike, statetensor_in: Optional[jax.Array] = None + ) -> jax.Array: + ... + + +class PureUnparameterizedCircuit(Protocol): + def __call__(self, statetensor_in: Optional[jax.Array] = None) -> jax.Array: + ... + + +class MixedParameterizedCircuit(Protocol): + def __call__( + self, params: ArrayLike, densitytensor_in: Optional[jax.Array] = None + ) -> jax.Array: + ... + + +class MixedUnparameterizedCircuit(Protocol): + def __call__(self, densitytensor_in: Optional[jax.Array] = None) -> jax.Array: + ... + + +GateArgs = TypeVarTuple("GateArgs") +# Function that takes arbitrary nr. of parameters and returns an array representing the gate +# Currently Python does not allow us to restrict the type of the arguments using a TypeVarTuple +GateFunction = Callable[[Unpack[GateArgs]], jax.Array] +GateParameterIndices = Optional[Sequence[int]] + +PureCircuitFunction = Union[PureUnparameterizedCircuit, PureParameterizedCircuit] +MixedCircuitFunction = Union[MixedUnparameterizedCircuit, MixedParameterizedCircuit] + +Gate = Union[str, jax.Array, GateFunction] + +KrausOp = Union[Gate, Iterable[Gate]] diff --git a/qujax/utils.py b/qujax/utils.py index 0078c47..e76f390 100644 --- a/qujax/utils.py +++ b/qujax/utils.py @@ -2,40 +2,31 @@ import collections.abc from inspect import signature -from typing import Callable, Iterable, List, Optional, Protocol, Sequence, Tuple, Union +from typing import Callable, List, Optional, Sequence, Tuple, Union from warnings import warn import jax +from jax.typing import ArrayLike from jax import numpy as jnp from jax import random from qujax import gates -paulis = {"X": gates.X, "Y": gates.Y, "Z": gates.Z} - - -class CallableArrayAndOptionalArray(Protocol): - def __call__( - self, params: jnp.ndarray, statetensor_in: jnp.ndarray = None - ) -> jnp.ndarray: - ... - +from qujax.typing import ( + Gate, + KrausOp, + GateParameterIndices, + PureParameterizedCircuit, + MixedParameterizedCircuit, +) -class CallableOptionalArray(Protocol): - def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: - ... - - -UnionCallableOptionalArray = Union[CallableArrayAndOptionalArray, CallableOptionalArray] -Gate = Union[ - str, jnp.ndarray, Callable[[jnp.ndarray], jnp.ndarray], Callable[[], jnp.ndarray] -] -KrausOp = Union[Gate, Iterable[Gate]] +paulis = {"X": gates.X, "Y": gates.Y, "Z": gates.Z} -def check_unitary(gate: Gate): +def check_unitary(gate: Gate) -> None: """ - Checks whether a matrix or tensor is unitary. + Checks whether a qujax Gate is unitary. + Throws a TypeError if this is found not to be the case. Args: gate: array containing potentially unitary string, array @@ -54,7 +45,7 @@ def check_unitary(gate: Gate): if callable(gate): num_args = len(signature(gate).parameters) gate_arr = gate(*jnp.ones(num_args) * 0.1) - elif hasattr(gate, "__array__"): + elif isinstance(gate, jax.Array): gate_arr = gate else: raise TypeError( @@ -71,7 +62,7 @@ def check_unitary(gate: Gate): raise TypeError(f"Gate not unitary: {gate}") -def check_hermitian(hermitian: Union[str, jnp.ndarray], check_z_commutes: bool = False): +def check_hermitian(hermitian: Union[str, jax.Array], check_z_commutes: bool = False): """ Checks whether a matrix or tensor is Hermitian. @@ -109,8 +100,8 @@ def check_hermitian(hermitian: Union[str, jnp.ndarray], check_z_commutes: bool = def _arrayify_inds( - param_inds_seq: Sequence[Union[None, Sequence[int]]] -) -> Sequence[jnp.ndarray]: + param_inds_seq: Optional[Sequence[GateParameterIndices]], +) -> Sequence[jax.Array]: """ Ensure each element of param_inds_seq is an array (and therefore valid for jnp.take) @@ -125,19 +116,21 @@ def _arrayify_inds( """ if param_inds_seq is None: param_inds_seq = [None] - param_inds_seq = [jnp.array(p) for p in param_inds_seq] - param_inds_seq = [ + array_param_inds = [jnp.array(p) for p in param_inds_seq] + array_param_inds = [ jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) - for p in param_inds_seq + for p in array_param_inds ] - return param_inds_seq + return array_param_inds def check_circuit( gate_seq: Sequence[KrausOp], qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Sequence[int]], - n_qubits: int = None, + param_inds_seq: Sequence[ + Union[GateParameterIndices, Sequence[GateParameterIndices]] + ], + n_qubits: Optional[int] = None, check_unitaries: bool = True, ): """ @@ -208,7 +201,8 @@ def check_circuit( def _get_gate_str( - gate_obj: KrausOp, param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]] + gate_obj: KrausOp, + param_inds: Union[GateParameterIndices, Sequence[GateParameterIndices]], ) -> str: """ Maps single gate object to a four character string representation @@ -303,13 +297,15 @@ def extend_row(row: str, qubit_row: bool) -> str: def print_circuit( gate_seq: Sequence[KrausOp], qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[ + Union[GateParameterIndices, Sequence[GateParameterIndices]] + ], n_qubits: Optional[int] = None, - qubit_min: Optional[int] = 0, - qubit_max: Optional[int] = jnp.inf, - gate_ind_min: Optional[int] = 0, - gate_ind_max: Optional[int] = jnp.inf, - sep_length: Optional[int] = 1, + qubit_min: int = 0, + qubit_max: Optional[int] = None, + gate_ind_min: int = 0, + gate_ind_max: Optional[int] = None, + sep_length: int = 1, ) -> List[str]: """ Returns and prints basic string representation of circuit. @@ -338,13 +334,21 @@ def print_circuit( """ check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits, False) - gate_ind_max = min(len(gate_seq) - 1, gate_ind_max) + if gate_ind_max is None: + gate_ind_max = len(gate_seq) - 1 + else: + gate_ind_max = min(len(gate_seq) - 1, gate_ind_max) + if gate_ind_min > gate_ind_max: raise TypeError("gate_ind_max must be larger or equal to gate_ind_min") if n_qubits is None: n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 - qubit_max = min(n_qubits - 1, qubit_max) + + if qubit_max is None: + qubit_max = n_qubits - 1 + else: + qubit_max = min(n_qubits - 1, qubit_max) if qubit_min > qubit_max: raise TypeError("qubit_max must be larger or equal to qubit_min") @@ -392,8 +396,8 @@ def print_circuit( def integers_to_bitstrings( - integers: Union[int, jnp.ndarray], nbits: int = None -) -> jnp.ndarray: + integers: Union[int, jax.Array], nbits: Optional[int] = None +) -> jax.Array: """ Convert integer or array of integers into their binary expansion(s). @@ -406,15 +410,19 @@ def integers_to_bitstrings( Array of binary expansion(s). """ integers = jnp.atleast_1d(integers) + # Guarantee `bitstrings` has the right type for type-checking purposes + if not isinstance(integers, jax.Array): + raise ValueError("This should not happen. Please open an issue on GitHub.") + if nbits is None: - nbits = (jnp.ceil(jnp.log2(jnp.maximum(integers.max(), 1)) + 1e-5)).astype(int) + nbits = int(jnp.ceil(jnp.log2(jnp.maximum(integers.max(), 1)) + 1e-5).item()) return jnp.squeeze( ((integers[:, None] & (1 << jnp.arange(nbits - 1, -1, -1))) > 0).astype(int) ) -def bitstrings_to_integers(bitstrings: jnp.ndarray) -> Union[int, jnp.ndarray]: +def bitstrings_to_integers(bitstrings: ArrayLike) -> jax.Array: """ Convert binary expansion(s) into integers. @@ -425,15 +433,20 @@ def bitstrings_to_integers(bitstrings: jnp.ndarray) -> Union[int, jnp.ndarray]: Array of integers. """ bitstrings = jnp.atleast_2d(bitstrings) + + # Guarantee `bitstrings` has the right type for type-checking purposes + if not isinstance(bitstrings, jax.Array): + raise ValueError("This should not happen. Please open an issue on GitHub.") + convarr = 2 ** jnp.arange(bitstrings.shape[-1] - 1, -1, -1) return jnp.squeeze(bitstrings.dot(convarr)).astype(int) def sample_integers( random_key: random.PRNGKeyArray, - statetensor: jnp.ndarray, - n_samps: Optional[int] = 1, -) -> jnp.ndarray: + statetensor: jax.Array, + n_samps: int = 1, +) -> jax.Array: """ Generate random integer samples according to statetensor. @@ -455,9 +468,9 @@ def sample_integers( def sample_bitstrings( random_key: random.PRNGKeyArray, - statetensor: jnp.ndarray, - n_samps: Optional[int] = 1, -) -> jnp.ndarray: + statetensor: jax.Array, + n_samps: int = 1, +) -> jax.Array: """ Generate random bitstring samples according to statetensor. @@ -474,7 +487,7 @@ def sample_bitstrings( ) -def statetensor_to_densitytensor(statetensor: jnp.ndarray) -> jnp.ndarray: +def statetensor_to_densitytensor(statetensor: jax.Array) -> jax.Array: """ Computes a densitytensor representation of a pure quantum state from its statetensor representaton @@ -494,7 +507,8 @@ def statetensor_to_densitytensor(statetensor: jnp.ndarray) -> jnp.ndarray: def repeat_circuit( - circuit: Callable[[jax.Array, jax.Array], jax.Array], nr_of_parameters: int + circuit: Union[PureParameterizedCircuit, MixedParameterizedCircuit], + nr_of_parameters: int, ) -> Callable[[jax.Array, jax.Array], jax.Array]: """ Repeats circuit encoded by `circuit` an arbitrary number of times. diff --git a/setup.py b/setup.py index 851571d..37ecb4f 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ license="Apache 2", packages=find_packages(), python_requires=">=3.8", - install_requires=["jax", "jaxlib"], + install_requires=["jax>=0.4.1", "jaxlib", "typing_extensions"], classifiers=[ "Programming Language :: Python", "Intended Audience :: Developers", diff --git a/tests/test_densitytensor.py b/tests/test_densitytensor.py index 46ead45..7107629 100644 --- a/tests/test_densitytensor.py +++ b/tests/test_densitytensor.py @@ -289,7 +289,7 @@ def test_measure(): measured_dt = qujax.densitytensor_to_measured_densitytensor(dt, qubit_inds, 0) measured_dt_bits = qujax.densitytensor_to_measured_densitytensor( - dt, qubit_inds, (0,) * n_qubits + dt, qubit_inds, jnp.zeros(n_qubits) ) assert jnp.allclose(measured_dt_true, measured_dt) assert jnp.allclose(measured_dt_true, measured_dt_bits) diff --git a/tests/test_gates.py b/tests/test_gates.py index 7e65d00..af044c9 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -3,6 +3,7 @@ def test_gates(): for g_str, g in gates.__dict__.items(): - if g_str[0] != "_" and g_str != "jnp": + # Exclude elements in jax.gates namespace which are not gates + if g_str[0] != "_" and g_str not in ("jax", "jnp"): check_unitary(g_str) check_unitary(g)