Skip to content

Commit

Permalink
add multiple gpu for basic branch
Browse files Browse the repository at this point in the history
  • Loading branch information
layumi committed May 11, 2023
1 parent f22dc9e commit 99dfb15
Showing 1 changed file with 52 additions and 16 deletions.
68 changes: 52 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@
gid = int(str_id)
if gid >=0:
gpu_ids.append(gid)

opt.gpu_ids = gpu_ids
# set gpu ids
if len(gpu_ids)>0:
torch.cuda.set_device(gpu_ids[0])
#torch.cuda.set_device(gpu_ids[0])
cudnn.enabled = True
cudnn.benchmark = True
######################################################################
# Load Data
Expand Down Expand Up @@ -390,10 +391,13 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
y_loss[phase].append(epoch_loss)
y_err[phase].append(1.0-epoch_acc)
# deep copy the model
if phase == 'val':
if phase == 'val' and epoch%10 == 9:
last_model_wts = model.state_dict()
if epoch%10 == 9:
save_network(model, epoch)
if len(opt.gpu_ids)>1:
save_network(model.module, epoch+1)
else:
save_network(model, epoch+1)

draw_curve(epoch)
if phase == 'train':
scheduler.step()
Expand All @@ -409,7 +413,11 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):

# load best model weights
model.load_state_dict(last_model_wts)
save_network(model, 'last')
if len(opt.gpu_ids)>1:
save_network(model.module, 'last')
else:
save_network(model, 'last')

return model


Expand Down Expand Up @@ -482,17 +490,45 @@ def save_network(network, epoch_label):
if opt.FSGD: # apex is needed
optim_name = FusedSGD

if not opt.PCB:
ignored_params = list(map(id, model.classifier.parameters() ))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
classifier_params = model.classifier.parameters()
optimizer_ft = optim_name([
if len(opt.gpu_ids)>1:
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids).cuda()
if not opt.PCB:
ignored_params = list(map(id, model.module.classifier.parameters() ))
base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters())
classifier_params = model.module.classifier.parameters()
optimizer_ft = optim_name([
{'params': base_params, 'lr': 0.1*opt.lr},
{'params': classifier_params, 'lr': opt.lr}
], weight_decay=opt.weight_decay, momentum=0.9, nesterov=True)
else:
ignored_params = list(map(id, model.module.model.fc.parameters() ))
ignored_params += (list(map(id, model.module.classifier0.parameters() ))
+list(map(id, model.module.classifier1.parameters() ))
+list(map(id, model.module.classifier2.parameters() ))
+list(map(id, model.module.classifier3.parameters() ))
+list(map(id, model.module.classifier4.parameters() ))
+list(map(id, model.module.classifier5.parameters() ))
#+list(map(id, model.module.classifier6.parameters() ))
#+list(map(id, model.module.classifier7.parameters() ))
)
base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters())
classifier_params = filter(lambda p: id(p) in ignored_params, model.module.parameters())
optimizer_ft = optim_name([
{'params': base_params, 'lr': 0.1*opt.lr},
{'params': classifier_params, 'lr': opt.lr}
], weight_decay=opt.weight_decay, momentum=0.9, nesterov=True)
else:
ignored_params = list(map(id, model.model.fc.parameters() ))
ignored_params += (list(map(id, model.classifier0.parameters() ))
if not opt.PCB:
ignored_params = list(map(id, model.classifier.parameters() ))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
classifier_params = model.classifier.parameters()
optimizer_ft = optim_name([
{'params': base_params, 'lr': 0.1*opt.lr},
{'params': classifier_params, 'lr': opt.lr}
], weight_decay=opt.weight_decay, momentum=0.9, nesterov=True)
else:
ignored_params = list(map(id, model.model.fc.parameters() ))
ignored_params += (list(map(id, model.classifier0.parameters() ))
+list(map(id, model.classifier1.parameters() ))
+list(map(id, model.classifier2.parameters() ))
+list(map(id, model.classifier3.parameters() ))
Expand All @@ -501,9 +537,9 @@ def save_network(network, epoch_label):
#+list(map(id, model.classifier6.parameters() ))
#+list(map(id, model.classifier7.parameters() ))
)
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
classifier_params = filter(lambda p: id(p) in ignored_params, model.parameters())
optimizer_ft = optim_name([
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
classifier_params = filter(lambda p: id(p) in ignored_params, model.parameters())
optimizer_ft = optim_name([
{'params': base_params, 'lr': 0.1*opt.lr},
{'params': classifier_params, 'lr': opt.lr}
], weight_decay=opt.weight_decay, momentum=0.9, nesterov=True)
Expand Down

0 comments on commit 99dfb15

Please sign in to comment.