diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index f921029..902bfbb 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -27,7 +27,7 @@ jobs: os: ['ubuntu-22.04', 'macos-12'] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: '0' - run: git fetch --depth=1 origin +refs/tags/*:refs/tags/* +refs/heads/*:refs/remotes/origin/* @@ -107,7 +107,7 @@ jobs: needs: publish_to_pypi runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: '0' - name: Set up Python 3.10 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 3cbfe03..b678d3f 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -15,7 +15,7 @@ jobs: name: build docs runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.9 uses: actions/setup-python@v4 with: diff --git a/.github/workflows/docs/conf.py b/.github/workflows/docs/conf.py index 2a6b938..cbea5b1 100644 --- a/.github/workflows/docs/conf.py +++ b/.github/workflows/docs/conf.py @@ -116,7 +116,6 @@ def correct_signature( signature: str, return_annotation: str, ) -> (str, str): - new_signature = signature new_return_annotation = return_annotation for k, v in app.config.custom_internal_mapping.items(): diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 41835f1..55055f4 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.x uses: actions/setup-python@v4 with: diff --git a/_metadata.py b/_metadata.py index 8c3f8db..d5c7894 100644 --- a/_metadata.py +++ b/_metadata.py @@ -1,2 +1,2 @@ -__extension_version__ = "0.13.0" +__extension_version__ = "0.14.0" __extension_name__ = "pytket-qujax" diff --git a/docs/changelog.rst b/docs/changelog.rst index f587f43..dd9fb9e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,11 @@ Changelog ~~~~~~~~~ +0.14.0 (October 2023) +--------------------- + +* Updated pytket version requirement to 1.21. + 0.13.0 (August 2023) -------------------- diff --git a/mypy.ini b/mypy.ini index 485c9c7..3daa6c7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -14,7 +14,7 @@ namespace_packages = True check_untyped_defs = True warn_redundant_casts = True -warn_unused_ignores = False +warn_unused_ignores = True warn_no_return = False warn_return_any = True warn_unreachable = True diff --git a/pytket/extensions/qujax/__init__.py b/pytket/extensions/qujax/__init__.py index 6a25d24..b0e613f 100644 --- a/pytket/extensions/qujax/__init__.py +++ b/pytket/extensions/qujax/__init__.py @@ -16,7 +16,7 @@ """ # _metadata.py is copied to the folder after installation. -from ._metadata import __extension_version__, __extension_name__ # type: ignore +from ._metadata import __extension_version__, __extension_name__ from .qujax_convert import ( tk_to_qujax, tk_to_qujax_args, diff --git a/pytket/extensions/qujax/qujax_convert.py b/pytket/extensions/qujax/qujax_convert.py index 8bbb555..346b186 100644 --- a/pytket/extensions/qujax/qujax_convert.py +++ b/pytket/extensions/qujax/qujax_convert.py @@ -21,9 +21,9 @@ import qujax # type: ignore from jax import numpy as jnp -from sympy import lambdify, Symbol # type: ignore +from sympy import lambdify, Symbol from pytket import Qubit, Circuit # type: ignore -from pytket._tket.circuit import Command # type: ignore +from pytket._tket.circuit import Command def _tk_qubits_to_inds(tk_qubits: Sequence[Qubit]) -> Tuple[int, ...]: @@ -96,11 +96,11 @@ def _symbolic_command_to_gate_and_param_inds( gate = gate_str else: if len(free_symbols) == 0: - gate = jnp.array(command.op.get_unitary()) # type: ignore + gate = jnp.array(command.op.get_unitary()) else: raise TypeError(f"Parameterised gate {gate_str} not found in qujax.gates") - param_inds = tuple(symbol_map[symbol] for symbol in free_symbols) # type: ignore + param_inds = tuple(symbol_map[symbol] for symbol in free_symbols) return gate, param_inds @@ -362,7 +362,7 @@ def qujax_args_to_tk( if param is None: n_params = max([max(p) + 1 if len(p) > 0 else 0 for p in param_inds_seq]) - param = jnp.zeros(n_params) # type: ignore + param = jnp.zeros(n_params) param_inds_seq = [jnp.array(p, dtype="int32") for p in param_inds_seq] # type: ignore diff --git a/setup.py b/setup.py index 037a9eb..a1d366d 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ packages=find_namespace_packages(include=["pytket.*"]), include_package_data=True, install_requires=[ - "pytket ~= 1.18", + "pytket ~= 1.21", "qujax ~= 1.0", ], classifiers=[ diff --git a/tests/test_tket.py b/tests/test_tket.py index 17b3d7a..6d63ec4 100644 --- a/tests/test_tket.py +++ b/tests/test_tket.py @@ -13,13 +13,13 @@ # limitations under the License. from typing import Union, Any -from jax import numpy as jnp, jit, grad, random # type: ignore +from jax import numpy as jnp, jit, grad, random import qujax # type: ignore import pytest -from pytket.circuit import Circuit, Qubit # type: ignore -from pytket.pauli import Pauli, QubitPauliString # type: ignore -from pytket.utils import QubitPauliOperator # type: ignore +from pytket.circuit import Circuit, Qubit +from pytket.pauli import Pauli, QubitPauliString +from pytket.utils import QubitPauliOperator from pytket.extensions.qujax import ( tk_to_qujax, tk_to_qujax_args, @@ -46,8 +46,8 @@ def _test_circuit( test_dt = apply_circuit_dt() n_qubits = test_dt.ndim // 2 - test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) # type: ignore - test_jit_dm_diag = jnp.diag( # type: ignore + test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) + test_jit_dm_diag = jnp.diag( jit_apply_circuit_dt().reshape(2**n_qubits, 2**n_qubits) ) else: @@ -55,8 +55,8 @@ def _test_circuit( test_jit_sv = jit_apply_circuit(param).flatten() test_dt = apply_circuit_dt(param) n_qubits = test_dt.ndim // 2 - test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) # type: ignore - test_jit_dm_diag = jnp.diag( # type: ignore + test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) + test_jit_dm_diag = jnp.diag( jit_apply_circuit_dt(param).reshape(2**n_qubits, 2**n_qubits) ) @@ -99,18 +99,18 @@ def test_H() -> None: def test_CX() -> None: - param = jnp.array([0.25]) # type: ignore + param = jnp.array([0.25]) circuit = Circuit(2) circuit.H(0) - circuit.Rz(param[0], 0) + circuit.Rz(float(param[0]), 0) circuit.CX(0, 1) _test_circuit(circuit, param, True) def test_CX_callable() -> None: - param = jnp.array([0.25]) # type: ignore + param = jnp.array([0.25]) def H() -> Any: return qujax.gates.H @@ -131,7 +131,7 @@ def CX() -> Any: circuit = Circuit(2) circuit.H(0) - circuit.Rz(param[0], 0) + circuit.Rz(float(param[0]), 0) circuit.CX(0, 1) true_sv = circuit.get_statevector() @@ -145,46 +145,46 @@ def CX() -> Any: def test_CX_qrev() -> None: - param = jnp.array([0.2, 0.8]) # type: ignore + param = jnp.array([0.2, 0.8]) circuit = Circuit(2) - circuit.Rx(param[0], 0) - circuit.Rx(param[1], 1) + circuit.Rx(float(param[0]), 0) + circuit.Rx(float(param[1]), 1) circuit.CX(1, 0) _test_circuit(circuit, param, True) def test_CZ() -> None: - param = jnp.array([0.25]) # type: ignore + param = jnp.array([0.25]) circuit = Circuit(2) circuit.H(0) - circuit.Rz(param[0], 0) + circuit.Rz(float(param[0]), 0) circuit.CZ(0, 1) _test_circuit(circuit, param, True) def test_CZ_qrev() -> None: - param = jnp.array([0.25]) # type: ignore + param = jnp.array([0.25]) circuit = Circuit(2) circuit.H(0) - circuit.Rz(param[0], 0) + circuit.Rz(float(param[0]), 0) circuit.CZ(1, 0) _test_circuit(circuit, param, True) def test_CX_Barrier_Rx() -> None: - param = jnp.array([0, 1 / jnp.pi]) # type: ignore + param = jnp.array([0, 1 / jnp.pi]) circuit = Circuit(3) circuit.CX(0, 1) circuit.add_barrier([0, 2]) - circuit.Rx(param[0], 0) - circuit.Rx(param[1], 2) + circuit.Rx(float(param[0]), 0) + circuit.Rx(float(param[1]), 2) _test_circuit(circuit, param) @@ -199,7 +199,7 @@ def test_circuit1() -> None: k = 0 for i in range(n_qubits): - circuit.Ry(param[k], i) + circuit.Ry(float(param[k]), i) k += 1 for _ in range(depth): @@ -207,9 +207,9 @@ def test_circuit1() -> None: circuit.CX(i, i + 1) for i in range(1, n_qubits - 1, 2): circuit.CX(i, i + 1) - circuit.add_barrier(range(0, n_qubits)) + circuit.add_barrier(list(range(0, n_qubits))) for i in range(n_qubits): - circuit.Ry(param[k], i) + circuit.Ry(float(param[k]), i) k += 1 _test_circuit(circuit, param) @@ -227,21 +227,21 @@ def test_circuit2() -> None: for i in range(n_qubits): circuit.H(i) for i in range(n_qubits): - circuit.Rz(param[k], i) + circuit.Rz(float(param[k]), i) k += 1 for i in range(n_qubits): - circuit.Rx(param[k], i) + circuit.Rx(float(param[k]), i) k += 1 for _ in range(depth): for i in range(0, n_qubits - 1): circuit.CZ(i, i + 1) - circuit.add_barrier(range(0, n_qubits)) + circuit.add_barrier(list(range(0, n_qubits))) for i in range(n_qubits): - circuit.Rz(param[k], i) + circuit.Rz(float(param[k]), i) k += 1 for i in range(n_qubits): - circuit.Rx(param[k], i) + circuit.Rx(float(param[k]), i) k += 1 _test_circuit(circuit, param) @@ -256,7 +256,7 @@ def test_HH() -> None: st1 = apply_circuit() st2 = apply_circuit(st1) - all_zeros_sv = jnp.array(jnp.arange(st2.size) == 0, dtype=int) # type: ignore + all_zeros_sv = jnp.array(jnp.arange(st2.size) == 0, dtype=int) assert jnp.all(jnp.abs(st2.flatten() - all_zeros_sv) < 1e-5) diff --git a/tests/test_tket_symbolic.py b/tests/test_tket_symbolic.py index 76933cb..580af79 100644 --- a/tests/test_tket_symbolic.py +++ b/tests/test_tket_symbolic.py @@ -14,10 +14,10 @@ from typing import Sequence import pytest -from sympy import Symbol # type: ignore +from sympy import Symbol from jax import numpy as jnp, jit, grad, random -from pytket.circuit import Circuit, OpType # type: ignore +from pytket.circuit import Circuit, OpType from pytket.extensions.qujax import ( tk_to_qujax, tk_to_qujax_args, @@ -49,8 +49,8 @@ def _test_circuit( test_dt = apply_circuit_dt() n_qubits = test_dt.ndim // 2 - test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) # type: ignore - test_jit_dm_diag = jnp.diag( # type: ignore + test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) + test_jit_dm_diag = jnp.diag( jit_apply_circuit_dt().reshape(2**n_qubits, 2**n_qubits) ) else: @@ -58,8 +58,8 @@ def _test_circuit( test_jit_sv = jit_apply_circuit(params).flatten() test_dt = apply_circuit_dt(params) n_qubits = test_dt.ndim // 2 - test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) # type: ignore - test_jit_dm_diag = jnp.diag( # type: ignore + test_dm_diag = jnp.diag(test_dt.reshape(2**n_qubits, 2**n_qubits)) + test_jit_dm_diag = jnp.diag( jit_apply_circuit_dt(params).reshape(2**n_qubits, 2**n_qubits) ) @@ -222,7 +222,7 @@ def test_circuit1() -> None: circuit.CX(i, i + 1) for i in range(1, n_qubits - 1, 2): circuit.CX(i, i + 1) - circuit.add_barrier(range(0, n_qubits)) + circuit.add_barrier(list(range(0, n_qubits))) for i in range(n_qubits): circuit.Ry(symbols[k], i) k += 1 @@ -248,7 +248,7 @@ def test_circuit2() -> None: for _ in range(depth): for i in range(0, n_qubits - 1): circuit.CZ(i, i + 1) - circuit.add_barrier(range(0, n_qubits)) + circuit.add_barrier(list(range(0, n_qubits))) for i in range(n_qubits): circuit.Rz(symbols[k], i) k += 1