Skip to content

Commit

Permalink
Merge pull request #34 from CQCL/develop
Browse files Browse the repository at this point in the history
Add more gates and check_unitary
  • Loading branch information
SamDuffield authored Sep 14, 2022
2 parents 871d340 + aedfc73 commit 3ed090f
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 24 deletions.
1 change: 1 addition & 0 deletions qujax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from qujax.observable import sample_integers
from qujax.observable import sample_bitstrings

from qujax.circuit_tools import check_unitary
from qujax.circuit_tools import check_circuit
from qujax.circuit_tools import print_circuit

Expand Down
10 changes: 3 additions & 7 deletions qujax/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,7 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]:
gate_seq_callable = []
for gate in gate_seq:
if isinstance(gate, str):
if gate in gates.__dict__:
gate = gates.__dict__[gate]
else:
raise KeyError(f'Gate string \'{gate}\' not found in qujax.gates '
f'- consider changing input to an array or callable')
gate = gates.__dict__[gate]

if callable(gate):
gate_func = gate
Expand All @@ -109,8 +105,8 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]:
gate = gate_arr.reshape((2,) * int(jnp.log2(gate_size)))
gate_func = _array_to_callable(gate)
else:
raise TypeError('Unsupported gate type'
'- gate must be either a string in qujax.gates, an array or callable')
raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or '
f'callable: {gate}')
gate_seq_callable.append(gate_func)

apply_gate_seq = [_get_apply_gate(g, q) for g, q in zip(gate_seq_callable, qubit_inds_seq)]
Expand Down
33 changes: 33 additions & 0 deletions qujax/circuit_tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
from __future__ import annotations
from typing import Sequence, Union, Callable, List, Tuple, Optional
import collections.abc
from inspect import signature

from jax import numpy as jnp

from qujax import gates


def check_unitary(gate: Union[str,
jnp.ndarray,
Callable[[jnp.ndarray], jnp.ndarray],
Callable[[], jnp.ndarray]]):
if isinstance(gate, str):
if gate in gates.__dict__:
gate = gates.__dict__[gate]
else:
raise KeyError(f'Gate string \'{gate}\' not found in qujax.gates '
f'- consider changing input to an array or callable')

if callable(gate):
num_args = len(signature(gate).parameters)
gate_arr = gate(*jnp.ones(num_args) * 0.1)
elif hasattr(gate, '__array__'):
gate_arr = gate
else:
raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or '
f'callable: {gate}')

gate_square_dim = int(jnp.sqrt(gate_arr.size))
gate_arr = gate_arr.reshape(gate_square_dim, gate_square_dim)

if jnp.any(jnp.abs(gate_arr @ jnp.conjugate(gate_arr).T - jnp.eye(gate_square_dim)) > 1e-3):
raise TypeError(f'Gate not unitary: {gate}')


def check_circuit(gate_seq: Sequence[Union[str,
jnp.ndarray,
Expand Down Expand Up @@ -45,6 +75,9 @@ def check_circuit(gate_seq: Sequence[Union[str,
if n_qubits is not None and n_qubits < max([max(qi) for qi in qubit_inds_seq]) + 1:
raise TypeError('n_qubits must be larger than largest qubit index in qubit_inds_seq')

for g in gate_seq:
check_unitary(g)


def _get_gate_str(gate_obj: Union[str,
jnp.ndarray,
Expand Down
135 changes: 119 additions & 16 deletions qujax/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

I = jnp.eye(2)

_0 = jnp.zeros((2, 2))

X = jnp.array([[0., 1.],
[1., 0.]])

Expand Down Expand Up @@ -38,20 +40,45 @@
SXdg = jnp.array([[1. - 1.j, 1. + 1.j],
[1. + 1.j, 1. - 1.j]]) / 2

CX = jnp.array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[0., 0., 1., 0.]]).reshape((2,) * 4)
CX = jnp.block([[I, _0],
[_0, X]]).reshape((2,) * 4)

CY = jnp.block([[I, _0],
[_0, Y]]).reshape((2,) * 4)

CZ = jnp.block([[I, _0],
[_0, Z]]).reshape((2,) * 4)

CH = jnp.block([[I, _0],
[_0, H]]).reshape((2,) * 4)

CV = jnp.block([[I, _0],
[_0, V]]).reshape((2,) * 4)

CVdg = jnp.block([[I, _0],
[_0, Vdg]]).reshape((2,) * 4)

CSX = jnp.block([[I, _0],
[_0, SX]]).reshape((2,) * 4)

CSXdg = jnp.block([[I, _0],
[_0, SXdg]]).reshape((2,) * 4)

CCX = jnp.block([[I, _0, _0, _0], # Toffoli gate
[_0, I, _0, _0],
[_0, _0, I, _0],
[_0, _0, _0, X]]).reshape((2,) * 6)

CY = jnp.array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., -1.j],
[0., 0., 1.j, 0.]]).reshape((2,) * 4)
ECR = jnp.block([[_0, Vdg],
[V, _0]]).reshape((2,) * 4)

CZ = jnp.array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., -1.]]).reshape((2,) * 4)
SWAP = jnp.array([[1., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1]])

