diff --git a/docs/apply_gate.rst b/docs/apply_gate.rst new file mode 100644 index 0000000..5947aa0 --- /dev/null +++ b/docs/apply_gate.rst @@ -0,0 +1,5 @@ +apply_gate +======================= + +.. autofunction:: qujax.apply_gate + diff --git a/docs/index.rst b/docs/index.rst index 9772edd..c52f038 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,6 +22,7 @@ Docs .. toctree:: + apply_gate get_params_to_statetensor_func get_statetensor_to_expectation_func get_statetensor_to_sampled_expectation_func diff --git a/qujax/__init__.py b/qujax/__init__.py index 141067b..9af7a26 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -3,6 +3,7 @@ from qujax import gates from qujax.circuit import UnionCallableOptionalArray +from qujax.circuit import apply_gate from qujax.circuit import get_params_to_statetensor_func from qujax.observable import get_statetensor_to_expectation_func diff --git a/qujax/circuit.py b/qujax/circuit.py index 1c4bd49..259c6d5 100644 --- a/qujax/circuit.py +++ b/qujax/circuit.py @@ -1,6 +1,5 @@ from __future__ import annotations from typing import Sequence, Union, Callable, Protocol - from jax import numpy as jnp from qujax import gates @@ -20,41 +19,70 @@ def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: UnionCallableOptionalArray = Union[CallableArrayAndOptionalArray, CallableOptionalArray] -def _get_apply_gate(gate_func: Callable, - qubit_inds: Sequence[int]) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]: +def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray: """ - Creates a function that applies a given gate_func to given qubit_inds of a statetensor. + Applies gate to statetensor and returns updated statetensor. + Gate is represented by a unitary matrix (i.e. not parameterised). Args: - gate_func: Function that takes any gate parameters and returns the gate unitary (in tensor form). + statetensor: Input statetensor. + gate_unitary: Unitary array representing gate + must be in tensor form with shape (2,2,...). qubit_inds: Sequence of indices for gate to be applied to. - len(qubit_inds) is equal to the dimension of the gate unitary tensor. + 2 * len(qubit_inds) is equal to the dimension of the gate unitary tensor. Returns: - Function that takes statetensor and gate parameters, returns updated statetensor. + Updated statetensor. + """ + statetensor = jnp.tensordot(gate_unitary, statetensor, + axes=(list(range(-len(qubit_inds), 0)), qubit_inds)) + statetensor = jnp.moveaxis(statetensor, list(range(len(qubit_inds))), qubit_inds) + return statetensor + +def _to_gate_funcs(gate_seq: Sequence[Union[str, + jnp.ndarray, + Callable[[jnp.ndarray], jnp.ndarray], + Callable[[], jnp.ndarray]]])\ + -> Sequence[Callable[[jnp.ndarray], jnp.ndarray]]: """ + Ensures all gate_seq elements are functions that map (possibly empty) parameters + to a unitary tensor. - def apply_gate(statetensor: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray: - """ - Applies {} gate to statetensor. + Args: + gate_seq: Sequence of gates. + Each element is either a string matching an array or function in qujax.gates, + a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) + or a function taking parameters and returning gate unitary in tensor form. - Args: - statetensor: Input statetensor. - params: Gate parameters if any. + Returns: + Sequence of gate parameter to unitary functions - Returns: - Updated statetensor. - """ - gate_unitary = gate_func(*params) - statetensor = jnp.tensordot(gate_unitary, statetensor, - axes=(list(range(-len(qubit_inds), 0)), qubit_inds)) - statetensor = jnp.moveaxis(statetensor, list(range(len(qubit_inds))), qubit_inds) - return statetensor + """ + def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: + return lambda: arr + + gate_seq_callable = [] + for gate in gate_seq: + if isinstance(gate, str): + gate = gates.__dict__[gate] + + if callable(gate): + gate_func = gate + elif hasattr(gate, '__array__'): + gate_func = _array_to_callable(jnp.array(gate)) + else: + raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' + f'callable: {gate}') + gate_seq_callable.append(gate_func) - apply_gate.__doc__ = apply_gate.__doc__.format(gate_func.__name__) + return gate_seq_callable - return apply_gate + +def _arrayify_inds(param_inds_seq: Sequence[Sequence[int]]) -> Sequence[jnp.ndarray]: + param_inds_seq = [jnp.array(p) for p in param_inds_seq] + param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq] + return param_inds_seq def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, @@ -69,13 +97,16 @@ def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, Args: gate_seq: Sequence of gates. - Each element is either a string matching an array or function in qujax.gates, - a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) - or a function taking parameters and returning gate unitary in tensor form. - qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. - param_inds_seq: Sequence of parameter indices that gates are using, - i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, - the second gate is not parameterised and the third gates used the fifth and second parameters. + Each element is either a string matching a unitary array or function in qujax.gates, + a custom unitary array or a custom function taking parameters and returning a unitary array. + Unitary arrays will be reshaped into tensor form (2, 2,...) + qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are acting on. + i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the zeroth qubit, + the second gate is a two qubit gate acting on the zeroth and first qubit etc. + param_inds_seq: Sequence of sequences representing parameter indices that gates are using, + i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter + (the float at position zero in the parameter vector/array), the second gate is not parameterised + and the third gates used the parameters at position five and two. n_qubits: Number of qubits, if fixed. Returns: @@ -89,29 +120,8 @@ def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, if n_qubits is None: n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 - def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: - return lambda: arr - - gate_seq_callable = [] - for gate in gate_seq: - if isinstance(gate, str): - gate = gates.__dict__[gate] - - if callable(gate): - gate_func = gate - elif hasattr(gate, '__array__'): - gate_arr = jnp.array(gate) - gate_size = gate_arr.size - gate = gate_arr.reshape((2,) * int(jnp.log2(gate_size))) - gate_func = _array_to_callable(gate) - else: - raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' - f'callable: {gate}') - gate_seq_callable.append(gate_func) - - apply_gate_seq = [_get_apply_gate(g, q) for g, q in zip(gate_seq_callable, qubit_inds_seq)] - param_inds_seq = [jnp.array(p) for p in param_inds_seq] - param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq] + gate_seq_callable = _to_gate_funcs(gate_seq) + param_inds_seq = _arrayify_inds(param_inds_seq) def params_to_statetensor_func(params: jnp.ndarray, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: @@ -133,9 +143,11 @@ def params_to_statetensor_func(params: jnp.ndarray, else: statetensor = statetensor_in params = jnp.atleast_1d(params) - for gate_ind, apply_gate in enumerate(apply_gate_seq): - gate_params = jnp.take(params, param_inds_seq[gate_ind]) - statetensor = apply_gate(statetensor, gate_params) + for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_seq): + gate_params = jnp.take(params, param_inds) + gate_unitary = gate_func(*gate_params) + gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form + statetensor = apply_gate(statetensor, gate_unitary, qubit_inds) return statetensor if all([pi.size == 0 for pi in param_inds_seq]): @@ -156,4 +168,3 @@ def no_params_to_statetensor_func(statetensor_in: jnp.ndarray = None) -> jnp.nda return no_params_to_statetensor_func return params_to_statetensor_func - diff --git a/qujax/version.py b/qujax/version.py index 14e974f..cd9b137 100644 --- a/qujax/version.py +++ b/qujax/version.py @@ -1 +1 @@ -__version__ = '0.2.8' +__version__ = '0.2.9'