From 8b8aa7f146772f4b45a50592373580630e2d7508 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 20 Oct 2022 14:03:33 +0100 Subject: [PATCH 1/4] arrayify inds --- qujax/circuit.py | 75 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 24 deletions(-) diff --git a/qujax/circuit.py b/qujax/circuit.py index 1c4bd49..6e3e5ab 100644 --- a/qujax/circuit.py +++ b/qujax/circuit.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Sequence, Union, Callable, Protocol +from typing import Sequence, Union, Callable, Protocol, Tuple from jax import numpy as jnp @@ -57,38 +57,25 @@ def apply_gate(statetensor: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray: return apply_gate -def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]], - qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Sequence[int]], - n_qubits: int = None) -> UnionCallableOptionalArray: +def _to_gate_funcs(gate_seq: Sequence[Union[str, + jnp.ndarray, + Callable[[jnp.ndarray], jnp.ndarray], + Callable[[], jnp.ndarray]]])\ + -> Sequence[Callable[[jnp.ndarray], jnp.ndarray]]: """ - Creates a function that maps circuit parameters to a statetensor. + Ensures all gate_seq elements are functions that map (possibly empty) parameters + to a unitary tensor. Args: gate_seq: Sequence of gates. Each element is either a string matching an array or function in qujax.gates, a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) or a function taking parameters and returning gate unitary in tensor form. - qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. - param_inds_seq: Sequence of parameter indices that gates are using, - i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, - the second gate is not parameterised and the third gates used the fifth and second parameters. - n_qubits: Number of qubits, if fixed. Returns: - Function which maps parameters (and optional statetensor_in) to a statetensor. - If no parameters are found then the function only takes optional statetensor_in. + Sequence of gate parameter to unitary functions """ - - check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) - - if n_qubits is None: - n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 - def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: return lambda: arr @@ -109,9 +96,50 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: f'callable: {gate}') gate_seq_callable.append(gate_func) - apply_gate_seq = [_get_apply_gate(g, q) for g, q in zip(gate_seq_callable, qubit_inds_seq)] + return gate_seq_callable + + +def _arrayify_inds(param_inds_seq: Sequence[Sequence[int]]) -> Sequence[jnp.ndarray]: param_inds_seq = [jnp.array(p) for p in param_inds_seq] param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq] + return param_inds_seq + + +def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, + jnp.ndarray, + Callable[[jnp.ndarray], jnp.ndarray], + Callable[[], jnp.ndarray]]], + qubit_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Sequence[int]], + n_qubits: int = None) -> UnionCallableOptionalArray: + """ + Creates a function that maps circuit parameters to a statetensor. + + Args: + gate_seq: Sequence of gates. + Each element is either a string matching an array or function in qujax.gates, + a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) + or a function taking parameters and returning gate unitary in tensor form. + qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. + param_inds_seq: Sequence of parameter indices that gates are using, + i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, + the second gate is not parameterised and the third gates used the fifth and second parameters. + n_qubits: Number of qubits, if fixed. + + Returns: + Function which maps parameters (and optional statetensor_in) to a statetensor. + If no parameters are found then the function only takes optional statetensor_in. + + """ + + check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + if n_qubits is None: + n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 + + gate_seq_callable = _to_gate_funcs(gate_seq) + param_inds_seq = _arrayify_inds(param_inds_seq) + apply_gate_seq = [_get_apply_gate(g, q) for g, q in zip(gate_seq_callable, qubit_inds_seq)] def params_to_statetensor_func(params: jnp.ndarray, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: @@ -156,4 +184,3 @@ def no_params_to_statetensor_func(statetensor_in: jnp.ndarray = None) -> jnp.nda return no_params_to_statetensor_func return params_to_statetensor_func - From 5bbfb8c509a52282a22aa61d9551871a439a4f91 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 20 Oct 2022 14:31:54 +0100 Subject: [PATCH 2/4] refactor apply_gate --- qujax/__init__.py | 1 + qujax/circuit.py | 50 +++++++++++++++-------------------------------- qujax/version.py | 2 +- 3 files changed, 18 insertions(+), 35 deletions(-) diff --git a/qujax/__init__.py b/qujax/__init__.py index 141067b..9af7a26 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -3,6 +3,7 @@ from qujax import gates from qujax.circuit import UnionCallableOptionalArray +from qujax.circuit import apply_gate from qujax.circuit import get_params_to_statetensor_func from qujax.observable import get_statetensor_to_expectation_func diff --git a/qujax/circuit.py b/qujax/circuit.py index 6e3e5ab..8ad7e17 100644 --- a/qujax/circuit.py +++ b/qujax/circuit.py @@ -1,6 +1,5 @@ from __future__ import annotations -from typing import Sequence, Union, Callable, Protocol, Tuple - +from typing import Sequence, Union, Callable, Protocol from jax import numpy as jnp from qujax import gates @@ -20,41 +19,25 @@ def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: UnionCallableOptionalArray = Union[CallableArrayAndOptionalArray, CallableOptionalArray] -def _get_apply_gate(gate_func: Callable, - qubit_inds: Sequence[int]) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]: +def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray: """ - Creates a function that applies a given gate_func to given qubit_inds of a statetensor. + Applies gate to statetensor and returns updated statetensor. + Gate is represented by a unitary matrix (i.e. not parameterised). Args: - gate_func: Function that takes any gate parameters and returns the gate unitary (in tensor form). + statetensor: Input statetensor. + gate_unitary: Unitary array representing gate + must be in tensor form with shape (2,2,...) qubit_inds: Sequence of indices for gate to be applied to. len(qubit_inds) is equal to the dimension of the gate unitary tensor. Returns: - Function that takes statetensor and gate parameters, returns updated statetensor. - + Updated statetensor. """ - - def apply_gate(statetensor: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray: - """ - Applies {} gate to statetensor. - - Args: - statetensor: Input statetensor. - params: Gate parameters if any. - - Returns: - Updated statetensor. - """ - gate_unitary = gate_func(*params) - statetensor = jnp.tensordot(gate_unitary, statetensor, - axes=(list(range(-len(qubit_inds), 0)), qubit_inds)) - statetensor = jnp.moveaxis(statetensor, list(range(len(qubit_inds))), qubit_inds) - return statetensor - - apply_gate.__doc__ = apply_gate.__doc__.format(gate_func.__name__) - - return apply_gate + statetensor = jnp.tensordot(gate_unitary, statetensor, + axes=(list(range(-len(qubit_inds), 0)), qubit_inds)) + statetensor = jnp.moveaxis(statetensor, list(range(len(qubit_inds))), qubit_inds) + return statetensor def _to_gate_funcs(gate_seq: Sequence[Union[str, @@ -90,7 +73,7 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: gate_arr = jnp.array(gate) gate_size = gate_arr.size gate = gate_arr.reshape((2,) * int(jnp.log2(gate_size))) - gate_func = _array_to_callable(gate) + gate_func = _array_to_callable(jnp.array(gate)) else: raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' f'callable: {gate}') @@ -139,7 +122,6 @@ def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, gate_seq_callable = _to_gate_funcs(gate_seq) param_inds_seq = _arrayify_inds(param_inds_seq) - apply_gate_seq = [_get_apply_gate(g, q) for g, q in zip(gate_seq_callable, qubit_inds_seq)] def params_to_statetensor_func(params: jnp.ndarray, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: @@ -161,9 +143,9 @@ def params_to_statetensor_func(params: jnp.ndarray, else: statetensor = statetensor_in params = jnp.atleast_1d(params) - for gate_ind, apply_gate in enumerate(apply_gate_seq): - gate_params = jnp.take(params, param_inds_seq[gate_ind]) - statetensor = apply_gate(statetensor, gate_params) + for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_seq): + gate_params = jnp.take(params, param_inds) + statetensor = apply_gate(statetensor, gate_func(*gate_params), qubit_inds) return statetensor if all([pi.size == 0 for pi in param_inds_seq]): diff --git a/qujax/version.py b/qujax/version.py index 14e974f..cd9b137 100644 --- a/qujax/version.py +++ b/qujax/version.py @@ -1 +1 @@ -__version__ = '0.2.8' +__version__ = '0.2.9' From 44d9569ec3c89efbe2a708b06c4d05ec52dea10a Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 20 Oct 2022 15:08:54 +0100 Subject: [PATCH 3/4] reshape --- qujax/circuit.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/qujax/circuit.py b/qujax/circuit.py index 8ad7e17..7655519 100644 --- a/qujax/circuit.py +++ b/qujax/circuit.py @@ -29,7 +29,7 @@ def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: gate_unitary: Unitary array representing gate must be in tensor form with shape (2,2,...) qubit_inds: Sequence of indices for gate to be applied to. - len(qubit_inds) is equal to the dimension of the gate unitary tensor. + 2 * len(qubit_inds) is equal to the dimension of the gate unitary tensor. Returns: Updated statetensor. @@ -70,9 +70,6 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: if callable(gate): gate_func = gate elif hasattr(gate, '__array__'): - gate_arr = jnp.array(gate) - gate_size = gate_arr.size - gate = gate_arr.reshape((2,) * int(jnp.log2(gate_size))) gate_func = _array_to_callable(jnp.array(gate)) else: raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' @@ -100,13 +97,16 @@ def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, Args: gate_seq: Sequence of gates. - Each element is either a string matching an array or function in qujax.gates, - a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) - or a function taking parameters and returning gate unitary in tensor form. - qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. - param_inds_seq: Sequence of parameter indices that gates are using, - i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, - the second gate is not parameterised and the third gates used the fifth and second parameters. + Each element is either a string matching a unitary array or function in qujax.gates, + a custom unitary array or a custom function taking parameters and returning a unitary array. + Unitary arrays will be reshaped into tensor form (2, 2,...) + qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are acting on. + i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the zeroth qubit, + the second gate is a two qubit gate acting on the zeroth and first qubit etc. + param_inds_seq: Sequence of sequences representing parameter indices that gates are using, + i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter + (the float at position zero in the parameter vector/array), the second gate is not parameterised + and the third gates used the parameters at position five and two. n_qubits: Number of qubits, if fixed. Returns: @@ -145,7 +145,9 @@ def params_to_statetensor_func(params: jnp.ndarray, params = jnp.atleast_1d(params) for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_seq): gate_params = jnp.take(params, param_inds) - statetensor = apply_gate(statetensor, gate_func(*gate_params), qubit_inds) + gate_unitary = gate_func(*gate_params) + gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form + statetensor = apply_gate(statetensor, gate_unitary, qubit_inds) return statetensor if all([pi.size == 0 for pi in param_inds_seq]): From 2ba252a9da973150a69e27824f80707810b86a3c Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Thu, 20 Oct 2022 15:59:56 +0100 Subject: [PATCH 4/4] docs --- docs/apply_gate.rst | 5 +++++ docs/index.rst | 1 + qujax/circuit.py | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 docs/apply_gate.rst diff --git a/docs/apply_gate.rst b/docs/apply_gate.rst new file mode 100644 index 0000000..5947aa0 --- /dev/null +++ b/docs/apply_gate.rst @@ -0,0 +1,5 @@ +apply_gate +======================= + +.. autofunction:: qujax.apply_gate + diff --git a/docs/index.rst b/docs/index.rst index 9772edd..c52f038 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,6 +22,7 @@ Docs .. toctree:: + apply_gate get_params_to_statetensor_func get_statetensor_to_expectation_func get_statetensor_to_sampled_expectation_func diff --git a/qujax/circuit.py b/qujax/circuit.py index 7655519..259c6d5 100644 --- a/qujax/circuit.py +++ b/qujax/circuit.py @@ -27,7 +27,7 @@ def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Args: statetensor: Input statetensor. gate_unitary: Unitary array representing gate - must be in tensor form with shape (2,2,...) + must be in tensor form with shape (2,2,...). qubit_inds: Sequence of indices for gate to be applied to. 2 * len(qubit_inds) is equal to the dimension of the gate unitary tensor.