-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
99 lines (69 loc) · 2.97 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import os.path as osp
import json
from argparse import ArgumentParser
from glob import glob
import torch
import cv2
from torch import cuda
from model import EAST
from tqdm import tqdm
from datetime import datetime
from detect import detect
CHECKPOINT_EXTENSIONS = ['.pth', '.ckpt']
def parse_args():
parser = ArgumentParser()
# Conventional args
parser.add_argument('--data_dir', default=os.environ.get('SM_CHANNEL_EVAL', '../data/medical'))
parser.add_argument('--model_dir', default=os.environ.get('SM_CHANNEL_MODEL', 'trained_models'))
parser.add_argument('--output_dir', default=os.environ.get('SM_OUTPUT_DATA_DIR', 'predictions'))
parser.add_argument('--device', default='cuda' if cuda.is_available() else 'cpu')
parser.add_argument('--input_size', type=int, default=1024) #2048
parser.add_argument('--batch_size', type=int, default=5)
args = parser.parse_args()
if args.input_size % 32 != 0:
raise ValueError('`input_size` must be a multiple of 32')
return args
def do_inference(model, ckpt_fpath, data_dir, input_size, batch_size, split='test'):
model.load_state_dict(torch.load(ckpt_fpath, map_location='cpu'))
model.eval()
image_fnames, by_sample_bboxes = [], []
images = []
for image_fpath in tqdm(glob(osp.join(data_dir, 'img/{}/*'.format(split)))):
image_fnames.append(osp.basename(image_fpath))
# Try reading the image, handle None case
img = cv2.imread(image_fpath)
if img is not None:
images.append(img[:, :, ::-1])
else:
print(f"Warning: Failed to read image {image_fpath}. Skipping.")
if len(images) == batch_size:
by_sample_bboxes.extend(detect(model, images, input_size))
images = []
if len(images):
by_sample_bboxes.extend(detect(model, images, input_size))
ufo_result = dict(images=dict())
for image_fname, bboxes in zip(image_fnames, by_sample_bboxes):
words_info = {idx: dict(points=bbox.tolist()) for idx, bbox in enumerate(bboxes)}
ufo_result['images'][image_fname] = dict(words=words_info)
return ufo_result
def main(args):
# Initialize model
model = EAST(pretrained=False).to(args.device)
# Get paths to checkpoint files
ckpt_fpath = osp.join(args.model_dir, 'latest.pth')
if not osp.exists(args.output_dir):
os.makedirs(args.output_dir)
print('Inference in progress')
ufo_result = dict(images=dict())
split_result = do_inference(model, ckpt_fpath, args.data_dir, args.input_size,
args.batch_size, split='test')
ufo_result['images'].update(split_result['images'])
# Add current timestamp to the output file name
current_time = datetime.now().strftime("%m%d%H%M")
output_fname = f'output_{current_time}.csv'
with open(osp.join(args.output_dir, output_fname), 'w') as f:
json.dump(ufo_result, f, indent=4)
if __name__ == '__main__':
args = parse_args()
main(args)