-
Notifications
You must be signed in to change notification settings - Fork 1
/
gen_multnist_data.py
executable file
·126 lines (103 loc) · 4.78 KB
/
gen_multnist_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import collections
import os
from torchvision import datasets, transforms
import numpy as np
import math
def train_test_filter(op, combs, ratio):
paths = {}
for i, j, k in combs:
val = op(i,j,k)
if val not in paths:
paths[val] = []
paths[val].append((i, j, k))
train_combs, test_combs = [], []
for val, paths in paths.items():
idxs = np.arange(len(paths))
if len(paths) > 1:
train_idxs = np.random.choice(idxs, size=int(len(paths) * ratio), replace=False)
train_combs += [paths[i] for i in train_idxs]
test_combs += [paths[i] for i in idxs if i not in train_idxs]
return train_combs, test_combs
def generate_examples(op, combs, weights, nums, n):
x, y, metainfo = [], [], []
for n1, n2, n3 in combs:
mod = op(n1, n2, n3)
r_idxs = np.arange(len(nums[n1]))
g_idxs = np.arange(len(nums[n2]))
b_idxs = np.arange(len(nums[n3]))
for _ in range(int(n * weights[mod])):
r = nums[n1][np.random.choice(r_idxs)]
g = nums[n2][np.random.choice(g_idxs)]
b = nums[n3][np.random.choice(b_idxs)]
x.append(np.vstack([r, g, b]))
y.append(mod)
metainfo.append([n1,n2,n3])
return x, y, metainfo
def proc_weights(op, combs):
weights = collections.Counter(op(i, j, k) for i, j, k in combs)
weights = {k: 1/(len(weights)*v) for k,v in weights.items()}
return weights
def class_balanced_truncater(x, y, total, nclasses, metainfo = None):
outx, outy, outm = [], [], []
for c in range(nclasses):
idxs = np.random.choice(np.where(y == c)[0], size=math.floor(total/nclasses))
outx.append(x[idxs])
outy.append(y[idxs])
if metainfo is not None:
outm += [metainfo[i] for i in idxs]
return np.concatenate(outx), np.concatenate(outy), outm
def generate_data(op, lb, ub):
download = 'MNIST' not in os.listdir('raw_data')
train_data = datasets.MNIST('raw_data/MNIST',
train=True,
download=download,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]
))
test_data = datasets.MNIST('raw_data/MNIST',
train=False,
download=download,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]
))
train_nums = {i:[] for i in range(10)}
test_nums = {i:[] for i in range(10)}
for image, number in train_data:
train_nums[number].append(image)
for image, number in test_data:
test_nums[number].append(image)
combs = [(i, j, k) for i in range(10) for j in range(10) for k in range(10) if lb <= op(i, j, k) <= ub]
train_combs, test_combs = train_test_filter(op, combs, .75)
train_weights = proc_weights(op, train_combs)
test_weights = proc_weights(op, test_combs)
train_n = 60000
test_n = 10000
train_x, train_y, metainfo = generate_examples(op,
combs,
train_weights,
train_nums,
train_n)
test_x, test_y, _ = generate_examples(op,
combs,
test_weights,
test_nums,
test_n)
train_x, train_y = np.array(train_x, dtype=np.float32).squeeze(), np.array(train_y).squeeze()
test_x, test_y = np.array(test_x, dtype=np.float32).squeeze(), np.array(test_y).squeeze()
train_x, train_y, metainfo = class_balanced_truncater(train_x, train_y, train_n, len(train_weights), metainfo)
test_x, test_y, _ = class_balanced_truncater(test_x, test_y, test_n, len(test_weights))
train_shuff = np.arange(len(train_y))
np.random.shuffle(train_shuff)
test_shuff = np.arange(len(test_y))
np.random.shuffle(test_shuff)
train_x, train_y, metainfo = train_x[train_shuff], train_y[train_shuff], [metainfo[i] for i in train_shuff]
test_x, test_y = test_x[test_shuff], test_y[test_shuff]
return (train_x[:60000], train_y[:60000], metainfo[:60000]), (test_x[:10000], test_y[:10000])
def load_multnist_data():
op = lambda i,j,k: (i * j * k) % 10
return generate_data(op, 0, 9)
def load_addnist_data():
op = lambda i,j,k: (i + j + k) - 1
return generate_data(op, 0, 19)