diff --git a/train.py b/train.py index aeedcc4..2d9a848 100644 --- a/train.py +++ b/train.py @@ -38,7 +38,7 @@ assert args.backbone in ['resnet50', 'resnet34', 'resnet18', 'densenet121'] dataset_name = dataset_dict[args.dataset] -model_name = '{}_nfc_softmax'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone) +model_name = '{}_nfc_id'.format(args.backbone) if args.use_id else '{}_nfc'.format(args.backbone) data_dir = args.data_path model_dir = os.path.join('./checkpoints', args.dataset, model_name)