Skip to content

Commit

Permalink
Improvements and fixes to type annotations (#114)
Browse files Browse the repository at this point in the history
* Update jnp.ndarray to jax.Array types
* Specify jax>=0.4.1 as package requirement
* Add missing `Optional` qualifier on some type annotations
* Update `test_gates.py` to reflect change in array type
* Refactor types into `typing.py`
* Introduce several new type aliases
* Fix several type warnings raised by `mypy` and `pylance`
* Add `typing_extensions` to package requirements
  • Loading branch information
gamatos authored Nov 16, 2023
1 parent 3832018 commit f0adb76
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 155 deletions.
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
autodoc_typehints = "description"

autodoc_type_aliases = {
"jnp.ndarray": "ndarray",
"random.PRNGKeyArray": "jax.random.PRNGKeyArray",
"UnionCallableOptionalArray": "Union[Callable[[ndarray, Optional[ndarray]], ndarray], "
"Callable[[Optional[ndarray]], ndarray]]",
Expand Down
3 changes: 2 additions & 1 deletion qujax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from qujax.densitytensor_observable import densitytensor_to_measurement_probabilities
from qujax.densitytensor_observable import densitytensor_to_measured_densitytensor

from qujax.utils import UnionCallableOptionalArray
from qujax.utils import check_unitary
from qujax.utils import check_hermitian
from qujax.utils import check_circuit
Expand All @@ -40,6 +39,8 @@
from qujax.utils import sample_bitstrings
from qujax.utils import statetensor_to_densitytensor

import qujax.typing

# pylint: disable=undefined-variable
del version
del statetensor
Expand Down
57 changes: 35 additions & 22 deletions qujax/densitytensor.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
from __future__ import annotations

from typing import Callable, Iterable, Sequence, Tuple, Union
from typing import Iterable, Sequence, Tuple, Union, Optional

import jax
from jax import numpy as jnp
from jax.typing import ArrayLike
from jax.lax import scan
from jax._src.dtypes import canonicalize_dtype
from jax._src.typing import DTypeLike
from qujax.statetensor import (
UnionCallableOptionalArray,
_arrayify_inds,
_gate_func_to_unitary,
_to_gate_func,
apply_gate,
)
from qujax.utils import KrausOp, check_circuit
from qujax.utils import check_circuit
from qujax.typing import (
MixedCircuitFunction,
KrausOp,
GateFunction,
GateParameterIndices,
)


def _kraus_single(
densitytensor: jnp.ndarray, array: jnp.ndarray, qubit_inds: Sequence[int]
) -> jnp.ndarray:
densitytensor: jax.Array, array: jax.Array, qubit_inds: Sequence[int]
) -> jax.Array:
r"""
Performs single Kraus operation
Expand All @@ -44,8 +50,8 @@ def _kraus_single(


def kraus(
densitytensor: jnp.ndarray, arrays: Iterable[jnp.ndarray], qubit_inds: Sequence[int]
) -> jnp.ndarray:
densitytensor: jax.Array, arrays: Iterable[jax.Array], qubit_inds: Sequence[int]
) -> jax.Array:
r"""
Performs Kraus operation.
Expand Down Expand Up @@ -78,8 +84,9 @@ def kraus(


def _to_kraus_operator_seq_funcs(
kraus_op: KrausOp, param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]]
) -> Tuple[Sequence[Callable[[jnp.ndarray], jnp.ndarray]], Sequence[jnp.ndarray]]:
kraus_op: KrausOp,
param_inds: Optional[Union[GateParameterIndices, Sequence[GateParameterIndices]]],
) -> Tuple[Sequence[GateFunction], Sequence[jax.Array]]:
"""
Ensures Kraus operators are a sequence of functions that map (possibly empty) parameters to
tensors and that each element of param_inds_seq is a sequence of arrays that correspond to the
Expand All @@ -96,20 +103,21 @@ def _to_kraus_operator_seq_funcs(
and sequence of arrays with parameter indices
"""
if param_inds is None:
param_inds = [None for _ in kraus_op]

