Skip to content

Commit

Permalink
Merge pull request #99 from CQCL/release/1.21.0
Browse files Browse the repository at this point in the history
Release/1.21.0
  • Loading branch information
cqc-melf authored Oct 17, 2023
2 parents cb22c70 + bfc4d2c commit bb67a34
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 53 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
os: ['ubuntu-22.04', 'macos-12']
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
fetch-depth: '0'
- run: git fetch --depth=1 origin +refs/tags/*:refs/tags/* +refs/heads/*:refs/remotes/origin/*
Expand Down Expand Up @@ -107,7 +107,7 @@ jobs:
needs: publish_to_pypi
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
fetch-depth: '0'
- name: Set up Python 3.10
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
name: build docs
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def correct_signature(
signature: str,
return_annotation: str,
) -> (str, str):

new_signature = signature
new_return_annotation = return_annotation
for k, v in app.config.custom_internal_mapping.items():
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python 3.x
uses: actions/setup-python@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion _metadata.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__extension_version__ = "0.13.0"
__extension_version__ = "0.14.0"
__extension_name__ = "pytket-qujax"
5 changes: 5 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changelog
~~~~~~~~~

0.14.0 (October 2023)
---------------------

* Updated pytket version requirement to 1.21.

0.13.0 (August 2023)
--------------------

Expand Down
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace_packages = True
check_untyped_defs = True

warn_redundant_casts = True
warn_unused_ignores = False
warn_unused_ignores = True
warn_no_return = False
warn_return_any = True
warn_unreachable = True
Expand Down
2 changes: 1 addition & 1 deletion pytket/extensions/qujax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

# _metadata.py is copied to the folder after installation.
from ._metadata import __extension_version__, __extension_name__ # type: ignore
from ._metadata import __extension_version__, __extension_name__
from .qujax_convert import (
tk_to_qujax,
tk_to_qujax_args,
Expand Down
10 changes: 5 additions & 5 deletions pytket/extensions/qujax/qujax_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

import qujax # type: ignore
from jax import numpy as jnp
from sympy import lambdify, Symbol # type: ignore
from sympy import lambdify, Symbol
from pytket import Qubit, Circuit # type: ignore
from pytket._tket.circuit import Command # type: ignore
from pytket._tket.circuit import Command


def _tk_qubits_to_inds(tk_qubits: Sequence[Qubit]) -> Tuple[int, ...]:
Expand Down Expand Up @@ -96,11 +96,11 @@ def _symbolic_command_to_gate_and_param_inds(
gate = gate_str
else:
if len(free_symbols) == 0:
gate = jnp.array(command.op.get_unitary()) # type: ignore
gate = jnp.array(command.op.get_unitary())
else:
raise TypeError(f"Parameterised gate {gate_str} not found in qujax.gates")

param_inds = tuple(symbol_map[symbol] for symbol in free_symbols) # type: ignore
param_inds = tuple(symbol_map[symbol] for symbol in free_symbols)
return gate, param_inds


Expand Down Expand Up @@ -362,7 +362,7 @@ def qujax_args_to_tk(

if param is None:
n_params = max([max(p) + 1 if len(p) > 0 else 0 for p in param_inds_seq])
param = jnp.zeros(n_params) # type: ignore
param = jnp.zeros(n_params)

param_inds_seq = [jnp.array(p, dtype="int32") for p in param_inds_seq] # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
packages=find_namespace_packages(include=["pytket.*"]),
include_package_data=True,
install_requires=[
"pytket ~= 1.18",
"pytket ~= 1.21",
"qujax ~= 1.0",
],
classifiers=[
Expand Down
62 changes: 31 additions & 31 deletions tests/test_tket.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.

from typing import Union, Any
from jax import numpy as jnp, jit, grad, random # type: ignore
from jax import numpy as jnp, jit, grad, random
import qujax # type: ignore
import pytest

from pytket.circuit import Circuit, Qubit # type: ignore
from pytket.pauli import Pauli, QubitPauliString # type: ignore
from pytket.utils import QubitPauliOperator # type: ignore
from pytket.circuit import Circuit, Qubit
from pytket.pauli import Pauli, QubitPauliString
from pytket.utils import QubitPauliOperator
from pytket.extensions.qujax import (
tk_to_qujax,
tk_to_qujax_args,
Expand All @@ -46,17 +46,17 @@ def _test_circuit(

test_dt = apply_circuit_dt()
n_qubits = test_dt.ndim // 2
test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) # type: ignore
test_jit_dm_diag = jnp.diag( # type: ignore
test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits))
test_jit_dm_diag = jnp.diag(
jit_apply_circuit_dt().reshape(2**n_qubits, 2**n_qubits)
)
else:
test_sv = apply_circuit(param).flatten()
test_jit_sv = jit_apply_circuit(param).flatten()
test_dt = apply_circuit_dt(param)
n_qubits = test_dt.ndim // 2
test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) # type: ignore
test_jit_dm_diag = jnp.diag( # type: ignore
test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits))
test_jit_dm_diag = jnp.diag(
jit_apply_circuit_dt(param).reshape(2**n_qubits, 2**n_qubits)
)

Expand Down Expand Up @@ -99,18 +99,18 @@ def test_H() -> None:


def test_CX() -> None:
param = jnp.array([0.25]) # type: ignore
param = jnp.array([0.25])

circuit = Circuit(2)
circuit.H(0)
circuit.Rz(param[0], 0)
circuit.Rz(float(param[0]), 0)
circuit.CX(0, 1)

_test_circuit(circuit, param, True)


def test_CX_callable() -> None:
param = jnp.array([0.25]) # type: ignore
param = jnp.array([0.25])

def H() -> Any:
return qujax.gates.H
Expand All @@ -131,7 +131,7 @@ def CX() -> Any:

circuit = Circuit(2)
circuit.H(0)
circuit.Rz(param[0], 0)
circuit.Rz(float(param[0]), 0)
circuit.CX(0, 1)
true_sv = circuit.get_statevector()

Expand All @@ -145,46 +145,46 @@ def CX() -> Any:


def test_CX_qrev() -> None:
param = jnp.array([0.2, 0.8]) # type: ignore
param = jnp.array([0.2, 0.8])

circuit = Circuit(2)
circuit.Rx(param[0], 0)
circuit.Rx(param[1], 1)
circuit.Rx(float(param[0]), 0)
circuit.Rx(float(param[1]), 1)
circuit.CX(1, 0)

_test_circuit(circuit, param, True)


def test_CZ() -> None:
param = jnp.array([0.25]) # type: ignore
param = jnp.array([0.25])

circuit = Circuit(2)
circuit.H(0)
circuit.Rz(param[0], 0)
circuit.Rz(float(param[0]), 0)
circuit.CZ(0, 1)

_test_circuit(circuit, param, True)


def test_CZ_qrev() -> None:
param = jnp.array([0.25]) # type: ignore
param = jnp.array([0.25])

circuit = Circuit(2)
circuit.H(0)
circuit.Rz(param[0], 0)
circuit.Rz(float(param[0]), 0)
circuit.CZ(1, 0)

_test_circuit(circuit, param, True)


def test_CX_Barrier_Rx() -> None:
param = jnp.array([0, 1 / jnp.pi]) # type: ignore
param = jnp.array([0, 1 / jnp.pi])

circuit = Circuit(3)
circuit.CX(0, 1)
circuit.add_barrier([0, 2])
circuit.Rx(param[0], 0)
circuit.Rx(param[1], 2)
circuit.Rx(float(param[0]), 0)
circuit.Rx(float(param[1]), 2)

_test_circuit(circuit, param)

Expand All @@ -199,17 +199,17 @@ def test_circuit1() -> None:

k = 0
for i in range(n_qubits):
circuit.Ry(param[k], i)
circuit.Ry(float(param[k]), i)
k += 1

for _ in range(depth):
for i in range(0, n_qubits - 1, 2):
circuit.CX(i, i + 1)
for i in range(1, n_qubits - 1, 2):
circuit.CX(i, i + 1)
circuit.add_barrier(range(0, n_qubits))
circuit.add_barrier(list(range(0, n_qubits)))
for i in range(n_qubits):
circuit.Ry(param[k], i)
circuit.Ry(float(param[k]), i)
k += 1

_test_circuit(circuit, param)
Expand All @@ -227,21 +227,21 @@ def test_circuit2() -> None:
for i in range(n_qubits):
circuit.H(i)
for i in range(n_qubits):
circuit.Rz(param[k], i)
circuit.Rz(float(param[k]), i)
k += 1
for i in range(n_qubits):
circuit.Rx(param[k], i)
circuit.Rx(float(param[k]), i)
k += 1

for _ in range(depth):
for i in range(0, n_qubits - 1):
circuit.CZ(i, i + 1)
circuit.add_barrier(range(0, n_qubits))
circuit.add_barrier(list(range(0, n_qubits)))
for i in range(n_qubits):
circuit.Rz(param[k], i)
circuit.Rz(float(param[k]), i)
k += 1
for i in range(n_qubits):
circuit.Rx(param[k], i)
circuit.Rx(float(param[k]), i)
k += 1

_test_circuit(circuit, param)
Expand All @@ -256,7 +256,7 @@ def test_HH() -> None:
st1 = apply_circuit()
st2 = apply_circuit(st1)

all_zeros_sv = jnp.array(jnp.arange(st2.size) == 0, dtype=int) # type: ignore
all_zeros_sv = jnp.array(jnp.arange(st2.size) == 0, dtype=int)

assert jnp.all(jnp.abs(st2.flatten() - all_zeros_sv) < 1e-5)

Expand Down
16 changes: 8 additions & 8 deletions tests/test_tket_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

from typing import Sequence
import pytest
from sympy import Symbol # type: ignore
from sympy import Symbol
from jax import numpy as jnp, jit, grad, random

from pytket.circuit import Circuit, OpType # type: ignore
from pytket.circuit import Circuit, OpType
from pytket.extensions.qujax import (
tk_to_qujax,
tk_to_qujax_args,
Expand Down Expand Up @@ -49,17 +49,17 @@ def _test_circuit(

test_dt = apply_circuit_dt()
n_qubits = test_dt.ndim // 2
test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) # type: ignore
test_jit_dm_diag = jnp.diag( # type: ignore
test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits))
test_jit_dm_diag = jnp.diag(
jit_apply_circuit_dt().reshape(2**n_qubits, 2**n_qubits)
)
else:
test_sv = apply_circuit(params).flatten()
test_jit_sv = jit_apply_circuit(params).flatten()
test_dt = apply_circuit_dt(params)
n_qubits = test_dt.ndim // 2
test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) # type: ignore
test_jit_dm_diag = jnp.diag( # type: ignore
test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits))
test_jit_dm_diag = jnp.diag(
jit_apply_circuit_dt(params).reshape(2**n_qubits, 2**n_qubits)
)

Expand Down Expand Up @@ -222,7 +222,7 @@ def test_circuit1() -> None:
circuit.CX(i, i + 1)
for i in range(1, n_qubits - 1, 2):
circuit.CX(i, i + 1)
circuit.add_barrier(range(0, n_qubits))
circuit.add_barrier(list(range(0, n_qubits)))
for i in range(n_qubits):
circuit.Ry(symbols[k], i)
k += 1
Expand All @@ -248,7 +248,7 @@ def test_circuit2() -> None:
for _ in range(depth):
for i in range(0, n_qubits - 1):
circuit.CZ(i, i + 1)
circuit.add_barrier(range(0, n_qubits))
circuit.add_barrier(list(range(0, n_qubits)))
for i in range(n_qubits):
circuit.Rz(symbols[k], i)
k += 1
Expand Down

0 comments on commit bb67a34

Please sign in to comment.