forked from NVlabs/dlinputs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dli-transform
executable file
·81 lines (68 loc) · 2.32 KB
/
dli-transform
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/python3
import sys
import argparse
import imp
import os
import dlinputs as dli
import numpy as np
from dlinputs import filters, gopen, utils, zcom
import multiprocessing
parser = argparse.ArgumentParser("Display info about an input module.")
parser.add_argument("-P", "--pipeline", default="info()")
parser.add_argument("-C", "--precode", default=[], nargs="*")
parser.add_argument("-c", "--code", default=None)
parser.add_argument("-F", "--format", default="rgb")
parser.add_argument("-l", "--load", default=None)
parser.add_argument("-e", "--every", type=int, default=1000)
parser.add_argument("-p", "--parallel", type=int, default=0)
parser.add_argument("-O", "--ordered", action="store_true")
parser.add_argument("input")
parser.add_argument("output", nargs="?")
args = parser.parse_args()
for c in args.precode:
exec(c)
decode = utils.autodecoder(args.format)
source = gopen.sharditerator(args.input, decode=decode, epochs=1)
sink = None
if args.output is not None:
sink = gopen.open_sink(args.output)
mapper = None
if args.load is not None:
fname = os.path.basename(args.load)
name, ext = os.path.splitext(fname)
with open(args.load) as stream:
mapper = imp.load_module("mapper", stream, args.load,
(ext, "r", imp.PY_SOURCE))
with dli.inputs:
f = eval("compose({})".format(args.pipeline))
transformed = f(source)
def process_sample(sample, mapper=mapper, code=args.code):
i, sample = sample
if mapper is not None:
sample = mapper.transform(sample)
_ = sample
if code is not None:
exec(args.code)
return i, sample
if args.parallel==0:
processed = (process_sample(x) for x in enumerate(transformed))
else:
pool = multiprocessing.Pool(args.parallel)
if args.ordered:
processed = pool.imap(process_sample, enumerate(transformed))
else:
processed = pool.imap_unordered(process_sample, enumerate(transformed))
for j, (i, sample) in enumerate(processed):
if isinstance(sample, dict):
sample = [sample]
elif isinstance(sample, list):
pass
else:
print("sample must be either dict or list, got:")
print(sample)
sys.exit(1)
if i%args.every==0:
print(i, j, sample[0].get("__key__"))
if sink is not None:
for s in sample:
sink.write(s)