-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
68 lines (55 loc) · 2.65 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import warnings
import pytorch_lightning as pl
import wandb
import yaml
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from datasets.data_module import SegmentationDataModule
from utils.train_utils import EarlyStoppingMinEpochs, get_training_module, get_project_name
def main(config):
# Create the segmentation module
seg_module = get_training_module(config)
# Setup Data
data_module = SegmentationDataModule(config)
# Setup W&B logging
wandb_logger = None
if config.get('logging', True):
try:
wandb_logger = WandbLogger(project=get_project_name(config['logging']),
name=config['name'],
tags=config.get('tags', []))
wandb_logger.experiment.config.update(config)
print('W&B logging connected.')
except Exception:
wandb_logger = None
warnings.warn('Skipping wandb logging')
else:
print('Logging disabled.')
# Create the PyTorch Lightning trainer
save_intermediate_checkpoints = config.get('save_checkpoints_every', 0)
trainer = pl.Trainer(max_epochs=config['epochs'],
callbacks=[
EarlyStoppingMinEpochs(config.get('min_epochs', 0),
monitor=config['early_stopping'].get('metric', 'val_loss'),
mode=config['early_stopping'].get('mode', 'min'),
patience=config['early_stopping']['patience'],
min_delta=config['early_stopping']['min_delta']),
ModelCheckpoint(save_last=True,
save_top_k=1 if save_intermediate_checkpoints > 0 else 0,
monitor=config['early_stopping'].get('metric', 'val_loss'),
mode="min",
every_n_epochs=save_intermediate_checkpoints)],
logger=wandb_logger)
# Train the model
trainer.fit(seg_module, data_module)
# Finish W&B logging
wandb.finish()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-C', '-c', type=str, default='config.yaml',
help='Path to the config file. Default: config.yaml')
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
main(config)