Skip to content

Commit

Permalink
Merge pull request #61 from CQCL/fix-print
Browse files Browse the repository at this point in the history
Bug fix in print_circuit, rename gate_type to Gate
  • Loading branch information
SamDuffield authored Jan 24, 2023
2 parents 1a52f20 + a3db326 commit cea2124
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 78 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
We now have a pure JAX function that generates the statetensor for given parameters
```python
param_to_st(jnp.array([0.1]))
# DeviceArray([[0.58778524+0.j, 0. +0.j],
# [0.80901706+0.j, 0. +0.j]], dtype=complex64)
# Array([[0.58778524+0.j, 0. +0.j],
# [0.80901706+0.j, 0. +0.j]], dtype=complex64)
```

The statevector can be obtained from the statetensor via ```.flatten()```.
```python
param_to_st(jnp.array([0.1])).flatten()
# DeviceArray([0.58778524+0.j, 0.+0.j, 0.80901706+0.j, 0.+0.j], dtype=complex64)
# Array([0.58778524+0.j, 0.+0.j, 0.80901706+0.j, 0.+0.j], dtype=complex64)
```

We can also use qujax to map the statetensor to an expected value
Expand All @@ -70,8 +70,8 @@ from jax import value_and_grad
param_to_expectation = lambda param: st_to_expectation(param_to_st(param))
expectation_and_grad = value_and_grad(param_to_expectation)
expectation_and_grad(jnp.array([0.1]))
# (DeviceArray(-0.3090171, dtype=float32),
# DeviceArray([-2.987832], dtype=float32))
# (Array(-0.3090171, dtype=float32),
# Array([-2.987832], dtype=float32))
```

## Densitytensor simulations with qujax
Expand All @@ -91,7 +91,7 @@ Expectations can also be evaluated through the densitytensor
```python
dt_to_expectation = qujax.get_densitytensor_to_expectation_func([['Z']], [[0]], [1.])
dt_to_expectation(dt)
# DeviceArray(-0.3090171, dtype=float32)
# Array(-0.3090171, dtype=float32)
```
Again everything is differentiable, jit-able and can be composed with other JAX code.

Expand Down
26 changes: 11 additions & 15 deletions examples/classification.ipynb

Large diffs are not rendered by default.

60 changes: 26 additions & 34 deletions examples/heisenberg_vqe.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions qujax/densitytensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from qujax.statetensor import apply_gate, UnionCallableOptionalArray
from qujax.statetensor import _to_gate_func, _arrayify_inds, _gate_func_to_unitary
from qujax.utils import check_circuit, kraus_op_type
from qujax.utils import check_circuit, KrausOp


def _kraus_single(densitytensor: jnp.ndarray,
Expand Down Expand Up @@ -61,7 +61,7 @@ def kraus(densitytensor: jnp.ndarray,
return new_densitytensor


def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type,
def _to_kraus_operator_seq_funcs(kraus_op: KrausOp,
param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]]) \
-> Tuple[Sequence[Callable[[jnp.ndarray], jnp.ndarray]],
Sequence[jnp.ndarray]]:
Expand All @@ -71,8 +71,8 @@ def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type,
of each Kraus operator.
Args:
kraus_op: Either a normal gate_type or a sequence of gate_types representing Kraus operators.
param_inds: If kraus_op is a normal gate_type then a sequence of parameter indices,
kraus_op: Either a normal Gate or a sequence of Gates representing Kraus operators.
param_inds: If kraus_op is a normal Gate then a sequence of parameter indices,
if kraus_op is a sequence of Kraus operators then a sequence of sequences of parameter indices
Returns:
Expand Down Expand Up @@ -113,7 +113,7 @@ def partial_trace(densitytensor: jnp.ndarray,
return densitytensor


def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[kraus_op_type],
def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[KrausOp],
qubit_inds_seq: Sequence[Sequence[int]],
param_inds_seq: Sequence[Union[None, Sequence[int], Sequence[Sequence[int]]]],
n_qubits: int = None) -> UnionCallableOptionalArray:
Expand Down
8 changes: 4 additions & 4 deletions qujax/statetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import numpy as jnp

from qujax import gates
from qujax.utils import check_circuit, _arrayify_inds, UnionCallableOptionalArray, gate_type
from qujax.utils import check_circuit, _arrayify_inds, UnionCallableOptionalArray, Gate


def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray:
Expand All @@ -28,7 +28,7 @@ def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds:
return statetensor


def _to_gate_func(gate: gate_type) -> Callable[[jnp.ndarray], jnp.ndarray]:
def _to_gate_func(gate: Gate) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""
Ensures a gate_seq element is a function that map (possibly empty) parameters
to a unitary tensor.
Expand Down Expand Up @@ -80,7 +80,7 @@ def _gate_func_to_unitary(gate_func: Callable[[jnp.ndarray], jnp.ndarray],
return gate_unitary


def get_params_to_statetensor_func(gate_seq: Sequence[gate_type],
def get_params_to_statetensor_func(gate_seq: Sequence[Gate],
qubit_inds_seq: Sequence[Sequence[int]],
param_inds_seq: Sequence[Union[None, Sequence[int]]],
n_qubits: int = None) -> UnionCallableOptionalArray:
Expand Down Expand Up @@ -161,7 +161,7 @@ def no_params_to_statetensor_func(statetensor_in: jnp.ndarray = None) -> jnp.nda
return params_to_statetensor_func


def get_params_to_unitarytensor_func(gate_seq: Sequence[gate_type],
def get_params_to_unitarytensor_func(gate_seq: Sequence[Gate],
qubit_inds_seq: Sequence[Sequence[int]],
param_inds_seq: Sequence[Union[None, Sequence[int]]],
n_qubits: int = None)\
Expand Down
25 changes: 12 additions & 13 deletions qujax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray:


UnionCallableOptionalArray = Union[CallableArrayAndOptionalArray, CallableOptionalArray]
gate_type = Union[str,
jnp.ndarray,
Callable[[jnp.ndarray], jnp.ndarray],
Callable[[], jnp.ndarray]]
kraus_op_type = Union[gate_type, Iterable[gate_type]]
Gate = Union[
str, jnp.ndarray, Callable[[jnp.ndarray], jnp.ndarray], Callable[[], jnp.ndarray]
]
KrausOp = Union[Gate, Iterable[Gate]]


def check_unitary(gate: gate_type):
def check_unitary(gate: Gate):
"""
Checks whether a matrix or tensor is unitary.
Expand Down Expand Up @@ -97,7 +96,7 @@ def _arrayify_inds(param_inds_seq: Sequence[Union[None, Sequence[int]]]) -> Sequ
return param_inds_seq


def check_circuit(gate_seq: Sequence[kraus_op_type],
def check_circuit(gate_seq: Sequence[KrausOp],
qubit_inds_seq: Sequence[Sequence[int]],
param_inds_seq: Sequence[Sequence[int]],
n_qubits: int = None,
Expand Down Expand Up @@ -143,7 +142,7 @@ def check_circuit(gate_seq: Sequence[kraus_op_type],
check_unitary(g)


def _get_gate_str(gate_obj: kraus_op_type,
def _get_gate_str(gate_obj: KrausOp,
param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]]) -> str:
"""
Maps single gate object to a four character string representation
Expand Down Expand Up @@ -230,7 +229,7 @@ def extend_row(row: str, qubit_row: bool) -> str:
return out_rows, [True] * len(rows)


def print_circuit(gate_seq: Sequence[kraus_op_type],
def print_circuit(gate_seq: Sequence[KrausOp],
qubit_inds_seq: Sequence[Sequence[int]],
param_inds_seq: Sequence[Sequence[int]],
n_qubits: Optional[int] = None,
Expand Down Expand Up @@ -266,7 +265,7 @@ def print_circuit(gate_seq: Sequence[kraus_op_type],
check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits, False)

gate_ind_max = min(len(gate_seq) - 1, gate_ind_max)
if gate_ind_max < gate_ind_min:
if gate_ind_min > gate_ind_max:
raise TypeError('gate_ind_max must be larger or equal to gate_ind_min')

if n_qubits is None:
Expand All @@ -292,13 +291,13 @@ def print_circuit(gate_seq: Sequence[kraus_op_type],

qi_min = min(qi)
qi_max = max(qi)
ri_min = 2 * qi_min
ri_max = 2 * qi_max + 1
ri_min = 2 * qi_min # index of top row used by gate
ri_max = 2 * qi_max # index of bottom row used by gate

if not all([rows_free[i] for i in range(ri_min, ri_max)]):
rows, rows_free = _pad_rows(rows)

for row_ind in range(ri_min, ri_max):
for row_ind in range(ri_min, ri_max + 1):
if row_ind == 2 * qi[-1]:
rows[row_ind] += '-' * sep_length + g
elif row_ind % 2 == 1:
Expand Down
2 changes: 1 addition & 1 deletion qujax/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.1'
__version__ = '0.3.2'

0 comments on commit cea2124

Please sign in to comment.