-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
uts9
committed
May 11, 2023
1 parent
a64c5b6
commit 515592a
Showing
1 changed file
with
330 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,330 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from __future__ import print_function, division | ||
|
||
import argparse | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.optim import lr_scheduler | ||
from torch.autograd import Variable | ||
import torch.backends.cudnn as cudnn | ||
import numpy as np | ||
import torchvision | ||
from torchvision import datasets, models, transforms | ||
import time | ||
import os | ||
import scipy.io | ||
import yaml | ||
import math | ||
from model import ft_net, ft_net_dense, ft_net_hr, ft_net_swin, ft_net_swinv2, ft_net_efficient, ft_net_NAS, ft_net_convnext, PCB, PCB_test | ||
from utils import fuse_all_conv_bn | ||
#fp16 | ||
try: | ||
from apex.fp16_utils import * | ||
except ImportError: # will be 3.x series | ||
print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0') | ||
###################################################################### | ||
# Options | ||
# -------- | ||
|
||
parser = argparse.ArgumentParser(description='Test') | ||
parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') | ||
parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last') | ||
parser.add_argument('--test_dir',default='../MSMT/pytorch',type=str, help='./test_data') | ||
parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path') | ||
parser.add_argument('--batchsize', default=256, type=int, help='batchsize') | ||
parser.add_argument('--linear_num', default=512, type=int, help='feature dimension: 512 or default or 0 (linear=False)') | ||
parser.add_argument('--use_dense', action='store_true', help='use densenet121' ) | ||
parser.add_argument('--use_efficient', action='store_true', help='use efficient-b4' ) | ||
parser.add_argument('--use_hr', action='store_true', help='use hr18 net' ) | ||
parser.add_argument('--PCB', action='store_true', help='use PCB' ) | ||
parser.add_argument('--multi', action='store_true', help='use multiple query' ) | ||
parser.add_argument('--fp16', action='store_true', help='use fp16.' ) | ||
parser.add_argument('--ibn', action='store_true', help='use ibn.' ) | ||
parser.add_argument('--ms',default='1', type=str,help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2') | ||
|
||
opt = parser.parse_args() | ||
###load config### | ||
# load the training config | ||
config_path = os.path.join('./model',opt.name,'opts.yaml') | ||
with open(config_path, 'r') as stream: | ||
config = yaml.load(stream, Loader=yaml.FullLoader) # for the new pyyaml via 'conda install pyyaml' | ||
opt.fp16 = config['fp16'] | ||
opt.PCB = config['PCB'] | ||
opt.use_dense = config['use_dense'] | ||
opt.use_NAS = config['use_NAS'] | ||
opt.stride = config['stride'] | ||
if 'use_swin' in config: | ||
opt.use_swin = config['use_swin'] | ||
if 'use_swinv2' in config: | ||
opt.use_swinv2 = config['use_swinv2'] | ||
if 'use_convnext' in config: | ||
opt.use_convnext = config['use_convnext'] | ||
if 'use_efficient' in config: | ||
opt.use_efficient = config['use_efficient'] | ||
if 'use_hr' in config: | ||
opt.use_hr = config['use_hr'] | ||
|
||
if 'nclasses' in config: # tp compatible with old config files | ||
opt.nclasses = config['nclasses'] | ||
else: | ||
opt.nclasses = 751 | ||
|
||
if 'ibn' in config: | ||
opt.ibn = config['ibn'] | ||
if 'linear_num' in config: | ||
opt.linear_num = config['linear_num'] | ||
|
||
str_ids = opt.gpu_ids.split(',') | ||
#which_epoch = opt.which_epoch | ||
name = opt.name | ||
test_dir = opt.test_dir | ||
|
||
gpu_ids = [] | ||
for str_id in str_ids: | ||
id = int(str_id) | ||
if id >=0: | ||
gpu_ids.append(id) | ||
|
||
print('We use the scale: %s'%opt.ms) | ||
str_ms = opt.ms.split(',') | ||
ms = [] | ||
for s in str_ms: | ||
s_f = float(s) | ||
ms.append(math.sqrt(s_f)) | ||
|
||
# set gpu ids | ||
if len(gpu_ids)>0: | ||
torch.cuda.set_device(gpu_ids[0]) | ||
cudnn.benchmark = True | ||
|
||
###################################################################### | ||
# Load Data | ||
# --------- | ||
# | ||
# We will use torchvision and torch.utils.data packages for loading the | ||
# data. | ||
# | ||
if opt.use_swin: | ||
h, w = 224, 224 | ||
else: | ||
h, w = 256, 128 | ||
|
||
data_transforms = transforms.Compose([ | ||
transforms.Resize((h, w), interpolation=3), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
############### Ten Crop | ||
#transforms.TenCrop(224), | ||
#transforms.Lambda(lambda crops: torch.stack( | ||
# [transforms.ToTensor()(crop) | ||
# for crop in crops] | ||
# )), | ||
#transforms.Lambda(lambda crops: torch.stack( | ||
# [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop) | ||
# for crop in crops] | ||
# )) | ||
]) | ||
|
||
if opt.PCB: | ||
data_transforms = transforms.Compose([ | ||
transforms.Resize((384,192), interpolation=3), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
h, w = 384, 192 | ||
|
||
|
||
data_dir = test_dir | ||
|
||
if opt.multi: | ||
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']} | ||
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, | ||
shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']} | ||
else: | ||
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']} | ||
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, | ||
shuffle=False, num_workers=16) for x in ['gallery','query']} | ||
class_names = image_datasets['query'].classes | ||
use_gpu = torch.cuda.is_available() | ||
|
||
###################################################################### | ||
# Load model | ||
#--------------------------- | ||
def load_network(network): | ||
save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch) | ||
network.load_state_dict(torch.load(save_path)) | ||
return network | ||
|
||
|
||
###################################################################### | ||
# Extract feature | ||
# ---------------------- | ||
# | ||
# Extract feature from a trained model. | ||
# | ||
def fliplr(img): | ||
'''flip horizontal''' | ||
inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W | ||
img_flip = img.index_select(3,inv_idx) | ||
return img_flip | ||
|
||
def extract_feature(model,dataloaders): | ||
#features = torch.FloatTensor() | ||
count = 0 | ||
if opt.linear_num <= 0: | ||
if opt.use_swin or opt.use_swinv2 or opt.use_dense or opt.use_convnext: | ||
opt.linear_num = 1024 | ||
elif opt.use_efficient: | ||
opt.linear_num = 1792 | ||
elif opt.use_NAS: | ||
opt.linear_num = 4032 | ||
else: | ||
opt.linear_num = 2048 | ||
|
||
for iter, data in enumerate(dataloaders): | ||
img, label = data | ||
n, c, h, w = img.size() | ||
count += n | ||
print(count) | ||
ff = torch.FloatTensor(n,opt.linear_num).zero_().cuda() | ||
|
||
if opt.PCB: | ||
ff = torch.FloatTensor(n,2048,6).zero_().cuda() # we have six parts | ||
|
||
for i in range(2): | ||
if(i==1): | ||
img = fliplr(img) | ||
input_img = Variable(img.cuda()) | ||
for scale in ms: | ||
if scale != 1: | ||
# bicubic is only available in pytorch>= 1.1 | ||
input_img = nn.functional.interpolate(input_img, scale_factor=scale, mode='bicubic', align_corners=False) | ||
outputs = model(input_img) | ||
ff += outputs | ||
# norm feature | ||
if opt.PCB: | ||
# feature size (n,2048,6) | ||
# 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. | ||
# 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). | ||
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) | ||
ff = ff.div(fnorm.expand_as(ff)) | ||
ff = ff.view(ff.size(0), -1) | ||
else: | ||
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) | ||
ff = ff.div(fnorm.expand_as(ff)) | ||
|
||
|
||
if iter == 0: | ||
features = torch.FloatTensor( len(dataloaders.dataset), ff.shape[1]) | ||
#features = torch.cat((features,ff.data.cpu()), 0) | ||
start = iter*opt.batchsize | ||
end = min( (iter+1)*opt.batchsize, len(dataloaders.dataset)) | ||
features[ start:end, :] = ff | ||
return features | ||
|
||
def get_id(img_path): | ||
camera_id = [] | ||
labels = [] | ||
for path, v in img_path: | ||
#filename = path.split('/')[-1] | ||
filename = os.path.basename(path) | ||
label = filename[0:4] | ||
camera = filename.split('_')[2][0:2] | ||
if label[0:2]=='-1': | ||
labels.append(-1) | ||
else: | ||
labels.append(int(label)) | ||
camera_id.append(int(camera)) | ||
return camera_id, labels | ||
|
||
gallery_path = image_datasets['gallery'].imgs | ||
query_path = image_datasets['query'].imgs | ||
|
||
gallery_cam,gallery_label = get_id(gallery_path) | ||
query_cam,query_label = get_id(query_path) | ||
|
||
if opt.multi: | ||
mquery_path = image_datasets['multi-query'].imgs | ||
mquery_cam,mquery_label = get_id(mquery_path) | ||
|
||
###################################################################### | ||
# Load Collected data Trained model | ||
print('-------test-----------') | ||
if opt.use_dense: | ||
model_structure = ft_net_dense(opt.nclasses, stride = opt.stride, linear_num=opt.linear_num) | ||
elif opt.use_NAS: | ||
model_structure = ft_net_NAS(opt.nclasses, linear_num=opt.linear_num) | ||
elif opt.use_swin: | ||
model_structure = ft_net_swin(opt.nclasses, linear_num=opt.linear_num) | ||
elif opt.use_swinv2: | ||
model_structure = ft_net_swinv2(opt.nclasses, (h,w), linear_num=opt.linear_num) | ||
elif opt.use_convnext: | ||
model_structure = ft_net_convnext(opt.nclasses, linear_num=opt.linear_num) | ||
elif opt.use_efficient: | ||
model_structure = ft_net_efficient(opt.nclasses, linear_num=opt.linear_num) | ||
elif opt.use_hr: | ||
model_structure = ft_net_hr(opt.nclasses, linear_num=opt.linear_num) | ||
else: | ||
model_structure = ft_net(opt.nclasses, stride = opt.stride, ibn = opt.ibn, linear_num=opt.linear_num) | ||
|
||
if opt.PCB: | ||
model_structure = PCB(opt.nclasses) | ||
|
||
#if opt.fp16: | ||
# model_structure = network_to_half(model_structure) | ||
|
||
model = load_network(model_structure) | ||
|
||
# Remove the final fc layer and classifier layer | ||
if opt.PCB: | ||
#if opt.fp16: | ||
# model = PCB_test(model[1]) | ||
#else: | ||
model = PCB_test(model) | ||
else: | ||
#if opt.fp16: | ||
#model[1].model.fc = nn.Sequential() | ||
#model[1].classifier = nn.Sequential() | ||
#else: | ||
model.classifier.classifier = nn.Sequential() | ||
|
||
# Change to test mode | ||
model = model.eval() | ||
if use_gpu: | ||
model = model.cuda() | ||
|
||
|
||
print('Here I fuse conv and bn for faster inference, and it does not work for transformers. Comment out this following line if you do not want to fuse conv&bn.') | ||
model = fuse_all_conv_bn(model) | ||
|
||
# We can optionally trace the forward method with PyTorch JIT so it runs faster. | ||
# To do so, we can call `.trace` on the reparamtrized module with dummy inputs | ||
# expected by the module. | ||
# Comment out this following line if you do not want to trace. | ||
dummy_forward_input = torch.rand(opt.batchsize, 3, h, w).cuda() | ||
model = torch.jit.trace(model, dummy_forward_input) | ||
|
||
print(model) | ||
# Extract feature | ||
since = time.time() | ||
with torch.no_grad(): | ||
gallery_feature = extract_feature(model,dataloaders['gallery']) | ||
query_feature = extract_feature(model,dataloaders['query']) | ||
if opt.multi: | ||
mquery_feature = extract_feature(model,dataloaders['multi-query']) | ||
time_elapsed = time.time() - since | ||
print('Training complete in {:.0f}m {:.2f}s'.format( | ||
time_elapsed // 60, time_elapsed % 60)) | ||
# Save to Matlab for check | ||
result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam} | ||
scipy.io.savemat('pytorch_result.mat',result) | ||
|
||
print(opt.name) | ||
result = './model/%s/result.txt'%opt.name | ||
os.system('python evaluate_gpu.py | tee -a %s'%result) | ||
|
||
if opt.multi: | ||
result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam} | ||
scipy.io.savemat('multi_query.mat',result) |