forked from DiaoXY/CSRnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Inference_B_test.py
71 lines (60 loc) · 1.93 KB
/
Inference_B_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""
Created on Sun Nov 4 11:11:52 2018
@author: lenovo
"""
import cv2
import h5py
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm as c
from keras.models import model_from_json
def load_model():
# Function to load and return neural network model
json_file = open('models/Model.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights("weights/model_B_weights.h5")
return loaded_model
def create_img(path):
#Function to load,normalize and return image
print(path)
im = Image.open(path).convert('RGB')
im = np.array(im)
im = im/255.0
im[:,:,0]=(im[:,:,0]-0.485)/0.229
im[:,:,1]=(im[:,:,1]-0.456)/0.224
im[:,:,2]=(im[:,:,2]-0.406)/0.225
im = np.expand_dims(im,axis = 0)
return im
def predict(path):
#Function to load image,predict heat map, generate count and return (count , image , heat map)
model = load_model()
image = create_img(path)
ans = model.predict(image)
count = np.sum(ans)
return count,image,ans
"""
#Inference_B_test 只能预测
ans,img,hmap = predict('data/test_data/test36.jpg')
print("Predict Count:",ans)
plt.imshow(img.reshape(img.shape[1],img.shape[2],img.shape[3]))
plt.show()
plt.imshow(hmap.reshape(hmap.shape[1],hmap.shape[2]) , cmap = c.jet )
plt.show()
"""
#预测2,3,4
ans,img,hmap = predict('data/part_B_final/test_data/images/IMG_4.jpg')
print("Predict Count:",ans)
#Print count, image, heat map
plt.imshow(img.reshape(img.shape[1],img.shape[2],img.shape[3]))
plt.show()
plt.imshow(hmap.reshape(hmap.shape[1],hmap.shape[2]) , cmap = c.jet )
plt.show()
temp = h5py.File('data/part_B_final/test_data/ground/IMG_4.h5' , 'r')
temp_1 = np.asarray(temp['density'])
#plt.imshow(temp_1,cmap = c.jet)
print("Original Count : ",int(np.sum(temp_1)) + 1)