if isinstance(kraus_op, (list, tuple)):
kraus_op_funcs = [_to_gate_func(ko) for ko in kraus_op]
else:
if param_inds is None:
param_inds = [None for _ in kraus_op]
elif isinstance(kraus_op, (str, jax.Array)) or callable(kraus_op):
kraus_op_funcs = [_to_gate_func(kraus_op)]
param_inds = [param_inds]
else:
raise ValueError(f"Invalid Kraus operator specification: {kraus_op}")
return kraus_op_funcs, _arrayify_inds(param_inds)


def partial_trace(
densitytensor: jnp.ndarray, indices_to_trace: Sequence[int]
) -> jnp.ndarray:
densitytensor: jax.Array, indices_to_trace: Sequence[int]
) -> jax.Array:
"""
Traces out (discards) specified qubits, resulting in a densitytensor
representing the mixed quantum state on the remaining qubits.
Expand Down Expand Up @@ -149,9 +157,11 @@ def all_zeros_densitytensor(n_qubits: int, dtype: DTypeLike = complex) -> jax.Ar
def get_params_to_densitytensor_func(
kraus_ops_seq: Sequence[KrausOp],
qubit_inds_seq: Sequence[Sequence[int]],
param_inds_seq: Sequence[Union[None, Sequence[int], Sequence[Sequence[int]]]],
n_qubits: int = None,
) -> UnionCallableOptionalArray:
param_inds_seq: Sequence[
Union[GateParameterIndices, Sequence[GateParameterIndices]]
],
n_qubits: Optional[int] = None,
) -> MixedCircuitFunction:
"""
Creates a function that maps circuit parameters to a density tensor (a density matrix in
tensor form).
Expand Down Expand Up @@ -195,8 +205,8 @@ def get_params_to_densitytensor_func(
param_inds_array_seq = [ko_pi[1] for ko_pi in kraus_ops_seq_callable_and_param_inds]

def params_to_densitytensor_func(
params: jnp.ndarray, densitytensor_in: jnp.ndarray = None
) -> jnp.ndarray:
params: ArrayLike, densitytensor_in: Optional[jax.Array] = None
) -> jax.Array:
"""
Applies parameterised circuit (series of gates) to a densitytensor_in
(default is |0>^N <0|^N).
Expand All @@ -216,6 +226,9 @@ def params_to_densitytensor_func(
else:
densitytensor = densitytensor_in
params = jnp.atleast_1d(params)
# Guarantee `params` has the right type for type-checking purposes
if not isinstance(params, jax.Array):
raise ValueError("This should not happen. Please open an issue on GitHub.")
for gate_func_single_seq, qubit_inds, param_inds_single_seq in zip(
kraus_ops_seq_callable, qubit_inds_seq, param_inds_array_seq
):
Expand All @@ -232,8 +245,8 @@ def params_to_densitytensor_func(
if non_parameterised:

def no_params_to_densitytensor_func(
densitytensor_in: jnp.ndarray = None,
) -> jnp.ndarray:
densitytensor_in: Optional[jax.Array] = None,
) -> jax.Array:
"""
Applies circuit (series of gates with no parameters) to a densitytensor_in
(default is |0>^N <0|^N).
Expand Down
34 changes: 18 additions & 16 deletions qujax/densitytensor_observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Callable, Sequence, Union

import jax
from jax.typing import ArrayLike
from jax import numpy as jnp
from jax import random

Expand All @@ -11,8 +13,8 @@


