forked from NVlabs/dlinputs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dli-index
executable file
·119 lines (94 loc) · 3.13 KB
/
dli-index
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/python3
from __future__ import print_function
import argparse
import collections
import imp
import pickle
import time
import dlinputs as dli
import matplotlib
import numpy as np
import pylab
import simplejson
from dlinputs import gopen, tarrecords, utils, zcom
from pylab import *
# matplotlib.use("GTK")
rc("figure", figsize="12,6")
rc("image", cmap="gray")
parser = argparse.ArgumentParser(
"Output key (and optionally other fields) for each input sample.")
parser.add_argument("input",
nargs="*",
help="input file, usually a .tgz file/shard")
parser.add_argument("-f", "--fields", default=None,
help="list of input fields to display")
parser.add_argument("-e", "--expression", default=None,
help="expression to be evaluated for each `sample` (also `_`)")
parser.add_argument("-d", "--decode", default="True",
help="decode sample before evaluation")
parser.add_argument("-k", "--keys", action="store_true",
help="display sample size and keys")
parser.add_argument("-s", "--shapes", action="store_true",
help="display shapes of all images")
args = parser.parse_args()
assert args.fields is None if args.expression is not None else True
args.decode = eval(args.decode)
if args.fields is None and args.expression is None:
args.decode = False
if args.fields is not None:
fields = args.fields.split(",")
if args.keys:
args.fields = None
args.expression = "xks(_)"
args.decode = False
if args.shapes:
args.fields = None
args.expression = "xshapes(_)"
args.decode = True
def _fixtype(x):
if isinstance(x, (collections.abc.KeysView, collections.abc.ValuesView)):
x = list(x)
return x
def _info(x):
if isinstance(x, np.ndarray):
shape = ",".join([str(l) for l in x.shape])
lo, med, hi = amin(x), median(x), amax(x)
return "@{}[{:.2g}:{:.2g}:{:.2g}]".format(shape, lo, med, hi)
if not isinstance(x, str):
x = str(x)
return x
def cshape(a, sep=","):
return sep.join([str(l) for l in a.shape])
def sshape(a):
return cshape(a, " ")
def xshapes(x):
result = []
for k in sorted(x.keys()):
v = x[k]
if isinstance(v, np.ndarray):
result += [k+":"+cshape(v)]
return result
def xkeys(x):
return [k for k in x.keys() if k[0] != "_"]
def xsize(x):
return len(pickle.dumps(x, -1))
def xks(sample):
return [xsize(sample)] + xkeys(sample)
for fname in args.input:
inputs = gopen.sharditerator_once(fname, decode=args.decode)
for i, sample in enumerate(inputs):
source = sample.get("__source__")
key = sample.get("__key__")
out = [source, key]
if args.fields is not None:
for k in fields:
out += [sample.get(k)]
elif args.expression is not None:
_ = sample
result = eval(args.expression)
result = _fixtype(result)
if not isinstance(result, list):
result = [result]
out += result
out = [_info(x) for x in out]
print("\t".join(out))