-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathinference_segmentation.py
184 lines (166 loc) · 7.54 KB
/
inference_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# Mask R-CNN model for lesion segmentation in chest CT scans
# Torchvision detection package is locally re-implemented
# by Alex Ter-Sarkisov@City, University of London
# 2020
import argparse
import time
import pickle
import copy
import torch
import torchvision
import numpy as np
import os
import cv2
import models.mask_net as mask_net
from models.mask_net.rpn_segmentation import AnchorGenerator
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor
from torchvision import transforms
from torch.utils import data
import torch.utils as utils
import datasets.dataset_segmentation as dataset
from PIL import Image as PILImage
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import utils
import config_segmentation as config
# main method
def main(config, step):
devices = ['cpu', 'cuda']
mask_classes = ['both', 'ggo', 'merge']
backbones = ['resnet50', 'resnet34', 'resnet18']
truncation_levels = ['0','1','2']
assert config.device in devices
assert config.backbone_name in backbones
assert config.truncation in truncation_levels
assert config.mask_type in mask_classes
if config.device == 'cuda' and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
# get the configuration
# get the thresholds
confidence_threshold, mask_threshold, save_dir, data_dir, img_dir, gt_dir, mask_type, rpn_nms, roi_nms, backbone_name, truncation \
= config.confidence_th, config.mask_logits_th, config.save_dir, config.test_data_dir, config.test_imgs_dir, \
config.gt_dir, config.mask_type, config.rpn_nms_th, config.roi_nms_th, config.backbone_name, config.truncation
if mask_type == "both":
n_c = 3
else:
n_c = 2
ckpt = torch.load(config.ckpt, map_location=device)
model_name = None
if 'model_name' in ckpt.keys():
model_name = ckpt['model_name']
sizes = ckpt['anchor_generator'].sizes
aspect_ratios = ckpt['anchor_generator'].aspect_ratios
anchor_generator = AnchorGenerator(sizes, aspect_ratios)
print("Anchors: ", anchor_generator.sizes, anchor_generator.aspect_ratios)
# create modules
# this assumes FPN with 256 channels
box_head = TwoMLPHead(in_channels=7 * 7 * 256, representation_size=128)
if backbone_name == 'resnet50':
maskrcnn_heads = None
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c)
else:
#Backbone->FPN->boxhead->boxpredictor
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1)
mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c)
# keyword arguments
maskrcnn_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': 100,
'box_nms_thresh': roi_nms, 'box_score_thresh': confidence_threshold, 'rpn_nms_thresh': rpn_nms,
'box_head': box_head, 'rpn_anchor_generator': anchor_generator, 'mask_head':maskrcnn_heads,
'mask_predictor': mask_predictor, 'box_predictor': box_predictor}
# Instantiate the segmentation model
maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=False, **maskrcnn_args)
# Load weights
maskrcnn_model.load_state_dict(ckpt['model_weights'])
# Set to evaluation mode
print(maskrcnn_model)
maskrcnn_model.eval().to(device)
start_time = time.time()
# get the correct masks and mask colors
if mask_type == "ggo":
ct_classes = {0: '__bgr', 1: 'GGO'}
ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])}
elif mask_type == "merge":
ct_classes = {0: '__bgr', 1: 'Lesion'}
ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])}
elif mask_type == "both":
ct_classes = {0: '__bgr', 1: 'GGO', 2: 'CL'}
ct_colors = {1: 'red', 2: 'blue', 'mask_cols': np.array([[255, 0, 0], [0, 0, 255]])}
if not save_dir in os.listdir('.'):
os.mkdir(save_dir)
# model name from config, not checkpoint
if model_name is None:
model_name = "maskrcnn_segmentation"
elif model_name is not None and config.model_name != model_name:
print("Using model name from the config.")
model_name = config.model_name
# run the inference with provided hyperparameters
test_ims = os.listdir(os.path.join(data_dir, img_dir))
for j, ims in enumerate(test_ims):
step(os.path.join(os.path.join(data_dir, img_dir), ims), device, maskrcnn_model, model_name,
confidence_threshold, mask_threshold, save_dir, ct_classes, ct_colors, j)
end_time = time.time()
print("Inference took {0:.1f} seconds".format(end_time - start_time))
def test_step(image, device, model, model_name, theta_conf, theta_mask, save_dir, cls, cols, num):
im = PILImage.open(image)
# convert image to RGB, remove the alpha channel
if im.mode != 'RGB':
im = im.convert(mode='RGB')
img = np.array(im)
# copy image to make background for plotting
bgr_img = copy.deepcopy(img)
if img.shape[2] > 3:
img = img[:, :, :3]
# torchvision transforms, the rest Mask R-CNN does internally
t_ = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor()])
img = t_(img).to(device)
out = model([img])
# scores + bounding boxes + labels + masks
scores = out[0]['scores']
bboxes = out[0]['boxes']
classes = out[0]['labels']
mask = out[0]['masks']
# this is the array for all masks
best_scores = scores[scores > theta_conf]
# Are there any detections with confidence above the threshold?
if len(best_scores):
best_idx = np.where(scores > theta_conf)
best_bboxes = bboxes[best_idx]
best_classes = classes[best_idx]
best_masks = mask[best_idx]
print('bm', best_masks.shape)
mask_array = np.zeros([best_masks[0].shape[1], best_masks[0].shape[2], 3], dtype=np.uint8)
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(12, 6)
ax.axis("off")
# plot predictions
for idx, dets in enumerate(best_bboxes):
found_masks = best_masks[idx][0].detach().clone().to(device).numpy()
pred_class = best_classes[idx].item()
pred_col_n = cols[pred_class]
pred_class_txt = cls[pred_class]
pred_col = cols['mask_cols'][pred_class - 1]
mask_array[found_masks > theta_mask] = pred_col
rect = Rectangle((dets[0], dets[1]), dets[2] - dets[0], dets[3] - dets[1], linewidth=1,
edgecolor=pred_col_n, facecolor='none', linestyle="--")
ax.text(dets[0] + 40, dets[1], '{0:}'.format(pred_class_txt), fontsize=10, color=pred_col_n)
ax.text(dets[0], dets[1], '{0:.2f}'.format(best_scores[idx]), fontsize=10, color=pred_col_n)
ax.add_patch(rect)
added_image = cv2.addWeighted(bgr_img, 0.5, mask_array, 0.75, gamma=0)
ax.imshow(added_image)
fig.savefig(os.path.join(save_dir, model_name + "_" + str(num) + ".png"),
bbox_inches='tight', pad_inches=0.0)
else:
print("No detections")
# run the inference
if __name__ == '__main__':
config_test = config.get_config_pars("test")
main(config_test, test_step)