def densitytensor_to_single_expectation(
densitytensor: jnp.ndarray, hermitian: jnp.ndarray, qubit_inds: Sequence[int]
) -> float:
densitytensor: jax.Array, hermitian: jax.Array, qubit_inds: Sequence[int]
) -> jax.Array:
"""
Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form).
Expand All @@ -35,18 +37,18 @@ def densitytensor_to_single_expectation(


def get_densitytensor_to_expectation_func(
hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]],
hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]],
qubits_seq_seq: Sequence[Sequence[int]],
coefficients: Union[Sequence[float], jnp.ndarray],
) -> Callable[[jnp.ndarray], float]:
coefficients: Union[Sequence[float], jax.Array],
) -> Callable[[jax.Array], float]:
"""
Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and
a list of coefficients and returns a function that converts a densitytensor into an
expected value.
Args:
hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors.
Each Hermitian matrix is either represented by a tensor (jnp.ndarray)
Each Hermitian matrix is either represented by a tensor (jax.Array)
or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices.
E.g. [['Z', 'Z'], ['X']]
qubits_seq_seq: Sequence of sequences of integer qubit indices.
Expand All @@ -66,10 +68,10 @@ def get_densitytensor_to_expectation_func(


def get_densitytensor_to_sampled_expectation_func(
hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]],
hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]],
qubits_seq_seq: Sequence[Sequence[int]],
coefficients: Union[Sequence[float], jnp.ndarray],
) -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]:
coefficients: Union[Sequence[float], jax.Array],
) -> Callable[[jax.Array, random.PRNGKeyArray, int], float]:
"""
Converts strings (or arrays) representing Hermitian matrices, qubit indices and
coefficients into a function that converts a densitytensor into a sampled expected value.
Expand All @@ -85,7 +87,7 @@ def get_densitytensor_to_sampled_expectation_func(
Args:
hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors.
Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z').
Each Hermitian is either a tensor (jax.Array) or a string in ('X', 'Y', 'Z').
E.g. [['Z', 'Z'], ['X']]
qubits_seq_seq: Sequence of sequences of integer qubit indices.
E.g. [[0,1], [2]]
Expand All @@ -104,7 +106,7 @@ def get_densitytensor_to_sampled_expectation_func(
check_hermitian(h, check_z_commutes=True)

def densitytensor_to_sampled_expectation_func(
densitytensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int
densitytensor: jax.Array, random_key: random.PRNGKeyArray, n_samps: int
) -> float:
"""
Maps densitytensor to sampled expected value.
Expand All @@ -131,8 +133,8 @@ def densitytensor_to_sampled_expectation_func(


def densitytensor_to_measurement_probabilities(
densitytensor: jnp.ndarray, qubit_inds: Sequence[int]
) -> jnp.ndarray:
densitytensor: jax.Array, qubit_inds: Sequence[int]
) -> jax.Array:
"""
Extract array of measurement probabilities given a densitytensor and some qubit indices to
measure (in the computational basis).
Expand All @@ -157,10 +159,10 @@ def densitytensor_to_measurement_probabilities(


def densitytensor_to_measured_densitytensor(
densitytensor: jnp.ndarray,
densitytensor: jax.Array,
qubit_inds: Sequence[int],
measurement: Union[int, jnp.ndarray],
) -> jnp.ndarray:
measurement: ArrayLike,
) -> jax.Array:
"""
Returns the post-measurement densitytensor assuming that qubit_inds are measured
(in the computational basis) and the given measurement (integer or bitstring) is observed.
Expand Down
37 changes: 19 additions & 18 deletions qujax/gates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
from jax import numpy as jnp

I = jnp.eye(2)
Expand Down Expand Up @@ -64,42 +65,42 @@
)


def Rx(param: float) -> jnp.ndarray:
def Rx(param: float) -> jax.Array:
param_pi_2 = param * jnp.pi / 2
return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * X * 1.0j


def Ry(param: float) -> jnp.ndarray:
def Ry(param: float) -> jax.Array:
param_pi_2 = param * jnp.pi / 2
return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Y * 1.0j


def Rz(param: float) -> jnp.ndarray:
def Rz(param: float) -> jax.Array:
param_pi_2 = param * jnp.pi / 2
return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Z * 1.0j


def CRx(param: float) -> jnp.ndarray:
def CRx(param: float) -> jax.Array:
return jnp.block([[I, _0], [_0, Rx(param)]]).reshape((2,) * 4)


def CRy(param: float) -> jnp.ndarray:
def CRy(param: float) -> jax.Array:
return jnp.block([[I, _0], [_0, Ry(param)]]).reshape((2,) * 4)


