Skip to content

Commit

Permalink
Minor update to default rescaling params in resizerawtrafo (computati…
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 authored Jun 12, 2024
1 parent e21006e commit a75d581
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def get_ctc_datasets(
datasets.get_tissuenet_dataset(
path=os.path.join(input_path, "tissuenet"), split=split_choice, download=True, patch_shape=patch_shape,
raw_channel="rgb", label_channel="cell", sampler=sampler, label_dtype=label_dtype,
raw_transform=ResizeRawTrafo(patch_shape), label_transform=ResizeLabelTrafo(patch_shape, min_size=0),
raw_transform=ResizeRawTrafo(patch_shape, do_rescaling=True),
label_transform=ResizeLabelTrafo(patch_shape, min_size=0),
n_samples=1000 if split_choice == "train" else 100
),
datasets.get_livecell_dataset(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_dataloaders(patch_shape, data_path):
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID.
Important: the ID 0 is reseved for background, and the IDs must be consecutive
"""
raw_transform = ResizeRawTrafo(patch_shape)
raw_transform = ResizeRawTrafo(patch_shape, do_rescaling=True)
label_transform = ResizeLabelTrafo(patch_shape)
sampler = MinInstanceSampler()
label_dtype = torch.float32
Expand Down

0 comments on commit a75d581

Please sign in to comment.