Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for preprocessing inputs in training CLI #879

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ examples/data/*
# Torch-em stuff
checkpoints/
logs/
*.pth
*.pt

# And some other stuff to avoid tracking as well.
gpu_jobs/
Expand All @@ -189,6 +191,10 @@ iterative_prompting_results/
*.sh
*.svg
*.csv
*.tiff
*.tif
*.zip
*MACOSX

# Related to i2k workshop folders.
data/
Expand Down
13 changes: 10 additions & 3 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..instance_segmentation import get_unetr
from . import joint_sam_trainer as joint_trainers
from ..util import get_device, get_model_names, export_custom_sam_model
from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit
from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform


FilePath = Union[str, os.PathLike]
Expand Down Expand Up @@ -487,6 +487,7 @@ def default_sam_dataset(
with_channels=with_channels,
ndim=2,
is_seg_dataset=is_seg_dataset,
raw_transform=raw_transform,
**kwargs
)
n_samples = max(len(loader), 100 if is_train else 5)
Expand Down Expand Up @@ -779,6 +780,9 @@ def main():
"--batch_size", type=int, default=1,
help="The choice of batch size for training the Segment Anything Model. By default, trains on batch size 1."
)
parser.add_argument(
"--preprocess", type=str, default=None, help="Whether to normalize the raw inputs."
)

args = parser.parse_args()

Expand All @@ -802,6 +806,9 @@ def main():

# 2. Prepare the dataloaders.

# If the user wants to preprocess the inputs, we allow the possibility to do so.
_raw_transform = get_raw_transform(args.preprocess)

# Get the dataset with files for training.
dataset = default_sam_dataset(
raw_paths=train_images,
Expand All @@ -810,6 +817,7 @@ def main():
label_key=train_gt_key,
patch_shape=patch_shape,
with_segmentation_decoder=with_segmentation_decoder,
raw_transform=_raw_transform,
)

# If val images are not exclusively provided, we create a val split from the training data.
Expand All @@ -828,6 +836,7 @@ def main():
label_key=val_gt_key,
patch_shape=patch_shape,
with_segmentation_decoder=with_segmentation_decoder,
raw_transform=_raw_transform,
)

# Get the dataloaders from the datasets.
Expand All @@ -845,8 +854,6 @@ def main():
if model_type is None: # If user does not specify the model, we use the default model corresponding to the config.
model_type = CONFIGURATIONS[config]["model_type"]

print(model_type, config)

train_sam_for_configuration(
name=checkpoint_name,
configuration=config,
Expand Down
31 changes: 30 additions & 1 deletion micro_sam/training/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from math import ceil, floor
from typing import Dict, List, Optional, Union, Tuple
from functools import partial
from typing import Dict, List, Optional, Union, Tuple, Callable

import numpy as np

Expand Down Expand Up @@ -39,6 +40,34 @@ def require_8bit(x):
return x


def _raw_transform(image: np.ndarray, raw_trafo: Callable) -> np.ndarray:
return raw_trafo(image) * 255


def get_raw_transform(preprocess: Optional[str] = None) -> Optional[Callable]:
"""Transformation functions to normalize inputs.

Args:
preprocess: By default, the transformation function is set to 'None'.
The user can choose from 'normalize_minmax' / 'normalize_percentile'.

Returns:
The transformation function.
"""

if preprocess is None:
return None
else:
if preprocess == "normalize_minmax":
raw_trafo = normalize
elif preprocess == "normalize_percentile":
raw_trafo = normalize_percentile
else:
raise ValueError(f"'{preprocess}' is not a supported preprocessing.")

return partial(_raw_transform, raw_trafo=raw_trafo)


def get_trainable_sam_model(
model_type: str = _DEFAULT_MODEL,
device: Optional[Union[str, torch.device]] = None,
Expand Down