forked from SamsungLabs/ritm_interactive_segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
57 lines (40 loc) · 1.79 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import tkinter as tk
import torch
from isegm.utils import exp
from isegm.inference import utils
from interactive_demo.app import InteractiveDemoApp
def main():
args, cfg = parse_args()
torch.backends.cudnn.deterministic = True
checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, args.checkpoint)
model = utils.load_is_model(checkpoint_path, args.device, cpu_dist_maps=True)
root = tk.Tk()
root.minsize(960, 480)
app = InteractiveDemoApp(root, args, model)
root.deiconify()
app.mainloop()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, required=True,
help='The path to the checkpoint. '
'This can be a relative path (relative to cfg.INTERACTIVE_MODELS_PATH) '
'or an absolute path. The file extension can be omitted.')
parser.add_argument('--gpu', type=int, default=0,
help='Id of GPU to use.')
parser.add_argument('--cpu', action='store_true', default=False,
help='Use only CPU for inference.')
parser.add_argument('--limit-longest-size', type=int, default=800,
help='If the largest side of an image exceeds this value, '
'it is resized so that its largest side is equal to this value.')
parser.add_argument('--cfg', type=str, default="config.yml",
help='The path to the config file.')
args = parser.parse_args()
if args.cpu:
args.device =torch.device('cpu')
else:
args.device = torch.device(f'cuda:{args.gpu}')
cfg = exp.load_config_file(args.cfg, return_edict=True)
return args, cfg
if __name__ == '__main__':
main()