Skip to content

Commit

Permalink
Merge pull request #55 from CQCL/develop
Browse files Browse the repository at this point in the history
Update README and add classifier notebook
  • Loading branch information
SamDuffield authored Nov 15, 2022
2 parents f47cf3d + f26c87b commit b1cbc04
Show file tree
Hide file tree
Showing 3 changed files with 480 additions and 11 deletions.
59 changes: 48 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# qujax

Represent a (parameterised) quantum circuit as a pure [JAX](https://github.com/google/jax) function that
takes as input any parameters of the circuit and outputs a _statetensor_. The statetensor encodes all $2^N$ amplitudes
of the quantum state and can then be used downstream for exact expectations, gradients or sampling.
takes as input any parameters of the circuit and outputs either a _statetensor_ or a _densitytensor_ depending on
the choice of simulator.
- The statetensor encodes all $2^N$ amplitudes of the quantum state in a tensor version
of the statevector, for $N$ qubits.
- The densitytensor represents a tensor version of the
$2^N \times 2^N$ density matrix (allowing for mixed states and generic Kraus operators).

qujax also supports densitytensor simulations. A densitytensor is a tensor representation of the density matrix and allows for mixed states and generic Kraus operators.

A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support
for GPUs/TPUs.
Either representation can then be used downstream for exact expectations, gradients or sampling. A JAX implementation
of a quantum circuit is useful for runtime speedups, automatic differentiation, support for GPUs/TPUs and compatibility
with other JAX code and packages.

Some useful links:
- [Documentation](https://cqcl.github.io/qujax/api/)
Expand All @@ -21,7 +24,7 @@ Some useful links:
pip install qujax
```

## Parameterised quantum circuits with qujax
## Statetensor simulations with qujax
```python
from jax import numpy as jnp
import qujax
Expand Down Expand Up @@ -71,16 +74,39 @@ expectation_and_grad(jnp.array([0.1]))
# DeviceArray([-2.987832], dtype=float32))
```

## Densitytensor simulations with qujax
```python
param_to_dt = qujax.get_params_to_densitytensor_func(circuit_gates,
circuit_qubit_inds,
circuit_params_inds)
dt = param_to_dt(jnp.array([0.1]))
dt.shape
# (2, 2, 2, 2)
```
The densitytensor has shape ```(2,) * 2 * N``` and the density matrix can be obtained
with ```.reshape(2 * N, 2 * N)```.

Expectations can also be evaluated through the densitytensor

```python
dt_to_expectation = qujax.get_densitytensor_to_expectation_func([['Z']], [[0]], [1.])
dt_to_expectation(dt)
# DeviceArray(-0.3090171, dtype=float32)
```
Again everything is differentiable, jit-able and can be composed with other JAX code.



## Notes
+ We use the convention where parameters are given in units of π (i.e. in [0,2] rather than [0, 2π]).
+ By default the parameter to statetensor function initiates in the all 0 state, however there is an optional ```statetensor_in``` argument to initiate in an arbitrary state.
+ By default, the simulators are initiated in the all 0 state, however the optional ```statetensor_in```
or ```densitytensor_in``` argument can be used for arbitrary initialisations and combining circuits.


## pytket-qujax
You can also generate the parameter to statetensor function from a [`pytket`](https://cqcl.github.io/tket/pytket/api/)
circuit using the [`pytket-qujax`](https://github.com/CQCL/pytket-qujax) extension.
In particular, the
You can also generate the parameter to statetensor/densitytensor functions from
a [`pytket`](https://cqcl.github.io/tket/pytket/api/) circuit using the
[`pytket-qujax`](https://github.com/CQCL/pytket-qujax) extension. In particular, the
[`tk_to_qujax`](https://cqcl.github.io/pytket-qujax/api/api.html#pytket.extensions.qujax.qujax_convert.tk_to_qujax) and
[`tk_to_qujax_symbolic`](https://cqcl.github.io/pytket-qujax/api/api.html#pytket.extensions.qujax.qujax_convert.tk_to_qujax_symbolic)
functions.
Expand All @@ -99,3 +125,14 @@ Pull requests are welcomed!

New commits on [`develop`](https://github.com/CQCL/qujax/tree/develop) will then be merged into
[`main`](https://github.com/CQCL/qujax/tree/main) on the next release.


## Cite
```
@software{qujax2022,
author = {Samuel Duffield and Kirill Plekhanov and Gabriel Matos and Melf Johannsen},
title = {qujax: Simulating quantum circuits with JAX},
url = {https://github.com/CQCL/qujax},
year = {2022},
}
```
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

In this directory, you can find a selection of notebooks demonstrating some simple use cases of `qujax`

- [`classification.ipynb`](https://github.com/CQCL/qujax/blob/main/examples/classification.ipynb) - train a quantum circuit for binary classification using data re-uploading.
- [`generative_modelling.ipynb`](https://github.com/CQCL/qujax/blob/main/examples/generative_modelling.ipynb) - uses a parameterised quantum circuit as a generative model for a real life dataset. Trains via stochastic gradient Langevin dynamics on the maximum mean discrepancy between statetensor and dataset.
- [`heisenberg_vqe.ipynb`](https://github.com/CQCL/qujax/blob/main/examples/heisenberg_vqe.ipynb) - an implementation of the variational quantum eigensolver to find the ground state of a quantum Hamiltonian.
- [`maxcut_vqe.ipynb`](https://github.com/CQCL/qujax/blob/main/examples/maxcut_vqe.ipynb) - an implementation of the variational quantum eigensolver to solve a maxcut problem. Trains with Adam via [`optax`](https://github.com/deepmind/optax) and uses more realistic stochastic parameter shift gradients.
Expand Down
431 changes: 431 additions & 0 deletions examples/classification.ipynb

Large diffs are not rendered by default.

0 comments on commit b1cbc04

Please sign in to comment.