def CRz(param: float) -> jnp.ndarray:
def CRz(param: float) -> jax.Array:
return jnp.block([[I, _0], [_0, Rz(param)]]).reshape((2,) * 4)


def U1(param: float) -> jnp.ndarray:
def U1(param: float) -> jax.Array:
return U3(0, 0, param)


def U2(param0: float, param1: float) -> jnp.ndarray:
def U2(param0: float, param1: float) -> jax.Array:
return U3(0.5, param0, param1)


def U3(param0: float, param1: float, param2: float) -> jnp.ndarray:
def U3(param0: float, param1: float, param2: float) -> jax.Array:
return (
jnp.exp((param1 + param2) * jnp.pi * 1.0j / 2)
* Rz(param1)
Expand All @@ -108,19 +109,19 @@ def U3(param0: float, param1: float, param2: float) -> jnp.ndarray:
)


def CU1(param: float) -> jnp.ndarray:
def CU1(param: float) -> jax.Array:
return jnp.block([[I, _0], [_0, U1(param)]]).reshape((2,) * 4)


def CU2(param0: float, param1: float) -> jnp.ndarray:
def CU2(param0: float, param1: float) -> jax.Array:
return jnp.block([[I, _0], [_0, U2(param0, param1)]]).reshape((2,) * 4)


def CU3(param0: float, param1: float, param2: float) -> jnp.ndarray:
def CU3(param0: float, param1: float, param2: float) -> jax.Array:
return jnp.block([[I, _0], [_0, U3(param0, param1, param2)]]).reshape((2,) * 4)


def ISWAP(param: float) -> jnp.ndarray:
def ISWAP(param: float) -> jax.Array:
param_pi_2 = param * jnp.pi / 2
c = jnp.cos(param_pi_2)
i_s = 1.0j * jnp.sin(param_pi_2)
Expand All @@ -134,7 +135,7 @@ def ISWAP(param: float) -> jnp.ndarray:
).reshape((2,) * 4)


def PhasedISWAP(param0: float, param1: float) -> jnp.ndarray:
def PhasedISWAP(param0: float, param1: float) -> jax.Array:
param1_pi_2 = param1 * jnp.pi / 2
c = jnp.cos(param1_pi_2)
i_s = 1.0j * jnp.sin(param1_pi_2)
Expand All @@ -148,7 +149,7 @@ def PhasedISWAP(param0: float, param1: float) -> jnp.ndarray:
).reshape((2,) * 4)


def XXPhase(param: float) -> jnp.ndarray:
def XXPhase(param: float) -> jax.Array:
param_pi_2 = param * jnp.pi / 2
c = jnp.cos(param_pi_2)
i_s = 1.0j * jnp.sin(param_pi_2)
Expand All @@ -162,7 +163,7 @@ def XXPhase(param: float) -> jnp.ndarray:
).reshape((2,) * 4)


def YYPhase(param: float) -> jnp.ndarray:
def YYPhase(param: float) -> jax.Array:
param_pi_2 = param * jnp.pi / 2
c = jnp.cos(param_pi_2)
i_s = 1.0j * jnp.sin(param_pi_2)
Expand All @@ -176,7 +177,7 @@ def YYPhase(param: float) -> jnp.ndarray:
).reshape((2,) * 4)


def ZZPhase(param: float) -> jnp.ndarray:
def ZZPhase(param: float) -> jax.Array:
param_pi_2 = param * jnp.pi / 2
e_m = jnp.exp(-1.0j * param_pi_2)
e_p = jnp.exp(1.0j * param_pi_2)
Expand All @@ -186,5 +187,5 @@ def ZZPhase(param: float) -> jnp.ndarray:
ZZMax = ZZPhase(0.5)


def PhasedX(param0: float, param1: float) -> jnp.ndarray:
def PhasedX(param0: float, param1: float) -> jax.Array:
return Rz(param1) @ Rx(param0) @ Rz(-param1)
Loading

0 comments on commit f0adb76

Please sign in to comment.