-
Notifications
You must be signed in to change notification settings - Fork 2
/
prediction.py
executable file
·105 lines (89 loc) · 3.64 KB
/
prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import cv2
from sklearn.metrics import accuracy_score, f1_score
from transforms import get_transforms
from datasets import make_loader
from losses import depth_transform
from config.base import load_config
from utils import coords2str, str2coords
from utils.postprocess import extract_coords, postprocess
from utils.functions import predict_batch
from models import CenterNetFPN, load_model
from catalyst.dl.utils import load_checkpoint
import argparse
import os
import warnings
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import pickle
warnings.filterwarnings("ignore")
def run(config_file, fold=0, device_id=0, ensemble=False):
os.environ['CUDA_VISIBLE_DEVICES'] = str(device_id)
config = load_config(config_file)
if not '_fold' in config.work_dir and not ensemble:
config.work_dir = config.work_dir + '_fold{}'.format(fold)
testloader = make_loader(
data_dir=config.data.test_dir,
df_path=config.data.sample_submission_path,
features=config.data.features,
phase='test',
img_size=(config.data.height, config.data.width),
batch_size=config.test.batch_size,
num_workers=config.num_workers,
transforms=get_transforms(config.transforms.test),
)
if ensemble:
# load model
models = []
for c in model_config_paths:
for i in range(5):
models.append(load_fold_model(c, i))
model = MultiModels(models, tta=False)
else:
checkpoint_path = config.work_dir + '/checkpoints/best.pth'
model = load_model(config_file, checkpoint_path, fold)
predictions = []
z_pos = config.data.z_pos[0]
with torch.no_grad():
for i, (batch_fnames, batch_images) in enumerate(tqdm(testloader)):
batch_images = batch_images.to(config.device)
batch_preds = model(batch_images.to(config.device))
batch_preds[:, 0] = torch.sigmoid(batch_preds[:, 0])
batch_preds[:, z_pos] = depth_transform(batch_preds[:, z_pos])
batch_preds = batch_preds.data.cpu().numpy()
for preds in batch_preds:
coords = extract_coords(
preds,
features=config.data.features,
img_size=(config.data.height, config.data.width),
confidence_threshold=config.test.confidence_threshold,
distance_threshold=config.test.distance_threshold,
)
s = coords2str(coords)
predictions.append(s)
# ---------------------------------------------------------------------------------
# submission
# ------------------------------------------------------------------------------------------------------------
test = pd.read_csv(config.data.sample_submission_path)
test['PredictionString'] = predictions
out_path = config.work_dir + '/submission.csv'
test.to_csv(out_path, index=False)
postprocess(out_path)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', dest='config_file',
default=None, type=str)
parser.add_argument('--device_id', '-d', default='0', type=str)
parser.add_argument('--fold', '-f', default=0, type=int)
parser.add_argument('--ensemble', action='store_true')
return parser.parse_args()
def main():
print('predict model.')
args = parse_args()
if args.config_file is None:
raise Exception('no configuration file')
print('load config from {}'.format(args.config_file))
run(args.config_file, args.fold, args.device_id, args.ensemble)
if __name__ == '__main__':
main()