CSWAP = jnp.block([[jnp.eye(4), jnp.zeros((4, 4))],
[jnp.zeros((4, 4)), SWAP]]).reshape((2,) * 6)


def Rx(param: float) -> jnp.ndarray:
Expand All @@ -69,14 +96,90 @@ def Rz(param: float) -> jnp.ndarray:
return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Z * 1.j


def CRx(param: float) -> jnp.ndarray:
return jnp.block([[I, _0],
[_0, Rx(param)]]).reshape((2,) * 4)


def CRy(param: float) -> jnp.ndarray:
return jnp.block([[I, _0],
[_0, Ry(param)]]).reshape((2,) * 4)


def CRz(param: float) -> jnp.ndarray:
return jnp.block([[I, _0],
[_0, Rz(param)]]).reshape((2,) * 4)


def U1(param: float) -> jnp.ndarray:
return U3(0, 0, param)


def U2(param1: float, param2: float) -> jnp.ndarray:
return U3(0.5, param1, param2)
def U2(param0: float, param1: float) -> jnp.ndarray:
return U3(0.5, param0, param1)


def U3(param0: float, param1: float, param2: float) -> jnp.ndarray:
return jnp.exp((param1 + param2) * jnp.pi * 1.j / 2) * Rz(param1) @ Ry(param0) @ Rz(param2)


def U3(param1: float, param2: float, param3: float) -> jnp.ndarray:
return jnp.exp((param2 + param3) * jnp.pi * 1.j / 2) * Rz(param2) @ Ry(param1) @ Rz(param3)
def CU1(param: float) -> jnp.ndarray:
return jnp.block([[I, _0],
[_0, U1(param)]]).reshape((2,) * 4)


def CU2(param0: float, param1: float) -> jnp.ndarray:
return jnp.block([[I, _0],
[_0, U2(param0, param1)]]).reshape((2,) * 4)


def CU3(param0: float, param1: float, param2: float) -> jnp.ndarray:
return jnp.block([[I, _0],
[_0, U3(param0, param1, param2)]]).reshape((2,) * 4)


def ISWAP(param: float) -> jnp.ndarray:
param_pi_2 = param * jnp.pi / 2
c = jnp.cos(param_pi_2)
i_s = 1.j * jnp.sin(param_pi_2)
return jnp.array([[1., 0., 0., 0.],
[0., c, i_s, 0.],
[0., i_s, c, 0.],
[0., 0., 0., 1.]]).reshape((2,) * 4)


def PhasedISWAP(param0: float, param1: float) -> jnp.ndarray:
param1_pi_2 = param1 * jnp.pi / 2
c = jnp.cos(param1_pi_2)
i_s = 1.j * jnp.sin(param1_pi_2)
return jnp.array([[1., 0., 0., 0.],
[0., c, i_s * jnp.exp(2.j * jnp.pi * param0), 0.],
[0., i_s * jnp.exp(-2.j * jnp.pi * param0), c, 0.],
[0., 0., 0., 1.]]).reshape((2,) * 4)


def XXPhase(param: float) -> jnp.ndarray:
param_pi_2 = param * jnp.pi / 2
c = jnp.cos(param_pi_2)
i_s = 1.j * jnp.sin(param_pi_2)
return jnp.array([[c, 0., 0., -i_s],
[0., c, -i_s, 0.],
[0., -i_s, c, 0.],
[-i_s, 0., 0., c]]).reshape((2,) * 4)


def YYPhase(param: float) -> jnp.ndarray:
param_pi_2 = param * jnp.pi / 2
c = jnp.cos(param_pi_2)
i_s = 1.j * jnp.sin(param_pi_2)
return jnp.array([[c, 0., 0., i_s],
[0., c, -i_s, 0.],
[0., -i_s, c, 0.],
[i_s, 0., 0., c]]).reshape((2,) * 4)


def ZZPhase(param: float) -> jnp.ndarray:
param_pi_2 = param * jnp.pi / 2
e_m = jnp.exp(-1.j * param_pi_2)
e_p = jnp.exp(1.j * param_pi_2)
return jnp.diag(jnp.array([e_m, e_p, e_p, e_m])).reshape((2,) * 4)
2 changes: 1 addition & 1 deletion qujax/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.7'
__version__ = '0.2.8'
9 changes: 9 additions & 0 deletions tests/test_gates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from qujax import gates
from qujax.circuit_tools import check_unitary


def test_gates():
for g_str, g in gates.__dict__.items():
if g_str[0] != '_' and g_str != 'jnp':
check_unitary(g_str)
check_unitary(g)

0 comments on commit 3ed090f

Please sign in to comment.