Skip to content

Commit

Permalink
add initial version of mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
dsethz committed Oct 30, 2024
1 parent e8718a9 commit c99174a
Show file tree
Hide file tree
Showing 3 changed files with 550 additions and 450 deletions.
133 changes: 76 additions & 57 deletions src/nuclai/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from nuclai.models.vqvae import LitVQVAE
from nuclai.models.mlp import LitMLP
from nuclai.utils.callbacks import CheckpointCallback
from nuclai.utils.datamodule import DataModule
from nuclai.utils.datamodule import DataModuleCls


def _get_args(mode: str) -> argparse.Namespace:
Expand Down Expand Up @@ -50,7 +50,7 @@ def _get_args(mode: str) -> argparse.Namespace:
parser.add_argument(
"--model",
type=str,
default="vqvae",
default="mlp",
help="Model type to train. Default is vqvae",
)

Expand All @@ -61,6 +61,34 @@ def _get_args(mode: str) -> argparse.Namespace:
help="Path to checkpoint file of trained pl.LightningModule. Default is None.",
)

parser.add_argument(
"--hidden_dim",
type=int,
nargs="+",
default=None,
help="Neurons per hidden layer provided as whitespace-separated int (e.g. 256 128). Default is None.",
)

parser.add_argument(
"--bias",
action="store_true",
help="If flag is used, linear layers of MLP will use a bias term. Default is False.",
)

parser.add_argument(
"--dropout",
type=float,
default=0.0,
help="Dropout ratio to use in MLP. Must be in [0, 1]. Default is 0.",
)

parser.add_argument(
"--loss_weight",
type=float,
default=1.0,
help="Weight of the positive class compared for the binary cross entropy loss. Default is 1.",
)

parser.add_argument(
"--epochs",
type=int,
Expand Down Expand Up @@ -124,13 +152,6 @@ def _get_args(mode: str) -> argparse.Namespace:
help="Path to checkpoint file of trained pl.LightningModule.",
)

parser.add_argument(
"--suffix",
type=str,
default="",
help="Suffix to append to all mask file names.",
)

