-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot_ds.py
executable file
·39 lines (32 loc) · 1.03 KB
/
plot_ds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#! /usr/bin/env python
"""
Produces semi-transparent neural segmenter output overlays
"""
import click
@click.command()
@click.argument('files', nargs=-1)
def cli(files):
import sys
import numpy as np
import torch
import dataset
from PIL import Image, ImageDraw
from os.path import splitext
import torchvision.transforms as tf
torch.set_num_threads(1)
transforms = tf.Resize(800, max_size=1333)
ds = dataset.BaselineSet(files, im_transforms=transforms)
for idx, (im, target) in enumerate(ds):
print(ds.imgs[idx])
im = im.convert('RGB')
draw = ImageDraw.Draw(im)
samples = np.linspace(0, 1, 20)
for line in target['curves']:
line = (np.array(line) * (im.size * 4))
line.resize(4, 2)
for t in np.array(dataset.BezierCoeff(samples)).dot(line):
draw.rectangle((t[0]-2, t[1]-2, t[0]+2, t[1]+2), fill='red')
del draw
im.save(splitext(ds.imgs[idx])[0] + '.overlay.png')
if __name__ == '__main__':
cli()