-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
70 lines (54 loc) · 2.08 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import time
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision.transforms import transforms
from src import u2net_full
def time_synchronized():
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
def main():
weights_path = "save_weights/model_best.pth"
img_path = "test.jpg"
threshold = 0.2
assert os.path.exists(img_path), f"image file {img_path} dose not exists."
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(320),
transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
])
origin_img = cv2.cvtColor(cv2.imread(img_path, flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
h, w = origin_img.shape[:2]
img = data_transform(origin_img)
img = torch.unsqueeze(img, 0).to(device) # [C, H, W] -> [1, C, H, W]
model = u2net_full()
weights = torch.load(weights_path, map_location='cpu')
if "model" in weights:
model.load_state_dict(weights["model"])
else:
model.load_state_dict(weights)
model.to(device)
model.eval()
with torch.no_grad():
# init model
img_height, img_width = img.shape[-2:]
init_img = torch.zeros((1, 3, img_height, img_width), device=device)
model(init_img)
t_start = time_synchronized()
pred = model(img)
t_end = time_synchronized()
print("inference time: {}".format(t_end - t_start))
pred = torch.squeeze(pred).to("cpu").numpy() # [1, 1, H, W] -> [H, W]
pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
pred_mask = np.where(pred > threshold, 1, 0)
origin_img = np.array(origin_img, dtype=np.uint8)
seg_img = origin_img * pred_mask[..., None]
plt.imshow(seg_img)
plt.show()
cv2.imwrite("pred_result.png", cv2.cvtColor(seg_img.astype(np.uint8), cv2.COLOR_RGB2BGR))
if __name__ == '__main__':
main()