forked from 91097luke/phileo-bench
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtraining_script.py
566 lines (467 loc) · 35.3 KB
/
training_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
import os
import torch
from functools import partial
from torchinfo import summary
import torch.nn as nn
from datetime import date
import argparse
import sys; sys.path.append("../")
from models.model_Baseline import BaselineNet
from models.model_CoreCNN_versions import CoreUnet_nano, CoreUnet_tiny, CoreUnet_base, CoreUnet_large, CoreUnet_huge, Core_nano
from models.model_Mixer_versions import Mixer_nano, Mixer_tiny, Mixer_base, Mixer_large, Mixer_huge
from models.model_LinearViT_versions import LinearViT_base, LinearViT_large, LinearViT_huge
from models.model_AutoEncoderViT_versions import AutoencoderViT_base, AutoencoderViT_large, AutoencoderViT_huge
from models.model_GeoAwarePretrained import MixerGeoPretrained, get_mixer_kwargs, get_core_encoder_kwargs, CoreEncoderGeoPretrained, CoreEncoderGeoPretrained_combined, CoreEncoderGeoAutoEncoder
from models.model_GeoAwarePretrained_classifier import CoreEncoderGeoPretrained_Classifier
from models.model_AutoEncoderViTPretrained import vit_cnn, vit_cnn_gc, vit_large, get_core_decoder_kwargs
from models.model_AutoEncoderViTPretrained_wSkip import vit_cnn_wSkip, vit_cnn_gc_wSkip, vit_large_wSkip
from models.model_AutoEncoderViTPretrained_classifier import vit_cnn_classifier, vit_cnn_gc_classifier
from models.model_CoreVAE import CoreVAE_nano
from models.model_SatMAE import satmae_vit_cnn
from models.models_Prithvi import prithvi
from models.model_Seco import seasonal_contrast
from models.model_Resnet50 import resnet
from utils import data_protocol
from utils import load_data
from utils import training_loops
from utils.training_utils import read_yaml
torch.manual_seed(123456)
CNN_LIST = ['baseline_cnn', 'core_unet_nano','core_unet_tiny','core_unet_base', 'core_unet_large', 'core_unet_huge',
'core_vae_nano', 'resnet_imagenet', 'resnet', 'core_encoder_nano', 'resnet_imagenet_classifier']
VIT_CNN_LIST = ['vit_cnn_base', 'vit_cnn_base_wSkip']
MIXER_LIST = ['mixer_nano', 'mixer_tiny', 'mixer_base', 'mixer_large', 'mixer_huge']
VIT_LIST = ['linear_vit_base', 'linear_vit_larger', 'linear_vit_huge',
'autoencoder_vit_base', 'autoencoder_vit_large', 'autoencoder_vit_huge']
CNN_PRETRAINED_LIST = ['GeoAware_core_nano', 'GeoAware_core_tiny', 'GeoAware_mixer_nano', 'GeoAware_mixer_tiny',
'GeoAware_contrastive_core_nano', 'GeoAware_mh_pred_core_nano', 'GeoAware_combined_core_nano',
'GeoAware_core_autoencoder_nano', 'seasonal_contrast',
'GeoAware_core_nano_classifier', 'GeoAware_contrastive_core_nano_classifier',
'GeoAware_mh_pred_core_nano_classifier', 'seasonal_contrast_classifier'
]
VIT_CNN_PRETRAINED_LIST = ['prithvi', 'vit_cnn', 'vit_cnn_gc', 'SatMAE', 'SatMAE_classifier', 'vit_cnn_gc_classifier',
'vit_cnn_classifier', 'prithvi_classifier', 'vit_cnn_wSkip', 'vit_cnn_gc_wSkip']
MODELS_224 = ['seasonal_contrast', 'resnet_imagenet', 'resnet', 'seasonal_contrast_classifier', 'resnet_imagenet_classifier']
MODELS_224_r30 = ['prithvi', 'prithvi_classifier']
MODEL_LIST = CNN_LIST + MIXER_LIST + VIT_LIST + CNN_PRETRAINED_LIST + VIT_CNN_LIST + VIT_CNN_PRETRAINED_LIST
DOWNSTREAM_LIST = ['lc', 'building', 'roads', 'lc_classification', 'building_classification', 'roads_classification']
def get_trainer(model_name, downstream_task, epochs, lr, model, device, lr_scheduler, warmup, early_stop, dl_train,
dl_val, dl_test, NAME, OUTPUT_FOLDER, vis_val, warmup_steps, warmup_gamma):
if model_name in (CNN_LIST + MIXER_LIST + VIT_CNN_LIST + CNN_PRETRAINED_LIST + VIT_CNN_PRETRAINED_LIST):
if downstream_task == 'roads' or downstream_task == 'building':
trainer = training_loops.TrainBase(epochs=epochs, lr=lr, model=model, device=device,
lr_scheduler=lr_scheduler, warmup=warmup, early_stop=early_stop,
train_loader=dl_train,
val_loader=dl_val, test_loader=dl_test, name=NAME,
out_folder=OUTPUT_FOLDER, visualise_validation=vis_val,
warmup_steps=warmup_steps, warmup_gamma=warmup_gamma)
elif downstream_task == 'lc':
trainer = training_loops.TrainLandCover(epochs=epochs, lr=lr, model=model, device=device,
lr_scheduler=lr_scheduler, warmup=warmup, early_stop=early_stop,
train_loader=dl_train,
val_loader=dl_val, test_loader=dl_test, name=NAME,
out_folder=OUTPUT_FOLDER, visualise_validation=vis_val,
warmup_steps=warmup_steps, warmup_gamma=warmup_gamma)
elif downstream_task == 'building_classification':
trainer = training_loops.TrainClassificationBuildings(epochs=epochs, lr=lr, model=model, device=device,
lr_scheduler=lr_scheduler, warmup=warmup, early_stop=early_stop,
train_loader=dl_train,
val_loader=dl_val, test_loader=dl_test, name=NAME,
out_folder=OUTPUT_FOLDER, visualise_validation=vis_val,
warmup_steps=warmup_steps, warmup_gamma=warmup_gamma
)
elif downstream_task == 'lc_classification':
trainer = training_loops.TrainClassificationLC(epochs=epochs, lr=lr, model=model, device=device,
lr_scheduler=lr_scheduler, warmup=warmup, early_stop=early_stop,
train_loader=dl_train,
val_loader=dl_val, test_loader=dl_test, name=NAME,
out_folder=OUTPUT_FOLDER, visualise_validation=vis_val,
warmup_steps=warmup_steps, warmup_gamma=warmup_gamma)
elif downstream_task == 'roads_classification':
trainer = training_loops.TrainClassificationRoads(epochs=epochs, lr=lr, model=model, device=device,
lr_scheduler=lr_scheduler, warmup=warmup, early_stop=early_stop,
train_loader=dl_train,
val_loader=dl_val, test_loader=dl_test, name=NAME,
out_folder=OUTPUT_FOLDER, visualise_validation=vis_val,
warmup_steps=warmup_steps, warmup_gamma=warmup_gamma)
elif model_name in (VIT_LIST):
if downstream_task == 'roads' or downstream_task == 'building':
trainer = training_loops.TrainViT(epochs=epochs, lr=lr, model=model, device=device,
lr_scheduler=lr_scheduler, warmup=warmup, early_stop=early_stop, train_loader=dl_train,
val_loader=dl_val, test_loader=dl_test, name=NAME,
out_folder=OUTPUT_FOLDER, visualise_validation=vis_val,
warmup_steps=warmup_steps, warmup_gamma=warmup_gamma)
elif downstream_task == 'lc':
trainer = training_loops.TrainViTLandCover(epochs=epochs, lr=lr, model=model, device=device,
lr_scheduler=lr_scheduler, warmup=warmup, early_stop=early_stop,
train_loader=dl_train,
val_loader=dl_val, test_loader=dl_test, name=NAME,
out_folder=OUTPUT_FOLDER, visualise_validation=vis_val,
warmup_steps=warmup_steps, warmup_gamma=warmup_gamma)
if model_name == 'core_vae_nano':
trainer = training_loops.TrainVAE(epochs=epochs, lr=lr, model=model, device=device,
lr_scheduler=lr_scheduler, warmup=warmup, early_stop=early_stop,
train_loader=dl_train,
val_loader=dl_val, test_loader=dl_test, name=NAME,
out_folder=OUTPUT_FOLDER, visualise_validation=vis_val,
warmup_steps=warmup_steps, warmup_gamma=warmup_gamma)
return trainer
def get_models(model_name, input_channels, output_channels, input_size):
if model_name == 'baseline_cnn':
return BaselineNet(input_dim=input_channels, output_dim=output_channels)
elif model_name == 'core_unet_nano':
return CoreUnet_nano(input_dim=input_channels, output_dim=output_channels)
elif model_name == 'core_encoder_nano':
return Core_nano(input_dim=input_channels, output_dim=output_channels)
elif model_name == 'core_unet_tiny':
return CoreUnet_tiny(input_dim=input_channels, output_dim=output_channels)
elif model_name == 'core_unet_base':
return CoreUnet_base(input_dim=input_channels, output_dim=output_channels)
elif model_name == 'core_unet_large':
return CoreUnet_large(input_dim=input_channels, output_dim=output_channels)
elif model_name == 'core_unet_huge':
return CoreUnet_huge(input_dim=input_channels, output_dim=output_channels)
elif model_name == 'mixer_nano':
return Mixer_nano(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'mixer_tiny':
return Mixer_tiny(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'mixer_base':
return Mixer_base(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'mixer_large':
return Mixer_large(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'mixer_huge':
return Mixer_huge(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'linear_vit_base':
return LinearViT_base(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'linear_vit_large':
return LinearViT_large(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'linear_vit_huge':
return LinearViT_huge(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'autoencoder_vit_base':
return AutoencoderViT_base(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'autoencoder_vit_large':
return AutoencoderViT_large(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'autoencoder_vit_huge':
return AutoencoderViT_huge(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'core_vae_nano':
return CoreVAE_nano(input_dim=input_channels, output_dim=10)
elif model_name == 'vit_cnn_base':
return vit_large(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'vit_cnn_base_wSkip':
return vit_large_wSkip(chw=(input_channels, input_size, input_size),
output_dim=output_channels)
elif model_name == 'resnet_imagenet':
resnet_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return resnet(imagenet_weights=True, **resnet_kwargs)
elif model_name == 'resnet_imagenet_classifier':
resnet_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return resnet(imagenet_weights=True, classifier=True, **resnet_kwargs)
elif model_name == 'resnet':
resnet_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return resnet(imagenet_weights=False, **resnet_kwargs)
def get_models_pretrained(model_name, input_channels, output_channels, input_size, path_model_weights=None, freeze=False, device='cuda'):
test_input = torch.rand((2,input_channels,input_size,input_size))
if (model_name == 'GeoAware_core_nano' or model_name == 'GeoAware_contrastive_core_nano' or
model_name == 'GeoAware_mh_pred_core_nano'):
sd = torch.load(path_model_weights)
core_kwargs = get_core_encoder_kwargs(output_dim=output_channels, input_dim=input_channels, core_size='core_nano', full_unet=True)
model = CoreEncoderGeoPretrained(output_channels, checkpoint=sd, core_encoder_kwargs=core_kwargs, freeze_body=freeze)
model(test_input)
return model
if (model_name == 'GeoAware_core_nano_classifier' or model_name == 'GeoAware_contrastive_core_nano_classifier' or
model_name == 'GeoAware_mh_pred_core_nano_classifier'):
sd = torch.load(path_model_weights)
core_kwargs = get_core_encoder_kwargs(output_dim=output_channels, input_dim=input_channels, core_size='core_nano', full_unet=False)
model = CoreEncoderGeoPretrained_Classifier(checkpoint=sd, core_encoder_kwargs=core_kwargs, freeze_body=freeze)
model(test_input)
return model
if model_name == 'GeoAware_core_autoencoder_nano':
sd = torch.load(path_model_weights)
core_kwargs = get_core_encoder_kwargs(output_dim=output_channels, input_dim=input_channels, core_size='core_nano', full_unet=True)
model = CoreEncoderGeoAutoEncoder(output_channels, checkpoint=sd, core_encoder_kwargs=core_kwargs, freeze_body=freeze)
model(test_input)
return model
if model_name == 'GeoAware_combined_core_nano':
sd_1 = torch.load(path_model_weights[0])
sd_2 = torch.load(path_model_weights[1])
core_kwargs = get_core_encoder_kwargs(output_dim=output_channels, input_dim=input_channels, core_size='core_nano')
model = CoreEncoderGeoPretrained_combined(output_channels, checkpoint_1=sd_1, checkpoint_2=sd_2,
core_encoder_kwargs=core_kwargs)
model(test_input)
return model
if model_name == 'GeoAware_core_tiny':
sd = torch.load(path_model_weights)
core_kwargs = get_core_encoder_kwargs(output_dim=output_channels, input_dim=input_channels, core_size='core_tiny', full_unet=True)
model = CoreEncoderGeoPretrained(output_channels, checkpoint=sd, core_encoder_kwargs=core_kwargs, freeze_body=freeze)
model(test_input)
return model
if model_name == 'GeoAware_mixer_nano':
sd = torch.load(path_model_weights)
mixer_kwargs = get_mixer_kwargs(chw=(input_channels,input_size,input_size),output_dim=output_channels, mixer_size='mixer_nano')
model = MixerGeoPretrained(output_dim=output_channels, checkpoint=sd, mixer_kwargs=mixer_kwargs, freeze_body=freeze)
model(test_input)
return model
if model_name == 'GeoAware_mixer_tiny':
sd = torch.load(path_model_weights)
mixer_kwargs = get_mixer_kwargs(chw=(input_channels,input_size,input_size),output_dim=output_channels, mixer_size='mixer_tiny')
model = MixerGeoPretrained(output_dim=output_channels, checkpoint=sd, mixer_kwargs=mixer_kwargs, freeze_body=freeze)
model(test_input)
return model
elif model_name == 'SatMAE':
sd = torch.load(path_model_weights)
satmae_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return satmae_vit_cnn(img_size=96, patch_size=8, in_chans=input_channels,
checkpoint=sd, freeze_body=freeze, classifier=False, **satmae_kwargs)
elif model_name == 'SatMAE_classifier':
sd = torch.load(path_model_weights)
satmae_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return satmae_vit_cnn(img_size=96, patch_size=8, in_chans=input_channels,
checkpoint=sd, freeze_body=freeze, classifier=True, **satmae_kwargs)
elif model_name == 'prithvi':
sd = torch.load(path_model_weights, map_location=device)
prithvi_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return prithvi(checkpoint=sd, freeze_body=freeze, **prithvi_kwargs)
elif model_name == 'prithvi_classifier':
sd = torch.load(path_model_weights, map_location=device)
prithvi_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return prithvi(checkpoint=sd, freeze_body=freeze, classifier=True, **prithvi_kwargs)
elif model_name == 'vit_cnn':
sd = torch.load(path_model_weights, map_location=device)
vit_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return vit_cnn(checkpoint=sd, freeze_body=freeze, **vit_kwargs)
elif model_name == 'vit_cnn_wSkip':
sd = torch.load(path_model_weights, map_location=device)
vit_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return vit_cnn_wSkip(checkpoint=sd, freeze_body=freeze, **vit_kwargs)
elif model_name == 'vit_cnn_classifier':
sd = torch.load(path_model_weights, map_location=device)
return vit_cnn_classifier(checkpoint=sd, freeze_body=freeze, output_dim=output_channels)
elif model_name == 'vit_cnn_gc':
sd = torch.load(path_model_weights, map_location=device)
vit_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return vit_cnn_gc(checkpoint=sd, freeze_body=freeze, **vit_kwargs)
elif model_name == 'vit_cnn_gc_wSkip':
sd = torch.load(path_model_weights, map_location=device)
vit_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return vit_cnn_gc_wSkip(checkpoint=sd, freeze_body=freeze, **vit_kwargs)
elif model_name == 'vit_cnn_gc_classifier':
sd = torch.load(path_model_weights, map_location=device)
return vit_cnn_gc_classifier(checkpoint=sd, freeze_body=freeze, output_dim=output_channels)
elif model_name == 'seasonal_contrast':
seco_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return seasonal_contrast(checkpoint=path_model_weights, freeze_body=freeze,
**seco_kwargs)
elif model_name == 'seasonal_contrast_classifier':
seco_kwargs = get_core_decoder_kwargs(output_dim=output_channels, core_size='core_nano')
return seasonal_contrast(checkpoint=path_model_weights, freeze_body=freeze, classifier=True,
**seco_kwargs)
def get_args():
parser_yaml = argparse.ArgumentParser(description='Experiment TestBed for Phi-Leo Foundation Model Project')
parser_yaml.add_argument('--read_yaml', type=str, help='take parameters from yaml path', default=None)
parser = argparse.ArgumentParser(description='Experiment TestBed for Phi-Leo Foundation Model Project')
parser.add_argument('--experiment_name', type=str, default=f'{date.today().strftime("%d%m%Y")}_experiment',
help='Experiment folder name')
parser.add_argument('--model_name', type=str, choices=MODEL_LIST, required=True,
help='Select appropriate model')
parser.add_argument('--lr', type=float, default=0.001, help='Set learning rate')
parser.add_argument('--batch_size', type=int, default=16, help='Set batch size')
parser.add_argument('--epochs', type=int, default=250, help='Set training epochs')
parser.add_argument('--early_stop', type=int, default=50, help='set training loop patience for early stopping')
parser.add_argument('--lr_scheduler', type=str, default=None,
choices=[None, 'reduce_on_plateau', 'cosine_annealing'], help='select learning rate scheduler')
parser.add_argument('--warmup', action="store_true", help='Enables linear 5 epoch warmup scheduler')
parser.add_argument('--model_device', default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
help='select training device')
parser.add_argument('--generator_device', default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
help='select training device')
parser.add_argument('--num_workers', type=int, default=0, help='set number of workers')
parser.add_argument('--vis_val', action="store_true", help='enable saving of intermediate visualization plots')
parser.add_argument('--downstream_task', type=str, choices=DOWNSTREAM_LIST, required=True,
help='select downstream task')
parser.add_argument('--input_channels', type=int, required=False, default=10, help='Define Number of input channels')
parser.add_argument('--input_size', type=int, required=True, default=128, help='Define input size')
parser.add_argument('--output_channels', type=int, required=True, default=1, help='Define Number of output channels')
parser.add_argument('--regions', type=list, default=None, help='select regions to be included',
choices=[None, 'denmark-1', 'denmark-2', 'east-africa', 'egypt-1', 'eq-guinea', 'europe', 'ghana-1',
'isreal-1', 'isreal-2', 'japan', 'nigeria', 'north-america', 'senegal', 'south-america',
'tanzania-1', 'tanzania-2', 'tanzania-3', 'tanzania-4', 'tanzania-5', 'uganda-1'])
parser.add_argument('--n_shot', type=int, default=None,
help='Loads n-samples of data from specified geographic regions')
parser.add_argument('--split_ratio', type=float, default=None,
help='Loads a percentage of the data from specified geographic regions.')
parser.add_argument('--augmentations', action="store_true", help='enables augmentations')
parser.add_argument('--pretrained_model_path', type=str, default=None, help='path to pretrained weights')
parser.add_argument('--freeze_pretrained', action="store_true", help='freeze pretrained model weights')
parser.add_argument('--data_path_128_10m', type=str, default='/home/phimultigpu/phileo_NFS/phileo_data/downstream/downstream_dataset_patches_np/')
parser.add_argument('--data_path_224_10m', type=str, default='/home/phimultigpu/phileo_NFS/phileo_data/downstream/downstream_dataset_patches_np_224/')
parser.add_argument('--data_path_224_30m', type=str, default='/home/phimultigpu/phileo_NFS/phileo_data/downstream/downstream_dataset_patches_np_HLS/')
parser.add_argument('--C', type=str, default='/home/phimultigpu/phileo_NFS/phileo_data/experiments')
parser.add_argument('--data_parallel', type=bool, default=False)
parser.add_argument('--device_ids', type=list, default=[0, 1, 2, 3])
parser.add_argument('--warmp_steps', type=int, default=5)
parser.add_argument('--warmup_gamma', type=int, default=10)
return parser, parser_yaml
def main(experiment_name:str, downstream_task:str, model_name:str, augmentations:bool=False, batch_size:int=16,
model_device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), generator_device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), num_workers:int=4,
early_stop:int=25, epochs:int=250, input_channels:int=10, output_channels:int=1, input_size:int=128, lr:float=0.001, lr_scheduler:str=None,
n_shot:int=None, split_ratio:float=0.1, regions:list=None, vis_val:bool=True, warmup:bool=False , warmp_steps:int=5, warmup_gamma:int=10, pretrained_model_path:str=None, freeze_pretrained:bool=None,
data_path_128_10m:str=None, data_path_224_10m:str=None, data_path_224_30m:str=None, output_path:str=None, data_parallel:bool=False, device_ids:list=None):
""" main script for PhilEO Bench. Used to run model training experiments with randomly initialized and pre-trained models on a number of downstream tasks.
The script handles dataset creation (based on data protocol options selected), data preprocessing (based on downstream task & model type) & model, training, validation and testing.
Parameters
----------
experiment_name (str): Experiment name
downstream_task (str): Select downstream task to test, validate and test on. Options: {DOWNSTREAM_LIST}
model_name (str): Select model. Options:{MODEL_LIST}
augmentations (bool, optional): Toggle on/off basic data augmentations (Rotation, Mirror, Noise). Defaults to False.
batch_size (int, optional): Define training batch size. Defaults to 16.
model_device (_type_, optional): Select model device. Defaults to torch.device('cuda' if torch.cuda.is_available() else 'cpu').
generator_device (_type_, optional): Select dataloader device. Defaults to torch.device('cuda' if torch.cuda.is_available() else 'cpu').
num_workers (int, optional): Select number of workers for dataloader. Defaults to 4.
early_stop (int, optional):Define early stoping patience. Defaults to 25.
epochs (int, optional): Define number of training epochs. Defaults to 250.
input_channels (int, optional): Define number of data input channels. Defaults to 10.
output_channels (int, optional): Define number of model output channels. Defaults to 1.
input_size (int, optional): Define data input size. Defaults to 128.
lr (float, optional): Define optimizer learning rate. Defaults to 0.001.
lr_scheduler (str, optional): Define learning rate scheduler. Options: [None, 'reduce_on_plateau', 'cosine_annealing']. Defaults to None.
n_shot (int, optional): Define dataset protocol - n samples per region. Defaults to None.
split_ratio (float, optional): Define dataset protocol - percentage of full dataset. Defaults to 0.1.
regions (list, optional): Select regions to include in training and test sets. If no regions are defined (None) all avalible regions will be included
Options: [None, 'denmark-1', 'denmark-2', 'east-africa', 'egypt-1', 'eq-guinea', 'europe', 'ghana-1',
'isreal-1', 'isreal-2', 'japan', 'nigeria', 'north-america', 'senegal', 'south-america',
'tanzania-1', 'tanzania-2', 'tanzania-3', 'tanzania-4', 'tanzania-5', 'uganda-1'] Defaults to None.
vis_val (bool, optional): If set to True data visulisations will be generated at each validation step. Defaults to True.
warmup (bool, optional): If set to True a linear optimizer warmup phase will occour. Defaults to False.
warmp_steps (int, optional): Define number of steps for linear warmup phase. Defaults to 5.
warmup_gamma (int, optional): Define learning rate increase per step in linear warmup phase - new_lr = lr*gamma. Defaults to 10. N.B. initial lr is calulated as follows init_lr = lr/(gamma**warmup_steps)
pretrained_model_path (str, optional): For pretrained models define the model weights path. Defaults to None.
freeze_pretrained (bool, optional): If True pretrained encoder weights will be frozen during training. Defaults to None.
data_path_128_10m (str, optional): Define data path for 128x128 10m resolution dataset. Defaults to None.
data_path_224_10m (str, optional): Define data path for 224x224 10m resolution dataset. Defaults to None.
data_path_224_30m (str, optional): Define data path for 224x224 30m resolution dataset. Defaults to None.
output_path (str, optional): Define folder to save artifacts in. Defaults to None.
data_parallel (bool, optional): If set to True Model training will be parallized on multiple gpus. Defaults to False.
device_ids (list, optional): Define GPU IDs to use for parallization. Defaults to None.
"""
init_lr = lr
# device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(model_device)
print('DEVICE', model_device)
assert not (n_shot == None) or not (split_ratio == None), 'Please define data partition protocol!'
assert isinstance(n_shot, int) ^ isinstance(split_ratio, float), 'n_shot cannot be used with split_ratio!'
if (downstream_task == 'lc') or (downstream_task == 'lc_classification'):
assert (output_channels == 11), 'land cover tasks should have 11 output channels'
if (downstream_task == 'roads') or (downstream_task == 'building'):
assert output_channels == 1, 'road and building density estimation tasks should have a single output channel'
if downstream_task == 'building_classification':
assert output_channels == 5, 'building classification tasks should have a 5 output channels'
if downstream_task == 'roads_classification':
assert output_channels == 2, 'road classification tasks should have a 5 output channels'
if pretrained_model_path is not None:
print(model_name)
assert model_name in (CNN_PRETRAINED_LIST + VIT_CNN_PRETRAINED_LIST), f"Pretrained weights were given but model {model_name} not found in list of pretrained models: {(CNN_PRETRAINED_LIST + VIT_CNN_PRETRAINED_LIST)}"
assert freeze_pretrained is not None, f"When supplying a pretrained model 'freeze_pretrained' must be either True or False"
model = get_models_pretrained(model_name, input_channels, output_channels, input_size, path_model_weights=pretrained_model_path, freeze=freeze_pretrained)
if model_name == 'GeoAware_contrastive_core_nano' or model_name == 'GeoAware_contrastive_core_nano_classifier':
NAME = model.__class__.__name__ +'_contrastive_frozen' if freeze_pretrained else model.__class__.__name__ +'_contrastive_unfrozen'
elif model_name == 'GeoAware_mh_pred_core_nano' or model_name == 'GeoAware_mh_pred_core_nano_classifier':
NAME = model.__class__.__name__ +'_mh_pred_frozen' if freeze_pretrained else model.__class__.__name__ +'_mh_pred_unfrozen'
else:
NAME = model.__class__.__name__ + '_frozen' if freeze_pretrained else model.__class__.__name__ + '_unfrozen'
else:
if freeze_pretrained:
print(f"Ignoring freeze_pretrained set to {freeze_pretrained} as no pretrained model was supplied")
model = get_models(model_name, input_channels, output_channels, input_size)
NAME = model.__class__.__name__
OUTPUT_FOLDER = f'{output_path}/{experiment_name}/{downstream_task}/{date.today().strftime("%d%m%Y")}_{NAME}_{downstream_task}'
if lr_scheduler is not None:
OUTPUT_FOLDER = f'{output_path}/{experiment_name}/{downstream_task}/{date.today().strftime("%d%m%Y")}_{NAME}_{downstream_task}_{lr_scheduler}'
if warmup:
lr = lr / int((warmup_gamma)**(warmp_steps)) # for warmup start
dataset_folder = data_path_128_10m
dataset_name = '128_10m'
if model_name in MODELS_224_r30:
dataset_folder = data_path_224_30m
dataset_name = '224_30m'
if model_name in MODELS_224:
dataset_folder = data_path_224_10m
dataset_name = '224_10m'
if downstream_task == 'pretraining':
OUTPUT_FOLDER = f'{OUTPUT_FOLDER}'
x_train, y_train, x_val, y_val = data_protocol.protocol_minifoundation(
folder='/home/phimultigpu/phileo_NFS/phileo_data/mini_foundation/mini_foundation_patches_np/patches_labeled/',
y='geo')
downstream_task = 'geo'
elif isinstance(n_shot, int):
OUTPUT_FOLDER = f'{OUTPUT_FOLDER}_{n_shot}'
x_train, y_train, x_val, y_val = data_protocol.protocol_fewshot_memmapped(
folder=dataset_folder,
dst='/home/phimultigpu/phileo_NFS/phileo_data/downstream/downstream_datasets_nshot/',
n=n_shot,
regions=regions,
y=downstream_task,
data_selection='create',
name=dataset_name)
elif isinstance(split_ratio, float):
OUTPUT_FOLDER = f'{OUTPUT_FOLDER}_{split_ratio}'
x_train, y_train, x_val, y_val = data_protocol.protocol_split(
dataset_folder,
split_percentage=split_ratio,
regions=regions,
y=downstream_task)
x_test, y_test = data_protocol.get_testset(folder=dataset_folder,
y=downstream_task)
dl_train, dl_test, dl_val = load_data.load_data(x_train, y_train, x_val, y_val, x_test, y_test,
with_augmentations=augmentations,
num_workers=num_workers,
batch_size=batch_size,
downstream_task=downstream_task,
model_name=model_name.split('_')[0],
device=generator_device
)
print(f'Training on: {model_name}')
print('--'*10)
if data_parallel:
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = nn.DataParallel(model, device_ids=device_ids)
model.to(model_device)
if model_name == 'SatMAE' or model_name =='SatMAE_classifier':
model_summary = summary(model,
input_size=(batch_size, input_channels, 96, 96), )
elif model_name == 'prithvi' or model_name =='prithvi_classifier':
model_summary = summary(model,
input_size=(batch_size, 6, 224, 224), dtypes=[torch.float32])
elif model_name in ['seasonal_contrast', 'resnet_imagenet', 'resnet', 'seasonal_contrast_classifier']:
model_summary = summary(model,
input_size=(batch_size, input_channels, 224, 224), )
else:
model_summary = summary(model, input_size=(batch_size, input_channels, input_size, input_size))
trainer = get_trainer(model_name, downstream_task, epochs, lr, model, model_device, lr_scheduler, warmup, early_stop, dl_train,
dl_val, dl_test, NAME, OUTPUT_FOLDER, vis_val, warmp_steps, warmup_gamma)
trainer.train()
trainer.test()
trainer.save_info(model_summary=model_summary, n_shot=n_shot, p_split=split_ratio, warmup=warmup,
lr=init_lr)
if __name__ == "__main__":
parser, parser_yaml = get_args()
args_yaml, remainder = parser_yaml.parse_known_args()
if args_yaml.read_yaml is not None:
print(f"WARNING: overwriting all parameters with defaults stored in {args_yaml.read_yaml}")
args = read_yaml(args_yaml.read_yaml)
else:
args = parser.parse_args()
main(**vars(args))