Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨Fully support decathlon datalist #63

Merged
merged 15 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .github/workflows/deploy_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ jobs:
os: windows-latest,
python-version: "3.9"
}
- {
name: "ubuntu-latest - Python 3.8",
os: ubuntu-latest,
python-version: "3.8"
}
- {
name: "ubuntu-latest - Python 3.9",
os: ubuntu-latest,
Expand Down
25 changes: 15 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand All @@ -26,34 +26,34 @@ repos:
- id: requirements-txt-fixer

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.277
rev: v0.3.2
hooks:
- id: ruff
name: ruff
args: [ --fix, --exit-non-zero-on-fix ]

- repo: https://github.com/psf/black
rev: 23.3.0
rev: 24.2.0
hooks:
- id: black
name: format code

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
name: format imports
args: ["--profile", "black"]

- repo: https://github.com/asottile/pyupgrade
rev: v3.8.0
rev: v3.15.1
hooks:
- id: pyupgrade
name: upgrade code
args: ["--py38-plus"]
args: ["--py39-plus"]

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
rev: 0.7.17
hooks:
- id: mdformat
name: format markdown
Expand All @@ -63,25 +63,30 @@ repos:
exclude: CHANGELOG.md

- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8
name: check PEP8
args: ["--ignore=E501,W503,E203"]

- repo: https://github.com/hadialqattan/pycln
rev: v2.1.5
rev: v2.4.0
hooks:
- id: pycln
name: prune imports
args: [--expand-stars]

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.0
rev: 1.8.4
hooks:
- id: nbqa-black
additional_dependencies: [black]
name: format notebooks
- id: nbqa-mypy
additional_dependencies: [mypy]
name: static analysis for notebooks

- repo: https://github.com/crate-ci/typos
rev: v1.19.0
hooks:
- id: typos
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ segmantic-unet train-config -c config.yml

## What is this tisse_list?

The example above included a tissue_list option. This is a path to a text file specifying the labels contained in a segmented image. By convention the 'label=0' is the background and is ommited from the the format. A segmentation with three tissues 'Bone'=1, 'Fat'=2, and 'Skin'=3 would be specified as follows:
The example above included a tissue_list option. This is a path to a text file specifying the labels contained in a segmented image. By convention the 'label=0' is the background and is omitted from the the format. A segmentation with three tissues 'Bone'=1, 'Fat'=2, and 'Skin'=3 would be specified as follows:

```
V7
Expand All @@ -96,7 +96,7 @@ Instead of providing the 'image_dir'/'labels_dir' pair, the training data can al

```json
{
"dataset": ["/dataA/dataset.json", "/dataB/dataset.json"],
"datalist": ["/dataA/dataset.json", "/dataB/dataset.json"],
"output_dir": "<path where trained model and logs are saved>",
"Etc": "etc"
}
Expand Down
17 changes: 11 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ build-backend = "setuptools.build_meta"
name = "segmantic"
authors = [{name = "Bryn Lloyd", email = "[email protected]"}]
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
license = {file = "LICENSE"}
dynamic = ["version"]
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down Expand Up @@ -72,7 +71,7 @@ profile = "black"
disallow_untyped_defs = false
warn_unused_configs = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_unused_ignores = false
warn_return_any = true
strict_equality = true
no_implicit_optional = false
Expand All @@ -90,9 +89,15 @@ module = [
disallow_untyped_defs = true

[[tool.mypy.overrides]]
module = "itk.*,PIL,matplotlib.*,torch,torchvision.*,numba,setuptools,pytest,typer.*,click,colorama,nibabel,sklearn.*,yaml,scipy.*,adabelief_pytorch,h5py,SimpleITK,vtk.*,sitk_cli"
module = "itk.*,PIL,matplotlib.*,torchvision.*,numba,setuptools,pytest,typer.*,click,colorama,nibabel,sklearn.*,yaml,scipy.*,adabelief_pytorch,h5py,SimpleITK,vtk.*,sitk_cli"
ignore_missing_imports = true

[tool.ruff]
select = ["E", "F"]
ignore = ["E501"]
lint.select = ["E", "F"]
lint.ignore = ["E501"]

[tool.typos]

[tool.typos.default.extend-identifiers]
# *sigh* monai adds 'd' to dictionary transforms
SpatialPadd = "SpatialPadd"
15 changes: 14 additions & 1 deletion scripts/check_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,18 @@ def fix_binary_masks(directory: Path, file_glob: str = "*.nii.gz"):
)


def round_half_up(input_dir: Path, output_dir: Path = None):
import SimpleITK as sitk

for f in input_dir.glob("*.nii.gz"):
img = sitk.ReadImage(f)
img_np = sitk.GetArrayViewFromImage(img)
imin, imax = np.min(img_np), np.max(img_np)
if imin < 0 or imax > 3:
print(f"{f.name}: [{imin}, {imax}]")
if img.GetPixelID() in (sitk.sitkFloat32, sitk.sitkFloat64):
print(f"{f.name}: {img.GetPixelIDTypeAsString()}")


if __name__ == "__main__":
typer.run(fix_binary_masks)
typer.run(round_half_up)
31 changes: 31 additions & 0 deletions scripts/check_training_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path

import numpy as np
import SimpleITK as sitk
import typer

from segmantic.utils.file_iterators import find_matching_files


def check_training_data(
image_dir: Path, labels_dir: Path, copy_image_information: bool = False
):
matches = find_matching_files([image_dir / "*.nii.gz", labels_dir / "*.nii.gz"])
for p in matches:
img = sitk.ReadImage(p[0])
lbl = sitk.ReadImage(p[1])
if img.GetSize() != lbl.GetSize():
print(f"Size mismatch {p[0].name}: {img.GetSize()} != {lbl.GetSize()}")
continue
if copy_image_information:
lbl.CopyInformation(img)
sitk.WriteImage(sitk.Cast(lbl, sitk.sitkUInt8), p[1])
elif img.GetSpacing() != lbl.GetSpacing() or img.GetOrigin() != lbl.GetOrigin():
np.testing.assert_almost_equal(
img.GetSpacing(), lbl.GetSpacing(), decimal=2
)
np.testing.assert_almost_equal(img.GetOrigin(), lbl.GetOrigin(), decimal=2)


if __name__ == "__main__":
typer.run(check_training_data)
2 changes: 0 additions & 2 deletions scripts/evaluate_segmentations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

from pathlib import Path

import SimpleITK as sitk
Expand Down
22 changes: 22 additions & 0 deletions scripts/extract_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pathlib import Path
from typing import Optional

import typer


def extract_unet(input_file: Path, output_file: Optional[Path] = None):
"""Load segmantic unet lightning module and export inner monai UNet"""
import torch

from segmantic.seg.monai_unet import Net

if output_file is None:
output_file = input_file.with_suffix(".pth")
if output_file.exists() and output_file.samefile(input_file):
raise RuntimeError("Input and output file are identical")
net = Net.load_from_checkpoint(input_file)
torch.save(net._model.state_dict(), output_file)


if __name__ == "__main__":
typer.run(extract_unet)
36 changes: 0 additions & 36 deletions scripts/generate_dataset.py

This file was deleted.

81 changes: 81 additions & 0 deletions scripts/make_datalist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
import random
from pathlib import Path

import typer

from segmantic.image.labels import load_tissue_list
from segmantic.utils.file_iterators import find_matching_files


def make_datalist(
data_dir: Path = typer.Option(
...,
help="root data directory. Paths in datalist will be relative to this directory",
),
image_dir: Path = typer.Option(..., help="Directory containing images"),
labels_dir: Path = typer.Option(None, help="Directory containing labels"),
datalist_path: Path = typer.Option(..., help="Filename of output datalist"),
num_channels: int = 1,
num_classes: int = -1,
tissuelist_path: Path = None,
percent: float = 1.0,
description: str = "",
image_glob: str = "*.nii.gz",
labels_glob: str = "*.nii.gz",
test_only: bool = False,
seed: int = 104,
) -> int:
# add labels
if tissuelist_path is not None:
tissuelist = load_tissue_list(tissuelist_path)
labels = {str(id): n for n, id in tissuelist.items() if id != 0}
elif num_classes > 0:
labels = {str(id): f"tissue{id:02d}" for id in range(1, num_classes + 1)}
else:
raise ValueError("Either specify 'tissuelist_path' or 'num_classes'")

data_config = {
"description": description,
"num_channels": num_channels,
"labels": labels,
}

# add all files as test files
if test_only:
test_files = (data_dir / image_dir).glob(image_glob)
data_config["training"] = []
data_config["validation"] = []
data_config["test"] = [str(f.relative_to(data_dir)) for f in test_files]

# build proper datalist with training/validation/test split
else:
matches = find_matching_files(
[data_dir / image_dir / image_glob, data_dir / labels_dir / labels_glob]
)
pairs = [
(p[0].relative_to(data_dir), p[1].relative_to(data_dir)) for p in matches
]

random.Random(seed).shuffle(pairs)
test, pairs = pairs[:10], pairs[10:]
num_valid = int(percent * 0.2 * len(pairs))
num_training = len(pairs) - num_valid if percent >= 1.0 else 4 * num_valid

data_config["training"] = [
{"image": str(im), "label": str(lbl)} for im, lbl in pairs[:num_training]
]
data_config["validation"] = [
{"image": str(im), "label": str(lbl)} for im, lbl in pairs[-num_valid:]
]
data_config["test"] = ([str(im) for im, _ in test],)

return datalist_path.write_text(json.dumps(data_config, indent=2))


def main():
typer.run(make_datalist)


if __name__ == "__main__":
main()
5 changes: 2 additions & 3 deletions scripts/visualize_label_surfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pathlib import Path
from typing import Dict, List

import itk
import numpy as np
Expand All @@ -15,11 +14,11 @@ def extract_surfaces(
file_path: Path,
output_dir: Path,
tissuelist_path: Path,
selected_tissues: List[int] = [],
selected_tissues: list[int] = [],
):
image = itk.imread(f"{file_path}", pixel_type=itk.US)

tissues: Dict[int, str] = {}
tissues: dict[int, str] = {}
if tissuelist_path.exists():
name_id_map = load_tissue_list(tissuelist_path)
tissues = {id: name for name, id in name_id_map.items()}
Expand Down
2 changes: 1 addition & 1 deletion src/segmantic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" ML-based segmentation for medical images
"""

__version__ = "0.3.0"
__version__ = "0.4.0"
Loading
Loading