diff --git a/.github/codecov.yml b/.github/codecov.yml index bfdc9877..0d964013 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -3,6 +3,7 @@ coverage: project: default: informational: true + patch: default: informational: true diff --git a/.github/workflows/check_py.yml b/.github/workflows/check_py.yml index 78bef732..90a90728 100644 --- a/.github/workflows/check_py.yml +++ b/.github/workflows/check_py.yml @@ -8,8 +8,12 @@ on: permissions: contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: - ruff_check: + ruff: runs-on: ubuntu-latest steps: @@ -19,7 +23,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.8" - name: Setup Ruff run: | @@ -31,7 +35,15 @@ jobs: pip install ruff ruff check - ruff_format: + - name: Run ruff isort + run: | + ruff check --select I --fix + + - name: Run ruff format + run: | + ruff format + + staticcheck: runs-on: ubuntu-latest steps: @@ -41,17 +53,16 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.8" - - name: Setup Ruff - run: | + - run: | python3 -m pip install -U pip - pip install ruff + pip install .[dev] - - name: Run ruff format + - name: Run mypy run: | - ruff format + mypy - - name: Run ruff isort + - name: Run pyright run: | - ruff check --select I --fix + pyright diff --git a/.github/workflows/check_rs.yml b/.github/workflows/check_rs.yml index f9926370..e04f8076 100644 --- a/.github/workflows/check_rs.yml +++ b/.github/workflows/check_rs.yml @@ -8,44 +8,30 @@ on: permissions: contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: - rustfmt: + lint_rs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - run: rustup update + - run: | + rustup update + # Use nightly for rustfmt + rustup toolchain install --allow-downgrade nightly + rustup component add --toolchain nightly rustfmt - uses: actions/setup-python@v5 with: - python-version: 3.8 - - # Necessary for nektos/act - # - uses: actions-rust-lang/setup-rust-toolchain@v1 - # with: - # components: rustfmt + python-version: "3.8" - name: Run rustfmt run: | - cargo fmt -- --check - - clippy: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - run: rustup update - - - uses: actions/setup-python@v5 - with: - python-version: 3.8 - - # Necessary for nektos/act - # - uses: actions-rust-lang/setup-rust-toolchain@v1 - # with: - # components: clippy + cargo +nightly fmt -- --check - name: Run clippy run: | diff --git a/.github/workflows/cov.yml b/.github/workflows/cov.yml index 6fb68ef4..7c62cf12 100644 --- a/.github/workflows/cov.yml +++ b/.github/workflows/cov.yml @@ -24,7 +24,7 @@ jobs: pip install pipenv # Use development build pipenv install -e .[dev] - pipenv run pytest --cov=./python --cov-report=xml + pipenv run pytest --cov-report=xml - uses: codecov/codecov-action@v4 with: diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml new file mode 100644 index 00000000..90572c26 --- /dev/null +++ b/.github/workflows/doc.yml @@ -0,0 +1,31 @@ +name: doc + +on: [pull_request] + +permissions: + contents: read + +jobs: + build_docs: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - run: rustup update + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - run: | + python3 -m pip install -U pip + pip install -e .[dev,doc] + + - name: Run examples + run: | + for file in examples/*.py; do + python3 ${file} + done + + - run: sphinx-build docs/source docs/build --fail-on-warning diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml deleted file mode 100644 index 5ccb030c..00000000 --- a/.github/workflows/mypy.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: mypy - -on: - pull_request: - branches: ["master"] - workflow_dispatch: - -permissions: - contents: read - -jobs: - typecheck: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - run: rustup update - - - uses: actions/setup-python@v5 - with: - python-version: 3.8 - - # Necessary for nektos/act - # - uses: actions-rust-lang/setup-rust-toolchain@v1 - - - name: Run mypy - run: | - python3 -m pip install -U pip - pip install mypy - pip install . - mypy python/ diff --git a/.github/workflows/test_py.yml b/.github/workflows/test_py.yml index 8dfb81f7..f4807f81 100644 --- a/.github/workflows/test_py.yml +++ b/.github/workflows/test_py.yml @@ -8,13 +8,17 @@ on: permissions: contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: pytest: strategy: fail-fast: false matrix: os: ["ubuntu-latest", "windows-latest", "macos-latest"] - python: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] name: "py${{ matrix.python }}_${{ matrix.os }}" runs-on: ${{ matrix.os }} @@ -28,9 +32,6 @@ jobs: with: python-version: ${{ matrix.python }} - # Necessary for nektos/act - # - uses: actions-rust-lang/setup-rust-toolchain@v1 - - name: Run Python tests # MEMO: DO NOT use `pip install .` # We need to test with debug build diff --git a/.github/workflows/test_rs.yml b/.github/workflows/test_rs.yml index 4bcd0102..79b173c2 100644 --- a/.github/workflows/test_rs.yml +++ b/.github/workflows/test_rs.yml @@ -8,6 +8,10 @@ on: permissions: contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: cargo_test: runs-on: ubuntu-latest @@ -19,11 +23,28 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: 3.8 - - # Necessary for nektos/act - # - uses: actions-rust-lang/setup-rust-toolchain@v1 + python-version: "3.8" - name: Run Rust tests run: | cargo test + + cargo_miri: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - run: | + rustup update + # Use nightly for miri + rustup toolchain install --allow-downgrade nightly + rustup component add --toolchain nightly miri + + - uses: actions/setup-python@v5 + with: + python-version: "3.8" + + - name: Run Rust tests with miri + run: | + cargo +nightly miri test diff --git a/.github/workflows/wheel.yml b/.github/workflows/wheel.yml index 27d8fe8c..9cbc3e56 100644 --- a/.github/workflows/wheel.yml +++ b/.github/workflows/wheel.yml @@ -5,6 +5,13 @@ on: branches: ["master"] workflow_dispatch: +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: windows: runs-on: windows-latest @@ -21,7 +28,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.8" architecture: ${{ matrix.target }} - run: | @@ -58,7 +65,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.8" architecture: x64 - run: | @@ -102,7 +109,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.8" architecture: x64 - run: | @@ -139,7 +146,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.8" architecture: x64 - run: | @@ -186,7 +193,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.8" - run: | python3 -m pip install -U pip @@ -232,7 +239,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.8" - run: | python3 -m pip install -U pip diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 00000000..e69de29b diff --git a/Cargo.toml b/Cargo.toml index 4eb24668..2b7605a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,13 +18,14 @@ crate-type = ["cdylib"] [dependencies] anyhow = "1" fixedbitset = "0.5" -hashbrown = "0.14" +hashbrown = "0.14" # Cannot use 0.15 for compatibility with pyo3 itertools = "0.13" pyo3 = { version = "0.22", features = ["abi3-py38", "hashbrown"] } +thiserror = "1" tracing = "0.1" [dev-dependencies] rand = "0.8" -rstest = "0.22" +rstest = "0.23" rstest_reuse = "0.7" test-log = { version = "0.2", features = ["trace"] } diff --git a/Pipfile b/Pipfile index d320eee1..1f790d26 100644 --- a/Pipfile +++ b/Pipfile @@ -4,7 +4,7 @@ verify_ssl = true name = "pypi" [packages] -fastflow = {extras = ["dev"], file = ".", editable = true} +fastflow = {extras = ["dev", "doc"], file = ".", editable = true} [dev-packages] maturin = "<2,>=1" diff --git a/README.md b/README.md index cce9927b..737e98ad 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,13 @@ # fastflow +[![License](https://img.shields.io/github/license/TeamGraphix/fastflow)](https://github.com/TeamGraphix/fastflow?tab=Apache-2.0-1-ov-file) +[![test_rs](https://github.com/TeamGraphix/fastflow/actions/workflows/test_rs.yml/badge.svg)](https://github.com/TeamGraphix/fastflow/actions/workflows/test_rs.yml) +[![test_py](https://github.com/TeamGraphix/fastflow/actions/workflows/test_py.yml/badge.svg)](https://github.com/TeamGraphix/fastflow/actions/workflows/test_py.yml) +[![codecov](https://codecov.io/github/TeamGraphix/fastflow/graph/badge.svg?token=JPLJWWPNF4)](https://codecov.io/github/TeamGraphix/fastflow) +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) + Rust binding of generalized and pauli flow finding algorithms. + +## License + +This project is licensed under the Apache-2.0 License. diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..d0c3cbf1 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000..747ffb7b --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 00000000..db8b56d1 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,53 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +from __future__ import annotations + +import sys +from pathlib import Path + +ROOT_DIR = Path("../../python/fastflow").resolve() + +sys.path.insert(0, str(ROOT_DIR)) + +project = "fastflow" +copyright = "2024, TeamGraphix" # noqa: A001 +author = "S.S." +release = "0.1.0" + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinxcontrib.bibtex", +] + +templates_path = ["_templates"] +# exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "furo" +# html_static_path = [] + +intersphinx_mapping = { + "networkx": ("https://networkx.org/documentation/stable", None), + "python": ("https://docs.python.org/3", None), +} + +default_role = "any" + +autodoc_typehints = "description" +autodoc_typehints_description_target = "documented" +bibtex_bibfiles = ["ref.bib"] diff --git a/docs/source/fastflow.rst b/docs/source/fastflow.rst new file mode 100644 index 00000000..2c3fb350 --- /dev/null +++ b/docs/source/fastflow.rst @@ -0,0 +1,75 @@ +fastflow package +================ + +fastflow.common module +---------------------- + +.. automodule:: fastflow.common + :members: + :exclude-members: Plane, PPlane, V, P + + .. autoclass:: fastflow.common.Plane + + .. py:attribute:: XY + + Arbitrary-angle measurement on the XY plane. + + .. py:attribute:: YZ + + Arbitrary-angle measurement on the YZ plane. + + .. py:attribute:: XZ + + Arbitrary-angle measurement on the XZ plane. + + .. autoclass:: fastflow.common.PPlane + + .. py:attribute:: XY + + Arbitrary-angle measurement on the XY plane. + + .. py:attribute:: YZ + + Arbitrary-angle measurement on the YZ plane. + + .. py:attribute:: XZ + + Arbitrary-angle measurement on the XZ plane. + + .. py:attribute:: X + + Pauli-X measurement. + + .. py:attribute:: Y + + Pauli-Y measurement. + + .. py:attribute:: Z + + Pauli-Z measurement. + +fastflow.flow module +-------------------- + +.. automodule:: fastflow.flow + :members: + :undoc-members: + +fastflow.gflow module +--------------------- + +.. automodule:: fastflow.gflow + :members: + :undoc-members: + +fastflow.pflow module +--------------------- + +.. automodule:: fastflow.pflow + :members: + :undoc-members: + +References +---------- + +.. footbibliography:: diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 00000000..d361aa0c --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,8 @@ +fastflow documentation +====================== + +.. toctree:: + :maxdepth: 2 + :caption: API Reference: + + fastflow diff --git a/docs/source/ref.bib b/docs/source/ref.bib new file mode 100644 index 00000000..ab42af2a --- /dev/null +++ b/docs/source/ref.bib @@ -0,0 +1,33 @@ +@article{Backens2021, + author = {Backens, Miriam and Miller-Bakewell, Hector and de Felice, Giovanni and Lobski, Leo and van de Wetering, John}, + doi = {10.22331/q-2021-03-25-421}, + issn = {2521-327X}, + journal = {{Quantum}}, + month = mar, + pages = {421}, + publisher = {{Verein zur F{\"{o}}rderung des Open Access Publizierens in den Quantenwissenschaften}}, + title = {There and back again: {A} circuit extraction tale}, + url = {https://doi.org/10.22331/q-2021-03-25-421}, + volume = {5}, + year = {2021} +} +@inproceedings{Mhalla2008, + address = {Berlin, Heidelberg}, + author = {Mhalla, Mehdi and Perdrix, Simon}, + booktitle = {Automata, Languages and Programming}, + isbn = {978-3-540-70575-8}, + pages = {857--868}, + publisher = {Springer Berlin Heidelberg}, + title = {Finding Optimal Flows Efficiently}, + year = {2008} +} +@article{Simons2021, + author = {Simmons, Will}, + doi = {10.4204/EPTCS.343.4}, + journal = {Electronic Proceedings in Theoretical Computer Science}, + month = {09}, + pages = {50-101}, + title = {Relating Measurement Patterns to Circuits via Pauli Flow}, + volume = {343}, + year = {2021} +} diff --git a/examples/flow.py b/examples/flow.py new file mode 100644 index 00000000..b4a7296f --- /dev/null +++ b/examples/flow.py @@ -0,0 +1,40 @@ +"""Example code for finding causal flow.""" + +# %% + +from __future__ import annotations + +import networkx as nx +from fastflow import flow + +g: nx.Graph[int] + +# %% + +# 1 - 3 - 5 +# | +# 2 - 4 - 6 +g = nx.Graph([(1, 3), (2, 4), (3, 5), (4, 6)]) +iset = {1, 2} +oset = {5, 6} + +result = flow.find(g, iset, oset) + +# Found +assert result is not None + +# %% + +# 1 - 3 +# \ / +# X +# / \ +# 2 - 4 +g = nx.Graph([(1, 3), (1, 4), (2, 3), (2, 4)]) +iset = {1, 2} +oset = {3, 4} + +# Not found +result = flow.find(g, iset, oset) + +assert result is None diff --git a/examples/gflow.py b/examples/gflow.py new file mode 100644 index 00000000..7c63aad7 --- /dev/null +++ b/examples/gflow.py @@ -0,0 +1,45 @@ +"""Example code for finding generalized flow.""" + +# %% + +from __future__ import annotations + +import networkx as nx +from fastflow import gflow +from fastflow.common import Plane + +g: nx.Graph[int] + +# %% + +# 0 - 1 +# /| | +# 4 | | +# \| | +# 2 - 5 - 3 +g = nx.Graph([(0, 1), (0, 2), (0, 4), (1, 5), (2, 4), (2, 5), (3, 5)]) +iset = {0, 1} +oset = {4, 5} +planes = {0: Plane.XY, 1: Plane.XY, 2: Plane.XZ, 3: Plane.YZ} + +result = gflow.find(g, iset, oset, planes) + +# Found +assert result is not None + +# %% + +# 1 - 3 +# \ / +# X +# / \ +# 2 - 4 +g = nx.Graph([(1, 3), (1, 4), (2, 3), (2, 4)]) +iset = {1, 2} +oset = {3, 4} +# Omitting planes (all Plane.XY) + +result = gflow.find(g, iset, oset) + +# Not found +assert result is None diff --git a/examples/pflow.py b/examples/pflow.py new file mode 100644 index 00000000..3d3be1bf --- /dev/null +++ b/examples/pflow.py @@ -0,0 +1,44 @@ +"""Example code for finding Pauli flow.""" + +# %% + +from __future__ import annotations + +import networkx as nx +from fastflow import pflow +from fastflow.common import PPlane + +g: nx.Graph[int] + +# %% + +# 1 2 3 +# | / | +# 0 - - - 4 +g = nx.Graph([(0, 1), (0, 2), (0, 4), (3, 4)]) +iset = {0} +oset = {4} +pplanes = {0: PPlane.Z, 1: PPlane.Z, 2: PPlane.Y, 3: PPlane.Y} + +result = pflow.find(g, iset, oset, pplanes) + +# Found +assert result is not None + +# %% + +# 1 - 3 +# \ / +# X +# / \ +# 2 - 4 +g = nx.Graph([(1, 3), (1, 4), (2, 3), (2, 4)]) +iset = {1, 2} +oset = {3, 4} +# Omitting pplanes (all PPlane.XY) + +# NOTE: This results in warning (use gflow.find if pplanes has no Pauli measurements) +result = pflow.find(g, iset, oset) + +# Not found +assert result is None diff --git a/pyproject.toml b/pyproject.toml index 60b7017c..e563c96c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ name = "fastflow" version = "0.1.0" description = "Rust binding of generalized and pauli flow finding algorithms." license = { file = "LICENSE" } +readme = "README.md" authors = [ { name = "S.S.", email = "66886825+EarlMilktea@users.noreply.github.com" }, ] @@ -16,7 +17,6 @@ maintainers = [ { name = "thierry-martinez", email = "thierry.martinez@inria.fr" }, { name = "Shinichi Sunami", email = "shinichi.sunami@gmail.com" }, ] -readme = "README.md" classifiers = [ "Development Status :: 3 - Alpha", "Environment :: Console", @@ -32,23 +32,35 @@ classifiers = [ "Topic :: Scientific/Engineering :: Physics", "Typing :: Typed", ] -requires-python = ">=3.8,<3.13" -dependencies = ["networkx", "types-networkx"] +requires-python = ">=3.8,<3.14" +dependencies = ["networkx"] [project.optional-dependencies] -dev = ["mypy", "pytest", "pytest-cov", "ruff"] +dev = ["mypy", "pyright", "pytest", "pytest-cov", "ruff", "types-networkx"] +doc = ["furo", "sphinxcontrib-bibtex", "sphinx"] [tool.maturin] +exclude = [ + "Pipfile", + "docs/**", + "rustfmt.toml", + "tests/**", + ".github/**", + ".gitignore", + ".nojekyll", +] features = ["pyo3/extension-module"] -python-source = "python" module-name = "fastflow._impl" +python-source = "python" [tool.mypy] python_version = "3.8" +strict = true +files = ["docs/source/conf.py", "python", "tests"] [tool.ruff] -line-length = 120 extend-include = ["*.ipynb"] +line-length = 120 [tool.ruff.format] docstring-code-format = true @@ -78,19 +90,40 @@ ignore = [ "DJ", "PD", - # Manually disabled - "ANN10", # `self`/`cls` not annotated - "CPY", # copyright missing - "D105", # undocumented magic method - "ERA", # commented-out code - "FBT", # boolean-trap - "FIX", # fixme + # Manually-disabled (lints) + "CPY", + "ERA", + "FBT", + "FIX", + + # Manually-disabled (rules) + "ANN10", # missing-type-self/cls (deprecated) + "D105", # undocumented-magic-method ] +[tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] + [tool.ruff.lint.pydocstyle] convention = "numpy" [tool.ruff.lint.per-file-ignores] +"docs/**/*.py" = [ + "D1", # undocumented-XXX + "INP", +] +"examples/**/*.py" = [ + "INP", + "S101", # assert +] "tests/*.py" = [ - "S101", # `assert` detected + "D1", # undocumented-XXX + "PLR2004", # magic-value-comparison + "S101", # assert ] + +[tool.pytest.ini_options] +addopts = ["--cov=./python", "--cov-report=term"] + +[tool.coverage.report] +exclude_also = ["if TYPE_CHECKING:"] diff --git a/python/fastflow/_common.py b/python/fastflow/_common.py index e5e25e23..6bf1d183 100644 --- a/python/fastflow/_common.py +++ b/python/fastflow/_common.py @@ -1,21 +1,17 @@ -"""Private common functionalities for the fastflow package.""" +"""Private common functionalities.""" from __future__ import annotations -from collections.abc import Hashable, Mapping +import re +from collections.abc import Iterable, Mapping from collections.abc import Set as AbstractSet -from typing import Generic, TypeVar +from typing import Generic import networkx as nx -from fastflow.common import Plane, PPlane +from fastflow.common import P, V -# Vertex type -V = TypeVar("V", bound=Hashable) - - -# Plane-like -P = TypeVar("P", Plane, PPlane) +MSG_RE = re.compile(r"^([- a-z]*) \(((?:\d+, )*\d+)\)$") def check_graph(g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> None: @@ -26,7 +22,7 @@ def check_graph(g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> N TypeError If input types are incorrect. ValueError - If the graph is empty, has self-loops, or iset/oset are not subsets of the vertices. + If the graph is empty, not simple, or `iset`/`oset` is not a subset of nodes. """ if not isinstance(g, nx.Graph): msg = "g must be a networkx.Graph." @@ -46,30 +42,39 @@ def check_graph(g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> N raise ValueError(msg) vset = set(g.nodes) if not (iset <= vset): - msg = "iset must be a subset of the vertices." + msg = "iset must be a subset of the nodes." raise ValueError(msg) if not (oset <= vset): - msg = "oset must be a subset of the vertices." + msg = "oset must be a subset of the nodes." raise ValueError(msg) def check_planelike(vset: AbstractSet[V], oset: AbstractSet[V], plike: Mapping[V, P]) -> None: - r"""Check if measurement config. is valid. + r"""Check if measurement description is valid. + + Parameters + ---------- + vset : `collections.abc.Set` + All nodes. + oset : `collections.abc.Set` + Output nodes. + plike : `collections.abc.Mapping` + Measurement plane or Pauli index for each node in :math:`V \setminus O`. Raises ------ TypeError If input types are incorrect. ValueError - If plike is not a subset of the vertices, or measurement planes are not specified for all u in V\O. + If `plike` is not a subset of `vset`, or `plike` does not cover all :math:`V \setminus O`. """ if not isinstance(plike, Mapping): msg = "Measurement planes must be passed as a mapping." raise TypeError(msg) - if plike.keys() > vset: - msg = "Cannot find corresponding vertices in the graph." + if not (plike.keys() <= vset): + msg = "Cannot find corresponding nodes in the graph." raise ValueError(msg) - if plike.keys() < vset - oset: + if not (vset - oset <= plike.keys()): msg = "Measurement planes should be specified for all u in V\\O." raise ValueError(msg) @@ -81,7 +86,14 @@ class IndexMap(Generic[V]): __i2v: dict[int, V] def __init__(self, vset: AbstractSet[V]) -> None: - """Initialize the map from `vset`.""" + """Initialize the map from `vset`. + + Parameters + ---------- + vset : `collections.abc.Set` + Set of nodes. + Can be any hashable type. + """ self.__v2i = {v: i for i, v in enumerate(vset)} self.__i2v = {i: v for v, i in self.__v2i.items()} @@ -90,7 +102,8 @@ def encode(self, v: V) -> int: Returns ------- - Index of `v`. + `int` + Index of `v`. Raises ------ @@ -108,7 +121,7 @@ def encode_graph(self, g: nx.Graph[V]) -> list[set[int]]: Returns ------- - Input graph with vertices encoded to indices. + `g` with transformed nodes. """ n = len(g) g_: list[set[int]] = [set() for _ in range(n)] @@ -118,12 +131,7 @@ def encode_graph(self, g: nx.Graph[V]) -> list[set[int]]: return g_ def encode_set(self, vset: AbstractSet[V]) -> set[int]: - """Encode set. - - Returns - ------- - Transformed set. - """ + """Encode set.""" return {self.encode(v) for v in vset} def encode_dictkey(self, mapping: Mapping[V, P]) -> dict[int, P]: @@ -131,10 +139,47 @@ def encode_dictkey(self, mapping: Mapping[V, P]) -> dict[int, P]: Returns ------- - Dict with transformed keys. + `mapping` with transformed keys. """ return {self.encode(k): v for k, v in mapping.items()} + def encode_flow(self, f: Mapping[V, V]) -> dict[int, int]: + """Encode flow. + + Returns + ------- + `f` with both keys and values transformed. + """ + return {self.encode(i): self.encode(j) for i, j in f.items()} + + def encode_gflow(self, f: Mapping[V, AbstractSet[V]]) -> dict[int, set[int]]: + """Encode gflow. + + Returns + ------- + `f` with both keys and values transformed. + """ + return {self.encode(i): self.encode_set(si) for i, si in f.items()} + + def encode_layer(self, layer: Mapping[V, int]) -> list[int]: + """Encode layer. + + Returns + ------- + `layer` values transformed. + + Notes + ----- + `list` is used instead of `dict` here because no missing values are allowed here. + """ + # Use -1 as sentinel + layer_ = [-1 for _ in range(len(self.__v2i))] + for v, li in layer.items(): + layer_[self.encode(v)] = li + if any(li == -1 for li in layer_): + raise RuntimeError # pragma: no cover + return layer_ + def decode(self, i: int) -> V: """Decode the index. @@ -154,37 +199,55 @@ def decode(self, i: int) -> V: return v def decode_set(self, iset: AbstractSet[int]) -> set[V]: - """Decode set. - - Returns - ------- - Transformed set. - """ + """Decode set.""" return {self.decode(i) for i in iset} - def decode_flow(self, f_: dict[int, int]) -> dict[V, V]: + def decode_flow(self, f_: Mapping[int, int]) -> dict[V, V]: """Decode MBQC flow. Returns ------- - Transformed flow. + `f_` with both keys and values transformed. """ return {self.decode(i): self.decode(j) for i, j in f_.items()} - def decode_gflow(self, f_: dict[int, set[int]]) -> dict[V, set[V]]: + def decode_gflow(self, f_: Mapping[int, AbstractSet[int]]) -> dict[V, set[V]]: """Decode MBQC gflow. Returns ------- - Transformed gflow. + `f_` with both keys and values transformed. """ return {self.decode(i): self.decode_set(si) for i, si in f_.items()} - def decode_layer(self, layer_: list[int]) -> dict[V, int]: + def decode_layer(self, layer_: Iterable[int]) -> dict[V, int]: """Decode MBQC layer. Returns ------- - Transformed layer as dict. + `layer_` transformed. + + Notes + ----- + `list` (generalized as `Iterable`) is used instead of `dict` here because no missing values are allowed here. """ return {self.decode(i): li for i, li in enumerate(layer_)} + + def decode_errmsg(self, err: str) -> str: + """Decode error message.""" + m = MSG_RE.match(err) + if m is None: + msg = f"Cannot parse message: {err}." + raise ValueError(msg) + body: str = m.group(1) + body = body.capitalize() + + def _mapfunc(i: str) -> str: + return str(self.decode(int(i))) + + info = [_mapfunc(i) for i in m.group(2).split(", ")] + return f"{body} (check {', '.join(info)})." + + def decode_err(self, err: Exception) -> Exception: + """Decode error directly.""" + return type(err)(self.decode_errmsg(str(err))) diff --git a/python/fastflow/_impl/flow.pyi b/python/fastflow/_impl/flow.pyi index 8970cbc7..1fe37a58 100644 --- a/python/fastflow/_impl/flow.pyi +++ b/python/fastflow/_impl/flow.pyi @@ -1 +1,2 @@ def find(g: list[set[int]], iset: set[int], oset: set[int]) -> tuple[dict[int, int], list[int]] | None: ... +def verify(flow: tuple[dict[int, int], list[int]], g: list[set[int]], iset: set[int], oset: set[int]) -> None: ... diff --git a/python/fastflow/_impl/gflow.pyi b/python/fastflow/_impl/gflow.pyi index bbf96664..2770059c 100644 --- a/python/fastflow/_impl/gflow.pyi +++ b/python/fastflow/_impl/gflow.pyi @@ -6,3 +6,10 @@ class Plane: def find( g: list[set[int]], iset: set[int], oset: set[int], plane: dict[int, Plane] ) -> tuple[dict[int, set[int]], list[int]] | None: ... +def verify( + gflow: tuple[dict[int, set[int]], list[int]], + g: list[set[int]], + iset: set[int], + oset: set[int], + plane: dict[int, Plane], +) -> None: ... diff --git a/python/fastflow/_impl/pflow.pyi b/python/fastflow/_impl/pflow.pyi index 0348780b..8765895b 100644 --- a/python/fastflow/_impl/pflow.pyi +++ b/python/fastflow/_impl/pflow.pyi @@ -9,3 +9,10 @@ class PPlane: def find( g: list[set[int]], iset: set[int], oset: set[int], pplane: dict[int, PPlane] ) -> tuple[dict[int, set[int]], list[int]] | None: ... +def verify( + pflow: tuple[dict[int, set[int]], list[int]], + g: list[set[int]], + iset: set[int], + oset: set[int], + pplane: dict[int, PPlane], +) -> None: ... diff --git a/python/fastflow/common.py b/python/fastflow/common.py index 0527468d..085194f2 100644 --- a/python/fastflow/common.py +++ b/python/fastflow/common.py @@ -1,4 +1,4 @@ -"""Common functionalities for the fastflow package.""" +"""Common functionalities.""" from __future__ import annotations @@ -9,40 +9,32 @@ from fastflow._impl import gflow, pflow Plane = gflow.Plane + PPlane = pflow.PPlane -_V = TypeVar("_V", bound=Hashable) +V = TypeVar("V", bound=Hashable) #: Node type. + + +P = TypeVar("P", Plane, PPlane) #: Measurement plane or Pauli index. @dataclasses.dataclass(frozen=True) -class FlowResult(Generic[_V]): - """Causal flow [Danos and Kashefi, Phys. Rev. A 74, 052310] of an open graph. - - Attributes - ---------- - f : `dict[V, V]` - Flow function. - layer : `dict[V, int]` - Layer of each vertex representing the partial order. - (u -> v iff `layer[u] > layer[v]`). - """ +class FlowResult(Generic[V]): + r"""Causal flow of an open graph.""" - f: dict[_V, _V] - layer: dict[_V, int] + f: dict[V, V] + """Flow map as a dictionary. :math:`f(u)` is stored in :py:obj:`f[u]`.""" + layer: dict[V, int] + r"""Layer of each node representing the partial order. :math:`layer(u) > layer(v)` implies :math:`u \prec v`. + """ @dataclasses.dataclass(frozen=True) -class GFlowResult(Generic[_V]): - """Generalized flow [Browne et al., NJP 9, 250 (2007)] of an open graph. - - Attributes - ---------- - f : `dict[V, set[V]]` - Gflow function. - layer : `dict[V, int]` - Layer of each vertex representing the partial order. - (u -> v iff `layer[u] > layer[v]`). - """ +class GFlowResult(Generic[V]): + r"""Generalized flow of an open graph.""" - f: dict[_V, set[_V]] - layer: dict[_V, int] + f: dict[V, set[V]] + """Generalized flow map as a dictionary. :math:`f(u)` is stored in :py:obj:`f[u]`.""" + layer: dict[V, int] + r"""Layer of each node representing the partial order. :math:`layer(u) > layer(v)` implies :math:`u \prec v`. + """ diff --git a/python/fastflow/flow.py b/python/fastflow/flow.py index 32a39c97..01c3bab9 100644 --- a/python/fastflow/flow.py +++ b/python/fastflow/flow.py @@ -1,8 +1,7 @@ """Maximally-delayed flow algorithm. -For given undirected graph, input nodes, and output nodes, compute the causal flow having \ -the minimum number of layers. -See [Mhalla and Perdrix, Proc. of 35th ICALP, 857 (2008)] for more details. +This module provides functions to compute and verify maximally-delayed causal flow. +See :footcite:t:`Mhalla2008` for details. """ from __future__ import annotations @@ -10,9 +9,9 @@ from typing import TYPE_CHECKING from fastflow import _common -from fastflow._common import IndexMap, V -from fastflow._impl import flow -from fastflow.common import FlowResult +from fastflow._common import IndexMap +from fastflow._impl import flow as flow_bind +from fastflow.common import FlowResult, V if TYPE_CHECKING: from collections.abc import Set as AbstractSet @@ -21,24 +20,23 @@ def find(g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> FlowResult[V] | None: - """Compute the maximally-delayed causal flow, if any. + """Compute causal flow. + + If it returns a flow, it is guaranteed to be maximally-delayed, i.e., the number of layers is minimized. Parameters ---------- - g : `nx.Graph[V]` - Undirected graph representing MBQC pattern. - Cannot have self-loops. - iset : `AbstractSet[V]` + g : `networkx.Graph` + Simple graph representing MBQC pattern. + iset : `collections.abc.Set` Input nodes. - Must be a subset of `g.nodes`. - oset : `AbstractSet[V]` + oset : `collections.abc.Set` Output nodes. - Must be a subset of `g.nodes`. Returns ------- - If a flow exists, return a `FlowResult[V]` object. - Otherwise, return `None`. + `FlowResult` or `None` + Return the flow if any, otherwise `None`. """ _common.check_graph(g, iset, oset) vset = g.nodes @@ -46,9 +44,42 @@ def find(g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> FlowResu g_ = codec.encode_graph(g) iset_ = codec.encode_set(iset) oset_ = codec.encode_set(oset) - if ret_ := flow.find(g_, iset_, oset_): + if ret_ := flow_bind.find(g_, iset_, oset_): f_, layer_ = ret_ f = codec.decode_flow(f_) layer = codec.decode_layer(layer_) return FlowResult(f, layer) return None + + +def verify(flow: FlowResult[V], g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> None: + """Verify maximally-delayed causal flow. + + Parameters + ---------- + flow : `FlowResult` + Flow to verify. + g : `networkx.Graph` + Simple graph representing MBQC pattern. + iset : `collections.abc.Set` + Input nodes. + oset : `collections.abc.Set` + Output nodes. + + Raises + ------ + ValueError + If the graph is invalid or verification fails. + """ + _common.check_graph(g, iset, oset) + vset = g.nodes + codec = IndexMap(vset) + g_ = codec.encode_graph(g) + iset_ = codec.encode_set(iset) + oset_ = codec.encode_set(oset) + f_ = codec.encode_flow(flow.f) + layer_ = codec.encode_layer(flow.layer) + try: + flow_bind.verify((f_, layer_), g_, iset_, oset_) + except ValueError as e: + raise codec.decode_err(e) from None diff --git a/python/fastflow/gflow.py b/python/fastflow/gflow.py index 1a42246d..8c0f94a0 100644 --- a/python/fastflow/gflow.py +++ b/python/fastflow/gflow.py @@ -1,21 +1,21 @@ """Maximally-delayed gflow algorithm. -For given undirected graph, input nodes, and output nodes, compute the generalized flow having \ -the minimum number of layers. -See [Mhalla and Perdrix, Proc. of 35th ICALP, 857 (2008)] and [Backens et al., Quantum 5, 421 (2021)] for more details. +This module provides functions to compute and verify maximally-delayed generalized flow. +See :footcite:t:`Mhalla2008` and :footcite:t:`Backens2021` for details. """ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Mapping +from typing import TYPE_CHECKING from fastflow import _common -from fastflow._common import IndexMap, V -from fastflow._impl import gflow -from fastflow.common import GFlowResult, Plane +from fastflow._common import IndexMap +from fastflow._impl import gflow as gflow_bind +from fastflow.common import GFlowResult, Plane, V if TYPE_CHECKING: + from collections.abc import Mapping from collections.abc import Set as AbstractSet import networkx as nx @@ -27,44 +27,90 @@ def find( oset: AbstractSet[V], plane: Mapping[V, Plane] | None = None, ) -> GFlowResult[V] | None: - r"""Compute the maximally-delayed generalized flow, if any. + r"""Compute generalized flow. + + If it returns a gflow, it is guaranteed to be maximally-delayed, i.e., the number of layers is minimized. Parameters ---------- - g : `nx.Graph[V]` - Undirected graph representing MBQC pattern. - Cannot have self-loops. - iset : `AbstractSet[V]` + g : `networkx.Graph` + Simple graph representing MBQC pattern. + iset : `collections.abc.Set` Input nodes. - Must be a subset of `g.nodes`. - oset : `AbstractSet[V]` + oset : `collections.abc.Set` Output nodes. - Must be a subset of `g.nodes`. - plane : `Mapping[V, Plane] | None`, optional - Measurement planes of each vertex in V\O. - If `None`, defaults to all `Plane.XY`. + plane : `collections.abc.Mapping` + Measurement plane for each node in :math:`V \setminus O`. + Defaults to `Plane.XY`. Returns ------- - If a gflow exists, return a `GFlowResult[V]` object. - Otherwise, return `None`. + `GFlowResult` or `None` + Return the gflow if any, otherwise `None`. """ _common.check_graph(g, iset, oset) vset = g.nodes if plane is None: plane = dict.fromkeys(vset - oset, Plane.XY) _common.check_planelike(vset, oset, plane) + ignore = plane.keys() & oset + if len(ignore) != 0: + msg = "Ignoring plane[v] where v in oset." + warnings.warn(msg, stacklevel=1) + plane = {k: v for k, v in plane.items() if k not in ignore} codec = IndexMap(vset) g_ = codec.encode_graph(g) iset_ = codec.encode_set(iset) oset_ = codec.encode_set(oset) plane_ = codec.encode_dictkey(plane) - if len(plane_) != len(plane): - msg = "Ignoring plane[v] where v in oset." - warnings.warn(msg, stacklevel=1) - if ret_ := gflow.find(g_, iset_, oset_, plane_): + if ret_ := gflow_bind.find(g_, iset_, oset_, plane_): f_, layer_ = ret_ f = codec.decode_gflow(f_) layer = codec.decode_layer(layer_) return GFlowResult(f, layer) return None + + +def verify( + gflow: GFlowResult[V], + g: nx.Graph[V], + iset: AbstractSet[V], + oset: AbstractSet[V], + plane: Mapping[V, Plane] | None = None, +) -> None: + r"""Verify maximally-delayed generalized flow. + + Parameters + ---------- + gflow : `GFlowResult` + Generalized flow to verify. + g : `networkx.Graph` + Simple graph representing MBQC pattern. + iset : `collections.abc.Set` + Input nodes. + oset : `collections.abc.Set` + Output nodes. + plane : `collections.abc.Mapping` + Measurement plane for each node in :math:`V \setminus O`. + Defaults to `Plane.XY`. + + Raises + ------ + ValueError + If the graph is invalid or verification fails. + """ + _common.check_graph(g, iset, oset) + vset = g.nodes + if plane is None: + plane = dict.fromkeys(vset - oset, Plane.XY) + codec = IndexMap(vset) + g_ = codec.encode_graph(g) + iset_ = codec.encode_set(iset) + oset_ = codec.encode_set(oset) + plane_ = codec.encode_dictkey(plane) + f_ = codec.encode_gflow(gflow.f) + layer_ = codec.encode_layer(gflow.layer) + try: + gflow_bind.verify((f_, layer_), g_, iset_, oset_, plane_) + except ValueError as e: + raise codec.decode_err(e) from None diff --git a/python/fastflow/pflow.py b/python/fastflow/pflow.py index 84d3146f..a9c3598c 100644 --- a/python/fastflow/pflow.py +++ b/python/fastflow/pflow.py @@ -1,16 +1,21 @@ -"""Maximally-delayed Pauli flow algorithm.""" +"""Maximally-delayed Pauli flow algorithm. + +This module provides functions to compute and verify maximally-delayed Pauli flow. +See :footcite:t:`Simons2021` for details. +""" from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Mapping +from typing import TYPE_CHECKING from fastflow import _common -from fastflow._common import IndexMap, V -from fastflow._impl import pflow -from fastflow.common import GFlowResult, PPlane +from fastflow._common import IndexMap +from fastflow._impl import pflow as pflow_bind +from fastflow.common import GFlowResult, PPlane, V if TYPE_CHECKING: + from collections.abc import Mapping from collections.abc import Set as AbstractSet import networkx as nx @@ -22,32 +27,30 @@ def find( oset: AbstractSet[V], pplane: Mapping[V, PPlane] | None = None, ) -> GFlowResult[V] | None: - r"""Compute the maximally-delayed Pauli flow, if any. + r"""Compute Pauli flow. + + If it returns a Pauli flow, it is guaranteed to be maximally-delayed, i.e., the number of layers is minimized. Parameters ---------- - g : `nx.Graph[V]` - Undirected graph representing MBQC pattern. - Cannot have self-loops. - iset : `AbstractSet[V]` + g : `networkx.Graph` + Simple graph representing MBQC pattern. + iset : `collections.abc.Set` Input nodes. - Must be a subset of `g.nodes`. - oset : `AbstractSet[V]` + oset : `collections.abc.Set` Output nodes. - Must be a subset of `g.nodes`. - pplane : `Mapping[V, PPlane] | None`, optional - Measurement planes or Pauli indices of each vertex in V\O. - If `None`, defaults to all `PPlane.XY`. + pplane : `collections.abc.Mapping` + Measurement plane or Pauli index for each node in :math:`V \setminus O`. + Defaults to `PPlane.XY`. Returns ------- - If a Pauli flow exists, return it as `GFlowResult[V]` object. - Otherwise, return `None`. + `GFlowResult` or `None` + Return the Pauli flow if any, otherwise `None`. Notes ----- - Do not pass `None` to `pplane`. - For that case, use `gflow.find` instead. + Use `gflow.find` whenever possible for better performance. """ _common.check_graph(g, iset, oset) vset = g.nodes @@ -57,17 +60,64 @@ def find( if all(pp not in {PPlane.X, PPlane.Y, PPlane.Z} for pp in pplane.values()): msg = "No Pauli measurement found. Use gflow.find instead." warnings.warn(msg, stacklevel=1) + ignore = pplane.keys() & oset + if len(ignore) != 0: + msg = "Ignoring pplane[v] where v in oset." + warnings.warn(msg, stacklevel=1) + pplane = {k: v for k, v in pplane.items() if k not in ignore} codec = IndexMap(vset) g_ = codec.encode_graph(g) iset_ = codec.encode_set(iset) oset_ = codec.encode_set(oset) pplane_ = codec.encode_dictkey(pplane) - if len(pplane_) != len(pplane): - msg = "Ignoring pplane[v] where v in oset." - warnings.warn(msg, stacklevel=1) - if ret_ := pflow.find(g_, iset_, oset_, pplane_): + if ret_ := pflow_bind.find(g_, iset_, oset_, pplane_): f_, layer_ = ret_ f = codec.decode_gflow(f_) layer = codec.decode_layer(layer_) return GFlowResult(f, layer) return None + + +def verify( + pflow: GFlowResult[V], + g: nx.Graph[V], + iset: AbstractSet[V], + oset: AbstractSet[V], + pplane: Mapping[V, PPlane] | None = None, +) -> None: + r"""Verify maximally-delayed Pauli flow. + + Parameters + ---------- + pflow : `GFlowResult` + Pauli flow to verify. + g : `networkx.Graph` + Simple graph representing MBQC pattern. + iset : `collections.abc.Set` + Input nodes. + oset : `collections.abc.Set` + Output nodes. + pplane : `collections.abc.Mapping` + Measurement plane or Pauli index for each node in :math:`V \setminus O`. + Defaults to `PPlane.XY`. + + Raises + ------ + ValueError + If the graph is invalid or verification fails. + """ + _common.check_graph(g, iset, oset) + vset = g.nodes + if pplane is None: + pplane = dict.fromkeys(vset - oset, PPlane.XY) + codec = IndexMap(vset) + g_ = codec.encode_graph(g) + iset_ = codec.encode_set(iset) + oset_ = codec.encode_set(oset) + pplane_ = codec.encode_dictkey(pplane) + f_ = codec.encode_gflow(pflow.f) + layer_ = codec.encode_layer(pflow.layer) + try: + pflow_bind.verify((f_, layer_), g_, iset_, oset_, pplane_) + except ValueError as e: + raise codec.decode_err(e) from None diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..44b6aab5 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,2 @@ +group_imports = "StdExternalCrate" +imports_granularity = "Crate" diff --git a/src/common.rs b/src/common.rs index cd9b222a..e7713d02 100644 --- a/src/common.rs +++ b/src/common.rs @@ -2,6 +2,8 @@ use std::collections::BTreeSet; +use thiserror::Error; + /// Set of nodes indexed by 0-based integers. pub type Nodes = hashbrown::HashSet; /// Simple graph encoded as list of neighbors. @@ -15,3 +17,25 @@ pub type Layer = Vec; /// /// Used only when iteration order matters. pub(crate) type OrderedNodes = BTreeSet; + +/// Error type for flow validation. +#[derive(Debug, Error)] +pub enum FlowValidationError { + // MEMO: Need to match fastflow._common.MSG_RE + #[error("non-zero-layer node inside output nodes ({0})")] + ExcessiveNonZeroLayer(usize), + #[error("zero-layer node outside output nodes ({0})")] + ExcessiveZeroLayer(usize), + #[error("flow function has invalid codomain ({0})")] + InvalidFlowCodomain(usize), + #[error("flow function has invalid domain ({0})")] + InvalidFlowDomain(usize), + #[error("measurement specification is excessive or insufficient ({0})")] + InvalidMeasurementSpec(usize), + #[error("flow function and partial order are inconsistent ({0}, {1})")] + InconsistentFlowOrder(usize, usize), + #[error("flow function and measurement specification are inconsistent ({0})")] + InconsistentFlowPlane(usize), + #[error("flow function and measurement specification are inconsistent ({0})")] + InconsistentFlowPPlane(usize), +} diff --git a/src/flow.rs b/src/flow.rs index ae0a6b0c..e3b1536e 100644 --- a/src/flow.rs +++ b/src/flow.rs @@ -1,10 +1,10 @@ //! Maximally-delayed causal flow algorithm. use hashbrown; -use pyo3::prelude::*; +use pyo3::{exceptions::PyValueError, prelude::*}; use crate::{ - common::{Graph, Layer, Nodes}, + common::{FlowValidationError::InconsistentFlowOrder, Graph, Layer, Nodes}, internal::{utils::InPlaceSetDiff, validate}, }; @@ -18,19 +18,17 @@ type Flow = hashbrown::HashMap; fn check_definition(f: &Flow, layer: &Layer, g: &Graph) -> anyhow::Result<()> { for (&i, &fi) in f { if layer[i] <= layer[fi] { - let err = anyhow::anyhow!("layer check failed").context(format!("must be {i} -> {fi}")); + let err = anyhow::Error::from(InconsistentFlowOrder(i, fi)); return Err(err); } for &j in &g[fi] { if i != j && layer[i] <= layer[j] { - let err = anyhow::anyhow!("layer check failed") - .context(format!("neither {i} == {j} nor {i} -> {j}")); + let err = anyhow::Error::from(InconsistentFlowOrder(i, j)); return Err(err); } } if !(g[fi].contains(&i) && g[i].contains(&fi)) { - let err = anyhow::anyhow!("graph check failed") - .context(format!("{i} and {fi} not connected")); + let err = anyhow::Error::from(InconsistentFlowOrder(i, fi)); return Err(err); } } @@ -57,7 +55,6 @@ fn check_definition(f: &Flow, layer: &Layer, g: &Graph) -> anyhow::Result<()> { #[tracing::instrument] #[allow(clippy::needless_pass_by_value, clippy::must_use_candidate)] pub fn find(g: Graph, iset: Nodes, mut oset: Nodes) -> Option<(Flow, Layer)> { - validate::check_graph(&g, &iset, &oset).unwrap(); let n = g.len(); let vset = (0..n).collect::(); let mut cset = &oset - &iset; @@ -117,6 +114,30 @@ pub fn find(g: Graph, iset: Nodes, mut oset: Nodes) -> Option<(Flow, Layer)> { } } +/// Validates flow. +/// +/// # Errors +/// +/// - If `flow` is invalid. +/// - If `flow` is inconsistent with `g`. +#[pyfunction] +#[allow(clippy::needless_pass_by_value)] +pub fn verify(flow: (Flow, Layer), g: Graph, iset: Nodes, oset: Nodes) -> PyResult<()> { + let (f, layer) = flow; + let n = g.len(); + let vset = (0..n).collect::(); + if let Err(e) = validate::check_domain(f.iter(), &vset, &iset, &oset) { + return Err(PyValueError::new_err(e.to_string())); + } + if let Err(e) = validate::check_initial(&layer, &oset, true) { + return Err(PyValueError::new_err(e.to_string())); + } + if let Err(e) = check_definition(&f, &layer, &g) { + return Err(PyValueError::new_err(e.to_string())); + } + Ok(()) +} + #[cfg(test)] mod tests { use test_log; @@ -128,35 +149,38 @@ mod tests { fn test_find_case0() { let TestCase { g, iset, oset } = test_utils::CASE0.clone(); let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(layer, vec![0, 0]); + verify((f, layer), g, iset, oset).unwrap(); } #[test_log::test] fn test_find_case1() { let TestCase { g, iset, oset } = test_utils::CASE1.clone(); let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], 1); assert_eq!(f[&1], 2); assert_eq!(f[&2], 3); assert_eq!(f[&3], 4); assert_eq!(layer, vec![4, 3, 2, 1, 0]); + verify((f, layer), g, iset, oset).unwrap(); } #[test_log::test] fn test_find_case2() { let TestCase { g, iset, oset } = test_utils::CASE2.clone(); let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], 2); assert_eq!(f[&1], 3); assert_eq!(f[&2], 4); assert_eq!(f[&3], 5); assert_eq!(layer, vec![2, 2, 1, 1, 0, 0]); + verify((f, layer), g, iset, oset).unwrap(); } #[test_log::test] diff --git a/src/gflow.rs b/src/gflow.rs index feb2ab2a..f6ab100d 100644 --- a/src/gflow.rs +++ b/src/gflow.rs @@ -4,10 +4,15 @@ use std::iter; use fixedbitset::FixedBitSet; use hashbrown; -use pyo3::prelude::*; +use pyo3::{exceptions::PyValueError, prelude::*}; use crate::{ - common::{Graph, Layer, Nodes, OrderedNodes}, + common::{ + FlowValidationError::{ + InconsistentFlowOrder, InconsistentFlowPlane, InvalidMeasurementSpec, + }, + Graph, Layer, Nodes, OrderedNodes, + }, internal::{ gf2_linalg::GF2Solver, utils::{self, InPlaceSetDiff}, @@ -17,13 +22,10 @@ use crate::{ #[pyclass(eq, hash, frozen)] #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] -/// Measurement plane. +/// Enum-like class for measurement planes. pub enum Plane { - /// Measurement on the XY plane. XY, - /// Measurement on the YZ plane. YZ, - /// Measurement on the XZ plane. XZ, } @@ -38,46 +40,39 @@ type GFlow = hashbrown::HashMap; /// 4. i in g(i) and in Odd(g(i)) if plane(i) == YZ /// 5. i in g(i) and not in Odd(g(i)) if plane(i) == XZ fn check_definition(f: &GFlow, layer: &Layer, g: &Graph, planes: &Planes) -> anyhow::Result<()> { - anyhow::ensure!( - f.len() == planes.len(), - "f and planes must have the same codomain" - ); + for &i in itertools::chain(f.keys(), planes.keys()) { + if f.contains_key(&i) != planes.contains_key(&i) { + let err = anyhow::Error::from(InvalidMeasurementSpec(i)); + return Err(err); + } + } for (&i, fi) in f { let pi = planes[&i]; for &fij in fi { if i != fij && layer[i] <= layer[fij] { - let err = anyhow::anyhow!("layer check failed") - .context(format!("neither {i} == {fij} nor {i} -> {fij}: fi")); + let err = anyhow::Error::from(InconsistentFlowOrder(i, fij)); return Err(err); } } let odd_fi = utils::odd_neighbors(g, fi); for &j in &odd_fi { if i != j && layer[i] <= layer[j] { - let err = anyhow::anyhow!("layer check failed").context(format!( - "neither {i} == {j} nor {i} -> {j}: odd_neighbors(g, fi)" - )); + let err = anyhow::Error::from(InconsistentFlowOrder(i, j)); return Err(err); } } let in_info = (fi.contains(&i), odd_fi.contains(&i)); match pi { Plane::XY if in_info != (false, true) => { - let err = anyhow::anyhow!("plane check failed").context(format!( - "must satisfy ({i} in f({i}), {i} in Odd(f({i})) = (false, true): XY" - )); + let err = anyhow::Error::from(InconsistentFlowPlane(i)); return Err(err); } Plane::YZ if in_info != (true, false) => { - let err = anyhow::anyhow!("plane check failed").context(format!( - "must satisfy ({i} in f({i}), {i} in Odd(f({i})) = (true, false): YZ" - )); + let err = anyhow::Error::from(InconsistentFlowPlane(i)); return Err(err); } Plane::XZ if in_info != (true, true) => { - let err = anyhow::anyhow!("plane check failed").context(format!( - "must satisfy ({i} in f({i}), {i} in Odd(f({i})) = (true, true): XZ" - )); + let err = anyhow::Error::from(InconsistentFlowPlane(i)); return Err(err); } _ => {} @@ -148,7 +143,6 @@ fn init_work( #[tracing::instrument] #[allow(clippy::needless_pass_by_value, clippy::must_use_candidate)] pub fn find(g: Graph, iset: Nodes, oset: Nodes, planes: Planes) -> Option<(GFlow, Layer)> { - validate::check_graph(&g, &iset, &oset).unwrap(); let n = g.len(); let vset = (0..n).collect::(); let mut cset = Nodes::new(); @@ -234,6 +228,39 @@ pub fn find(g: Graph, iset: Nodes, oset: Nodes, planes: Planes) -> Option<(GFlow } } +/// Validates generalized flow. +/// +/// # Errors +/// +/// - If `gflow` is invalid. +/// - If `gflow` is inconsistent with `g`. +#[pyfunction] +#[allow(clippy::needless_pass_by_value)] +pub fn verify( + gflow: (GFlow, Layer), + g: Graph, + iset: Nodes, + oset: Nodes, + planes: Planes, +) -> PyResult<()> { + let (f, layer) = gflow; + let n = g.len(); + let vset = (0..n).collect::(); + let f_flatiter = f + .iter() + .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); + if let Err(e) = validate::check_domain(f_flatiter, &vset, &iset, &oset) { + return Err(PyValueError::new_err(e.to_string())); + } + if let Err(e) = validate::check_initial(&layer, &oset, true) { + return Err(PyValueError::new_err(e.to_string())); + } + if let Err(e) = check_definition(&f, &layer, &g, &planes) { + return Err(PyValueError::new_err(e.to_string())); + } + Ok(()) +} + #[cfg(test)] mod tests { use test_log; @@ -246,9 +273,10 @@ mod tests { let TestCase { g, iset, oset } = test_utils::CASE0.clone(); let planes = measurements! {}; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, planes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(layer, vec![0, 0]); + verify((f, layer), g, iset, oset, planes).unwrap(); } #[test_log::test] @@ -261,13 +289,14 @@ mod tests { 3: Plane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, planes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([1])); assert_eq!(f[&1], Nodes::from([2])); assert_eq!(f[&2], Nodes::from([3])); assert_eq!(f[&3], Nodes::from([4])); assert_eq!(layer, vec![4, 3, 2, 1, 0]); + verify((f, layer), g, iset, oset, planes).unwrap(); } #[test_log::test] @@ -280,13 +309,14 @@ mod tests { 3: Plane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, planes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([2])); assert_eq!(f[&1], Nodes::from([3])); assert_eq!(f[&2], Nodes::from([4])); assert_eq!(f[&3], Nodes::from([5])); assert_eq!(layer, vec![2, 2, 1, 1, 0, 0]); + verify((f, layer), g, iset, oset, planes).unwrap(); } #[test_log::test] @@ -298,12 +328,13 @@ mod tests { 2: Plane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, planes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([4, 5])); assert_eq!(f[&1], Nodes::from([3, 4, 5])); assert_eq!(f[&2], Nodes::from([3, 5])); assert_eq!(layer, vec![1, 1, 1, 0, 0, 0]); + verify((f, layer), g, iset, oset, planes).unwrap(); } #[test_log::test] @@ -316,13 +347,14 @@ mod tests { 3: Plane::YZ }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, planes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([2])); assert_eq!(f[&1], Nodes::from([5])); assert_eq!(f[&2], Nodes::from([2, 4])); assert_eq!(f[&3], Nodes::from([3])); assert_eq!(layer, vec![2, 2, 1, 1, 0, 0]); + verify((f, layer), g, iset, oset, planes).unwrap(); } #[test_log::test] diff --git a/src/internal/gf2_linalg.rs b/src/internal/gf2_linalg.rs index a32ead4b..c5f20811 100644 --- a/src/internal/gf2_linalg.rs +++ b/src/internal/gf2_linalg.rs @@ -52,7 +52,11 @@ impl<'a> GF2Solver<'a> { /// /// # Panics /// - /// - If similar conditions to `try_new_from` are not met. + /// - If `neqs` is zero. + /// - If `work` is empty (no rows). + /// - If `work` is jagged, i.e., the number of columns is not uniform. + /// - If `work[...]` is empty (no columns). + /// - If `neqs` is so large that there is no room for the coefficient matrix. pub fn attach(work: &'a mut GF2Matrix, neqs: usize) -> Self { if let Err(e) = Self::attach_check(work, neqs) { panic!("invalid argument detected: {e}"); @@ -314,6 +318,9 @@ mod tests { assert_eq!(format!("{:}", sol.work[0]), "1000111"); assert_eq!(format!("{:}", sol.work[1]), "0100011"); assert_eq!(format!("{:}", sol.work[2]), "0010001"); + // Call Debug + let ex = format!("{sol:?}"); + assert!(!ex.is_empty()); } /// Helper function to create a solver storage from the coefficient matrix and the right-hand side. @@ -382,17 +389,119 @@ mod tests { rhs } + #[test] + #[should_panic = "neqs is zero"] + fn test_attach_noeq() { + GF2Solver::attach(&mut [], 0); + } + + #[test] + #[should_panic = "work is empty"] + fn test_attach_empty_rows() { + GF2Solver::attach(&mut [], 1); + } + + #[test] + #[should_panic = "zero-length columns"] + fn test_attach_empty_cols() { + let mut work = vec![FixedBitSet::with_capacity(0); 3]; + GF2Solver::attach(&mut work, 1); + } + + #[test] + #[should_panic = "work is jagged"] + fn test_attach_empty_jagged() { + let mut work = vec![FixedBitSet::with_capacity(3), FixedBitSet::with_capacity(4)]; + GF2Solver::attach(&mut work, 1); + } + + #[test] + #[should_panic = "neqs too large"] + fn test_attach_neqs_large() { + let mut work = vec![FixedBitSet::with_capacity(3); 3]; + GF2Solver::attach(&mut work, 4); + } + + #[test] + #[should_panic = "output size mismatch:"] + fn test_solve_invalid_size() { + let mut work = vec![ + // 1000111 + FixedBitSet::with_capacity_and_blocks(7, vec![0b111_0001]), + // 0100011 + FixedBitSet::with_capacity_and_blocks(7, vec![0b110_0010]), + // 0010001 + FixedBitSet::with_capacity_and_blocks(7, vec![0b100_0100]), + ]; + let mut sol = GF2Solver::attach(&mut work, 3); + let mut out = FixedBitSet::with_capacity(5); + sol.solve_in_place(&mut out, 0); + } + + #[test] + #[should_panic = "equation index out of range:"] + fn test_solve_invalid_index() { + let mut work = vec![ + // 1000111 + FixedBitSet::with_capacity_and_blocks(7, vec![0b111_0001]), + // 0100011 + FixedBitSet::with_capacity_and_blocks(7, vec![0b110_0010]), + // 0010001 + FixedBitSet::with_capacity_and_blocks(7, vec![0b100_0100]), + ]; + let mut sol = GF2Solver::attach(&mut work, 3); + let mut out = FixedBitSet::with_capacity(4); + sol.solve_in_place(&mut out, 9); + } + + #[cfg(not(miri))] const REP: usize = 1000; + #[cfg(miri)] + const REP: usize = 1; + #[cfg(not(miri))] #[template] #[rstest] fn template_tests( - #[values(1, 2, 7, 12, 23, 36)] rows: usize, - #[values(1, 2, 7, 12, 23, 36)] cols: usize, + #[values(1, 2, 7, 12, 23)] rows: usize, + #[values(1, 2, 7, 12, 23)] cols: usize, #[values(1, 2, 7, 12)] neqs: usize, ) { } + #[cfg(miri)] + #[template] + #[rstest] + fn template_tests( + #[values(1, 2, 7)] rows: usize, + #[values(1, 2, 7)] cols: usize, + #[values(1, 2)] neqs: usize, + ) { + } + + #[test] + fn test_solve_simple() { + let mut work = vec![ + // 1001 + FixedBitSet::with_capacity_and_blocks(4, vec![0b1001]), + // 0101 + FixedBitSet::with_capacity_and_blocks(4, vec![0b1010]), + // 0000 + FixedBitSet::with_capacity_and_blocks(4, vec![0b0000]), + ]; + let mut sol = GF2Solver::attach(&mut work, 1); + let mut x = FixedBitSet::with_capacity(3); + assert_eq!(sol.rank, None); + assert!(sol.solve_in_place(&mut x, 0)); + assert_eq!(sol.rank, Some(2)); + let x_orig = x.clone(); + assert!(sol.solve_in_place(&mut x, 0)); + assert_eq!(x, x_orig); + assert!(x[0]); + assert!(x[1]); + assert!(!x[2]); + } + #[apply(template_tests)] fn test_eliminate_lower_random(rows: usize, cols: usize, neqs: usize) { let mut rng = thread_rng(); @@ -474,9 +583,9 @@ mod tests { assert!(sol.rank.unwrap() < sol.rows); continue; } - for i in sol.rank.unwrap()..sol.rows { - assert!(sol.work[i].count_ones(..sol.cols) == 0); - assert!(!sol.work[i][cols + ieq]); + for row in &sol.work[sol.rank.unwrap()..sol.rows] { + assert!(row.count_ones(..sol.cols) == 0); + assert!(!row[cols + ieq]); } let b = compute_lhs(&co, &x); assert_eq!(&b, rhsi); @@ -505,9 +614,9 @@ mod tests { assert!(sol.rank.unwrap() < sol.rows); continue; } - for i in sol.rank.unwrap()..sol.rows { - assert!(sol.work[i].count_ones(..sol.cols) == 0); - assert!(!sol.work[i][cols + ieq]); + for row in &sol.work[sol.rank.unwrap()..sol.rows] { + assert!(row.count_ones(..sol.cols) == 0); + assert!(!row[cols + ieq]); } let b = compute_lhs(&co, &x); assert_eq!(&b, rhsi); diff --git a/src/internal/test_utils.rs b/src/internal/test_utils.rs index 03173248..6ef1cb88 100644 --- a/src/internal/test_utils.rs +++ b/src/internal/test_utils.rs @@ -146,3 +146,62 @@ pub static CASE8: LazyLock = LazyLock::new(|| { oset: Nodes::from([3, 4]), } }); + +#[cfg(test)] +mod tests { + use rstest::rstest; + use rstest_reuse::{apply, template}; + + use super::*; + + #[template] + #[rstest] + fn template_tests( + #[values(&*CASE0, &*CASE1, &*CASE2, &*CASE3, &*CASE4, &*CASE5, &*CASE6, &*CASE7, &*CASE8)] + input: &TestCase, + ) { + } + + /// Checks if the graph is valid. + /// + /// In production code, this chech should be done in the Python layer. + fn check_graph(g: &Graph, iset: &Nodes, oset: &Nodes) -> anyhow::Result<()> { + let n = g.len(); + if n == 0 { + anyhow::bail!("empty graph"); + } + for (u, gu) in g.iter().enumerate() { + if gu.contains(&u) { + anyhow::bail!("self-loop detected: {u}"); + } + gu.iter().try_for_each(|&v| { + if v >= n { + anyhow::bail!("node index out of range: {v}"); + } + if !g[v].contains(&u) { + anyhow::bail!("g must be undirected: needs {v} -> {u}"); + } + Ok(()) + })?; + } + iset.iter().try_for_each(|&u| { + if !(0..n).contains(&u) { + anyhow::bail!("unknown node in iset: {u}"); + } + Ok(()) + })?; + oset.iter().try_for_each(|&u| { + if !(0..n).contains(&u) { + anyhow::bail!("unknown node in oset: {u}"); + } + Ok(()) + })?; + Ok(()) + } + + #[apply(template_tests)] + fn test_input(input: &TestCase) { + let TestCase { g, iset, oset } = input; + check_graph(g, iset, oset).unwrap(); + } +} diff --git a/src/internal/utils.rs b/src/internal/utils.rs index 52fced91..32dd3a47 100644 --- a/src/internal/utils.rs +++ b/src/internal/utils.rs @@ -10,7 +10,7 @@ use fixedbitset::FixedBitSet; use crate::common::{Graph, Nodes, OrderedNodes}; -/// Computes the odd neighbors of the vertices in `kset`. +/// Computes the odd neighbors of the nodes in `kset`. /// /// # Note /// @@ -150,6 +150,8 @@ impl Drop for ScopedExclude<'_> { #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use super::*; use crate::internal::test_utils::{TestCase, CASE3}; @@ -178,4 +180,63 @@ mod tests { Nodes::from([1, 5]) ); } + + #[test] + fn test_zerofill() { + let mut mat = vec![FixedBitSet::new(), FixedBitSet::new(), FixedBitSet::new()]; + zerofill(&mut mat, 10); + for row in &mat { + assert_eq!(row.len(), 10); + assert!(row.is_clear()); + } + } + + #[test] + fn test_difference_with_hashset() { + let mut set = hashbrown::HashSet::from([1, 2, 3]); + set.difference_with(&[2, 3, 4]); + assert_eq!(set, hashbrown::HashSet::from([1])); + } + + #[test] + fn test_difference_with_btreeset() { + let mut set = BTreeSet::from([1, 2, 3]); + set.difference_with(&[2, 3, 4]); + assert_eq!(set, BTreeSet::from([1])); + } + + #[test] + fn test_indexmap() { + let set = OrderedNodes::from([8, 1, 0]); + let imap = indexmap::>(&set); + assert_eq!(imap[&0], 0); + assert_eq!(imap[&1], 1); + assert_eq!(imap[&8], 2); + } + + #[test] + fn test_scoped_include() { + let mut set = OrderedNodes::new(); + { + let mut guard = ScopedInclude::new(&mut set, 0); + // Mutable borrow + guard.insert(1); + // Immutable borrow + assert_eq!(*guard, OrderedNodes::from([0, 1])); + } + assert_eq!(set, OrderedNodes::from([1])); + } + + #[test] + fn test_scoped_exclude() { + let mut set = OrderedNodes::from([0]); + { + let mut guard = ScopedExclude::new(&mut set, 0); + // Mutable borrow + guard.insert(1); + // Immutable borrow + assert_eq!(*guard, OrderedNodes::from([1])); + } + assert_eq!(set, OrderedNodes::from([0, 1])); + } } diff --git a/src/internal/validate.rs b/src/internal/validate.rs index 463a7398..0b38894c 100644 --- a/src/internal/validate.rs +++ b/src/internal/validate.rs @@ -4,52 +4,12 @@ //! //! - Internal module for testing. -use crate::common::{Graph, Layer, Nodes}; - -/// Checks if the graph is valid. -/// -/// # Returns -/// -/// Returns `Err` if any of the following conditions are met: -/// -/// - `g` is empty. -/// - `g` contains self-loops. -/// - `g` is not symmetric. -/// - `g` contains nodes other than `0..g.len()`. -/// - `iset`/`oset` contains inconsistent nodes. -pub fn check_graph(g: &Graph, iset: &Nodes, oset: &Nodes) -> anyhow::Result<()> { - let n = g.len(); - if n == 0 { - anyhow::bail!("empty graph"); - } - for (u, gu) in g.iter().enumerate() { - if gu.contains(&u) { - anyhow::bail!("self-loop detected: {u}"); - } - gu.iter().try_for_each(|&v| { - if v >= n { - anyhow::bail!("node index out of range: {v}"); - } - if !g[v].contains(&u) { - anyhow::bail!("g must be undirected: needs {v} -> {u}"); - } - Ok(()) - })?; - } - iset.iter().try_for_each(|&u| { - if !(0..n).contains(&u) { - anyhow::bail!("unknown node in iset: {u}"); - } - Ok(()) - })?; - oset.iter().try_for_each(|&u| { - if !(0..n).contains(&u) { - anyhow::bail!("unknown node in oset: {u}"); - } - Ok(()) - })?; - Ok(()) -} +use crate::common::{ + FlowValidationError::{ + ExcessiveNonZeroLayer, ExcessiveZeroLayer, InvalidFlowCodomain, InvalidFlowDomain, + }, + Layer, Nodes, +}; /// Checks if the layer-zero nodes are correctly chosen. /// @@ -62,13 +22,11 @@ pub fn check_initial(layer: &Layer, oset: &Nodes, iff: bool) -> anyhow::Result<( for (u, &lu) in layer.iter().enumerate() { match (oset.contains(&u), lu == 0) { (true, false) => { - let err = anyhow::anyhow!("initial check failed") - .context(format!("layer({u}) != 0 && {u} in O")); + let err = anyhow::Error::from(ExcessiveNonZeroLayer(u)); return Err(err); } (false, true) if iff => { - let err = anyhow::anyhow!("initial check failed") - .context(format!("layer({u}) == 0 && {u} not in O")); + let err = anyhow::Error::from(ExcessiveZeroLayer(u)); return Err(err); } _ => {} @@ -101,14 +59,104 @@ pub fn check_domain<'a, 'b>( for (&i, &fi) in f_flatiter { dom.insert(i); if i != fi && !icset.contains(&fi) { - let err = anyhow::anyhow!("domain check failed").context(format!("{fi} not in V\\I")); + let err = anyhow::Error::from(InvalidFlowCodomain(i)); return Err(err); } } - if dom != ocset { - let err = anyhow::anyhow!("domain check failed") - .context(format!("invalid domain: {dom:?} != V\\O")); + if let Some(&i) = dom.symmetric_difference(&ocset).next() { + let err = anyhow::Error::from(InvalidFlowDomain(i)); return Err(err); } Ok(()) } + +#[cfg(test)] +mod tests { + use std::iter; + + use super::*; + use crate::common::Nodes; + + #[test] + fn test_check_initial() { + let layer = vec![0, 0, 0, 1, 1, 1]; + let oset = Nodes::from([0, 1]); + check_initial(&layer, &oset, false).unwrap(); + } + + #[test] + fn test_check_initial_ng() { + let layer = vec![0, 0, 0, 1, 1, 1]; + let oset = Nodes::from([0, 1, 2, 3]); + assert!(check_initial(&layer, &oset, false).is_err()); + } + + #[test] + fn test_check_initial_iff() { + let layer = vec![0, 0, 0, 1, 1, 1]; + let oset = Nodes::from([0, 1, 2]); + check_initial(&layer, &oset, true).unwrap(); + } + + #[test] + fn test_check_initial_iff_ng() { + let layer = vec![0, 0, 0, 1, 1, 1]; + let oset = Nodes::from([0, 1]); + assert!(check_initial(&layer, &oset, true).is_err()); + } + + #[test] + fn test_check_domain_flow() { + let f = hashbrown::HashMap::::from([(0, 1), (1, 2)]); + let vset = Nodes::from([0, 1, 2]); + let iset = Nodes::from([0]); + let oset = Nodes::from([2]); + check_domain(f.iter(), &vset, &iset, &oset).unwrap(); + } + + #[test] + fn test_check_domain_gflow() { + let f = hashbrown::HashMap::::from([ + // OK: 0 in f(0) + (0, Nodes::from([0, 1])), + (1, Nodes::from([2])), + ]); + let vset = Nodes::from([0, 1, 2]); + let iset = Nodes::from([0]); + let oset = Nodes::from([2]); + let f_flatiter = f + .iter() + .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); + check_domain(f_flatiter, &vset, &iset, &oset).unwrap(); + } + + #[test] + fn test_check_domain_ng_iset() { + let f = hashbrown::HashMap::::from([ + (0, Nodes::from([0, 1])), + (2, Nodes::from([2])), + ]); + let vset = Nodes::from([0, 1, 2]); + let iset = Nodes::from([0]); + let oset = Nodes::from([2]); + let f_flatiter = f + .iter() + .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); + assert!(check_domain(f_flatiter, &vset, &iset, &oset).is_err()); + } + + #[test] + fn test_check_domain_ng_oset() { + let f = hashbrown::HashMap::::from([ + (0, Nodes::from([1])), + (1, Nodes::from([0])), + ]); + let vset = Nodes::from([0, 1, 2]); + let iset = Nodes::from([0]); + let oset = Nodes::from([2]); + let f_flatiter = f + .iter() + .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); + assert!(check_domain(f_flatiter, &vset, &iset, &oset).is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 37e5a7ae..d2d20427 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,16 +25,19 @@ fn entrypoint(m: &Bound<'_, PyModule>) -> PyResult<()> { // fastflow._impl.flow let mod_flow = PyModule::new_bound(m.py(), "flow")?; mod_flow.add_function(wrap_pyfunction!(flow::find, &mod_flow)?)?; + mod_flow.add_function(wrap_pyfunction!(flow::verify, &mod_flow)?)?; m.add_submodule(&mod_flow)?; // fastflow._impl.gflow let mod_gflow = PyModule::new_bound(m.py(), "gflow")?; mod_gflow.add_class::()?; mod_gflow.add_function(wrap_pyfunction!(gflow::find, &mod_gflow)?)?; + mod_gflow.add_function(wrap_pyfunction!(gflow::verify, &mod_gflow)?)?; m.add_submodule(&mod_gflow)?; // fastflow._impl.pflow let mod_pflow = PyModule::new_bound(m.py(), "pflow")?; mod_pflow.add_class::()?; mod_pflow.add_function(wrap_pyfunction!(pflow::find, &mod_pflow)?)?; + mod_pflow.add_function(wrap_pyfunction!(pflow::verify, &mod_pflow)?)?; m.add_submodule(&mod_pflow)?; Ok(()) } diff --git a/src/pflow.rs b/src/pflow.rs index 9f2d6e52..bbbde9ab 100644 --- a/src/pflow.rs +++ b/src/pflow.rs @@ -4,10 +4,15 @@ use std::iter; use fixedbitset::FixedBitSet; use hashbrown; -use pyo3::prelude::*; +use pyo3::{exceptions::PyValueError, prelude::*}; use crate::{ - common::{Graph, Layer, Nodes, OrderedNodes}, + common::{ + FlowValidationError::{ + InconsistentFlowOrder, InconsistentFlowPlane, InvalidMeasurementSpec, + }, + Graph, Layer, Nodes, OrderedNodes, + }, internal::{ gf2_linalg::GF2Solver, utils::{self, InPlaceSetDiff, ScopedExclude, ScopedInclude}, @@ -17,19 +22,13 @@ use crate::{ #[pyclass(eq, hash, frozen)] #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] -/// Measurement plane or Pauli index. +/// Enum-like class for measurement planes or Pauli measurements. pub enum PPlane { - /// Arbitrary measurement on the XY plane. XY, - /// Arbitrary measurement on the YZ plane. YZ, - /// Arbitrary measurement on the XZ plane. XZ, - /// Pauli X measurement. X, - /// Pauli Y measurement. Y, - /// Pauli Z measurement. Z, } @@ -38,17 +37,18 @@ type PFlow = hashbrown::HashMap; /// Checks the definition of Pauli flow. fn check_definition(f: &PFlow, layer: &Layer, g: &Graph, pplanes: &PPlanes) -> anyhow::Result<()> { - anyhow::ensure!( - f.len() == pplanes.len(), - "f and pplanes must have the same codomain" - ); + for &i in itertools::chain(f.keys(), pplanes.keys()) { + if f.contains_key(&i) != pplanes.contains_key(&i) { + let err = anyhow::Error::from(InvalidMeasurementSpec(i)); + return Err(err); + } + } for (&i, fi) in f { let pi = pplanes[&i]; for &fij in fi { match (i != fij, layer[i] <= layer[fij]) { (true, true) if !matches!(pplanes[&fij], PPlane::X | PPlane::Y) => { - let err = anyhow::anyhow!("layer check failed") - .context(format!("neither {i} == {fij} nor {i} -> {fij}: fi")); + let err = anyhow::Error::from(InconsistentFlowOrder(i, fij)); return Err(err); } (false, false) => unreachable!("layer[i] == layer[i]"), @@ -59,9 +59,7 @@ fn check_definition(f: &PFlow, layer: &Layer, g: &Graph, pplanes: &PPlanes) -> a for &j in &odd_fi { match (i != j, layer[i] <= layer[j]) { (true, true) if !matches!(pplanes[&j], PPlane::Y | PPlane::Z) => { - let err = anyhow::anyhow!("layer check failed").context(format!( - "neither {i} == {j} nor {i} -> {j}: odd_neighbors(g, fi)" - )); + let err = anyhow::Error::from(InconsistentFlowOrder(i, j)); return Err(err); } (false, false) => unreachable!("layer[i] == layer[i]"), @@ -70,45 +68,34 @@ fn check_definition(f: &PFlow, layer: &Layer, g: &Graph, pplanes: &PPlanes) -> a } for &j in fi.symmetric_difference(&odd_fi) { if pplanes.get(&j) == Some(&PPlane::Y) && i != j && layer[i] <= layer[j] { - let err = anyhow::anyhow!("Y correction check failed") - .context(format!("{j} must be corrected by f({i}) xor Odd(f({i}))")); + let err = anyhow::Error::from(InconsistentFlowPlane(j)); return Err(err); } } let in_info = (fi.contains(&i), odd_fi.contains(&i)); match pi { PPlane::XY if in_info != (false, true) => { - let err = anyhow::anyhow!("pplane check failed").context(format!( - "must satisfy ({i} in f({i}), {i} in Odd(f({i})) = (false, true): XY" - )); + let err = anyhow::Error::from(InconsistentFlowPlane(i)); return Err(err); } PPlane::YZ if in_info != (true, false) => { - let err = anyhow::anyhow!("pplane check failed").context(format!( - "must satisfy ({i} in f({i}), {i} in Odd(f({i})) = (true, false): YZ" - )); + let err = anyhow::Error::from(InconsistentFlowPlane(i)); return Err(err); } PPlane::XZ if in_info != (true, true) => { - let err = anyhow::anyhow!("pplane check failed").context(format!( - "must satisfy ({i} in f({i}), {i} in Odd(f({i})) = (true, true): XZ" - )); + let err = anyhow::Error::from(InconsistentFlowPlane(i)); return Err(err); } PPlane::X if !in_info.1 => { - let err = anyhow::anyhow!("pplane check failed") - .context(format!("{i} must be in Odd(f({i})): X")); + let err = anyhow::Error::from(InconsistentFlowPlane(i)); return Err(err); } PPlane::Y if !(in_info.0 ^ in_info.1) => { - let err = anyhow::anyhow!("pplane check failed").context(format!( - "{i} must be in either f({i}) or Odd(f({i})), not both: Y" - )); + let err = anyhow::Error::from(InconsistentFlowPlane(i)); return Err(err); } PPlane::Z if !in_info.0 => { - let err = anyhow::anyhow!("pplane check failed") - .context(format!("{i} must be in f({i}): Z")); + let err = anyhow::Error::from(InconsistentFlowPlane(i)); return Err(err); } _ => {} @@ -321,7 +308,6 @@ fn find_impl(ctx: &mut PFlowContext) -> bool { #[tracing::instrument] #[allow(clippy::needless_pass_by_value, clippy::must_use_candidate)] pub fn find(g: Graph, iset: Nodes, oset: Nodes, pplanes: PPlanes) -> Option<(PFlow, Layer)> { - validate::check_graph(&g, &iset, &oset).unwrap(); let yset = matching_nodes(&pplanes, |pp| matches!(pp, PPlane::Y)); let xyset = matching_nodes(&pplanes, |pp| matches!(pp, PPlane::X | PPlane::Y)); let yzset = matching_nodes(&pplanes, |pp| matches!(pp, PPlane::Y | PPlane::Z)); @@ -422,6 +408,39 @@ pub fn find(g: Graph, iset: Nodes, oset: Nodes, pplanes: PPlanes) -> Option<(PFl } } +/// Validates Pauli flow. +/// +/// # Errors +/// +/// - If `pflow` is invalid. +/// - If `pflow` is inconsistent with `g`. +#[pyfunction] +#[allow(clippy::needless_pass_by_value)] +pub fn verify( + pflow: (PFlow, Layer), + g: Graph, + iset: Nodes, + oset: Nodes, + pplanes: PPlanes, +) -> PyResult<()> { + let (f, layer) = pflow; + let n = g.len(); + let vset = (0..n).collect::(); + let f_flatiter = f + .iter() + .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); + if let Err(e) = validate::check_domain(f_flatiter, &vset, &iset, &oset) { + return Err(PyValueError::new_err(e.to_string())); + } + if let Err(e) = validate::check_initial(&layer, &oset, false) { + return Err(PyValueError::new_err(e.to_string())); + } + if let Err(e) = check_definition(&f, &layer, &g, &pplanes) { + return Err(PyValueError::new_err(e.to_string())); + } + Ok(()) +} + #[cfg(test)] mod tests { use test_log; @@ -434,9 +453,10 @@ mod tests { let TestCase { g, iset, oset } = test_utils::CASE0.clone(); let pplanes = measurements! {}; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, pplanes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(layer, vec![0, 0]); + verify((f, layer), g, iset, oset, pplanes).unwrap(); } #[test_log::test] @@ -449,13 +469,14 @@ mod tests { 3: PPlane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, pplanes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([1])); assert_eq!(f[&1], Nodes::from([2])); assert_eq!(f[&2], Nodes::from([3])); assert_eq!(f[&3], Nodes::from([4])); assert_eq!(layer, vec![4, 3, 2, 1, 0]); + verify((f, layer), g, iset, oset, pplanes).unwrap(); } #[test_log::test] @@ -468,13 +489,14 @@ mod tests { 3: PPlane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, pplanes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([2])); assert_eq!(f[&1], Nodes::from([3])); assert_eq!(f[&2], Nodes::from([4])); assert_eq!(f[&3], Nodes::from([5])); assert_eq!(layer, vec![2, 2, 1, 1, 0, 0]); + verify((f, layer), g, iset, oset, pplanes).unwrap(); } #[test_log::test] @@ -486,12 +508,13 @@ mod tests { 2: PPlane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, pplanes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([4, 5])); assert_eq!(f[&1], Nodes::from([3, 4, 5])); assert_eq!(f[&2], Nodes::from([3, 5])); assert_eq!(layer, vec![1, 1, 1, 0, 0, 0]); + verify((f, layer), g, iset, oset, pplanes).unwrap(); } #[test_log::test] @@ -504,13 +527,14 @@ mod tests { 3: PPlane::YZ }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, pplanes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([2])); assert_eq!(f[&1], Nodes::from([5])); assert_eq!(f[&2], Nodes::from([2, 4])); assert_eq!(f[&3], Nodes::from([3])); assert_eq!(layer, vec![2, 2, 1, 1, 0, 0]); + verify((f, layer), g, iset, oset, pplanes).unwrap(); } #[test_log::test] @@ -533,13 +557,14 @@ mod tests { 3: PPlane::X }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, pplanes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([1])); assert_eq!(f[&1], Nodes::from([4])); assert_eq!(f[&2], Nodes::from([3])); assert_eq!(f[&3], Nodes::from([2, 4])); assert_eq!(layer, vec![1, 1, 0, 1, 0]); + verify((f, layer), g, iset, oset, pplanes).unwrap(); } #[test_log::test] @@ -552,7 +577,7 @@ mod tests { 3: PPlane::Y }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, pplanes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); // Graphix // assert_eq!(f[&0], Nodes::from([0, 1])); @@ -561,6 +586,7 @@ mod tests { assert_eq!(f[&2], Nodes::from([2])); assert_eq!(f[&3], Nodes::from([4])); assert_eq!(layer, vec![1, 0, 0, 1, 0]); + verify((f, layer), g, iset, oset, pplanes).unwrap(); } #[test_log::test] @@ -572,7 +598,7 @@ mod tests { 2: PPlane::Y }; let flen = g.len() - oset.len(); - let (f, layer) = find(g, iset, oset, pplanes).unwrap(); + let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); // Graphix // assert_eq!(f[&0], Nodes::from([0, 3, 4])); @@ -580,5 +606,6 @@ mod tests { assert_eq!(f[&1], Nodes::from([1, 2])); assert_eq!(f[&2], Nodes::from([4])); assert_eq!(layer, vec![1, 1, 1, 0, 0]); + verify((f, layer), g, iset, oset, pplanes).unwrap(); } } diff --git a/tests/__init__.py b/tests/__init__.py index ca13e7f8..9d48db4f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -"""Initialize the tests module.""" +from __future__ import annotations diff --git a/tests/assets.py b/tests/assets.py index 2f8248c6..fa12cc72 100644 --- a/tests/assets.py +++ b/tests/assets.py @@ -10,8 +10,6 @@ @dataclasses.dataclass(frozen=True) class FlowTestCase: - """Test case for flow/gflow.""" - g: nx.Graph[int] iset: set[int] oset: set[int] @@ -165,4 +163,4 @@ class FlowTestCase: GFlowResult({0: {0, 2, 4}, 1: {1, 2}, 2: {4}}, {0: 1, 1: 1, 2: 1, 3: 0, 4: 0}), ) -CASES = [CASE0, CASE1, CASE2, CASE3, CASE4, CASE5, CASE6, CASE7, CASE8] +CASES: tuple[FlowTestCase, ...] = (CASE0, CASE1, CASE2, CASE3, CASE4, CASE5, CASE6, CASE7, CASE8) diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 00000000..a4c7d0c7 --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import networkx as nx +import pytest +from fastflow import _common +from fastflow._common import IndexMap +from fastflow.common import Plane + + +def test_check_graph_ng_g() -> None: + with pytest.raises(TypeError): + _common.check_graph("hoge", set(), set()) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="Graph is empty."): + _common.check_graph(nx.Graph(), set(), set()) + + with pytest.raises(ValueError, match="Self-loop detected."): + _common.check_graph(nx.Graph([("a", "a"), ("a", "b")]), set(), set()) + + with pytest.raises(ValueError, match="iset must be a subset of the nodes."): + _common.check_graph(nx.Graph([("a", "b")]), {"x"}, set()) + + with pytest.raises(ValueError, match="oset must be a subset of the nodes."): + _common.check_graph(nx.Graph([("a", "b")]), set(), {"x"}) + + +def test_check_graph_ng_set() -> None: + with pytest.raises(TypeError): + _common.check_graph(nx.Graph(), "hoge", set()) # type: ignore[arg-type] + + with pytest.raises(TypeError): + _common.check_graph(nx.Graph(), set(), "hoge") # type: ignore[arg-type] + + +def test_check_planelike_ng() -> None: + with pytest.raises(TypeError): + _common.check_planelike(set(), set(), "hoge") # type: ignore[arg-type] + + with pytest.raises(ValueError, match="Cannot find corresponding nodes in the graph."): + _common.check_planelike({"a"}, set(), {"x": Plane.XY}) + + with pytest.raises(ValueError, match=r"Measurement planes should be specified for all u in V\\O."): + _common.check_planelike({"a", "b"}, {"b"}, {}) + + +@pytest.fixture +def fx_indexmap() -> IndexMap[str]: + return IndexMap({"a", "b", "c"}) + + +class TestIndexMap: + def test_encode(self, fx_indexmap: IndexMap[str]) -> None: + assert { + fx_indexmap.encode("a"), + fx_indexmap.encode("b"), + fx_indexmap.encode("c"), + } == {0, 1, 2} + + with pytest.raises(ValueError, match="x not found."): + fx_indexmap.encode("x") + + def test_decode(self, fx_indexmap: IndexMap[str]) -> None: + assert { + fx_indexmap.decode(0), + fx_indexmap.decode(1), + fx_indexmap.decode(2), + } == {"a", "b", "c"} + + with pytest.raises(ValueError, match="3 not found."): + fx_indexmap.decode(3) + + def test_encdec(self, fx_indexmap: IndexMap[str]) -> None: + assert fx_indexmap.decode(fx_indexmap.encode("a")) == "a" + assert fx_indexmap.decode(fx_indexmap.encode("b")) == "b" + assert fx_indexmap.decode(fx_indexmap.encode("c")) == "c" + + def test_errmsg(self, fx_indexmap: IndexMap[str]) -> None: + c0 = fx_indexmap.decode(0) + c1 = fx_indexmap.decode(1) + c2 = fx_indexmap.decode(2) + assert fx_indexmap.decode_errmsg("hoge-fuga piyo (0)") == f"Hoge-fuga piyo (check {c0})." + assert fx_indexmap.decode_errmsg("hoge-fuga piyo (1)") == f"Hoge-fuga piyo (check {c1})." + assert fx_indexmap.decode_errmsg("hoge-fuga piyo (2)") == f"Hoge-fuga piyo (check {c2})." + assert fx_indexmap.decode_errmsg("hoge (0, 1)") == f"Hoge (check {c0}, {c1})." + assert fx_indexmap.decode_errmsg("hoge (0, 1, 2)") == f"Hoge (check {c0}, {c1}, {c2})." + with pytest.raises(ValueError, match=r"3 not found."): + fx_indexmap.decode_errmsg("hoge (3)") + with pytest.raises(ValueError, match=r"Cannot parse message:"): + fx_indexmap.decode_errmsg("piyo") + with pytest.raises(ValueError, match=r"Cannot parse message:"): + fx_indexmap.decode_errmsg("Hoge (0)") + with pytest.raises(ValueError, match=r"Cannot parse message:"): + fx_indexmap.decode_errmsg("hoge. (0)") + with pytest.raises(ValueError, match=r"Cannot parse message:"): + fx_indexmap.decode_errmsg("hoge (0).") diff --git a/tests/test_flow.py b/tests/test_flow.py index ddc6a1e0..ff379098 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -1,4 +1,4 @@ -"""Test flow.""" +from __future__ import annotations import pytest from fastflow import flow @@ -8,6 +8,7 @@ @pytest.mark.parametrize("c", CASES) def test_flow_graphix(c: FlowTestCase) -> None: - """Compare the results with the graphix package.""" result = flow.find(c.g, c.iset, c.oset) assert result == c.flow + if result is not None: + flow.verify(result, c.g, c.iset, c.oset) diff --git a/tests/test_gflow.py b/tests/test_gflow.py index 80583193..a4dc830e 100644 --- a/tests/test_gflow.py +++ b/tests/test_gflow.py @@ -1,13 +1,25 @@ -"""Test gflow.""" +from __future__ import annotations +import networkx as nx import pytest from fastflow import gflow +from fastflow.common import Plane from tests.assets import CASES, FlowTestCase @pytest.mark.parametrize("c", CASES) def test_gflow_graphix(c: FlowTestCase) -> None: - """Compare the results with the graphix package.""" result = gflow.find(c.g, c.iset, c.oset, c.plane) assert result == c.gflow + if result is not None: + gflow.verify(result, c.g, c.iset, c.oset, c.plane) + + +def test_gflow_redundant() -> None: + g: nx.Graph[int] = nx.Graph([(0, 1)]) + iset = {0} + oset = {1} + planes = {0: Plane.XY, 1: Plane.XY} + with pytest.warns(UserWarning, match=r".*Ignoring plane\[v\] where v in oset\..*"): + gflow.find(g, iset, oset, planes) diff --git a/tests/test_pflow.py b/tests/test_pflow.py index 3e7d35cf..d7e8c2a7 100644 --- a/tests/test_pflow.py +++ b/tests/test_pflow.py @@ -1,13 +1,35 @@ -"""Test Pauli flow.""" +from __future__ import annotations +import networkx as nx import pytest from fastflow import pflow +from fastflow.common import PPlane from tests.assets import CASES, FlowTestCase +@pytest.mark.filterwarnings("ignore:No Pauli measurement found") @pytest.mark.parametrize("c", CASES) def test_pflow_graphix(c: FlowTestCase) -> None: - """Compare the results with the graphix package.""" result = pflow.find(c.g, c.iset, c.oset, c.pplane) assert result == c.pflow + if result is not None: + pflow.verify(result, c.g, c.iset, c.oset, c.pplane) + + +def test_pflow_nopauli() -> None: + g: nx.Graph[int] = nx.Graph([(0, 1)]) + iset = {0} + oset = {1} + planes = {0: PPlane.XY} + with pytest.warns(UserWarning, match=r".*No Pauli measurement found\. Use gflow\.find instead\..*"): + pflow.find(g, iset, oset, planes) + + +def test_pflow_redundant() -> None: + g: nx.Graph[int] = nx.Graph([(0, 1)]) + iset = {0} + oset = {1} + planes = {0: PPlane.X, 1: PPlane.Y} + with pytest.warns(UserWarning, match=r".*Ignoring pplane\[v\] where v in oset\..*"): + pflow.find(g, iset, oset, planes)