Skip to content

Commit

Permalink
black and pyright fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Nov 8, 2024
1 parent a243cad commit bdef5f6
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 49 deletions.
12 changes: 8 additions & 4 deletions example/dataset_creation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@
" \"2018 Data Science Bowl sponsored by Booz Allen Hamilton with cash prizes. The image set was a testing ground \"\n",
" \"for the application of novel and cutting edge approaches in computer vision and machine learning to the \"\n",
" \"segmentation of the nuclei belonging to cells from a breadth of biological contexts.\",\n",
" documentation=HttpUrl(\"https://uk1s3.embassy.ebi.ac.uk/public-datasets/examples.bioimage.io/dsb-2018.md\"),\n",
" documentation=HttpUrl(\n",
" \"https://uk1s3.embassy.ebi.ac.uk/public-datasets/examples.bioimage.io/dsb-2018.md\"\n",
" ),\n",
" covers=[\n",
" HttpUrl(\n",
" \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage1.png\"\n",
Expand Down Expand Up @@ -118,7 +120,9 @@
"\n",
"from bioimageio.spec import save_bioimageio_package\n",
"\n",
"exported = save_bioimageio_package(dataset, output_path=Path(\"my_bioimageio_dataset.zip\"))\n",
"exported = save_bioimageio_package(\n",
" dataset, output_path=Path(\"my_bioimageio_dataset.zip\")\n",
")\n",
"print(f\"exported dataset description to {exported.absolute()}\")"
]
},
Expand Down Expand Up @@ -167,8 +171,8 @@
" img: NDArray[Any] = imread(downloaded.path.read_bytes())\n",
" _ = plt.imshow(img)\n",
" _ = plt.title(downloaded.original_file_name)\n",
" _ = plt.axis('off')\n",
" _ = plt.show()\n"
" _ = plt.axis(\"off\")\n",
" _ = plt.show()"
]
}
],
Expand Down
139 changes: 99 additions & 40 deletions example/load_model_and_create_your_own.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@
"MODEL_DRAFT = \"emotional-cricket/draft\"\n",
"\n",
"# version specific ID/DOI\n",
"MODEL_VERSION_ID = \"emotional-cricket/1.1\" # recommended to preserve backward compatibility\n",
"MODEL_VERSION_ID = (\n",
" \"emotional-cricket/1.1\" # recommended to preserve backward compatibility\n",
")\n",
"MODEL_VERSION_DOI = \"10.5281/zenodo.7768142\" # version DOI of backup on zenodo.org\n",
"\n",
"# version unspecific (implicitly refering to the latest version):\n",
Expand Down Expand Up @@ -212,6 +214,7 @@
" img: NDArray[Any] = imageio.v3.imread(download(src).path)\n",
" return img\n",
"\n",
"\n",
"print(f\"The model is named '{model.name}'\")\n",
"print(f\"Description:\\n{model.description}\")\n",
"print(f\"License: {model.license}\")"
Expand Down Expand Up @@ -296,14 +299,25 @@
"metadata": {},
"outputs": [],
"source": [
"for w in [(weights := model.weights).onnx, weights.keras_hdf5, weights.tensorflow_js, weights.tensorflow_saved_model_bundle, weights.torchscript,weights.pytorch_state_dict]:\n",
" if w is None:\n",
"for w in [\n",
" (weights := model.weights).onnx,\n",
" weights.keras_hdf5,\n",
" weights.tensorflow_js,\n",
" weights.tensorflow_saved_model_bundle,\n",
" weights.torchscript,\n",
" weights.pytorch_state_dict,\n",
"]:\n",
" if w is None:\n",
" continue\n",
"\n",
" print(w.weights_format_name)\n",
" print(f\"weights are available at {w.source.absolute()}\")\n",
" print(f\"and have a SHA-256 value of {w.sha256}\")\n",
" details = {k: v for k, v in w.model_dump(mode=\"json\", exclude_none=True).items() if k not in (\"source\", \"sha256\")}\n",
" details = {\n",
" k: v\n",
" for k, v in w.model_dump(mode=\"json\", exclude_none=True).items()\n",
" if k not in (\"source\", \"sha256\")\n",
" }\n",
" if details:\n",
" print(f\"additonal metadata for {w.weights_format_name}:\")\n",
" pprint(details)\n",
Expand All @@ -324,7 +338,9 @@
"metadata": {},
"outputs": [],
"source": [
"print(f\"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:\")\n",
"print(\n",
" f\"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:\"\n",
")\n",
"for ipt in model.inputs:\n",
" print(f\"\\ninput '{ipt.id}' with axes:\")\n",
" pprint(ipt.axes)\n",
Expand All @@ -335,9 +351,13 @@
" for p in ipt.preprocessing:\n",
" print(p)\n",
"\n",
"print(\"\\n-------------------------------------------------------------------------------\")\n",
"print(\n",
" \"\\n-------------------------------------------------------------------------------\"\n",
")\n",
"# # and what the model outputs are\n",
"print(f\"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:\")\n",
"print(\n",
" f\"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:\"\n",
")\n",
"for out in model.outputs:\n",
" print(f\"\\noutput '{out.id}' with axes:\")\n",
" pprint(out.axes)\n",
Expand Down Expand Up @@ -372,7 +392,7 @@
")\n",
"\n",
"assert isinstance(model, ModelDescr)\n",
"if (w:=model.weights.pytorch_state_dict) is not None:\n",
"if (w := model.weights.pytorch_state_dict) is not None:\n",
" arch = w.architecture\n",
" print(f\"callable: {arch.callable}\")\n",
" if isinstance(arch, ArchitectureFromFileDescr):\n",
Expand Down Expand Up @@ -459,27 +479,32 @@
" WeightsDescr,\n",
")\n",
"\n",
"input_axes = [\n",
" BatchAxis(),\n",
" ChannelAxis(channel_names=[Identifier(\"raw\")])]\n",
"if len(model.inputs[0].axes)==5: # e.g. impartial-shrimp\n",
"input_axes = [BatchAxis(), ChannelAxis(channel_names=[Identifier(\"raw\")])]\n",
"if len(model.inputs[0].axes) == 5: # e.g. impartial-shrimp\n",
" input_axes += [\n",
" SpaceInputAxis(id=AxisId(\"z\"), size=ParameterizedSize(min=16, step=8)),\n",
" SpaceInputAxis(id=AxisId('y'), size=ParameterizedSize(min=144, step=72)),\n",
" SpaceInputAxis(id=AxisId('x'), size=ParameterizedSize(min=144, step=72)),\n",
" SpaceInputAxis(id=AxisId(\"y\"), size=ParameterizedSize(min=144, step=72)),\n",
" SpaceInputAxis(id=AxisId(\"x\"), size=ParameterizedSize(min=144, step=72)),\n",
" ]\n",
" data_descr = IntervalOrRatioDataDescr(type=\"float32\")\n",
"elif len(model.inputs[0].axes)==4: # e.g. pioneering-rhino\n",
"elif len(model.inputs[0].axes) == 4: # e.g. pioneering-rhino\n",
" input_axes += [\n",
" SpaceInputAxis(id=AxisId('y'), size=ParameterizedSize(min=256, step=8)),\n",
" SpaceInputAxis(id=AxisId('x'), size=ParameterizedSize(min=256, step=8)),\n",
" SpaceInputAxis(id=AxisId(\"y\"), size=ParameterizedSize(min=256, step=8)),\n",
" SpaceInputAxis(id=AxisId(\"x\"), size=ParameterizedSize(min=256, step=8)),\n",
" ]\n",
" data_descr = IntervalOrRatioDataDescr(type=\"float32\")\n",
"else:\n",
" raise NotImplementedError(f\"Recreating inputs for {example_model_id} is not implemented\")\n",
" raise NotImplementedError(\n",
" f\"Recreating inputs for {example_model_id} is not implemented\"\n",
" )\n",
"\n",
"test_input_path = model.inputs[0].test_tensor.download().path\n",
"input_descr = InputTensorDescr(id=TensorId(\"raw\"), axes=input_axes, test_tensor=FileDescr(source=test_input_path), data=data_descr)"
"input_descr = InputTensorDescr(\n",
" id=TensorId(\"raw\"),\n",
" axes=input_axes,\n",
" test_tensor=FileDescr(source=test_input_path),\n",
" data=data_descr,\n",
")"
]
},
{
Expand All @@ -500,24 +525,47 @@
"assert isinstance(model.outputs[0].axes[1], ChannelAxis)\n",
"output_axes = [\n",
" BatchAxis(),\n",
" ChannelAxis(channel_names=[Identifier(n) for n in model.outputs[0].axes[1].channel_names])\n",
" ChannelAxis(\n",
" channel_names=[Identifier(n) for n in model.outputs[0].axes[1].channel_names]\n",
" ),\n",
"]\n",
"if len(model.outputs[0].axes) == 5: # e.g. impartial-shrimp\n",
"if len(model.outputs[0].axes) == 5: # e.g. impartial-shrimp\n",
" output_axes += [\n",
" SpaceOutputAxis(id=AxisId(\"z\"), size=SizeReference(tensor_id=TensorId(\"raw\"), axis_id=AxisId(\"z\"))), # same size as input (tensor `raw`) axis `z`\n",
" SpaceOutputAxis(id=AxisId('y'), size=SizeReference(tensor_id=TensorId(\"raw\"), axis_id=AxisId(\"y\"))),\n",
" SpaceOutputAxis(id=AxisId('x'), size=SizeReference(tensor_id=TensorId(\"raw\"), axis_id=AxisId(\"x\")))\n",
" SpaceOutputAxis(\n",
" id=AxisId(\"z\"),\n",
" size=SizeReference(tensor_id=TensorId(\"raw\"), axis_id=AxisId(\"z\")),\n",
" ), # same size as input (tensor `raw`) axis `z`\n",
" SpaceOutputAxis(\n",
" id=AxisId(\"y\"),\n",
" size=SizeReference(tensor_id=TensorId(\"raw\"), axis_id=AxisId(\"y\")),\n",
" ),\n",
" SpaceOutputAxis(\n",
" id=AxisId(\"x\"),\n",
" size=SizeReference(tensor_id=TensorId(\"raw\"), axis_id=AxisId(\"x\")),\n",
" ),\n",
" ]\n",
"elif len(model.outputs[0].axes) == 4: # e.g. pioneering-rhino\n",
"elif len(model.outputs[0].axes) == 4: # e.g. pioneering-rhino\n",
" output_axes += [\n",
" SpaceOutputAxis(id=AxisId(\"y\"), size=SizeReference(tensor_id=TensorId('raw'), axis_id=AxisId('y'))), # same size as input (tensor `raw`) axis `y`\n",
" SpaceOutputAxis(id=AxisId(\"x\"), size=SizeReference(tensor_id=TensorId('raw'), axis_id=AxisId('x'))),\n",
" SpaceOutputAxis(\n",
" id=AxisId(\"y\"),\n",
" size=SizeReference(tensor_id=TensorId(\"raw\"), axis_id=AxisId(\"y\")),\n",
" ), # same size as input (tensor `raw`) axis `y`\n",
" SpaceOutputAxis(\n",
" id=AxisId(\"x\"),\n",
" size=SizeReference(tensor_id=TensorId(\"raw\"), axis_id=AxisId(\"x\")),\n",
" ),\n",
" ]\n",
"else:\n",
" raise NotImplementedError(f\"Recreating outputs for {example_model_id} is not implemented\")\n",
" raise NotImplementedError(\n",
" f\"Recreating outputs for {example_model_id} is not implemented\"\n",
" )\n",
"\n",
"test_output_path = model.outputs[0].test_tensor.download().path\n",
"output_descr = OutputTensorDescr(id=TensorId(\"prob\"), axes=output_axes, test_tensor=FileDescr(source=test_output_path))"
"output_descr = OutputTensorDescr(\n",
" id=TensorId(\"prob\"),\n",
" axes=output_axes,\n",
" test_tensor=FileDescr(source=test_output_path),\n",
")"
]
},
{
Expand Down Expand Up @@ -561,7 +609,7 @@
" source=arch_file_path,\n",
" sha256=arch_file_sha256,\n",
" callable=arch_name,\n",
" kwargs=arch_kwargs\n",
" kwargs=arch_kwargs,\n",
" )\n",
"else:\n",
" # For a model architecture that is published in a Python package\n",
Expand All @@ -570,7 +618,7 @@
" callable=arch.callable,\n",
" kwargs=arch.kwargs,\n",
" import_from=arch.import_from,\n",
" )\n"
" )"
]
},
{
Expand Down Expand Up @@ -601,11 +649,19 @@
"my_model_descr = ModelDescr(\n",
" name=\"My cool model\",\n",
" description=\"A test model for demonstration purposes only\",\n",
" authors=[Author(name=\"me\", affiliation=\"my institute\", github_user=\"bioimageiobot\")], # change github_user to your GitHub account name\n",
" cite=[CiteEntry(text=\"for model training see my paper\", doi=Doi(\"10.1234something\"))],\n",
" authors=[\n",
" Author(name=\"me\", affiliation=\"my institute\", github_user=\"bioimageiobot\")\n",
" ], # change github_user to your GitHub account name\n",
" cite=[\n",
" CiteEntry(text=\"for model training see my paper\", doi=Doi(\"10.1234something\"))\n",
" ],\n",
" license=LicenseId(\"MIT\"),\n",
" documentation=HttpUrl(\"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/README.md\"),\n",
" git_repo=HttpUrl(\"https://github.com/bioimage-io/spec-bioimage-io\"), # change to repo where your model is developed\n",
" documentation=HttpUrl(\n",
" \"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/README.md\"\n",
" ),\n",
" git_repo=HttpUrl(\n",
" \"https://github.com/bioimage-io/spec-bioimage-io\"\n",
" ), # change to repo where your model is developed\n",
" inputs=model.inputs,\n",
" # inputs=[input_descr], # try out our recreated input description\n",
" outputs=model.outputs,\n",
Expand All @@ -615,17 +671,17 @@
" source=model.weights.pytorch_state_dict.source,\n",
" sha256=model.weights.pytorch_state_dict.sha256,\n",
" architecture=pytorch_architecture,\n",
" pytorch_version=pytorch_version\n",
" pytorch_version=pytorch_version,\n",
" ),\n",
" torchscript=TorchscriptWeightsDescr(\n",
" source=model.weights.torchscript.source,\n",
" sha256=model.weights.torchscript.sha256,\n",
" pytorch_version=pytorch_version,\n",
" parent=\"pytorch_state_dict\", # these weights were converted from the pytorch_state_dict weights ones.\n",
" parent=\"pytorch_state_dict\", # these weights were converted from the pytorch_state_dict weights ones.\n",
" ),\n",
" ),\n",
" )\n",
"print(f\"created '{my_model_descr.name}'\")\n"
")\n",
"print(f\"created '{my_model_descr.name}'\")"
]
},
{
Expand Down Expand Up @@ -721,7 +777,10 @@
"\n",
"from bioimageio.spec import save_bioimageio_package\n",
"\n",
"print(\"package path:\", save_bioimageio_package(my_model_descr, output_path=Path('my_model.zip')))"
"print(\n",
" \"package path:\",\n",
" save_bioimageio_package(my_model_descr, output_path=Path(\"my_model.zip\")),\n",
")"
]
}
],
Expand Down
13 changes: 11 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
[tool.black]
line-length = 88
extend_exclude = "^/scripts/pdoc/original.py | ^/scripts/pdoc/patched.py"
target-version = ["py38", "py39", "py310", "py311", "py312"]
preview = true

