forked from uwm-bigdata/wound-segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_with_post_processing.py
65 lines (56 loc) · 2.7 KB
/
evaluate_with_post_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import cv2
import numpy as np
import multiprocessing
from tqdm import tqdm
from utils.io.data import get_png_filename_list
from utils.postprocessing.hole_filling import fill_holes
from utils.postprocessing.remove_small_noise import remove_small_areas
def evaluate(threshold, file_list, label_path, post_prosecced_path):
false_positives = 0
false_negatives = 0
true_positives = 0
for img_name in tqdm(file_list):
img = cv2.imread(pred_dir + img_name)
_, threshed = cv2.threshold(img, threshold, 255, type=cv2.THRESH_BINARY)
################################################################################################################
# call image post processing functions
mask = np.zeros((226, 226, 3))
filled = fill_holes(threshed, threshold,0.1)
denoised = remove_small_areas(filled, threshold, 0.05)
################################################################################################################
cv2.imwrite('whatever/filled/' + img_name, filled)
cv2.imwrite('whatever/post_processed/' + img_name, denoised)
for filename in tqdm(file_list):
label = cv2.imread(label_path + filename,0)
post_prosecced = cv2.imread(post_prosecced_path + filename,0)
xdim = label.shape[0]
ydim = label.shape[1]
for x in range(xdim):
for y in range(ydim):
if post_prosecced[x, y] and label[x, y] > threshold:
true_positives += 1
if label[x, y] > threshold > post_prosecced[x, y]:
false_negatives += 1
if label[x, y] < threshold < post_prosecced[x, y]:
false_positives += 1
IOU = float(true_positives) / (true_positives + false_negatives + false_positives)
Dice = 2*float(true_positives) / (2*true_positives + false_negatives + false_positives)
print("--------------------------------------------------------")
print("Weight file: ",post_prosecced_path.rsplit("/")[1])
print("--------------------------------------------------------")
print("Threshold: ", threshold)
print("True pos = " + str(true_positives))
print("False neg = " + str(false_negatives))
print("False pos = " + str(false_positives))
print("IOU = " + str(IOU))
print("Dice = " + str(Dice))
# change to your own folder names
pred_dir = './whatever/'
img_filename_list = get_png_filename_list(pred_dir)
print(img_filename_list)
label_path = './data/azh_wound_care_center_dataset_patches/test/labels/'
post_path = './whatever/post_processed/'
num_threads = multiprocessing.cpu_count()
# test your own threshold
threshold = 120
evaluate(threshold, img_filename_list, label_path, post_path)