-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathextract_features.py
74 lines (58 loc) · 2.33 KB
/
extract_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import cv2
import numpy as np
import os
from models.torchmodel.model_irse import IR_50
import matplotlib.pyplot as plt
def l2_norm(input, axis = 1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
def extract_feature(img_root, backbone = IR_50((112, 112)), model_root='models/torchmodel/backbone_ir50_asia.pth', device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), tta = True):
# pre-requisites
assert(os.path.exists(img_root))
print('Testing Data Root:', img_root)
assert (os.path.exists(model_root))
print('Backbone Model Root:', model_root)
# load image
img = cv2.imread(img_root)
# resize image to [128, 128]
resized = cv2.resize(img, (128, 128))
# center crop image
a=int((128-112)/2) # x start
b=int((128-112)/2+112) # x end
c=int((128-112)/2) # y start
d=int((128-112)/2+112) # y end
ccropped = resized[a:b, c:d] # center crop the image
ccropped = ccropped[...,::-1] # BGR to RGB
# flip image horizontally
flipped = cv2.flip(ccropped, 1)
# load numpy to tensor
ccropped = ccropped.swapaxes(1, 2).swapaxes(0, 1)
ccropped = np.reshape(ccropped, [1, 3, 112, 112])
ccropped = np.array(ccropped, dtype = np.float32)
ccropped = (ccropped - 127.5) / 128.0
ccropped = torch.from_numpy(ccropped)
flipped = flipped.swapaxes(1, 2).swapaxes(0, 1)
flipped = np.reshape(flipped, [1, 3, 112, 112])
flipped = np.array(flipped, dtype = np.float32)
flipped = (flipped - 127.5) / 128.0
flipped = torch.from_numpy(flipped)
# load backbone from a checkpoint
print("Loading Backbone Checkpoint '{}'".format(model_root))
backbone.load_state_dict(torch.load(model_root))
backbone.to(device)
# extract features
backbone.eval() # set to evaluation mode
with torch.no_grad():
if tta:
emb_batch = backbone(ccropped.to(device)).cpu() + backbone(flipped.to(device)).cpu()
features = l2_norm(emb_batch)
else:
features = l2_norm(backbone(ccropped.to(device)).cpu())
# np.save("features.npy", features)
# features = np.load("features.npy")
return features
if __name__=='__main__':
features = extract_feature(img_root='Dataset/processed/13521/280820-0012.jpg')
print(features.shape)