diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..120c689 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" \ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..80a2b9c --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,33 @@ +name: Lint python projects + +on: + pull_request: + branches: + - main + - develop + push: + branches: + - main + - develop + +jobs: + lint: + + runs-on: ubuntu-22.04 + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.x + uses: actions/setup-python@v4 + with: + python-version: '3.x' + - name: Update pip + run: pip install --upgrade pip + - name: Install black and pylint + run: pip install black pylint + - name: Check files are formatted with black + run: | + black --check . + - name: Run pylint + run: | + pylint */ \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..02da98b --- /dev/null +++ b/.pylintrc @@ -0,0 +1,53 @@ +[MASTER] +output-format=colorized +disable=all +enable= + anomalous-backslash-in-string, + assert-on-tuple, + bad-indentation, + bad-option-value, + bad-reversed-sequence, + bad-super-call, + consider-merging-isinstance, + continue-in-finally, + dangerous-default-value, + duplicate-argument-name, + expression-not-assigned, + function-redefined, + inconsistent-mro, + init-is-generator, + line-too-long, + lost-exception, + missing-kwoa, + mixed-line-endings, + not-callable, + no-value-for-parameter, + nonexistent-operator, + not-in-loop, + pointless-statement, + redefined-builtin, + return-arg-in-generator, + return-in-init, + return-outside-function, + simplifiable-if-statement, + syntax-error, + too-many-function-args, + trailing-whitespace, + undefined-variable, + unexpected-keyword-arg, + unhashable-dict-key, + unnecessary-pass, + unreachable, + unrecognized-inline-option, + unused-import, + unnecessary-semicolon, + unused-variable, + unused-wildcard-import, + wildcard-import, + wrong-import-order, + wrong-import-position, + yield-outside-function + + +# Ignore long lines containing URLs or pylint. +ignore-long-lines=^(.*#\w*pylint: disable.*|\s*(# )??)$ diff --git a/README.md b/README.md index 7b21654..5a24488 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,8 @@ Pull requests are welcomed! 2. Add your code. 3. Add your tests. 4. Update the documentation if required. -5. Issue a pull request into [`develop`](https://github.com/CQCL/qujax/tree/develop). +5. Check the code lints (run `black . --check` and `pylint */`) +6. Issue a pull request into [`develop`](https://github.com/CQCL/qujax/tree/develop). New commits on [`develop`](https://github.com/CQCL/qujax/tree/develop) will then be merged into [`main`](https://github.com/CQCL/qujax/tree/main) on the next release. diff --git a/docs/conf.py b/docs/conf.py index 00bb712..6edd8db 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,42 +1,46 @@ import os import sys -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) # pylint: disable=wrong-import-position from qujax.version import __version__ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'qujax' -copyright = '2022, Sam Duffield' -author = 'Sam Duffield' +project = "qujax" +project_copyright = "2022, Sam Duffield" +author = "Sam Duffield" version = __version__ release = __version__ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = ['sphinx.ext.autodoc', 'sphinx_rtd_theme', 'sphinx.ext.napoleon', 'sphinx.ext.mathjax'] +extensions = [ + "sphinx.ext.autodoc", + "sphinx_rtd_theme", + "sphinx.ext.napoleon", + "sphinx.ext.mathjax", +] -templates_path = ['_templates'] -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" autodoc_typehints = "description" html4_writer = True autodoc_type_aliases = { - 'jnp.ndarray': 'ndarray', - 'random.PRNGKeyArray': 'jax.random.PRNGKeyArray', - 'UnionCallableOptionalArray': 'Union[Callable[[ndarray, Optional[ndarray]], ndarray], ' - 'Callable[[Optional[ndarray]], ndarray]]' - + "jnp.ndarray": "ndarray", + "random.PRNGKeyArray": "jax.random.PRNGKeyArray", + "UnionCallableOptionalArray": "Union[Callable[[ndarray, Optional[ndarray]], ndarray], " + "Callable[[Optional[ndarray]], ndarray]]", } -latex_engine = 'pdflatex' +latex_engine = "pdflatex" diff --git a/qujax/__init__.py b/qujax/__init__.py index 9f9522f..c273b63 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -1,3 +1,8 @@ +""" +Simulating quantum circuits with JAX +""" + + from qujax.version import __version__ from qujax import gates @@ -32,10 +37,10 @@ from qujax.utils import sample_bitstrings from qujax.utils import statetensor_to_densitytensor +# pylint: disable=undefined-variable del version del statetensor del statetensor_observable del densitytensor del densitytensor_observable del utils - diff --git a/qujax/densitytensor.py b/qujax/densitytensor.py index 1f2b8da..b355ae9 100644 --- a/qujax/densitytensor.py +++ b/qujax/densitytensor.py @@ -1,16 +1,23 @@ from __future__ import annotations -from typing import Sequence, Union, Callable, Iterable, Tuple + +from typing import Callable, Iterable, Sequence, Tuple, Union + from jax import numpy as jnp from jax.lax import scan -from qujax.statetensor import apply_gate, UnionCallableOptionalArray -from qujax.statetensor import _to_gate_func, _arrayify_inds, _gate_func_to_unitary -from qujax.utils import check_circuit, KrausOp +from qujax.statetensor import ( + UnionCallableOptionalArray, + _arrayify_inds, + _gate_func_to_unitary, + _to_gate_func, + apply_gate, +) +from qujax.utils import KrausOp, check_circuit -def _kraus_single(densitytensor: jnp.ndarray, - array: jnp.ndarray, - qubit_inds: Sequence[int]) -> jnp.ndarray: +def _kraus_single( + densitytensor: jnp.ndarray, array: jnp.ndarray, qubit_inds: Sequence[int] +) -> jnp.ndarray: r""" Performs single Kraus operation @@ -28,13 +35,15 @@ def _kraus_single(densitytensor: jnp.ndarray, """ n_qubits = densitytensor.ndim // 2 densitytensor = apply_gate(densitytensor, array, qubit_inds) - densitytensor = apply_gate(densitytensor, array.conj(), [n_qubits + i for i in qubit_inds]) + densitytensor = apply_gate( + densitytensor, array.conj(), [n_qubits + i for i in qubit_inds] + ) return densitytensor -def kraus(densitytensor: jnp.ndarray, - arrays: Iterable[jnp.ndarray], - qubit_inds: Sequence[int]) -> jnp.ndarray: +def kraus( + densitytensor: jnp.ndarray, arrays: Iterable[jnp.ndarray], qubit_inds: Sequence[int] +) -> jnp.ndarray: r""" Performs Kraus operation. @@ -55,25 +64,30 @@ def kraus(densitytensor: jnp.ndarray, # ensure first dimensions indexes different kraus operators arrays = arrays.reshape((arrays.shape[0],) + (2,) * 2 * len(qubit_inds)) - new_densitytensor, _ = scan(lambda dt, arr: (dt + _kraus_single(densitytensor, arr, qubit_inds), None), - init=jnp.zeros_like(densitytensor) * 0.j, xs=arrays) - # i.e. new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))(densitytensor, arrays, qubit_inds).sum(0) + new_densitytensor, _ = scan( + lambda dt, arr: (dt + _kraus_single(densitytensor, arr, qubit_inds), None), + init=jnp.zeros_like(densitytensor) * 0.0j, + xs=arrays, + ) + # new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))( + # densitytensor, arrays, qubit_inds + # ).sum(0) return new_densitytensor -def _to_kraus_operator_seq_funcs(kraus_op: KrausOp, - param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]]) \ - -> Tuple[Sequence[Callable[[jnp.ndarray], jnp.ndarray]], - Sequence[jnp.ndarray]]: +def _to_kraus_operator_seq_funcs( + kraus_op: KrausOp, param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]] +) -> Tuple[Sequence[Callable[[jnp.ndarray], jnp.ndarray]], Sequence[jnp.ndarray]]: """ - Ensures Kraus operators are a sequence of functions that map (possibly empty) parameters to tensors - and that each element of param_inds_seq is a sequence of arrays that correspond to the parameter indices - of each Kraus operator. + Ensures Kraus operators are a sequence of functions that map (possibly empty) parameters to + tensors and that each element of param_inds_seq is a sequence of arrays that correspond to the + parameter indices of each Kraus operator. Args: kraus_op: Either a normal Gate or a sequence of Gates representing Kraus operators. param_inds: If kraus_op is a normal Gate then a sequence of parameter indices, - if kraus_op is a sequence of Kraus operators then a sequence of sequences of parameter indices + if kraus_op is a sequence of Kraus operators then a sequence of sequences of + parameter indices Returns: Tuple containing sequence of functions mapping to Kraus operators @@ -91,8 +105,9 @@ def _to_kraus_operator_seq_funcs(kraus_op: KrausOp, return kraus_op_funcs, _arrayify_inds(param_inds) -def partial_trace(densitytensor: jnp.ndarray, - indices_to_trace: Sequence[int]) -> jnp.ndarray: +def partial_trace( + densitytensor: jnp.ndarray, indices_to_trace: Sequence[int] +) -> jnp.ndarray: """ Traces out (discards) specified qubits, resulting in a densitytensor representing the mixed quantum state on the remaining qubits. @@ -113,27 +128,32 @@ def partial_trace(densitytensor: jnp.ndarray, return densitytensor -def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[KrausOp], - qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Union[None, Sequence[int], Sequence[Sequence[int]]]], - n_qubits: int = None) -> UnionCallableOptionalArray: +def get_params_to_densitytensor_func( + kraus_ops_seq: Sequence[KrausOp], + qubit_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Union[None, Sequence[int], Sequence[Sequence[int]]]], + n_qubits: int = None, +) -> UnionCallableOptionalArray: """ - Creates a function that maps circuit parameters to a density tensor (a density matrix in tensor form). + Creates a function that maps circuit parameters to a density tensor (a density matrix in + tensor form). densitytensor = densitymatrix.reshape((2,) * 2 * n_qubits) densitymatrix = densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits) Args: kraus_ops_seq: Sequence of gates. Each element is either a string matching a unitary array or function in qujax.gates, - a custom unitary array or a custom function taking parameters and returning a unitary array. - Unitary arrays will be reshaped into tensor form (2, 2,...) - qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are acting on. - i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the zeroth qubit, - the second gate is a two qubit gate acting on the zeroth and first qubit etc. + a custom unitary array or a custom function taking parameters and returning a unitary + array. Unitary arrays will be reshaped into tensor form (2, 2,...) + qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are + acting on. + i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the + zeroth qubit, the second gate is a two qubit gate acting on the zeroth and + first qubit etc. param_inds_seq: Sequence of sequences representing parameter indices that gates are using, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter - (the float at position zero in the parameter vector/array), the second gate is not parameterised - and the third gate uses the parameters at position five and two. + (the float at position zero in the parameter vector/array), the second gate is not + parameterised and the third gate uses the parameters at position five and two. n_qubits: Number of qubits, if fixed. Returns: @@ -147,20 +167,27 @@ def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[KrausOp], if n_qubits is None: n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 - kraus_ops_seq_callable_and_param_inds = [_to_kraus_operator_seq_funcs(ko, param_inds) - for ko, param_inds in zip(kraus_ops_seq, param_inds_seq)] - kraus_ops_seq_callable = [ko_pi[0] for ko_pi in kraus_ops_seq_callable_and_param_inds] + kraus_ops_seq_callable_and_param_inds = [ + _to_kraus_operator_seq_funcs(ko, param_inds) + for ko, param_inds in zip(kraus_ops_seq, param_inds_seq) + ] + kraus_ops_seq_callable = [ + ko_pi[0] for ko_pi in kraus_ops_seq_callable_and_param_inds + ] param_inds_array_seq = [ko_pi[1] for ko_pi in kraus_ops_seq_callable_and_param_inds] - def params_to_densitytensor_func(params: jnp.ndarray, - densitytensor_in: jnp.ndarray = None) -> jnp.ndarray: + def params_to_densitytensor_func( + params: jnp.ndarray, densitytensor_in: jnp.ndarray = None + ) -> jnp.ndarray: """ - Applies parameterised circuit (series of gates) to a densitytensor_in (default is |0>^N <0|^N). + Applies parameterised circuit (series of gates) to a densitytensor_in + (default is |0>^N <0|^N). Args: params: Parameters of the circuit. densitytensor_in: Optional. Input densitytensor. - Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in [0]*(2*N) index). + Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in + the [0]*(2*N) index). Returns: Updated densitytensor. @@ -168,26 +195,36 @@ def params_to_densitytensor_func(params: jnp.ndarray, """ if densitytensor_in is None: densitytensor = jnp.zeros((2,) * 2 * n_qubits) - densitytensor = densitytensor.at[(0,) * 2 * n_qubits].set(1.) + densitytensor = densitytensor.at[(0,) * 2 * n_qubits].set(1.0) else: densitytensor = densitytensor_in params = jnp.atleast_1d(params) - for gate_func_single_seq, qubit_inds, param_inds_single_seq in zip(kraus_ops_seq_callable, qubit_inds_seq, - param_inds_array_seq): - kraus_operators = [_gate_func_to_unitary(gf, qubit_inds, pi, params) - for gf, pi in zip(gate_func_single_seq, param_inds_single_seq)] + for gate_func_single_seq, qubit_inds, param_inds_single_seq in zip( + kraus_ops_seq_callable, qubit_inds_seq, param_inds_array_seq + ): + kraus_operators = [ + _gate_func_to_unitary(gf, qubit_inds, pi, params) + for gf, pi in zip(gate_func_single_seq, param_inds_single_seq) + ] densitytensor = kraus(densitytensor, kraus_operators, qubit_inds) return densitytensor - non_parameterised = all([all([pi.size == 0 for pi in pi_seq]) for pi_seq in param_inds_array_seq]) + non_parameterised = all( + [all([pi.size == 0 for pi in pi_seq]) for pi_seq in param_inds_array_seq] + ) if non_parameterised: - def no_params_to_densitytensor_func(densitytensor_in: jnp.ndarray = None) -> jnp.ndarray: + + def no_params_to_densitytensor_func( + densitytensor_in: jnp.ndarray = None, + ) -> jnp.ndarray: """ - Applies circuit (series of gates with no parameters) to a densitytensor_in (default is |0>^N <0|^N). + Applies circuit (series of gates with no parameters) to a densitytensor_in + (default is |0>^N <0|^N). Args: densitytensor_in: Optional. Input densitytensor. - Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in [0]*(2*N) index). + Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in + the [0]*(2*N) index). Returns: Updated densitytensor. diff --git a/qujax/densitytensor_observable.py b/qujax/densitytensor_observable.py index 2de7079..a28179a 100644 --- a/qujax/densitytensor_observable.py +++ b/qujax/densitytensor_observable.py @@ -1,16 +1,23 @@ from __future__ import annotations -from typing import Sequence, Union, Callable -from jax import numpy as jnp, random + +from typing import Callable, Sequence, Union + +from jax import numpy as jnp +from jax import random from jax.lax import fori_loop from qujax.densitytensor import _kraus_single, partial_trace from qujax.statetensor_observable import _get_tensor_to_expectation_func -from qujax.utils import sample_integers, statetensor_to_densitytensor, bitstrings_to_integers +from qujax.utils import ( + bitstrings_to_integers, + sample_integers, + statetensor_to_densitytensor, +) -def densitytensor_to_single_expectation(densitytensor: jnp.ndarray, - hermitian: jnp.ndarray, - qubit_inds: Sequence[int]) -> float: +def densitytensor_to_single_expectation( + densitytensor: jnp.ndarray, hermitian: jnp.ndarray, qubit_inds: Sequence[int] +) -> float: """ Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). @@ -32,13 +39,15 @@ def densitytensor_to_single_expectation(densitytensor: jnp.ndarray, return jnp.einsum(densitytensor, dt_indices, hermitian, hermitian_indices).real -def get_densitytensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray], float]: +def get_densitytensor_to_expectation_func( + hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray], +) -> Callable[[jnp.ndarray], float]: """ Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and - a list of coefficients and returns a function that converts a densitytensor into an expected value. + a list of coefficients and returns a function that converts a densitytensor into an + expected value. Args: hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. @@ -53,14 +62,19 @@ def get_densitytensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[U Function that takes densitytensor and returns expected value (float). """ - return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, - densitytensor_to_single_expectation) + return _get_tensor_to_expectation_func( + hermitian_seq_seq, + qubits_seq_seq, + coefficients, + densitytensor_to_single_expectation, + ) -def get_densitytensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: +def get_densitytensor_to_sampled_expectation_func( + hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray], +) -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: """ Converts strings (or arrays) representing Hermitian matrices, qubit indices and coefficients into a function that converts a densitytensor into a sampled expected value. @@ -77,13 +91,13 @@ def get_densitytensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Se Function that takes densitytensor, random key and integer number of shots and returns sampled expected value (float). """ - densitytensor_to_expectation_func = get_densitytensor_to_expectation_func(hermitian_seq_seq, - qubits_seq_seq, - coefficients) + densitytensor_to_expectation_func = get_densitytensor_to_expectation_func( + hermitian_seq_seq, qubits_seq_seq, coefficients + ) - def densitytensor_to_sampled_expectation_func(statetensor: jnp.ndarray, - random_key: random.PRNGKeyArray, - n_samps: int) -> float: + def densitytensor_to_sampled_expectation_func( + statetensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int + ) -> float: """ Maps statetensor to sampled expected value. @@ -97,22 +111,28 @@ def densitytensor_to_sampled_expectation_func(statetensor: jnp.ndarray, """ sampled_integers = sample_integers(random_key, statetensor, n_samps) - sampled_probs = fori_loop(0, n_samps, - lambda i, sv: sv.at[sampled_integers[i]].add(1), - jnp.zeros(statetensor.size)) + sampled_probs = fori_loop( + 0, + n_samps, + lambda i, sv: sv.at[sampled_integers[i]].add(1), + jnp.zeros(statetensor.size), + ) sampled_probs /= n_samps - sampled_dt = statetensor_to_densitytensor(jnp.sqrt(sampled_probs).reshape(statetensor.shape)) + sampled_dt = statetensor_to_densitytensor( + jnp.sqrt(sampled_probs).reshape(statetensor.shape) + ) return densitytensor_to_expectation_func(sampled_dt) return densitytensor_to_sampled_expectation_func -def densitytensor_to_measurement_probabilities(densitytensor: jnp.ndarray, - qubit_inds: Sequence[int]) -> jnp.ndarray: +def densitytensor_to_measurement_probabilities( + densitytensor: jnp.ndarray, qubit_inds: Sequence[int] +) -> jnp.ndarray: """ - Extract array of measurement probabilities given a densitytensor and some qubit indices to measure - (in the computational basis). + Extract array of measurement probabilities given a densitytensor and some qubit indices to + measure (in the computational basis). I.e. the ith element of the array corresponds to the probability of observing the bitstring represented by the integer i on the measured qubits. @@ -126,13 +146,18 @@ def densitytensor_to_measurement_probabilities(densitytensor: jnp.ndarray, n_qubits = densitytensor.ndim // 2 n_qubits_measured = len(qubit_inds) qubit_inds_trace_out = [i for i in range(n_qubits) if i not in qubit_inds] - return jnp.diag(partial_trace(densitytensor, qubit_inds_trace_out).reshape(2 * n_qubits_measured, - 2 * n_qubits_measured)).real - - -def densitytensor_to_measured_densitytensor(densitytensor: jnp.ndarray, - qubit_inds: Sequence[int], - measurement: Union[int, jnp.ndarray]) -> jnp.ndarray: + return jnp.diag( + partial_trace(densitytensor, qubit_inds_trace_out).reshape( + 2 * n_qubits_measured, 2 * n_qubits_measured + ) + ).real + + +def densitytensor_to_measured_densitytensor( + densitytensor: jnp.ndarray, + qubit_inds: Sequence[int], + measurement: Union[int, jnp.ndarray], +) -> jnp.ndarray: """ Returns the post-measurement densitytensor assuming that qubit_inds are measured (in the computational basis) and the given measurement (integer or bitstring) is observed. @@ -146,12 +171,19 @@ def densitytensor_to_measured_densitytensor(densitytensor: jnp.ndarray, Post-measurement densitytensor (same shape as input densitytensor). """ measurement = jnp.array(measurement) - measured_int = bitstrings_to_integers(measurement) if measurement.ndim == 1 else measurement + measured_int = ( + bitstrings_to_integers(measurement) if measurement.ndim == 1 else measurement + ) n_qubits = densitytensor.ndim // 2 n_qubits_measured = len(qubit_inds) - qubit_inds_projector = jnp.diag(jnp.zeros(2 ** n_qubits_measured).at[measured_int].set(1)) \ - .reshape((2,) * 2 * n_qubits_measured) - unnorm_densitytensor = _kraus_single(densitytensor, qubit_inds_projector, qubit_inds) - norm_const = jnp.trace(unnorm_densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits)).real + qubit_inds_projector = jnp.diag( + jnp.zeros(2**n_qubits_measured).at[measured_int].set(1) + ).reshape((2,) * 2 * n_qubits_measured) + unnorm_densitytensor = _kraus_single( + densitytensor, qubit_inds_projector, qubit_inds + ) + norm_const = jnp.trace( + unnorm_densitytensor.reshape(2**n_qubits, 2**n_qubits) + ).real return unnorm_densitytensor / norm_const diff --git a/qujax/gates.py b/qujax/gates.py index 70c1e18..ae88fcf 100644 --- a/qujax/gates.py +++ b/qujax/gates.py @@ -4,111 +4,91 @@ _0 = jnp.zeros((2, 2)) -X = jnp.array([[0., 1.], - [1., 0.]]) +X = jnp.array([[0.0, 1.0], [1.0, 0.0]]) -Y = jnp.array([[0., -1.j], - [1.j, 0.]]) +Y = jnp.array([[0.0, -1.0j], [1.0j, 0.0]]) -Z = jnp.array([[1., 0.], - [0., -1.]]) +Z = jnp.array([[1.0, 0.0], [0.0, -1.0]]) -H = jnp.array([[1., 1.], - [1., -1]]) / jnp.sqrt(2) +H = jnp.array([[1.0, 1.0], [1.0, -1]]) / jnp.sqrt(2) -S = jnp.array([[1., 0.], - [0., 1.j]]) +S = jnp.array([[1.0, 0.0], [0.0, 1.0j]]) -Sdg = jnp.array([[1., 0.], - [0., -1.j]]) +Sdg = jnp.array([[1.0, 0.0], [0.0, -1.0j]]) -T = jnp.array([[1., 0.], - [0., jnp.exp(jnp.pi * 1.j / 4)]]) +T = jnp.array([[1.0, 0.0], [0.0, jnp.exp(jnp.pi * 1.0j / 4)]]) -Tdg = jnp.array([[1., 0.], - [0., jnp.exp(-jnp.pi * 1.j / 4)]]) +Tdg = jnp.array([[1.0, 0.0], [0.0, jnp.exp(-jnp.pi * 1.0j / 4)]]) -V = jnp.array([[1., -1.j], - [-1.j, 1.]]) / jnp.sqrt(2) +V = jnp.array([[1.0, -1.0j], [-1.0j, 1.0]]) / jnp.sqrt(2) -Vdg = jnp.array([[1., 1.j], - [1.j, 1.]]) / jnp.sqrt(2) +Vdg = jnp.array([[1.0, 1.0j], [1.0j, 1.0]]) / jnp.sqrt(2) -SX = jnp.array([[1. + 1.j, 1. - 1.j], - [1. - 1.j, 1. + 1.j]]) / 2 +SX = jnp.array([[1.0 + 1.0j, 1.0 - 1.0j], [1.0 - 1.0j, 1.0 + 1.0j]]) / 2 -SXdg = jnp.array([[1. - 1.j, 1. + 1.j], - [1. + 1.j, 1. - 1.j]]) / 2 +SXdg = jnp.array([[1.0 - 1.0j, 1.0 + 1.0j], [1.0 + 1.0j, 1.0 - 1.0j]]) / 2 -CX = jnp.block([[I, _0], - [_0, X]]).reshape((2,) * 4) +CX = jnp.block([[I, _0], [_0, X]]).reshape((2,) * 4) -CY = jnp.block([[I, _0], - [_0, Y]]).reshape((2,) * 4) +CY = jnp.block([[I, _0], [_0, Y]]).reshape((2,) * 4) -CZ = jnp.block([[I, _0], - [_0, Z]]).reshape((2,) * 4) +CZ = jnp.block([[I, _0], [_0, Z]]).reshape((2,) * 4) -CH = jnp.block([[I, _0], - [_0, H]]).reshape((2,) * 4) +CH = jnp.block([[I, _0], [_0, H]]).reshape((2,) * 4) -CV = jnp.block([[I, _0], - [_0, V]]).reshape((2,) * 4) +CV = jnp.block([[I, _0], [_0, V]]).reshape((2,) * 4) -CVdg = jnp.block([[I, _0], - [_0, Vdg]]).reshape((2,) * 4) +CVdg = jnp.block([[I, _0], [_0, Vdg]]).reshape((2,) * 4) -CSX = jnp.block([[I, _0], - [_0, SX]]).reshape((2,) * 4) +CSX = jnp.block([[I, _0], [_0, SX]]).reshape((2,) * 4) -CSXdg = jnp.block([[I, _0], - [_0, SXdg]]).reshape((2,) * 4) +CSXdg = jnp.block([[I, _0], [_0, SXdg]]).reshape((2,) * 4) -CCX = jnp.block([[I, _0, _0, _0], # Toffoli gate - [_0, I, _0, _0], - [_0, _0, I, _0], - [_0, _0, _0, X]]).reshape((2,) * 6) +CCX = jnp.block( + [[I, _0, _0, _0], [_0, I, _0, _0], [_0, _0, I, _0], [_0, _0, _0, X]] # Toffoli gate +).reshape((2,) * 6) -ECR = jnp.block([[_0, Vdg], - [V, _0]]).reshape((2,) * 4) +ECR = jnp.block([[_0, Vdg], [V, _0]]).reshape((2,) * 4) -SWAP = jnp.array([[1., 0., 0., 0.], - [0., 0., 1., 0.], - [0., 1., 0., 0.], - [0., 0., 0., 1]]) +SWAP = jnp.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1], + ] +) -CSWAP = jnp.block([[jnp.eye(4), jnp.zeros((4, 4))], - [jnp.zeros((4, 4)), SWAP]]).reshape((2,) * 6) +CSWAP = jnp.block([[jnp.eye(4), jnp.zeros((4, 4))], [jnp.zeros((4, 4)), SWAP]]).reshape( + (2,) * 6 +) def Rx(param: float) -> jnp.ndarray: param_pi_2 = param * jnp.pi / 2 - return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * X * 1.j + return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * X * 1.0j def Ry(param: float) -> jnp.ndarray: param_pi_2 = param * jnp.pi / 2 - return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Y * 1.j + return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Y * 1.0j def Rz(param: float) -> jnp.ndarray: param_pi_2 = param * jnp.pi / 2 - return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Z * 1.j + return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Z * 1.0j def CRx(param: float) -> jnp.ndarray: - return jnp.block([[I, _0], - [_0, Rx(param)]]).reshape((2,) * 4) + return jnp.block([[I, _0], [_0, Rx(param)]]).reshape((2,) * 4) def CRy(param: float) -> jnp.ndarray: - return jnp.block([[I, _0], - [_0, Ry(param)]]).reshape((2,) * 4) + return jnp.block([[I, _0], [_0, Ry(param)]]).reshape((2,) * 4) def CRz(param: float) -> jnp.ndarray: - return jnp.block([[I, _0], - [_0, Rz(param)]]).reshape((2,) * 4) + return jnp.block([[I, _0], [_0, Rz(param)]]).reshape((2,) * 4) def U1(param: float) -> jnp.ndarray: @@ -120,68 +100,86 @@ def U2(param0: float, param1: float) -> jnp.ndarray: def U3(param0: float, param1: float, param2: float) -> jnp.ndarray: - return jnp.exp((param1 + param2) * jnp.pi * 1.j / 2) * Rz(param1) @ Ry(param0) @ Rz(param2) + return ( + jnp.exp((param1 + param2) * jnp.pi * 1.0j / 2) + * Rz(param1) + @ Ry(param0) + @ Rz(param2) + ) def CU1(param: float) -> jnp.ndarray: - return jnp.block([[I, _0], - [_0, U1(param)]]).reshape((2,) * 4) + return jnp.block([[I, _0], [_0, U1(param)]]).reshape((2,) * 4) def CU2(param0: float, param1: float) -> jnp.ndarray: - return jnp.block([[I, _0], - [_0, U2(param0, param1)]]).reshape((2,) * 4) + return jnp.block([[I, _0], [_0, U2(param0, param1)]]).reshape((2,) * 4) def CU3(param0: float, param1: float, param2: float) -> jnp.ndarray: - return jnp.block([[I, _0], - [_0, U3(param0, param1, param2)]]).reshape((2,) * 4) + return jnp.block([[I, _0], [_0, U3(param0, param1, param2)]]).reshape((2,) * 4) def ISWAP(param: float) -> jnp.ndarray: param_pi_2 = param * jnp.pi / 2 c = jnp.cos(param_pi_2) - i_s = 1.j * jnp.sin(param_pi_2) - return jnp.array([[1., 0., 0., 0.], - [0., c, i_s, 0.], - [0., i_s, c, 0.], - [0., 0., 0., 1.]]).reshape((2,) * 4) + i_s = 1.0j * jnp.sin(param_pi_2) + return jnp.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, c, i_s, 0.0], + [0.0, i_s, c, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ).reshape((2,) * 4) def PhasedISWAP(param0: float, param1: float) -> jnp.ndarray: param1_pi_2 = param1 * jnp.pi / 2 c = jnp.cos(param1_pi_2) - i_s = 1.j * jnp.sin(param1_pi_2) - return jnp.array([[1., 0., 0., 0.], - [0., c, i_s * jnp.exp(2.j * jnp.pi * param0), 0.], - [0., i_s * jnp.exp(-2.j * jnp.pi * param0), c, 0.], - [0., 0., 0., 1.]]).reshape((2,) * 4) + i_s = 1.0j * jnp.sin(param1_pi_2) + return jnp.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, c, i_s * jnp.exp(2.0j * jnp.pi * param0), 0.0], + [0.0, i_s * jnp.exp(-2.0j * jnp.pi * param0), c, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ).reshape((2,) * 4) def XXPhase(param: float) -> jnp.ndarray: param_pi_2 = param * jnp.pi / 2 c = jnp.cos(param_pi_2) - i_s = 1.j * jnp.sin(param_pi_2) - return jnp.array([[c, 0., 0., -i_s], - [0., c, -i_s, 0.], - [0., -i_s, c, 0.], - [-i_s, 0., 0., c]]).reshape((2,) * 4) + i_s = 1.0j * jnp.sin(param_pi_2) + return jnp.array( + [ + [c, 0.0, 0.0, -i_s], + [0.0, c, -i_s, 0.0], + [0.0, -i_s, c, 0.0], + [-i_s, 0.0, 0.0, c], + ] + ).reshape((2,) * 4) def YYPhase(param: float) -> jnp.ndarray: param_pi_2 = param * jnp.pi / 2 c = jnp.cos(param_pi_2) - i_s = 1.j * jnp.sin(param_pi_2) - return jnp.array([[c, 0., 0., i_s], - [0., c, -i_s, 0.], - [0., -i_s, c, 0.], - [i_s, 0., 0., c]]).reshape((2,) * 4) + i_s = 1.0j * jnp.sin(param_pi_2) + return jnp.array( + [ + [c, 0.0, 0.0, i_s], + [0.0, c, -i_s, 0.0], + [0.0, -i_s, c, 0.0], + [i_s, 0.0, 0.0, c], + ] + ).reshape((2,) * 4) def ZZPhase(param: float) -> jnp.ndarray: param_pi_2 = param * jnp.pi / 2 - e_m = jnp.exp(-1.j * param_pi_2) - e_p = jnp.exp(1.j * param_pi_2) + e_m = jnp.exp(-1.0j * param_pi_2) + e_p = jnp.exp(1.0j * param_pi_2) return jnp.diag(jnp.array([e_m, e_p, e_p, e_m])).reshape((2,) * 4) diff --git a/qujax/statetensor.py b/qujax/statetensor.py index cb32710..a72c2ed 100644 --- a/qujax/statetensor.py +++ b/qujax/statetensor.py @@ -1,13 +1,17 @@ from __future__ import annotations -from typing import Sequence, Union, Callable + from functools import partial +from typing import Callable, Sequence, Union + from jax import numpy as jnp from qujax import gates -from qujax.utils import check_circuit, _arrayify_inds, UnionCallableOptionalArray, Gate +from qujax.utils import Gate, UnionCallableOptionalArray, _arrayify_inds, check_circuit -def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray: +def apply_gate( + statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int] +) -> jnp.ndarray: """ Applies gate to statetensor and returns updated statetensor. Gate is represented by a unitary matrix in tensor form. @@ -22,8 +26,9 @@ def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Returns: Updated statetensor. """ - statetensor = jnp.tensordot(gate_unitary, statetensor, - axes=(list(range(-len(qubit_inds), 0)), qubit_inds)) + statetensor = jnp.tensordot( + gate_unitary, statetensor, axes=(list(range(-len(qubit_inds), 0)), qubit_inds) + ) statetensor = jnp.moveaxis(statetensor, list(range(len(qubit_inds))), qubit_inds) return statetensor @@ -50,24 +55,29 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: if callable(gate): gate_func = gate - elif hasattr(gate, '__array__'): + elif hasattr(gate, "__array__"): gate_func = _array_to_callable(jnp.array(gate)) else: - raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' - f'callable: {gate}') + raise TypeError( + f"Unsupported gate type - gate must be either a string in qujax.gates, an array or " + f"callable: {gate}" + ) return gate_func -def _gate_func_to_unitary(gate_func: Callable[[jnp.ndarray], jnp.ndarray], - qubit_inds: Sequence[int], - param_inds: jnp.ndarray, - params: jnp.ndarray) -> jnp.ndarray: +def _gate_func_to_unitary( + gate_func: Callable[[jnp.ndarray], jnp.ndarray], + qubit_inds: Sequence[int], + param_inds: jnp.ndarray, + params: jnp.ndarray, +) -> jnp.ndarray: """ Extract gate unitary. Args: gate_func: Function that maps a (possibly empty) parameter array to a unitary tensor (array) - qubit_inds: Indices of qubits to apply gate to (only needed to ensure gate is in tensor form) + qubit_inds: Indices of qubits to apply gate to + (only needed to ensure gate is in tensor form) param_inds: Indices of full parameter to extract gate specific parameters params: Full parameter vector @@ -76,29 +86,35 @@ def _gate_func_to_unitary(gate_func: Callable[[jnp.ndarray], jnp.ndarray], """ gate_params = jnp.take(params, param_inds) gate_unitary = gate_func(*gate_params) - gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form + gate_unitary = gate_unitary.reshape( + (2,) * (2 * len(qubit_inds)) + ) # Ensure gate is in tensor form return gate_unitary -def get_params_to_statetensor_func(gate_seq: Sequence[Gate], - qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Union[None, Sequence[int]]], - n_qubits: int = None) -> UnionCallableOptionalArray: +def get_params_to_statetensor_func( + gate_seq: Sequence[Gate], + qubit_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Union[None, Sequence[int]]], + n_qubits: int = None, +) -> UnionCallableOptionalArray: """ Creates a function that maps circuit parameters to a statetensor. Args: gate_seq: Sequence of gates. Each element is either a string matching a unitary array or function in qujax.gates, - a custom unitary array or a custom function taking parameters and returning a unitary array. - Unitary arrays will be reshaped into tensor form (2, 2,...) - qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are acting on. - i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the zeroth qubit, - the second gate is a two qubit gate acting on the zeroth and first qubit etc. + a custom unitary array or a custom function taking parameters and returning a + unitary array. Unitary arrays will be reshaped into tensor form (2, 2,...) + qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are + acting on. + i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the + zeroth qubit, the second gate is a two qubit gate acting on the zeroth and first qubit + etc. param_inds_seq: Sequence of sequences representing parameter indices that gates are using, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter - (the float at position zero in the parameter vector/array), the second gate is not parameterised - and the third gate uses the parameters at position five and two. + (the float at position zero in the parameter vector/array), the second gate is not + parameterised and the third gate uses the parameters at position five and two. n_qubits: Number of qubits, if fixed. Returns: @@ -115,8 +131,9 @@ def get_params_to_statetensor_func(gate_seq: Sequence[Gate], gate_seq_callable = [_to_gate_func(g) for g in gate_seq] param_inds_array_seq = _arrayify_inds(param_inds_seq) - def params_to_statetensor_func(params: jnp.ndarray, - statetensor_in: jnp.ndarray = None) -> jnp.ndarray: + def params_to_statetensor_func( + params: jnp.ndarray, statetensor_in: jnp.ndarray = None + ) -> jnp.ndarray: """ Applies parameterised circuit (series of gates) to a statetensor_in (default is |0>^N). @@ -131,24 +148,33 @@ def params_to_statetensor_func(params: jnp.ndarray, """ if statetensor_in is None: statetensor = jnp.zeros((2,) * n_qubits) - statetensor = statetensor.at[(0,) * n_qubits].set(1.) + statetensor = statetensor.at[(0,) * n_qubits].set(1.0) else: statetensor = statetensor_in params = jnp.atleast_1d(params) - for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_array_seq): - gate_unitary = _gate_func_to_unitary(gate_func, qubit_inds, param_inds, params) + for gate_func, qubit_inds, param_inds in zip( + gate_seq_callable, qubit_inds_seq, param_inds_array_seq + ): + gate_unitary = _gate_func_to_unitary( + gate_func, qubit_inds, param_inds, params + ) statetensor = apply_gate(statetensor, gate_unitary, qubit_inds) return statetensor non_parameterised = all([pi.size == 0 for pi in param_inds_array_seq]) if non_parameterised: - def no_params_to_statetensor_func(statetensor_in: jnp.ndarray = None) -> jnp.ndarray: + + def no_params_to_statetensor_func( + statetensor_in: jnp.ndarray = None, + ) -> jnp.ndarray: """ - Applies circuit (series of gates with no parameters) to a statetensor_in (default is |0>^N). + Applies circuit (series of gates with no parameters) to a statetensor_in + (default is |0>^N). Args: statetensor_in: Optional. Input statetensor. - Defaults to |0>^N (tensor of size 2^n with all zeroes except one in [0]*N index). + Defaults to |0>^N (tensor of size 2^n with all zeroes except one in + the [0]*N index). Returns: Updated statetensor. @@ -161,11 +187,12 @@ def no_params_to_statetensor_func(statetensor_in: jnp.ndarray = None) -> jnp.nda return params_to_statetensor_func -def get_params_to_unitarytensor_func(gate_seq: Sequence[Gate], - qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Union[None, Sequence[int]]], - n_qubits: int = None)\ - -> Union[Callable[[], jnp.ndarray], Callable[[jnp.ndarray], jnp.ndarray]]: +def get_params_to_unitarytensor_func( + gate_seq: Sequence[Gate], + qubit_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Union[None, Sequence[int]]], + n_qubits: int = None, +) -> Union[Callable[[], jnp.ndarray], Callable[[jnp.ndarray], jnp.ndarray]]: """ Creates a function that maps circuit parameters to a unitarytensor. The unitarytensor is an array with shape (2,) * 2 * n_qubits @@ -174,15 +201,17 @@ def get_params_to_unitarytensor_func(gate_seq: Sequence[Gate], Args: gate_seq: Sequence of gates. Each element is either a string matching a unitary array or function in qujax.gates, - a custom unitary array or a custom function taking parameters and returning a unitary array. - Unitary arrays will be reshaped into tensor form (2, 2,...) - qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are acting on. - i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the zeroth qubit, - the second gate is a two qubit gate acting on the zeroth and first qubit etc. + a custom unitary array or a custom function taking parameters and returning a unitary + array. Unitary arrays will be reshaped into tensor form (2, 2,...) + qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are + acting on. + i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the + zeroth qubit, the second gate is a two qubit gate acting on the zeroth and first qubit + etc. param_inds_seq: Sequence of sequences representing parameter indices that gates are using, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter - (the float at position zero in the parameter vector/array), the second gate is not parameterised - and the third gate uses the parameters at position five and two. + (the float at position zero in the parameter vector/array), the second gate is not + parameterised and the third gate uses the parameters at position five and two. n_qubits: Number of qubits, if fixed. Returns: @@ -193,7 +222,8 @@ def get_params_to_unitarytensor_func(gate_seq: Sequence[Gate], if n_qubits is None: n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 - param_to_st = get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) - identity_unitarytensor = jnp.eye(2 ** n_qubits).reshape((2,) * 2 * n_qubits) + param_to_st = get_params_to_statetensor_func( + gate_seq, qubit_inds_seq, param_inds_seq, n_qubits + ) + identity_unitarytensor = jnp.eye(2**n_qubits).reshape((2,) * 2 * n_qubits) return partial(param_to_st, statetensor_in=identity_unitarytensor) - diff --git a/qujax/statetensor_observable.py b/qujax/statetensor_observable.py index ab1a08f..44ada59 100644 --- a/qujax/statetensor_observable.py +++ b/qujax/statetensor_observable.py @@ -1,15 +1,18 @@ from __future__ import annotations -from typing import Sequence, Callable, Union -from jax import numpy as jnp, random + +from typing import Callable, Sequence, Union + +from jax import numpy as jnp +from jax import random from jax.lax import fori_loop from qujax.statetensor import apply_gate -from qujax.utils import check_hermitian, sample_integers, paulis +from qujax.utils import check_hermitian, paulis, sample_integers -def statetensor_to_single_expectation(statetensor: jnp.ndarray, - hermitian: jnp.ndarray, - qubit_inds: Sequence[int]) -> float: +def statetensor_to_single_expectation( + statetensor: jnp.ndarray, hermitian: jnp.ndarray, qubit_inds: Sequence[int] +) -> float: """ Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). @@ -25,13 +28,17 @@ def statetensor_to_single_expectation(statetensor: jnp.ndarray, """ statetensor_new = apply_gate(statetensor, hermitian, qubit_inds) axes = tuple(range(statetensor.ndim)) - return jnp.tensordot(statetensor.conjugate(), statetensor_new, axes=(axes, axes)).real + return jnp.tensordot( + statetensor.conjugate(), statetensor_new, axes=(axes, axes) + ).real -def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jnp.ndarray: +def get_hermitian_tensor( + hermitian_seq: Sequence[Union[str, jnp.ndarray]] +) -> jnp.ndarray: """ - Convert a sequence of observables represented by Pauli strings or Hermitian matrices in tensor form - into single array (in tensor form). + Convert a sequence of observables represented by Pauli strings or Hermitian matrices + in tensor form into single array (in tensor form). Args: hermitian_seq: Sequence of Hermitian strings or arrays. @@ -43,7 +50,9 @@ def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jn check_hermitian(h) single_arrs = [paulis[h] if isinstance(h, str) else h for h in hermitian_seq] - single_arrs = [h_arr.reshape((2,) * int(jnp.log2(h_arr.size))) for h_arr in single_arrs] + single_arrs = [ + h_arr.reshape((2,) * int(jnp.log2(h_arr.size))) for h_arr in single_arrs + ] full_mat = single_arrs[0] for single_matrix in single_arrs[1:]: @@ -52,20 +61,22 @@ def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jn return full_mat -def _get_tensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray], - contraction_function: Callable) \ - -> Callable[[jnp.ndarray], float]: +def _get_tensor_to_expectation_func( + hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray], + contraction_function: Callable, +) -> Callable[[jnp.ndarray], float]: """ Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and a list of coefficients and returns a function that converts a tensor into an expected value. - The contraction function performs the tensor contraction according to the type of tensor provided - (i.e. whether it is a statetensor or a densitytensor). + The contraction function performs the tensor contraction according to the type of tensor + provided (i.e. whether it is a statetensor or a densitytensor). Args: hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a + list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. E.g. [['Z', 'Z'], ['X']] qubits_seq_seq: Sequence of sequences of integer qubit indices. E.g. [[0,1], [2]] @@ -89,20 +100,24 @@ def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: Expected value (float). """ out = 0 - for hermitian, qubit_inds, coeff in zip(hermitian_tensors, qubits_seq_seq, coefficients): + for hermitian, qubit_inds, coeff in zip( + hermitian_tensors, qubits_seq_seq, coefficients + ): out += coeff * contraction_function(statetensor, hermitian, qubit_inds) return out return statetensor_to_expectation_func -def get_statetensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray], float]: +def get_statetensor_to_expectation_func( + hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray], +) -> Callable[[jnp.ndarray], float]: """ Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and - a list of coefficients and returns a function that converts a statetensor into an expected value. + a list of coefficients and returns a function that converts a statetensor into an expected + value. Args: hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. @@ -117,14 +132,19 @@ def get_statetensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Uni Function that takes statetensor and returns expected value (float). """ - return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, - statetensor_to_single_expectation) + return _get_tensor_to_expectation_func( + hermitian_seq_seq, + qubits_seq_seq, + coefficients, + statetensor_to_single_expectation, + ) -def get_statetensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: +def get_statetensor_to_sampled_expectation_func( + hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray], +) -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: """ Converts strings (or arrays) representing Hermitian matrices, qubit indices and coefficients into a function that converts a statetensor into a sampled expected value. @@ -141,13 +161,13 @@ def get_statetensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequ Function that takes statetensor, random key and integer number of shots and returns sampled expected value (float). """ - statetensor_to_expectation_func = get_statetensor_to_expectation_func(hermitian_seq_seq, - qubits_seq_seq, - coefficients) + statetensor_to_expectation_func = get_statetensor_to_expectation_func( + hermitian_seq_seq, qubits_seq_seq, coefficients + ) - def statetensor_to_sampled_expectation_func(statetensor: jnp.ndarray, - random_key: random.PRNGKeyArray, - n_samps: int) -> float: + def statetensor_to_sampled_expectation_func( + statetensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int + ) -> float: """ Maps statetensor to sampled expected value. @@ -160,9 +180,12 @@ def statetensor_to_sampled_expectation_func(statetensor: jnp.ndarray, Sampled expected value (float). """ sampled_integers = sample_integers(random_key, statetensor, n_samps) - sampled_probs = fori_loop(0, n_samps, - lambda i, sv: sv.at[sampled_integers[i]].add(1), - jnp.zeros(statetensor.size)) + sampled_probs = fori_loop( + 0, + n_samps, + lambda i, sv: sv.at[sampled_integers[i]].add(1), + jnp.zeros(statetensor.size), + ) sampled_probs /= n_samps sampled_st = jnp.sqrt(sampled_probs).reshape(statetensor.shape) diff --git a/qujax/utils.py b/qujax/utils.py index 1f10df8..36b34c4 100644 --- a/qujax/utils.py +++ b/qujax/utils.py @@ -1,16 +1,21 @@ from __future__ import annotations -from typing import Sequence, Union, Callable, List, Tuple, Optional, Protocol, Iterable + import collections.abc from inspect import signature -from jax import numpy as jnp, random +from typing import Callable, Iterable, List, Optional, Protocol, Sequence, Tuple, Union + +from jax import numpy as jnp +from jax import random from qujax import gates -paulis = {'X': gates.X, 'Y': gates.Y, 'Z': gates.Z} +paulis = {"X": gates.X, "Y": gates.Y, "Z": gates.Z} class CallableArrayAndOptionalArray(Protocol): - def __call__(self, params: jnp.ndarray, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: + def __call__( + self, params: jnp.ndarray, statetensor_in: jnp.ndarray = None + ) -> jnp.ndarray: ... @@ -39,23 +44,29 @@ def check_unitary(gate: Gate): if gate in gates.__dict__: gate = gates.__dict__[gate] else: - raise KeyError(f'Gate string \'{gate}\' not found in qujax.gates ' - f'- consider changing input to an array or callable') + raise KeyError( + f"Gate string '{gate}' not found in qujax.gates " + f"- consider changing input to an array or callable" + ) if callable(gate): num_args = len(signature(gate).parameters) gate_arr = gate(*jnp.ones(num_args) * 0.1) - elif hasattr(gate, '__array__'): + elif hasattr(gate, "__array__"): gate_arr = gate else: - raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' - f'callable: {gate}') + raise TypeError( + f"Unsupported gate type - gate must be either a string in qujax.gates, an array or " + f"callable: {gate}" + ) gate_square_dim = int(jnp.sqrt(gate_arr.size)) gate_arr = gate_arr.reshape(gate_square_dim, gate_square_dim) - if jnp.any(jnp.abs(gate_arr @ jnp.conjugate(gate_arr).T - jnp.eye(gate_square_dim)) > 1e-3): - raise TypeError(f'Gate not unitary: {gate}') + if jnp.any( + jnp.abs(gate_arr @ jnp.conjugate(gate_arr).T - jnp.eye(gate_square_dim)) > 1e-3 + ): + raise TypeError(f"Gate not unitary: {gate}") def check_hermitian(hermitian: Union[str, jnp.ndarray]): @@ -68,23 +79,28 @@ def check_hermitian(hermitian: Union[str, jnp.ndarray]): """ if isinstance(hermitian, str): if hermitian not in paulis: - raise TypeError(f'qujax only accepts {tuple(paulis.keys())} as Hermitian strings, received: {hermitian}') + raise TypeError( + f"qujax only accepts {tuple(paulis.keys())} as Hermitian strings," + "received: {hermitian}" + ) else: n_qubits = hermitian.ndim // 2 hermitian_mat = hermitian.reshape(2 * n_qubits, 2 * n_qubits) if not jnp.allclose(hermitian_mat, hermitian_mat.T.conj()): - raise TypeError(f'Array not Hermitian: {hermitian}') + raise TypeError(f"Array not Hermitian: {hermitian}") -def _arrayify_inds(param_inds_seq: Sequence[Union[None, Sequence[int]]]) -> Sequence[jnp.ndarray]: +def _arrayify_inds( + param_inds_seq: Sequence[Union[None, Sequence[int]]] +) -> Sequence[jnp.ndarray]: """ Ensure each element of param_inds_seq is an array (and therefore valid for jnp.take) Args: param_inds_seq: Sequence of sequences representing parameter indices that gates are using, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter - (the float at position zero in the parameter vector/array), the second gate is not parameterised - and the third gates used the parameters at position five and two. + (the float at position zero in the parameter vector/array), the second gate is not + parameterised and the third gates used the parameters at position five and two. Returns: Sequence of arrays representing parameter indices. @@ -92,15 +108,20 @@ def _arrayify_inds(param_inds_seq: Sequence[Union[None, Sequence[int]]]) -> Sequ if param_inds_seq is None: param_inds_seq = [None] param_inds_seq = [jnp.array(p) for p in param_inds_seq] - param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq] + param_inds_seq = [ + jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) + for p in param_inds_seq + ] return param_inds_seq -def check_circuit(gate_seq: Sequence[KrausOp], - qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Sequence[int]], - n_qubits: int = None, - check_unitaries: bool = True): +def check_circuit( + gate_seq: Sequence[KrausOp], + qubit_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Sequence[int]], + n_qubits: int = None, + check_unitaries: bool = True, +): """ Basic checks that circuit arguments conform. @@ -113,7 +134,8 @@ def check_circuit(gate_seq: Sequence[KrausOp], qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. param_inds_seq: Sequence of parameter indices that gates are using, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, - the second gate is not parameterised and the third gates used the fifth and second parameters. + the second gate is not parameterised and the third gates used the fifth and second + parameters. n_qubits: Number of qubits, if fixed. check_unitaries: boolean on whether to check if each gate represents a unitary matrix @@ -121,63 +143,93 @@ def check_circuit(gate_seq: Sequence[KrausOp], if not isinstance(gate_seq, collections.abc.Sequence): raise TypeError("gate_seq must be Sequence e.g. ['H', 'Rx', 'CX']") - if (not isinstance(qubit_inds_seq, collections.abc.Sequence)) or \ - (any([not (isinstance(q, collections.abc.Sequence) or hasattr(q, '__array__')) for q in qubit_inds_seq])): - raise TypeError('qubit_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]') - - if (not isinstance(param_inds_seq, collections.abc.Sequence)) or \ - (any([not (isinstance(p, collections.abc.Sequence) or hasattr(p, '__array__') or p is None) - for p in param_inds_seq])): - raise TypeError('param_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]') - - if len(gate_seq) != len(qubit_inds_seq) or len(param_inds_seq) != len(param_inds_seq): - raise TypeError(f'gate_seq ({len(gate_seq)}), qubit_inds_seq ({len(qubit_inds_seq)})' - f'and param_inds_seq ({len(param_inds_seq)}) must have matching lengths') + if (not isinstance(qubit_inds_seq, collections.abc.Sequence)) or ( + any( + [ + not (isinstance(q, collections.abc.Sequence) or hasattr(q, "__array__")) + for q in qubit_inds_seq + ] + ) + ): + raise TypeError( + "qubit_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]" + ) + + if (not isinstance(param_inds_seq, collections.abc.Sequence)) or ( + any( + [ + not ( + isinstance(p, collections.abc.Sequence) + or hasattr(p, "__array__") + or p is None + ) + for p in param_inds_seq + ] + ) + ): + raise TypeError( + "param_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]" + ) + + if len(gate_seq) != len(qubit_inds_seq) or len(param_inds_seq) != len( + param_inds_seq + ): + raise TypeError( + f"gate_seq ({len(gate_seq)}), qubit_inds_seq ({len(qubit_inds_seq)})" + f"and param_inds_seq ({len(param_inds_seq)}) must have matching lengths" + ) if n_qubits is not None and n_qubits < max([max(qi) for qi in qubit_inds_seq]) + 1: - raise TypeError('n_qubits must be larger than largest qubit index in qubit_inds_seq') + raise TypeError( + "n_qubits must be larger than largest qubit index in qubit_inds_seq" + ) if check_unitaries: for g in gate_seq: check_unitary(g) -def _get_gate_str(gate_obj: KrausOp, - param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]]) -> str: +def _get_gate_str( + gate_obj: KrausOp, param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]] +) -> str: """ Maps single gate object to a four character string representation Args: gate_obj: Either a string matching a function in qujax.gates, a unitary array (which will be reshaped into a tensor of shape e.g. (2,2,2,...) ) - or a function taking parameters (can be empty) and returning gate unitary in tensor form. - Or alternatively, a sequence of Krause operators represented by strings, arrays or functions. + or a function taking parameters (can be empty) and returning gate unitary + in tensor form. + Or alternatively, a sequence of Krause operators represented by strings, arrays or + functions. param_inds: Parameter indices that gates are using, i.e. gate uses 1st and 5th parameter. Returns: Four character string representation of the gate """ - if isinstance(gate_obj, (tuple, list)) or (hasattr(gate_obj, '__array__') and gate_obj.ndim % 2 == 1): + if isinstance(gate_obj, (tuple, list)) or ( + hasattr(gate_obj, "__array__") and gate_obj.ndim % 2 == 1 + ): # Kraus operators - gate_obj = 'Kr' + gate_obj = "Kr" param_inds = jnp.unique(jnp.concatenate(_arrayify_inds(param_inds), axis=0)) if isinstance(gate_obj, str): gate_str = gate_obj - elif hasattr(gate_obj, '__array__'): - gate_str = 'Arr' + elif hasattr(gate_obj, "__array__"): + gate_str = "Arr" elif callable(gate_obj): - gate_str = 'Func' + gate_str = "Func" else: - if hasattr(gate_obj, '__name__'): + if hasattr(gate_obj, "__name__"): gate_str = gate_obj.__name__ - elif hasattr(gate_obj, '__class__') and hasattr(gate_obj.__class__, '__name__'): + elif hasattr(gate_obj, "__class__") and hasattr(gate_obj.__class__, "__name__"): gate_str = gate_obj.__class__.__name__ else: - gate_str = 'Other' + gate_str = "Other" - if hasattr(param_inds, 'tolist'): + if hasattr(param_inds, "tolist"): param_inds = param_inds.tolist() if isinstance(param_inds, tuple): @@ -185,26 +237,27 @@ def _get_gate_str(gate_obj: KrausOp, if param_inds == [] or param_inds == [None] or param_inds is None: if len(gate_str) > 7: - gate_str = gate_str[:6] + '.' + gate_str = gate_str[:6] + "." else: - param_str = str(param_inds).replace(' ', '') + param_str = str(param_inds).replace(" ", "") if len(param_str) > 5: - param_str = '[.]' + param_str = "[.]" if (len(gate_str) + len(param_str)) > 7: - gate_str = gate_str[:1] + '.' + gate_str = gate_str[:1] + "." gate_str += param_str - gate_str = gate_str.center(7, '-') + gate_str = gate_str.center(7, "-") return gate_str def _pad_rows(rows: List[str]) -> Tuple[List[str], List[bool]]: """ - Pad string representation of circuit to be rectangular. Fills qubit rows with '-' and between-qubit rows with ' '. + Pad string representation of circuit to be rectangular. + Fills qubit rows with '-' and between-qubit rows with ' '. Args: rows: String representation of circuit @@ -220,24 +273,26 @@ def extend_row(row: str, qubit_row: bool) -> str: lr = len(row) if lr < max_len: if qubit_row: - row += '-' * (max_len - lr) + row += "-" * (max_len - lr) else: - row += ' ' * (max_len - lr) + row += " " * (max_len - lr) return row out_rows = [extend_row(r, i % 2 == 0) for i, r in enumerate(rows)] return out_rows, [True] * len(rows) -def print_circuit(gate_seq: Sequence[KrausOp], - qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Sequence[int]], - n_qubits: Optional[int] = None, - qubit_min: Optional[int] = 0, - qubit_max: Optional[int] = jnp.inf, - gate_ind_min: Optional[int] = 0, - gate_ind_max: Optional[int] = jnp.inf, - sep_length: Optional[int] = 1) -> List[str]: +def print_circuit( + gate_seq: Sequence[KrausOp], + qubit_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Sequence[int]], + n_qubits: Optional[int] = None, + qubit_min: Optional[int] = 0, + qubit_max: Optional[int] = jnp.inf, + gate_ind_min: Optional[int] = 0, + gate_ind_max: Optional[int] = jnp.inf, + sep_length: Optional[int] = 1, +) -> List[str]: """ Returns and prints basic string representation of circuit. @@ -250,7 +305,8 @@ def print_circuit(gate_seq: Sequence[KrausOp], qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. param_inds_seq: Sequence of parameter indices that gates are using, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, - the second gate is not parameterised and the third gates used the fifth and second parameters. + the second gate is not parameterised and the third gates used the fifth and + second parameters. n_qubits: Number of qubits, if fixed. qubit_min: Index of first qubit to display. qubit_max: Index of final qubit to display. @@ -266,23 +322,23 @@ def print_circuit(gate_seq: Sequence[KrausOp], gate_ind_max = min(len(gate_seq) - 1, gate_ind_max) if gate_ind_min > gate_ind_max: - raise TypeError('gate_ind_max must be larger or equal to gate_ind_min') + raise TypeError("gate_ind_max must be larger or equal to gate_ind_min") if n_qubits is None: n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 qubit_max = min(n_qubits - 1, qubit_max) if qubit_min > qubit_max: - raise TypeError('qubit_max must be larger or equal to qubit_min') + raise TypeError("qubit_max must be larger or equal to qubit_min") gate_str_seq = [_get_gate_str(g, p) for g, p in zip(gate_seq, param_inds_seq)] n_qubits_disp = qubit_max - qubit_min + 1 - rows = [f'q{qubit_min}: '.ljust(3) + '-' * sep_length] + rows = [f"q{qubit_min}: ".ljust(3) + "-" * sep_length] if n_qubits_disp > 1: for i in range(qubit_min + 1, qubit_max + 1): - rows += [' ', f'q{i}: '.ljust(3) + '-' * sep_length] + rows += [" ", f"q{i}: ".ljust(3) + "-" * sep_length] rows, rows_free = _pad_rows(rows) for gate_ind in range(gate_ind_min, gate_ind_max + 1): @@ -291,21 +347,21 @@ def print_circuit(gate_seq: Sequence[KrausOp], qi_min = min(qi) qi_max = max(qi) - ri_min = 2 * qi_min # index of top row used by gate - ri_max = 2 * qi_max # index of bottom row used by gate + ri_min = 2 * qi_min # index of top row used by gate + ri_max = 2 * qi_max # index of bottom row used by gate if not all([rows_free[i] for i in range(ri_min, ri_max)]): rows, rows_free = _pad_rows(rows) for row_ind in range(ri_min, ri_max + 1): if row_ind == 2 * qi[-1]: - rows[row_ind] += '-' * sep_length + g + rows[row_ind] += "-" * sep_length + g elif row_ind % 2 == 1: - rows[row_ind] += ' ' * sep_length + ' ' + '|' + ' ' + rows[row_ind] += " " * sep_length + " " + "|" + " " elif row_ind / 2 in qi: - rows[row_ind] += '-' * sep_length + '---' + '◯' + '---' + rows[row_ind] += "-" * sep_length + "---" + "◯" + "---" else: - rows[row_ind] += '-' * sep_length + '---' + '|' + '---' + rows[row_ind] += "-" * sep_length + "---" + "|" + "---" rows_free[row_ind] = False @@ -317,8 +373,9 @@ def print_circuit(gate_seq: Sequence[KrausOp], return rows -def integers_to_bitstrings(integers: Union[int, jnp.ndarray], - nbits: int = None) -> jnp.ndarray: +def integers_to_bitstrings( + integers: Union[int, jnp.ndarray], nbits: int = None +) -> jnp.ndarray: """ Convert integer or array of integers into their binary expansion(s). @@ -334,7 +391,9 @@ def integers_to_bitstrings(integers: Union[int, jnp.ndarray], if nbits is None: nbits = (jnp.ceil(jnp.log2(jnp.maximum(integers.max(), 1)) + 1e-5)).astype(int) - return jnp.squeeze(((integers[:, None] & (1 << jnp.arange(nbits - 1, -1, -1))) > 0).astype(int)) + return jnp.squeeze( + ((integers[:, None] & (1 << jnp.arange(nbits - 1, -1, -1))) > 0).astype(int) + ) def bitstrings_to_integers(bitstrings: jnp.ndarray) -> Union[int, jnp.ndarray]: @@ -352,9 +411,11 @@ def bitstrings_to_integers(bitstrings: jnp.ndarray) -> Union[int, jnp.ndarray]: return jnp.squeeze(bitstrings.dot(convarr)).astype(int) -def sample_integers(random_key: random.PRNGKeyArray, - statetensor: jnp.ndarray, - n_samps: Optional[int] = 1) -> jnp.ndarray: +def sample_integers( + random_key: random.PRNGKeyArray, + statetensor: jnp.ndarray, + n_samps: Optional[int] = 1, +) -> jnp.ndarray: """ Generate random integer samples according to statetensor. @@ -368,13 +429,17 @@ def sample_integers(random_key: random.PRNGKeyArray, """ sv_probs = jnp.square(jnp.abs(statetensor.flatten())) - sampled_inds = random.choice(random_key, a=jnp.arange(statetensor.size), shape=(n_samps,), p=sv_probs) + sampled_inds = random.choice( + random_key, a=jnp.arange(statetensor.size), shape=(n_samps,), p=sv_probs + ) return sampled_inds -def sample_bitstrings(random_key: random.PRNGKeyArray, - statetensor: jnp.ndarray, - n_samps: Optional[int] = 1) -> jnp.ndarray: +def sample_bitstrings( + random_key: random.PRNGKeyArray, + statetensor: jnp.ndarray, + n_samps: Optional[int] = 1, +) -> jnp.ndarray: """ Generate random bitstring samples according to statetensor. @@ -387,7 +452,9 @@ def sample_bitstrings(random_key: random.PRNGKeyArray, Array with sampled bitstrings, shape=(n_samps, statetensor.ndim). """ - return integers_to_bitstrings(sample_integers(random_key, statetensor, n_samps), statetensor.ndim) + return integers_to_bitstrings( + sample_integers(random_key, statetensor, n_samps), statetensor.ndim + ) def statetensor_to_densitytensor(statetensor: jnp.ndarray) -> jnp.ndarray: @@ -403,5 +470,7 @@ def statetensor_to_densitytensor(statetensor: jnp.ndarray) -> jnp.ndarray: """ n_qubits = statetensor.ndim st = statetensor - dt = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2 * n_qubits)) + dt = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape( + 2 for _ in range(2 * n_qubits) + ) return dt diff --git a/qujax/version.py b/qujax/version.py index 73e3bb4..e19434e 100644 --- a/qujax/version.py +++ b/qujax/version.py @@ -1 +1 @@ -__version__ = '0.3.2' +__version__ = "0.3.3" diff --git a/setup.py b/setup.py index dc1c0c9..851571d 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup -exec(open('qujax/version.py').read()) +exec(open("qujax/version.py").read()) setup( name="qujax", @@ -22,5 +22,5 @@ ], include_package_data=True, platforms="any", - version=__version__ + version=__version__, ) diff --git a/tests/test_densitytensor.py b/tests/test_densitytensor.py index 8a79345..46ead45 100644 --- a/tests/test_densitytensor.py +++ b/tests/test_densitytensor.py @@ -1,20 +1,24 @@ from itertools import combinations -from jax import numpy as jnp, jit + +from jax import jit +from jax import numpy as jnp import qujax def test_kraus_single(): n_qubits = 3 - dim = 2 ** n_qubits - density_matrix = jnp.arange(dim ** 2).reshape(dim, dim) + dim = 2**n_qubits + density_matrix = jnp.arange(dim**2).reshape(dim, dim) density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) kraus_operator = qujax.gates.Rx(0.2) qubit_inds = (1,) unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) - unitary_matrix = jnp.kron(unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1))) + unitary_matrix = jnp.kron( + unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1)) + ) check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T # qujax._kraus_single @@ -23,7 +27,9 @@ def test_kraus_single(): assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))( + density_tensor, kraus_operator, qubit_inds + ) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) @@ -32,15 +38,17 @@ def test_kraus_single(): qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))( + density_tensor, kraus_operator, qubit_inds + ) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) def test_kraus_single_2qubit(): n_qubits = 4 - dim = 2 ** n_qubits - density_matrix = jnp.arange(dim ** 2).reshape(dim, dim) + dim = 2**n_qubits + density_matrix = jnp.arange(dim**2).reshape(dim, dim) density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) kraus_operator_tensor = qujax.gates.ZZPhase(0.1) kraus_operator = qujax.gates.ZZPhase(0.1).reshape(4, 4) @@ -48,18 +56,22 @@ def test_kraus_single_2qubit(): qubit_inds = (1, 2) unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) - unitary_matrix = jnp.kron(unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1))) + unitary_matrix = jnp.kron( + unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1)) + ) check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T # qujax._kraus_single - qujax_kraus_dt = qujax._kraus_single(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dt = qujax._kraus_single( + density_tensor, kraus_operator_tensor, qubit_inds + ) qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))(density_tensor, - kraus_operator_tensor, - qubit_inds) + qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))( + density_tensor, kraus_operator_tensor, qubit_inds + ) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) @@ -68,27 +80,40 @@ def test_kraus_single_2qubit(): qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator, qubit_inds) # check reshape kraus_operator correctly + qujax_kraus_dt = qujax.kraus( + density_tensor, kraus_operator, qubit_inds + ) # check reshape kraus_operator correctly qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))( + density_tensor, kraus_operator_tensor, qubit_inds + ) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) def test_kraus_multiple(): n_qubits = 3 - dim = 2 ** n_qubits - density_matrix = jnp.arange(dim ** 2).reshape(dim, dim) + dim = 2**n_qubits + density_matrix = jnp.arange(dim**2).reshape(dim, dim) density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) - kraus_operators = [0.25 * qujax.gates.H, 0.25 * qujax.gates.Rx(0.3), 0.5 * qujax.gates.Ry(0.1)] + kraus_operators = [ + 0.25 * qujax.gates.H, + 0.25 * qujax.gates.Rx(0.3), + 0.5 * qujax.gates.Ry(0.1), + ] qubit_inds = (1,) - unitary_matrices = [jnp.kron(jnp.eye(2 * qubit_inds[0]), ko) for ko in kraus_operators] - unitary_matrices = [jnp.kron(um, jnp.eye(2 * (n_qubits - qubit_inds[0] - 1))) for um in unitary_matrices] + unitary_matrices = [ + jnp.kron(jnp.eye(2 * qubit_inds[0]), ko) for ko in kraus_operators + ] + unitary_matrices = [ + jnp.kron(um, jnp.eye(2 * (n_qubits - qubit_inds[0] - 1))) + for um in unitary_matrices + ] check_kraus_dm = jnp.zeros_like(density_matrix) for um in unitary_matrices: @@ -99,7 +124,9 @@ def test_kraus_multiple(): assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(density_tensor, kraus_operators, qubit_inds) + qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))( + density_tensor, kraus_operators, qubit_inds + ) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) @@ -115,10 +142,14 @@ def test_params_to_densitytensor_func(): qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] - params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) - params_to_st = qujax.get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_dt = qujax.get_params_to_densitytensor_func( + gate_seq, qubit_inds_seq, param_inds_seq, n_qubits + ) + params_to_st = qujax.get_params_to_statetensor_func( + gate_seq, qubit_inds_seq, param_inds_seq, n_qubits + ) - params = jnp.arange(n_qubits) / 10. + params = jnp.arange(n_qubits) / 10.0 st = params_to_st(params) dt_test = qujax.statetensor_to_densitytensor(st) @@ -142,7 +173,9 @@ def test_params_to_densitytensor_func_with_bit_flip(): qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] - params_to_pre_bf_st = qujax.get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_pre_bf_st = qujax.get_params_to_statetensor_func( + gate_seq, qubit_inds_seq, param_inds_seq, n_qubits + ) kraus_ops = [[0.3 * jnp.eye(2), 0.7 * qujax.gates.X]] kraus_qubit_inds = [(0,)] @@ -154,12 +187,16 @@ def test_params_to_densitytensor_func_with_bit_flip(): _ = qujax.print_circuit(gate_seq, qubit_inds_seq, param_inds_seq) - params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_dt = qujax.get_params_to_densitytensor_func( + gate_seq, qubit_inds_seq, param_inds_seq, n_qubits + ) - params = jnp.arange(n_qubits) / 10. + params = jnp.arange(n_qubits) / 10.0 pre_bf_st = params_to_pre_bf_st(params) - pre_bf_dt = (pre_bf_st.reshape(-1, 1) @ pre_bf_st.reshape(1, -1).conj()).reshape(2 for _ in range(2 * n_qubits)) + pre_bf_dt = (pre_bf_st.reshape(-1, 1) @ pre_bf_st.reshape(1, -1).conj()).reshape( + 2 for _ in range(2 * n_qubits) + ) dt_test = qujax.kraus(pre_bf_dt, kraus_ops[0], kraus_qubit_inds[0]) dt = params_to_dt(params) @@ -171,7 +208,7 @@ def test_params_to_densitytensor_func_with_bit_flip(): def test_partial_trace_1(): - state1 = 1 / jnp.sqrt(2) * jnp.array([1., 1.]) + state1 = 1 / jnp.sqrt(2) * jnp.array([1.0, 1.0]) state2 = jnp.kron(state1, state1) state3 = jnp.kron(state1, state2) @@ -197,9 +234,11 @@ def test_partial_trace_2(): qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] - params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_dt = qujax.get_params_to_densitytensor_func( + gate_seq, qubit_inds_seq, param_inds_seq, n_qubits + ) - params = jnp.arange(1, n_qubits + 1) / 10. + params = jnp.arange(1, n_qubits + 1) / 10.0 dt = params_to_dt(params) dt_discard_test = jnp.trace(dt, axis1=0, axis2=n_qubits) @@ -219,25 +258,28 @@ def test_measure(): qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] - params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_dt = qujax.get_params_to_densitytensor_func( + gate_seq, qubit_inds_seq, param_inds_seq, n_qubits + ) - params = jnp.arange(1, n_qubits + 1) / 10. + params = jnp.arange(1, n_qubits + 1) / 10.0 dt = params_to_dt(params) qubit_inds = [0] - all_probs = jnp.diag(dt.reshape(2 ** n_qubits, 2 ** n_qubits)).real - all_probs_marginalise \ - = all_probs.reshape((2,) * n_qubits).sum(axis=[i for i in range(n_qubits) if i not in qubit_inds]) + all_probs = jnp.diag(dt.reshape(2**n_qubits, 2**n_qubits)).real + all_probs_marginalise = all_probs.reshape((2,) * n_qubits).sum( + axis=[i for i in range(n_qubits) if i not in qubit_inds] + ) probs = qujax.densitytensor_to_measurement_probabilities(dt, qubit_inds) - assert jnp.isclose(probs.sum(), 1.) - assert jnp.isclose(all_probs.sum(), 1.) + assert jnp.isclose(probs.sum(), 1.0) + assert jnp.isclose(all_probs.sum(), 1.0) assert jnp.allclose(probs, all_probs_marginalise) - dm = dt.reshape(2 ** n_qubits, 2 ** n_qubits) + dm = dt.reshape(2**n_qubits, 2**n_qubits) projector = jnp.array([[1, 0], [0, 0]]) for _ in range(n_qubits - 1): projector = jnp.kron(projector, jnp.eye(2)) @@ -246,6 +288,8 @@ def test_measure(): measured_dt_true = measured_dm.reshape((2,) * 2 * n_qubits) measured_dt = qujax.densitytensor_to_measured_densitytensor(dt, qubit_inds, 0) - measured_dt_bits = qujax.densitytensor_to_measured_densitytensor(dt, qubit_inds, (0,)*n_qubits) + measured_dt_bits = qujax.densitytensor_to_measured_densitytensor( + dt, qubit_inds, (0,) * n_qubits + ) assert jnp.allclose(measured_dt_true, measured_dt) assert jnp.allclose(measured_dt_true, measured_dt_bits) diff --git a/tests/test_expectations.py b/tests/test_expectations.py index 0f168a5..3581219 100644 --- a/tests/test_expectations.py +++ b/tests/test_expectations.py @@ -1,10 +1,12 @@ -from jax import numpy as jnp, jit, grad, random, config +from jax import config, grad, jit +from jax import numpy as jnp +from jax import random import qujax def test_pauli_hermitian(): - for p_str in ('X', 'Y', 'Z'): + for p_str in ("X", "Y", "Z"): qujax.check_hermitian(p_str) qujax.check_hermitian(qujax.gates.__dict__[p_str]) @@ -14,8 +16,8 @@ def test_single_expectation(): st1 = jnp.zeros((2, 2, 2)) st2 = jnp.zeros((2, 2, 2)) - st1 = st1.at[(0, 0, 0)].set(1.) - st2 = st2.at[(1, 0, 0)].set(1.) + st1 = st1.at[(0, 0, 0)].set(1.0) + st2 = st2.at[(1, 0, 0)].set(1.0) dt1 = qujax.statetensor_to_densitytensor(st1) dt2 = qujax.statetensor_to_densitytensor(st2) ZZ = jnp.kron(Z, Z).reshape(2, 2, 2, 2) @@ -32,24 +34,32 @@ def test_single_expectation(): def test_bitstring_expectation(): n_qubits = 4 - gates = ['H'] * n_qubits \ - + ['Ry'] * n_qubits + ['Rz'] * n_qubits \ - + ['CX'] * (n_qubits - 1) \ - + ['Ry'] * n_qubits + ['Rz'] * n_qubits - qubits = [[i] for i in range(n_qubits)] * 3 \ - + [[i, i + 1] for i in range(n_qubits - 1)] \ - + [[i] for i in range(n_qubits)] * 2 - param_inds = [[]] * n_qubits \ - + [[i] for i in range(n_qubits * 2)] \ - + [[]] * (n_qubits - 1) \ - + [[i] for i in range(n_qubits * 2, n_qubits * 4)] + gates = ( + ["H"] * n_qubits + + ["Ry"] * n_qubits + + ["Rz"] * n_qubits + + ["CX"] * (n_qubits - 1) + + ["Ry"] * n_qubits + + ["Rz"] * n_qubits + ) + qubits = ( + [[i] for i in range(n_qubits)] * 3 + + [[i, i + 1] for i in range(n_qubits - 1)] + + [[i] for i in range(n_qubits)] * 2 + ) + param_inds = ( + [[]] * n_qubits + + [[i] for i in range(n_qubits * 2)] + + [[]] * (n_qubits - 1) + + [[i] for i in range(n_qubits * 2, n_qubits * 4)] + ) param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds) n_params = n_qubits * 4 params = random.uniform(random.PRNGKey(0), shape=(n_params,)) - costs = random.normal(random.PRNGKey(1), shape=(2 ** n_qubits,)) + costs = random.normal(random.PRNGKey(1), shape=(2**n_qubits,)) def st_to_expectation(statetensor): probs = jnp.square(jnp.abs(statetensor.flatten())) @@ -67,7 +77,7 @@ def brute_force_param_to_exp(p): expectation_jit = jit(param_to_expectation)(params) assert expectation.shape == () - assert expectation.dtype.name[:5] == 'float' + assert expectation.dtype.name[:5] == "float" assert jnp.isclose(true_expectation, expectation) assert jnp.isclose(true_expectation, expectation_jit) @@ -76,7 +86,7 @@ def brute_force_param_to_exp(p): expectation_grad_jit = jit(grad(param_to_expectation))(params) assert expectation_grad.shape == (n_params,) - assert expectation_grad.dtype.name[:5] == 'float' + assert expectation_grad.dtype.name[:5] == "float" assert jnp.allclose(true_expectation_grad, expectation_grad, atol=1e-5) assert jnp.allclose(true_expectation_grad, expectation_grad_jit, atol=1e-5) @@ -86,18 +96,20 @@ def test_ZZ_Y(): n_qubits = 4 - hermitian_str_seq_seq = [['Z', 'Z']] * (n_qubits - 1) + [['Y']] * n_qubits + hermitian_str_seq_seq = [["Z", "Z"]] * (n_qubits - 1) + [["Y"]] * n_qubits coefs = random.normal(random.PRNGKey(0), shape=(len(hermitian_str_seq_seq),)) - qubit_inds_seq = [[i, i + 1] for i in range(n_qubits - 1)] + [[i] for i in range(n_qubits)] - st_to_exp = qujax.get_statetensor_to_expectation_func(hermitian_str_seq_seq, - qubit_inds_seq, - coefs) - dt_to_exp = qujax.get_statetensor_to_expectation_func(hermitian_str_seq_seq, - qubit_inds_seq, - coefs) - - state = random.uniform(random.PRNGKey(0), shape=(2 ** n_qubits,)) * 2 + qubit_inds_seq = [[i, i + 1] for i in range(n_qubits - 1)] + [ + [i] for i in range(n_qubits) + ] + st_to_exp = qujax.get_statetensor_to_expectation_func( + hermitian_str_seq_seq, qubit_inds_seq, coefs + ) + dt_to_exp = qujax.get_statetensor_to_expectation_func( + hermitian_str_seq_seq, qubit_inds_seq, coefs + ) + + state = random.uniform(random.PRNGKey(0), shape=(2**n_qubits,)) * 2 state /= jnp.linalg.norm(state) st_in = state.reshape((2,) * n_qubits) dt_in = qujax.statetensor_to_densitytensor(st_in) @@ -118,9 +130,11 @@ def big_hermitian_matrix(hermitian_str_seq, qubit_inds): big_h = jnp.kron(big_h, hermitian_arrs[k]) return big_h - sum_big_hs = jnp.zeros((2 ** n_qubits, 2 ** n_qubits), dtype='complex') + sum_big_hs = jnp.zeros((2**n_qubits, 2**n_qubits), dtype="complex") for i in range(len(hermitian_str_seq_seq)): - sum_big_hs += coefs[i] * big_hermitian_matrix(hermitian_str_seq_seq[i], qubit_inds_seq[i]) + sum_big_hs += coefs[i] * big_hermitian_matrix( + hermitian_str_seq_seq[i], qubit_inds_seq[i] + ) assert jnp.allclose(sum_big_hs, sum_big_hs.conj().T) @@ -133,24 +147,28 @@ def big_hermitian_matrix(hermitian_str_seq, qubit_inds): qujax_dt_exp_jit = jit(dt_to_exp)(dt_in) assert jnp.array(qujax_exp).shape == () - assert jnp.array(qujax_exp).dtype.name[:5] == 'float' + assert jnp.array(qujax_exp).dtype.name[:5] == "float" assert jnp.isclose(true_exp, qujax_exp) assert jnp.isclose(true_exp, qujax_dt_exp) assert jnp.isclose(true_exp, qujax_exp_jit) assert jnp.isclose(true_exp, qujax_dt_exp_jit) - st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(hermitian_str_seq_seq, - qubit_inds_seq, - coefs) - dt_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(hermitian_str_seq_seq, - qubit_inds_seq, - coefs) + st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func( + hermitian_str_seq_seq, qubit_inds_seq, coefs + ) + dt_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func( + hermitian_str_seq_seq, qubit_inds_seq, coefs + ) qujax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 1000000) - qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 1000000) + qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)( + st_in, random.PRNGKey(2), 1000000 + ) qujax_samp_exp_dt = dt_to_samp_exp(st_in, random.PRNGKey(1), 1000000) - qujax_samp_exp_dt_jit = jit(dt_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 1000000) + qujax_samp_exp_dt_jit = jit(dt_to_samp_exp, static_argnums=2)( + st_in, random.PRNGKey(2), 1000000 + ) assert jnp.array(qujax_samp_exp).shape == () - assert jnp.array(qujax_samp_exp).dtype.name[:5] == 'float' + assert jnp.array(qujax_samp_exp).dtype.name[:5] == "float" assert jnp.isclose(true_exp, qujax_samp_exp, rtol=1e-2) assert jnp.isclose(true_exp, qujax_samp_exp_jit, rtol=1e-2) assert jnp.isclose(true_exp, qujax_samp_exp_dt, rtol=1e-2) diff --git a/tests/test_gates.py b/tests/test_gates.py index 3d29f0b..7e65d00 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -1,8 +1,8 @@ -from qujax import gates, check_unitary +from qujax import check_unitary, gates def test_gates(): for g_str, g in gates.__dict__.items(): - if g_str[0] != '_' and g_str != 'jnp': + if g_str[0] != "_" and g_str != "jnp": check_unitary(g_str) check_unitary(g) diff --git a/tests/test_statetensor.py b/tests/test_statetensor.py index 8117034..284306d 100644 --- a/tests/test_statetensor.py +++ b/tests/test_statetensor.py @@ -1,10 +1,11 @@ -from jax import numpy as jnp, jit +from jax import jit +from jax import numpy as jnp import qujax def test_H(): - gates = ['H'] + gates = ["H"] qubits = [[0]] param_inds = [[]] @@ -12,7 +13,7 @@ def test_H(): st = param_to_st() st_jit = jit(param_to_st)() - true_sv = jnp.array([0.70710678 + 0.j, 0.70710678 + 0.j]) + true_sv = jnp.array([0.70710678 + 0.0j, 0.70710678 + 0.0j]) assert st.size == true_sv.size assert jnp.allclose(st.flatten(), true_sv) @@ -27,30 +28,33 @@ def test_H(): def test_H_redundant_qubits(): - gates = ['H'] + gates = ["H"] qubits = [[0]] param_inds = [[]] n_qubits = 3 - param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds, n_qubits) + param_to_st = qujax.get_params_to_statetensor_func( + gates, qubits, param_inds, n_qubits + ) st = param_to_st(statetensor_in=None) - true_sv = jnp.array([0.70710678, 0., 0., 0., - 0.70710678, 0., 0., 0.]) + true_sv = jnp.array([0.70710678, 0.0, 0.0, 0.0, 0.70710678, 0.0, 0.0, 0.0]) assert st.size == true_sv.size assert jnp.allclose(st.flatten(), true_sv) - param_to_unitary = qujax.get_params_to_unitarytensor_func(gates, qubits, param_inds, n_qubits) - unitary = param_to_unitary().reshape(2 ** n_qubits, 2 ** n_qubits) - unitary_jit = jit(param_to_unitary)().reshape(2 ** n_qubits, 2 ** n_qubits) - zero_sv = jnp.zeros(2 ** n_qubits).at[0].set(1) + param_to_unitary = qujax.get_params_to_unitarytensor_func( + gates, qubits, param_inds, n_qubits + ) + unitary = param_to_unitary().reshape(2**n_qubits, 2**n_qubits) + unitary_jit = jit(param_to_unitary)().reshape(2**n_qubits, 2**n_qubits) + zero_sv = jnp.zeros(2**n_qubits).at[0].set(1) assert jnp.allclose(unitary @ zero_sv, true_sv) assert jnp.allclose(unitary_jit @ zero_sv, true_sv) def test_CX_Rz_CY(): - gates = ['H', 'H', 'H', 'CX', 'Rz', 'CY'] + gates = ["H", "H", "H", "CX", "Rz", "CY"] qubits = [[0], [1], [2], [0, 1], [1], [1, 2]] param_inds = [[], [], [], None, [0], []] @@ -58,25 +62,36 @@ def test_CX_Rz_CY(): param = jnp.array(0.1) st = param_to_st(param) - true_sv = jnp.array([0.34920055 - 0.05530793j, 0.34920055 - 0.05530793j, - 0.05530793 - 0.34920055j, -0.05530793 + 0.34920055j, - 0.34920055 - 0.05530793j, 0.34920055 - 0.05530793j, - 0.05530793 - 0.34920055j, -0.05530793 + 0.34920055j], dtype='complex64') + true_sv = jnp.array( + [ + 0.34920055 - 0.05530793j, + 0.34920055 - 0.05530793j, + 0.05530793 - 0.34920055j, + -0.05530793 + 0.34920055j, + 0.34920055 - 0.05530793j, + 0.34920055 - 0.05530793j, + 0.05530793 - 0.34920055j, + -0.05530793 + 0.34920055j, + ], + dtype="complex64", + ) assert st.size == true_sv.size assert jnp.allclose(st.flatten(), true_sv) n_qubits = 3 - param_to_unitary = qujax.get_params_to_unitarytensor_func(gates, qubits, param_inds, n_qubits) - unitary = param_to_unitary(param).reshape(2 ** n_qubits, 2 ** n_qubits) - unitary_jit = jit(param_to_unitary)(param).reshape(2 ** n_qubits, 2 ** n_qubits) - zero_sv = jnp.zeros(2 ** n_qubits).at[0].set(1) - assert jnp.allclose(unitary @ zero_sv, true_sv) + param_to_unitary = qujax.get_params_to_unitarytensor_func( + gates, qubits, param_inds, n_qubits + ) + unitary = param_to_unitary(param).reshape(2**n_qubits, 2**n_qubits) + unitary_jit = jit(param_to_unitary)(param).reshape(2**n_qubits, 2**n_qubits) + zero_sv = jnp.zeros(2**n_qubits).at[0].set(1) + assert jnp.allclose(unitary @ zero_sv, true_sv) assert jnp.allclose(unitary_jit @ zero_sv, true_sv) def test_stacked_circuits(): - gates = ['H'] + gates = ["H"] qubits = [[0]] param_inds = [[]] @@ -91,4 +106,3 @@ def test_stacked_circuits(): assert jnp.allclose(st2.flatten(), all_zeros_sv, atol=1e-7) assert jnp.allclose(st2_2.flatten(), all_zeros_sv, atol=1e-7) -