diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 909d0a8bfd..e8fdf9dd7b 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -10,7 +10,7 @@ from typing_extensions import Self from hugr import ext, tys -from hugr.envelope import EnvelopeConfig +from hugr.envelope import EnvelopeConfig, EnvelopeFormat from hugr.hugr import Hugr from hugr.ops import AsExtOp, Command, Const, Custom, DataflowOp, ExtOp, RegisteredOp from hugr.package import Package @@ -181,20 +181,26 @@ def validate( if os.environ.get("HUGR_RENDER_DOT"): dot.pipe("svg") - # Encoding formats to test, indexed by the format name as used by - # `hugr convert --format`. + # Encoding formats to test. Note that these include other formats than + # those supported by `hugr convert`. FORMATS = { "json": EnvelopeConfig.TEXT, + "json-compressed": EnvelopeConfig(format=EnvelopeFormat.JSON, zstd=0), "model-exts": EnvelopeConfig.BINARY, + "model-exts-no-compression": EnvelopeConfig( + format=EnvelopeFormat.MODEL_WITH_EXTS, zstd=None + ), } # Envelope formats used when exporting test hugrs. - WRITE_FORMATS = ["json", "model-exts"] - # Envelope formats used as target for `hugr convert` before loading back the - # test hugrs. - # - # Model envelopes cannot currently be loaded from python. - # TODO: Add model envelope loading to python, and add it to the list. - LOAD_FORMATS = ["json"] + WRITE_FORMATS = [ + "json", + "json-compressed", + "model-exts", + "model-exts-no-compression", + ] + # Envelope formats used as target before loading back the test hugrs. + # These should correspond to the formats supported by `hugr convert`. + LOAD_FORMATS = ["json", "model-exts"] cmd = [*_base_command(), "validate", "-"] @@ -324,7 +330,7 @@ def _run_hugr_cmd(serial: bytes, cmd: list[str]) -> subprocess.CompletedProcess[ The `serial` argument is the serialized HUGR to pass to the command via stdin. """ try: - return subprocess.run(cmd, check=True, input=serial, capture_output=True) # noqa: S603 + return subprocess.run(cmd, check=True, input=serial, capture_output=True) except subprocess.CalledProcessError as e: error = e.stderr.decode() raise RuntimeError(error) from e diff --git a/hugr-py/tests/test_envelope.py b/hugr-py/tests/test_envelope.py index b42d4729fc..eb517bde3a 100644 --- a/hugr-py/tests/test_envelope.py +++ b/hugr-py/tests/test_envelope.py @@ -30,18 +30,27 @@ def package() -> Package: return Package([mod.hugr, mod2.hugr]) -def test_envelope(package: Package): - # Binary compression roundtrip - for format in [ +@pytest.mark.parametrize( + "compression", [None, 0], ids=["compression:None", "compression:0"] +) +@pytest.mark.parametrize( + "format", + [ EnvelopeFormat.JSON, EnvelopeFormat.MODEL, EnvelopeFormat.MODEL_WITH_EXTS, - ]: - for compression in [None, 0]: - encoded = package.to_bytes(EnvelopeConfig(format=format, zstd=compression)) - decoded = Package.from_bytes(encoded) - assert decoded == package + ], +) +def test_envelope_binary( + package: Package, compression: int | None, format: EnvelopeFormat +): + # Binary compression roundtrip + encoded = package.to_bytes(EnvelopeConfig(format=format, zstd=compression)) + decoded = Package.from_bytes(encoded) + assert decoded == package + +def test_envelope_text(package: Package): # String roundtrip encoded_str = package.to_str(EnvelopeConfig.TEXT) decoded = Package.from_str(encoded_str)