parser.add_argument(
"--output_base_dir",
type=str,
Expand All @@ -154,7 +175,6 @@ def _initialise_inferrence(
model: str,
devices: list[str],
output_base_dir: str,
suffix: str,
) -> tuple[L.Trainer, L.LightningModule, L.LightningDataModule]:
"""
Construct trainer, model, and data module for testing/predicting
Expand All @@ -170,8 +190,6 @@ def _initialise_inferrence(
output_base_dir, str
), "Output base directory must be a string."

assert isinstance(suffix, str), "Suffix must be a string."

# create directories
os.makedirs(output_base_dir, exist_ok=True)

Expand All @@ -188,16 +206,14 @@ def _initialise_inferrence(

# load model
if os.path.isfile(model):
model = LitVQVAE.load_from_checkpoint(model)
model.suffix = suffix
model = LitMLP.load_from_checkpoint(model)
else:
raise FileNotFoundError(f'The file "{model}" does not exist.')

# set up data
data_module = DataModule(
data_module = DataModuleCls(
path_data=data,
batch_size=1,
shape=model.shape,
)

# test model
Expand All @@ -223,10 +239,13 @@ def train():
path_data_val = args.data_val
model_type = args.model
path_checkpoint = args.checkpoint
hidden_dim = args.hidden_dim
bias = args.bias
dropout = args.dropout
loss_weight = args.loss_weight
epochs = args.epochs
batch_size = args.batch_size
lr = args.lr
shape = args.shape
log_frequency = args.log_frequency
multiprocessing = args.multiprocessing
retrain = args.retrain
Expand All @@ -241,14 +260,30 @@ def train():
path_data_val
), f"File {path_data_val} does not exist."

assert model_type in [
"vqvae"
], f"Model type {model_type} is not supported."
assert model_type in ["mlp"], f"Model type {model_type} is not supported."

assert path_checkpoint is None or os.path.isfile(
path_checkpoint
), f"File {path_checkpoint} does not exist."

assert hidden_dim is None or all(
isinstance(x, int) for x in hidden_dim
), "Hidden dimensions must be a list of integers or None"

assert isinstance(
bias, bool
), f"Bias must be a boolean, but is {type(bias)}."

assert isinstance(
dropout, float
), f"Dropout must be a float, but is {type(dropout)}."

assert 0 <= dropout <= 1, "Dropout must be in [0, 1]."

assert isinstance(
loss_weight, float
), f"Loss weight must be a float, but is {type(loss_weight)}."

assert (
isinstance(epochs, int) and epochs > 0
), "Epochs must be a positive integer."
Expand All @@ -261,10 +296,6 @@ def train():
isinstance(lr, float) and lr > 0
), "Learning rate must be a positive float."

assert (
isinstance(shape, list) and len(shape) == 3
), "Shape must be a list of 3 integers."

assert (
isinstance(log_frequency, int) and log_frequency > 0
), "Log frequency must be a positive integer."
Expand All @@ -283,8 +314,7 @@ def train():
output_base_dir, str
), "Output base directory must be a string."

# reformat shape to tuple
shape = tuple(shape)
assert isinstance(devices, list), "Devices must be a list."

# create directories
d = date.today()
Expand Down Expand Up @@ -318,6 +348,7 @@ def train():
strategy = "auto"
n_devices = 1
precision = "32-true"
sync_batchnorm = False
else:
accelerator = "gpu"
n_devices = len([int(device) for device in devices])
Expand All @@ -334,53 +365,38 @@ def train():

batch_size = int(batch_size / n_devices)
strategy = "ddp"
sync_batchnorm = True
elif accelerator == "gpu":
strategy = "auto"
n_devices = 1
sync_batchnorm = False

# set up data
data_module = DataModule(
data_module = DataModuleCls(
path_data=path_data,
path_data_val=path_data_val,
batch_size=batch_size,
shape=shape,
)

# random seeding
if seed is not None:
L.pytorch.seed_everything(seed, workers=True)

# set up model
if model_type == "vqvae":
model = LitVQVAE(
spatial_dims=3,
in_channels=1,
out_channels=1,
num_channels=(256, 256, 256),
num_res_channels=256,
num_res_layers=2,
downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)),
upsample_parameters=(
(2, 4, 1, 1, 0),
(2, 4, 1, 1, 0),
(2, 4, 1, 1, 0),
),
num_embeddings=256,
embedding_dim=32,
embedding_init="normal",
commitment_cost=0.25,
decay=0.5,
epsilon=1e-5,
dropout=0.0,
ddp_sync=True,
use_checkpointing=False,
shape=shape,
if model_type == "mlp":
model = LitMLP(
input_dim=data_module.feature_dim,
output_dim=1,
hidden_dim=hidden_dim,
bias=bias,
dropout=dropout,
loss_weight=loss_weight,
learning_rate=lr,
suffix="",
)
else:
raise ValueError(f"Model type {model_type} is not supported.")

# TODO: add callback for f1 on validation set
# set up callback for best model
checkpoint_best_loss = ModelCheckpoint(
monitor="loss_val",
Expand All @@ -401,7 +417,7 @@ def train():
old_epoch = int(epoch_pattern.search(path_checkpoint)[1])
epochs += old_epoch

# train model
# set up trainer
logger = CSVLogger(output_base_dir, name="lightning_logs")
trainer = L.Trainer(
max_epochs=epochs,
Expand All @@ -417,7 +433,12 @@ def train():
],
log_every_n_steps=log_frequency,
precision=precision,
sync_batchnorm=sync_batchnorm,
)

# TODO: add lr tuner

# train model
trainer.fit(model, data_module, ckpt_path=path_checkpoint)


Expand All @@ -432,7 +453,6 @@ def test():
model=args.model,
devices=args.devices,
output_base_dir=args.output_base_dir,
suffix=args.suffix,
)
trainer.test(model, data_module)

Expand All @@ -448,6 +468,5 @@ def predict():
model=args.model,
devices=args.devices,
output_base_dir=args.output_base_dir,
suffix=args.suffix,
)
trainer.predict(model, data_module)
Loading

0 comments on commit c99174a

Please sign in to comment.