From d6304b6cc787528c42251d73a961b2a891a482f4 Mon Sep 17 00:00:00 2001 From: Zhedong Zheng Date: Tue, 2 Jul 2024 15:44:35 +1200 Subject: [PATCH] Update train.py --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index ed592b9..3300271 100755 --- a/train.py +++ b/train.py @@ -221,7 +221,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25): #best_acc = 0.0 warm_up = 0.1 # We start from the 0.1*lrRate warm_iteration = round(dataset_sizes['train']/opt.batchsize)*opt.warm_epoch # first 5 epoch - embedding_size = model.classifier.linear.linear_num + embedding_size = model.classifier.linear_num if opt.arcface: criterion_arcface = losses.ArcFaceLoss(num_classes=opt.nclasses, embedding_size=embedding_size) if opt.cosface: