-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
160 lines (133 loc) · 6.07 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# PLEASE SEE: use --debug 1 if you would like to disable WANDB integration AND logging
# To see samples (logs, models, visual), please see samples/ directory
# Search the workspace for "INSERT PATH HERE" comment to change the path for dataset etc
# (RECOMMENDED SINCE 1 EPOCH DEMONSTRATION BY DEFAULT, NO WANDB)
# Demo Usage:
# python main.py --demo 2 | noise = vanilla, mask_sampling = random
# python main.py --demo 2 | noise = vanilla, mask_sampling = block
# python main.py --demo 3 | noise = gaussian, mask_sampling = grid
# python main.py --demo 4 | BASELINE FULLY SUPERVISED
# Other typical usage:
# For baseline | python main.py --debug 1 --baseline 1
# Other example 1 | python main.py --noise gaussian --mask_sampling random --fine_tune_size 2 --epochs 30
# Other example 2 | python main.py --noise vanilla --mask_sampling block --block_mask_ratio 0.5 --fine_tune_size 1 --epochs 30
import builtins
import logging
import random
import sys
import wandb
import torch
import numpy as np
# Use of wildcard imports is not good practice according to PEP 8
from data_loader import getPetDataset
from train import get_pretrain_data_loaders, start_pretrain, transfer_model, start_fine_tune
from evaluate import evaluate_fine_tuned_model
from parse_and_log import start_logging, print_command_line_arguments, get_args_parser
def set_seed(SEED):
print("Setting seed value: " + str(SEED))
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)
def print_command_line_args():
print("Number of arguments:", len(sys.argv))
print("Argument List:", sys.argv)
def main(args):
args_dict = vars(args)
if args.debug == 0:
# Initiate logging with unique identifier
# Structure: YYYY-MM-DD_HH-MM-SS.log (and corresponding images/models)
log_file_name = start_logging()
print("Log file generated: " + log_file_name)
# Override print() to print AND log the message
original_print = print
def custom_print(*args, **kwargs):
original_print(*args, **kwargs)
logging.info(' '.join(map(str, args)))
builtins.print = custom_print
if args.debug == 0 and args.demo == 4:
print("STARTED IN DEMO MODE: NO WANDB INIT")
print("Enabling baseline mode with 1 epoch")
args.baseline = 1
args.epochs = 1
elif args.debug == 0 and args.demo in [1, 2, 3]:
print("STARTED IN DEMO MODE: NO WANDB INIT")
args.noise = "gaussian"
args.mask_sampling = "grid"
args.fine_tune_size = 1
args.epochs = 1
# Regular run, not in debug mode
elif args.debug == 0 and args.demo == 0:
# Initialize wandb run
init_config = {
"batch_size": args.batch_size,
"image_size": args.image_size,
"patch_size": args.patch_size,
"weight_decay": args.weight_decay,
"learning_rate": args.lr,
"epochs": args.epochs,
"encoder_dim": args.enc_projection_dim,
"enc_heads": args.enc_num_heads,
"enc_layers": args.enc_layers,
"decoder_dim": args.dec_projection_dim,
"dec_heads": args.dec_num_heads,
"dec_layers": args.dec_layers,
"train_size": args.fine_tune_size,
"noise": args.noise,
"mask_sampling": args.mask_sampling,
"mask_ratio": args.mask_ratio,
"block_mask_ratio": args.block_mask_ratio,
"log_file": log_file_name
}
pretrain_config = init_config
pretrain_config['phase'] = "pretrain"
pretrain_run = wandb.init(
# set the wandb project where this run will be logged
project = args.wandb_project,
# track hyperparameters and run metadata
config=pretrain_config)
elif args.debug == 1 and args.demo == 0:
# Debugging is enabled
# Set Debugging arguments
print("STARTED IN DEBUG MODE: NO LOG FILE GENERATED")
print("STARTED IN DEBUG MODE: NO WANDB INIT")
args.noise = "gaussian"
args.mask_sampling = "grid"
args.fine_tune_size = 0.1
args.epochs = 1
log_file_name = 'debug.log'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Seed for reproducability, can be seen from default arg
set_seed(args.seed)
# Prints (and logs) all parameters for the experiment
print_command_line_arguments(args, args_dict, log_file_name, args.epochs)
# Get pretrain loaders
# (IF APPLICABLE) Data augmentation happens in this function
trainloader, valloader, testloader = get_pretrain_data_loaders(args)
# Pretrain model and put it on CPU / GPU
model, pretrain_metrics = start_pretrain(log_file_name, trainloader, valloader, testloader, args)
model = model.to(device)
if args.debug == 0 and args.demo == 0:
wandb.finish()
# Tranfer model
fine_tune_model = transfer_model(model, args)
# Gets the pet train, validation and test loaders
# Note: train scales with --fine_tune_size argument
pet_train_loader, pet_validation_loader, pet_test_loader = getPetDataset(args)
if args.debug == 0 and args.demo == 0:
finetune_config = init_config
finetune_config['phase'] = "finetune"
finetune_run = wandb.init(
# set the wandb project where this run will be logged
project = args.wandb_project,
# track hyperparameters and run metadata
config=finetune_config)
fine_tune_model, finetune_metrics = start_fine_tune(fine_tune_model, pet_train_loader, pet_validation_loader, log_file_name, args)
test_metrics = evaluate_fine_tuned_model(model=fine_tune_model, test_loader=pet_test_loader)
if args.debug == 0 and args.demo == 0:
for key,value in test_metrics.items():
wandb.summary[key] = value
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
main(args)