-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo_single.py
72 lines (49 loc) · 1.8 KB
/
demo_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
To visualize the results, demo.py needs two arguments,
--model_path (compulsary) - path of the saved_model
--img_path (optional) - image to evaluate, default takes, "images/demo.png"
Press 'q' to quit the demo.
Press any key to visualize the next image.
"""
import torch
import torchvision.transforms as transforms
import numpy as np
import cv2
import imutils
from PIL import Image
from torch.utils.data import DataLoader
from cityscapes import CityScapes
from model import model
from arg_parser import demo_single
def to_tensor(img):
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])(img)
def main(args):
scale = 1
cropsize = [int(2048 * scale), int(1024 * scale)]
img_path = args.img_path
ds = CityScapes("", cropsize=cropsize, mode='val', demo=True)
n_classes = ds.n_classes
net = model.get_network(n_classes)
saved_path = args.saved_model
loaded_model = torch.load(saved_path, map_location=torch.device('cuda') if torch.cuda.is_available() else 'cpu')
state_dict = loaded_model['state_dict']
net.load_state_dict(state_dict, strict=False)
if torch.cuda.is_available():
net.cuda()
net.eval()
img = Image.open(img_path).convert('RGB')
im = to_tensor(img).unsqueeze(0)
with torch.no_grad():
if torch.cuda.is_available():
im = im.cuda()
pred = net(im).argmax(dim=1).squeeze(0).cpu().numpy()
pred = ds.vis_label(pred)
image = np.array(img)[:, :, ::-1]
cv2.imshow('demo', imutils.resize(cv2.hconcat([image, pred]), width=1920))
cv2.waitKey(0)
if __name__ == '__main__':
args = demo_single()
main(args)