Skip to content

Commit

Permalink
Merge pull request #58 from CQCL/develop
Browse files Browse the repository at this point in the history
Add parameter to unitary support, add ZZMax and PhasedX gates, bugfix in print_circuit
  • Loading branch information
SamDuffield authored Dec 5, 2022
2 parents b1cbc04 + 1d85bb9 commit 67a5730
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 20 deletions.
6 changes: 6 additions & 0 deletions docs/get_params_to_unitarytensor_func.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
get_params_to_unitarytensor_func
=================================

.. autofunction:: qujax.get_params_to_unitarytensor_func


1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Docs

apply_gate
get_params_to_statetensor_func
get_params_to_unitarytensor_func
get_statetensor_to_expectation_func
get_statetensor_to_sampled_expectation_func
integers_to_bitstrings
Expand Down
4 changes: 3 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ In this directory, you can find a selection of notebooks demonstrating some simp
- [`qaoa.ipynb`](https://github.com/CQCL/qujax/blob/main/examples/qaoa.ipynb) - uses a problem inspired QAOA ansatz to find the ground state of a quantum Hamiltonian. Demonstrates how to encode more sophisticated parameters that control multiple gates.
- [`variational_inference.ipynb`](https://github.com/CQCL/qujax/blob/main/examples/variational_inference.ipynb) - uses a parameterised quantum circuit as a variational distribution to fit to a target probability mass function. Uses Adam via [`optax`](https://github.com/deepmind/optax) to minimise the KL divergence between circuit and target distributions.

The Heisenberg notebook with a `tk_to_qujax` implementation can be found in the [`pytket`](https://github.com/CQCL/pytket) repository at [`pytket-qujax_heisenberg_vqe.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_heisenberg_vqe.ipynb).
The [`pytket`](https://github.com/CQCL/pytket) repository also contains `tk_to_qujax` implementations for some of the above at [`pytket-qujax_classification.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax-classification.ipynb),
[`pytket-qujax_heisenberg_vqe.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_heisenberg_vqe.ipynb)
and [`pytket-qujax_qaoa.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_qaoa.ipynb).
1 change: 1 addition & 0 deletions qujax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from qujax.statetensor import apply_gate
from qujax.statetensor import get_params_to_statetensor_func
from qujax.statetensor import get_params_to_unitarytensor_func

from qujax.statetensor_observable import statetensor_to_single_expectation
from qujax.statetensor_observable import get_statetensor_to_expectation_func
Expand Down
2 changes: 1 addition & 1 deletion qujax/densitytensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[kraus_op_type],
param_inds_seq: Sequence of sequences representing parameter indices that gates are using,
i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter
(the float at position zero in the parameter vector/array), the second gate is not parameterised
and the third gates used the parameters at position five and two.
and the third gate uses the parameters at position five and two.
n_qubits: Number of qubits, if fixed.
Returns:
Expand Down
7 changes: 7 additions & 0 deletions qujax/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,10 @@ def ZZPhase(param: float) -> jnp.ndarray:
e_m = jnp.exp(-1.j * param_pi_2)
e_p = jnp.exp(1.j * param_pi_2)
return jnp.diag(jnp.array([e_m, e_p, e_p, e_m])).reshape((2,) * 4)


ZZMax = ZZPhase(0.5)


def PhasedX(param0: float, param1: float) -> jnp.ndarray:
return Rz(param1) @ Rx(param0) @ Rz(-param1)
41 changes: 40 additions & 1 deletion qujax/statetensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import Sequence, Union, Callable
from functools import partial
from jax import numpy as jnp

from qujax import gates
Expand Down Expand Up @@ -97,7 +98,7 @@ def get_params_to_statetensor_func(gate_seq: Sequence[gate_type],
param_inds_seq: Sequence of sequences representing parameter indices that gates are using,
i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter
(the float at position zero in the parameter vector/array), the second gate is not parameterised
and the third gates used the parameters at position five and two.
and the third gate uses the parameters at position five and two.
n_qubits: Number of qubits, if fixed.
Returns:
Expand Down Expand Up @@ -158,3 +159,41 @@ def no_params_to_statetensor_func(statetensor_in: jnp.ndarray = None) -> jnp.nda
return no_params_to_statetensor_func

return params_to_statetensor_func


def get_params_to_unitarytensor_func(gate_seq: Sequence[gate_type],
qubit_inds_seq: Sequence[Sequence[int]],
param_inds_seq: Sequence[Union[None, Sequence[int]]],
n_qubits: int = None)\
-> Union[Callable[[], jnp.ndarray], Callable[[jnp.ndarray], jnp.ndarray]]:
"""
Creates a function that maps circuit parameters to a unitarytensor.
The unitarytensor is an array with shape (2,) * 2 * n_qubits
representing the full unitary matrix of the circuit.
Args:
gate_seq: Sequence of gates.
Each element is either a string matching a unitary array or function in qujax.gates,
a custom unitary array or a custom function taking parameters and returning a unitary array.
Unitary arrays will be reshaped into tensor form (2, 2,...)
qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are acting on.
i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the zeroth qubit,
the second gate is a two qubit gate acting on the zeroth and first qubit etc.
param_inds_seq: Sequence of sequences representing parameter indices that gates are using,
i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter
(the float at position zero in the parameter vector/array), the second gate is not parameterised
and the third gate uses the parameters at position five and two.
n_qubits: Number of qubits, if fixed.
Returns:
Function which maps any parameters to a unitarytensor.
"""

if n_qubits is None:
n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1

param_to_st = get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits)
identity_unitarytensor = jnp.eye(2 ** n_qubits).reshape((2,) * 2 * n_qubits)
return partial(param_to_st, statetensor_in=identity_unitarytensor)

15 changes: 8 additions & 7 deletions qujax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,30 +284,31 @@ def print_circuit(gate_seq: Sequence[kraus_op_type],
if n_qubits_disp > 1:
for i in range(qubit_min + 1, qubit_max + 1):
rows += [' ', f'q{i}: '.ljust(3) + '-' * sep_length]
rows, qubits_free = _pad_rows(rows)
rows, rows_free = _pad_rows(rows)

for gate_ind in range(gate_ind_min, gate_ind_max + 1):
g = gate_str_seq[gate_ind]
qi = qubit_inds_seq[gate_ind]

qi_min = min(qi)
qi_max = max(qi)
ri_min = 2 * qi_min
ri_max = 2 * qi_max + 1

if not all([qubits_free[i] for i in range(qi_min, qi_max)]):
rows, qubits_free = _pad_rows(rows)
if not all([rows_free[i] for i in range(ri_min, ri_max)]):
rows, rows_free = _pad_rows(rows)

for row_ind in range(2 * qi_min, 2 * qi_max + 1):
for row_ind in range(ri_min, ri_max):
if row_ind == 2 * qi[-1]:
rows[row_ind] += '-' * sep_length + g
qubits_free[row_ind // 2] = False
elif row_ind % 2 == 1:
rows[row_ind] += ' ' * sep_length + ' ' + '|' + ' '
elif row_ind / 2 in qi:
rows[row_ind] += '-' * sep_length + '---' + '◯' + '---'
qubits_free[row_ind // 2] = False
else:
rows[row_ind] += '-' * sep_length + '---' + '|' + '---'
qubits_free[row_ind // 2] = False

rows_free[row_ind] = False

rows, _ = _pad_rows(rows)

Expand Down
2 changes: 1 addition & 1 deletion qujax/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.0'
__version__ = '0.3.1'
41 changes: 32 additions & 9 deletions tests/test_circuits.py → tests/test_statetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,38 @@ def test_H():
true_sv = jnp.array([0.70710678 + 0.j, 0.70710678 + 0.j])

assert st.size == true_sv.size
assert jnp.all(jnp.abs(st.flatten() - true_sv) < 1e-5)
assert jnp.all(jnp.abs(st_jit.flatten() - true_sv) < 1e-5)
assert jnp.allclose(st.flatten(), true_sv)
assert jnp.allclose(st_jit.flatten(), true_sv)

param_to_unitary = qujax.get_params_to_unitarytensor_func(gates, qubits, param_inds)
unitary = param_to_unitary().reshape(2, 2)
unitary_jit = jit(param_to_unitary)().reshape(2, 2)
zero_sv = jnp.zeros(2).at[0].set(1)
assert jnp.allclose(unitary @ zero_sv, true_sv)
assert jnp.allclose(unitary_jit @ zero_sv, true_sv)


def test_H_redundant_qubits():
gates = ['H']
qubits = [[0]]
param_inds = [[]]
n_qubits = 3

param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds, 3)
param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds, n_qubits)
st = param_to_st(statetensor_in=None)

true_sv = jnp.array([0.70710678, 0., 0., 0.,
0.70710678, 0., 0., 0.])

assert st.size == true_sv.size
assert jnp.all(jnp.abs(st.flatten() - true_sv) < 1e-5)
assert jnp.allclose(st.flatten(), true_sv)

param_to_unitary = qujax.get_params_to_unitarytensor_func(gates, qubits, param_inds, n_qubits)
unitary = param_to_unitary().reshape(2 ** n_qubits, 2 ** n_qubits)
unitary_jit = jit(param_to_unitary)().reshape(2 ** n_qubits, 2 ** n_qubits)
zero_sv = jnp.zeros(2 ** n_qubits).at[0].set(1)
assert jnp.allclose(unitary @ zero_sv, true_sv)
assert jnp.allclose(unitary_jit @ zero_sv, true_sv)


def test_CX_Rz_CY():
Expand All @@ -40,15 +55,24 @@ def test_CX_Rz_CY():
param_inds = [[], [], [], None, [0], []]

param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds)
st = param_to_st(jnp.array(0.1))
param = jnp.array(0.1)
st = param_to_st(param)

true_sv = jnp.array([0.34920055 - 0.05530793j, 0.34920055 - 0.05530793j,
0.05530793 - 0.34920055j, -0.05530793 + 0.34920055j,
0.34920055 - 0.05530793j, 0.34920055 - 0.05530793j,
0.05530793 - 0.34920055j, -0.05530793 + 0.34920055j], dtype='complex64')

assert st.size == true_sv.size
assert jnp.all(jnp.abs(st.flatten() - true_sv) < 1e-5)
assert jnp.allclose(st.flatten(), true_sv)

n_qubits = 3
param_to_unitary = qujax.get_params_to_unitarytensor_func(gates, qubits, param_inds, n_qubits)
unitary = param_to_unitary(param).reshape(2 ** n_qubits, 2 ** n_qubits)
unitary_jit = jit(param_to_unitary)(param).reshape(2 ** n_qubits, 2 ** n_qubits)
zero_sv = jnp.zeros(2 ** n_qubits).at[0].set(1)
assert jnp.allclose(unitary @ zero_sv, true_sv)
assert jnp.allclose(unitary_jit @ zero_sv, true_sv)


def test_stacked_circuits():
Expand All @@ -65,7 +89,6 @@ def test_stacked_circuits():

all_zeros_sv = jnp.array(jnp.arange(st2.size) == 0, dtype=int)

assert jnp.all(jnp.abs(st2.flatten() - all_zeros_sv) < 1e-5)
assert jnp.all(jnp.abs(st2_2.flatten() - all_zeros_sv) < 1e-5)

assert jnp.allclose(st2.flatten(), all_zeros_sv, atol=1e-7)
assert jnp.allclose(st2_2.flatten(), all_zeros_sv, atol=1e-7)

0 comments on commit 67a5730

Please sign in to comment.