-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_toy_example.py
executable file
·70 lines (54 loc) · 2.44 KB
/
generate_toy_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python
from __future__ import division
import argparse
try:
import cPickle as pickle
except ImportError:
import pickle
import random
import sys
import numpy
PATCH_SIZE = 17
LINE_LENGTH = 7
NUM_LINES_PER_PATCH = 2
LINE_CENTER_SIZE = 5
LINE_PERIPHERY_SIZE = (PATCH_SIZE - LINE_CENTER_SIZE) // 2
HALF_LINE_LENGTH = (LINE_LENGTH - 1) // 2
CENTERED_ONE = numpy.zeros(LINE_LENGTH, dtype=numpy.float32)
CENTERED_ONE[HALF_LINE_LENGTH] = 1.0
LINE_TYPES = [numpy.eye(LINE_LENGTH, dtype=numpy.float32), # UL/BR diagonal
numpy.rot90(numpy.eye(LINE_LENGTH, dtype=numpy.float32)), # BL/TR diagonal
numpy.outer(CENTERED_ONE, numpy.ones(LINE_LENGTH, dtype=numpy.float32)), # horizontal, centered
numpy.outer(numpy.ones(LINE_LENGTH, dtype=numpy.float32), CENTERED_ONE), # verical, centered
]
def generate_sample_patch():
patch = numpy.zeros((PATCH_SIZE, PATCH_SIZE), dtype=numpy.float32)
lines = random.sample(LINE_TYPES, NUM_LINES_PER_PATCH)
for line in lines:
line_center = [random.randint(LINE_PERIPHERY_SIZE, LINE_PERIPHERY_SIZE + LINE_CENTER_SIZE) for dim in xrange(2)]
patch_area_slices = [slice(dim_center - HALF_LINE_LENGTH, dim_center + HALF_LINE_LENGTH + 1) for dim_center in line_center]
patch[patch_area_slices] = numpy.maximum(patch[patch_area_slices], line)
return patch
def main(argv=None):
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(description="Generate toy examples for reproducing Ranzato et al 07.")
parser.add_argument("output_filename",
help="output filename"
)
parser.add_argument("-n", "--num-samples",
type=int,
required=True,
help="number of samples to produce"
)
parser.add_argument("-r", "--random-seed",
type=int,
help="rng seed, for reproducibility"
)
args = parser.parse_args()
random.seed(args.random_seed) # if no seed is provided, args.random_seed defaults to None, which seeds randomly
sample_patches = numpy.asarray([generate_sample_patch() for sample_num in xrange(args.num_samples)])
with open(args.output_filename, "wb") as outfile:
pickle.dump(sample_patches, outfile, protocol=pickle.HIGHEST_PROTOCOL)
if __name__ == '__main__':
sys.exit(main())