Skip to content

Commit

Permalink
Merge pull request #38 from CQCL/develop
Browse files Browse the repository at this point in the history
Refactor circuit + add apply_gate function
  • Loading branch information
SamDuffield authored Oct 21, 2022
2 parents 8d82bb6 + a65ba88 commit dc18c4b
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 58 deletions.
5 changes: 5 additions & 0 deletions docs/apply_gate.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
apply_gate
=======================

.. autofunction:: qujax.apply_gate

1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Docs

.. toctree::

apply_gate
get_params_to_statetensor_func
get_statetensor_to_expectation_func
get_statetensor_to_sampled_expectation_func
Expand Down
1 change: 1 addition & 0 deletions qujax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from qujax import gates

from qujax.circuit import UnionCallableOptionalArray
from qujax.circuit import apply_gate
from qujax.circuit import get_params_to_statetensor_func

from qujax.observable import get_statetensor_to_expectation_func
Expand Down
125 changes: 68 additions & 57 deletions qujax/circuit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
from typing import Sequence, Union, Callable, Protocol

from jax import numpy as jnp

from qujax import gates
Expand All @@ -20,41 +19,70 @@ def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray:
UnionCallableOptionalArray = Union[CallableArrayAndOptionalArray, CallableOptionalArray]


def _get_apply_gate(gate_func: Callable,
qubit_inds: Sequence[int]) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray:
"""
Creates a function that applies a given gate_func to given qubit_inds of a statetensor.
Applies gate to statetensor and returns updated statetensor.
Gate is represented by a unitary matrix (i.e. not parameterised).
Args:
gate_func: Function that takes any gate parameters and returns the gate unitary (in tensor form).
statetensor: Input statetensor.
gate_unitary: Unitary array representing gate
must be in tensor form with shape (2,2,...).
qubit_inds: Sequence of indices for gate to be applied to.
len(qubit_inds) is equal to the dimension of the gate unitary tensor.
2 * len(qubit_inds) is equal to the dimension of the gate unitary tensor.
Returns:
Function that takes statetensor and gate parameters, returns updated statetensor.
Updated statetensor.
"""
statetensor = jnp.tensordot(gate_unitary, statetensor,
axes=(list(range(-len(qubit_inds), 0)), qubit_inds))
statetensor = jnp.moveaxis(statetensor, list(range(len(qubit_inds))), qubit_inds)
return statetensor


def _to_gate_funcs(gate_seq: Sequence[Union[str,
jnp.ndarray,
Callable[[jnp.ndarray], jnp.ndarray],
Callable[[], jnp.ndarray]]])\
-> Sequence[Callable[[jnp.ndarray], jnp.ndarray]]:
"""
Ensures all gate_seq elements are functions that map (possibly empty) parameters
to a unitary tensor.
def apply_gate(statetensor: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
"""
Applies {} gate to statetensor.
Args:
gate_seq: Sequence of gates.
Each element is either a string matching an array or function in qujax.gates,
a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) )
or a function taking parameters and returning gate unitary in tensor form.
Args:
statetensor: Input statetensor.
params: Gate parameters if any.
Returns:
Sequence of gate parameter to unitary functions
Returns:
Updated statetensor.
"""
gate_unitary = gate_func(*params)
statetensor = jnp.tensordot(gate_unitary, statetensor,
axes=(list(range(-len(qubit_inds), 0)), qubit_inds))
statetensor = jnp.moveaxis(statetensor, list(range(len(qubit_inds))), qubit_inds)
return statetensor
"""
def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]:
return lambda: arr

gate_seq_callable = []
for gate in gate_seq:
if isinstance(gate, str):
gate = gates.__dict__[gate]

if callable(gate):
gate_func = gate
elif hasattr(gate, '__array__'):
gate_func = _array_to_callable(jnp.array(gate))
else:
raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or '
f'callable: {gate}')
gate_seq_callable.append(gate_func)

apply_gate.__doc__ = apply_gate.__doc__.format(gate_func.__name__)
return gate_seq_callable

return apply_gate

def _arrayify_inds(param_inds_seq: Sequence[Sequence[int]]) -> Sequence[jnp.ndarray]:
param_inds_seq = [jnp.array(p) for p in param_inds_seq]
param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq]
return param_inds_seq


def get_params_to_statetensor_func(gate_seq: Sequence[Union[str,
Expand All @@ -69,13 +97,16 @@ def get_params_to_statetensor_func(gate_seq: Sequence[Union[str,
Args:
gate_seq: Sequence of gates.
Each element is either a string matching an array or function in qujax.gates,
a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) )
or a function taking parameters and returning gate unitary in tensor form.
qubit_inds_seq: Sequences of qubits (ints) that gates are acting on.
param_inds_seq: Sequence of parameter indices that gates are using,
i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter,
the second gate is not parameterised and the third gates used the fifth and second parameters.
Each element is either a string matching a unitary array or function in qujax.gates,
a custom unitary array or a custom function taking parameters and returning a unitary array.
Unitary arrays will be reshaped into tensor form (2, 2,...)
qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are acting on.
i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the zeroth qubit,
the second gate is a two qubit gate acting on the zeroth and first qubit etc.
param_inds_seq: Sequence of sequences representing parameter indices that gates are using,
i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter
(the float at position zero in the parameter vector/array), the second gate is not parameterised
and the third gates used the parameters at position five and two.
n_qubits: Number of qubits, if fixed.
Returns:
Expand All @@ -89,29 +120,8 @@ def get_params_to_statetensor_func(gate_seq: Sequence[Union[str,
if n_qubits is None:
n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1

def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]:
return lambda: arr

gate_seq_callable = []
for gate in gate_seq:
if isinstance(gate, str):
gate = gates.__dict__[gate]

if callable(gate):
gate_func = gate
elif hasattr(gate, '__array__'):
gate_arr = jnp.array(gate)
gate_size = gate_arr.size
gate = gate_arr.reshape((2,) * int(jnp.log2(gate_size)))
gate_func = _array_to_callable(gate)
else:
raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or '
f'callable: {gate}')
gate_seq_callable.append(gate_func)

apply_gate_seq = [_get_apply_gate(g, q) for g, q in zip(gate_seq_callable, qubit_inds_seq)]
param_inds_seq = [jnp.array(p) for p in param_inds_seq]
param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq]
gate_seq_callable = _to_gate_funcs(gate_seq)
param_inds_seq = _arrayify_inds(param_inds_seq)

def params_to_statetensor_func(params: jnp.ndarray,
statetensor_in: jnp.ndarray = None) -> jnp.ndarray:
Expand All @@ -133,9 +143,11 @@ def params_to_statetensor_func(params: jnp.ndarray,
else:
statetensor = statetensor_in
params = jnp.atleast_1d(params)
for gate_ind, apply_gate in enumerate(apply_gate_seq):
gate_params = jnp.take(params, param_inds_seq[gate_ind])
statetensor = apply_gate(statetensor, gate_params)
for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_seq):
gate_params = jnp.take(params, param_inds)
gate_unitary = gate_func(*gate_params)
gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form
statetensor = apply_gate(statetensor, gate_unitary, qubit_inds)
return statetensor

if all([pi.size == 0 for pi in param_inds_seq]):
Expand All @@ -156,4 +168,3 @@ def no_params_to_statetensor_func(statetensor_in: jnp.ndarray = None) -> jnp.nda
return no_params_to_statetensor_func

return params_to_statetensor_func

2 changes: 1 addition & 1 deletion qujax/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.8'
__version__ = '0.2.9'

0 comments on commit dc18c4b

Please sign in to comment.