Skip to content

Commit

Permalink
add dropout and z-normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
dsethz committed Nov 18, 2024
1 parent 949cea1 commit 7b5dd90
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 47 deletions.
14 changes: 13 additions & 1 deletion src/nuclai/rep_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def _get_args(mode: str) -> argparse.Namespace:
help="Learning rate of the optimizer. Default is 1e-4.",
)

parser.add_argument(
"--dropout",
type=float,
default=0.1,
help="Dropout used during training. Default is 0.1.",
)

parser.add_argument(
"--shape",
type=int,
Expand Down Expand Up @@ -238,6 +245,7 @@ def train():
epochs = args.epochs
batch_size = args.batch_size
lr = args.lr
dropout = args.dropout
shape = args.shape
log_frequency = args.log_frequency
multiprocessing = args.multiprocessing
Expand Down Expand Up @@ -273,6 +281,10 @@ def train():
isinstance(lr, float) and lr > 0
), "Learning rate must be a positive float."

assert (
isinstance(dropout, float) and dropout >= 0 and dropout < 1
), "Dropout must be a float in range [0, 1)."

assert (
isinstance(shape, list) and len(shape) == 3
), "Shape must be a list of 3 integers."
Expand Down Expand Up @@ -383,7 +395,7 @@ def train():
commitment_cost=0.25,
decay=0.5,
epsilon=1e-5,
dropout=0.0,
dropout=dropout,
ddp_sync=True,
use_checkpointing=False,
shape=shape,
Expand Down
96 changes: 50 additions & 46 deletions src/nuclai/utils/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,15 @@ class DataSet:
path_data: path to CSV file containing image paths and header "image".
trans: Compose of transforms to apply to each image.
shape: shape of the input image.
bit_depth: bit depth of the input image.
(remove) bit_depth: bit depth of the input image.
"""

def __init__(
self,
path_data: Union[str, pathlib.PosixPath, pathlib.WindowsPath],
trans: Optional[transforms.Compose] = None,
shape: tuple[int, ...] = (30, 300, 300),
bit_depth: int = 8,
# bit_depth: int = 8,
):
super().__init__()

Expand Down Expand Up @@ -219,9 +219,9 @@ def __init__(
isinstance(i, int) for i in shape
), "values of shape should be of type integer."

assert isinstance(
bit_depth, int
), f'type of bit_depth should be int instead it is of type: "{type(bit_depth)}".'
# assert isinstance(
# bit_depth, int
# ), f'type of bit_depth should be int instead it is of type: "{type(bit_depth)}".'

self.path_data = path_data
self.data = pd.read_csv(path_data)
Expand All @@ -233,15 +233,15 @@ def __init__(
"image" in self.data.columns
), 'The input file requires "image" as header.'

if bit_depth == 8:
self.bit_depth = np.uint8
elif bit_depth == 16:
self.bit_depth = np.int32
else:
self.bit_depth = np.uint8
raise Warning(
f'bit_depth must be in {8, 16}, but is "{bit_depth}". It will be handled as 8bit and may create an integer overflow.'
)
# if bit_depth == 8:
# self.bit_depth = np.uint8
# elif bit_depth == 16:
# self.bit_depth = np.int32
# else:
# self.bit_depth = np.uint8
# raise Warning(
# f'bit_depth must be in {8, 16}, but is "{bit_depth}". It will be handled as 8bit and may create an integer overflow.'
# )

def __len__(self):
return len(self.data)
Expand Down Expand Up @@ -285,7 +285,7 @@ def _preprocess(self, img: np.array) -> tuple[torch.Tensor, torch.Tensor]:
len(img.shape) == 3
), f'images are expected to be grayscale and len(img.shape)==3, here it is: "{len(img.shape)}".'

img = img.astype(self.bit_depth)
# img = img.astype(self.bit_depth)

img_t = torch.from_numpy(img).type(torch.FloatTensor)
img_t = torch.unsqueeze(img_t, 0)
Expand Down Expand Up @@ -334,21 +334,21 @@ def setup(self, stage: Optional[str] = None):
"""

# catch image data type
tmp = pd.read_csv(self.path_data)
img = tifffile.imread(tmp.loc[0, "image"])

if img.dtype == np.uint8:
max_intensity = 255.0
bit_depth = 8
elif img.dtype == np.uint16:
max_intensity = 65535.0
bit_depth = 16
else:
max_intensity = 255.0
bit_depth = 8
raise Warning(
f'Image type "{img.dtype}" is currently not supported and will be converted to "uint8".'
)
# tmp = pd.read_csv(self.path_data)
# img = tifffile.imread(tmp.loc[0, "image"])

# if img.dtype == np.uint8:
# max_intensity = 255.0
# bit_depth = 8
# elif img.dtype == np.uint16:
# max_intensity = 65535.0
# bit_depth = 16
# else:
# max_intensity = 255.0
# bit_depth = 8
# raise Warning(
# f'Image type "{img.dtype}" is currently not supported and will be converted to "uint8".'
# )

if stage == "fit" or stage is None:
assert self.path_data_val is not None, "path_data_val is missing."
Expand All @@ -357,9 +357,10 @@ def setup(self, stage: Optional[str] = None):
# TODO: use translation to not only have centered images (e.g. RandAffine)
trans = transforms.Compose(
[
transforms.NormalizeIntensity(
subtrahend=0, divisor=max_intensity
),
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
transforms.RandZoom(keep_size=True),
transforms.RandAxisFlip(),
transforms.RandAdjustContrast(),
Expand All @@ -371,57 +372,60 @@ def setup(self, stage: Optional[str] = None):

trans_val = transforms.Compose(
[
transforms.NormalizeIntensity(
subtrahend=0, divisor=max_intensity
),
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
]
)

self.data = DataSet(
self.path_data,
trans=trans,
shape=self.shape,
bit_depth=bit_depth,
# bit_depth=bit_depth,
)
self.data_val = DataSet(
self.path_data_val,
trans=trans_val,
shape=self.shape,
bit_depth=bit_depth,
# bit_depth=bit_depth,
)

if stage == "test" or stage is None:
# instantiate transforms and datasets
trans = transforms.Compose(
[
transforms.NormalizeIntensity(
subtrahend=0, divisor=max_intensity
),
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
]
)

self.data_test = DataSet(
self.path_data,
trans=trans,
shape=self.shape,
bit_depth=bit_depth,
# bit_depth=bit_depth,
)

if stage == "predict" or stage is None:
# instantiate transforms and datasets
trans = transforms.Compose(
[
transforms.NormalizeIntensity(
subtrahend=0, divisor=max_intensity
),
# transforms.NormalizeIntensity(
# subtrahend=0, divisor=max_intensity
# ),
transforms.NormalizeIntensity(),
]
)

self.data_predict = DataSet(
self.path_data,
trans=trans,
shape=self.shape,
bit_depth=bit_depth,
# bit_depth=bit_depth,
)

def train_dataloader(self):
Expand Down

0 comments on commit 7b5dd90

Please sign in to comment.