From a75d581ec21670249a5ea24168e1d5767f57d29b Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 12 Jun 2024 15:47:28 +0200 Subject: [PATCH] Minor update to default rescaling params in resizerawtrafo (#635) --- .../training/light_microscopy/obtain_lm_datasets.py | 3 ++- .../training/light_microscopy/tissuenet_finetuning.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 083f8c21..8ac629b4 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -61,7 +61,8 @@ def get_ctc_datasets( datasets.get_tissuenet_dataset( path=os.path.join(input_path, "tissuenet"), split=split_choice, download=True, patch_shape=patch_shape, raw_channel="rgb", label_channel="cell", sampler=sampler, label_dtype=label_dtype, - raw_transform=ResizeRawTrafo(patch_shape), label_transform=ResizeLabelTrafo(patch_shape, min_size=0), + raw_transform=ResizeRawTrafo(patch_shape, do_rescaling=True), + label_transform=ResizeLabelTrafo(patch_shape, min_size=0), n_samples=1000 if split_choice == "train" else 100 ), datasets.get_livecell_dataset( diff --git a/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py b/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py index 5b3b9cc9..b6f75290 100644 --- a/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py +++ b/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py @@ -25,7 +25,7 @@ def get_dataloaders(patch_shape, data_path): I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. Important: the ID 0 is reseved for background, and the IDs must be consecutive """ - raw_transform = ResizeRawTrafo(patch_shape) + raw_transform = ResizeRawTrafo(patch_shape, do_rescaling=True) label_transform = ResizeLabelTrafo(patch_shape) sampler = MinInstanceSampler() label_dtype = torch.float32