[tool.pyright]
exclude = ["**/node_modules", "**/__pycache__", "tests/old_*", "tests/cache"]
exclude = [
"**/__pycache__",
"**/node_modules",
"scripts/pdoc/original.py",
"scripts/pdoc/patched.py",
"tests/cache",
"tests/old_*",
]
include = ["bioimageio", "scripts", "tests"]
pythonPlatform = "All"
pythonVersion = "3.12"
Expand Down Expand Up @@ -41,8 +49,9 @@ testpaths = ["bioimageio/spec", "tests", "scripts"]

[tool.ruff]
line-length = 88
include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"]
target-version = "py312"
include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"]
exclude = ["scripts/pdoc/original.py", "scripts/pdoc/patched.py"]

[tool.coverage.report]
exclude_also = ["if TYPE_CHECKING:", "assert_never\\("]
1 change: 1 addition & 0 deletions scripts/generate_version_submodule_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def process(info: Info, check: bool):
for tv in black_config.pop("target_version")
)
)
black_config.pop("extend_exclude")
updated = black.format_str(updated, mode=black.mode.Mode(**black_config))
if check:
if init_content == updated:
Expand Down
6 changes: 3 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def check_bioimageio_yaml(
) -> None:
downloaded_source = download(source)
root = downloaded_source.original_root
data: Dict[Any, Any] = yaml.load(
StringIO(downloaded_source.path.read_bytes().decode(encoding="utf-8"))
)
raw = downloaded_source.path.read_text(encoding="utf-8")
assert isinstance(raw, str)
data: Dict[Any, Any] = yaml.load(StringIO(raw))

assert isinstance(data, dict), type(data)
format_version = "latest" if as_latest else "discover"
Expand Down

0 comments on commit bdef5f6

Please sign in to comment.