-
Notifications
You must be signed in to change notification settings - Fork 68
/
data_iter.py
executable file
·66 lines (55 loc) · 2.28 KB
/
data_iter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
from collections import defaultdict
from itertools import izip
class DataIterator(object):
def __init__(self, tune_lens, tune_idxs, batch_size, random_lens=False):
self.batch_size = batch_size
self.ntunes = len(tune_lens)
self.tune_idxs = tune_idxs
self.len2idx = defaultdict(list)
for k, v in izip(tune_lens, tune_idxs):
self.len2idx[k].append(v)
self.random_lens = random_lens
self.rng = np.random.RandomState(42)
def __iter__(self):
if self.random_lens:
for batch_idxs in self.__iter_random_lens():
yield np.int32(batch_idxs)
else:
for batch_idxs in self.__iter_homogeneous_lens():
yield np.int32(batch_idxs)
def __iter_random_lens(self):
available_idxs = np.copy(self.tune_idxs)
while len(available_idxs) >= self.batch_size:
rand_idx = self.rng.choice(range(len(available_idxs)), size=self.batch_size, replace=False)
yield available_idxs[rand_idx]
available_idxs = np.delete(available_idxs, rand_idx)
def __iter_homogeneous_lens(self):
for idxs in self.len2idx.itervalues():
self.rng.shuffle(idxs)
progress = defaultdict(int)
available_lengths = self.len2idx.keys()
batch_idxs = []
b_size = self.batch_size
get_tune_len = lambda: self.rng.choice(available_lengths)
k = get_tune_len()
while available_lengths:
batch_idxs.extend(self.len2idx[k][progress[k]:progress[k] + b_size])
progress[k] += b_size
if len(batch_idxs) == self.batch_size:
yield batch_idxs
batch_idxs = []
b_size = self.batch_size
k = get_tune_len()
else:
b_size = self.batch_size - len(batch_idxs)
i = available_lengths.index(k)
del available_lengths[i]
if not available_lengths:
break
if i == 0:
k = available_lengths[0]
elif i >= len(available_lengths) - 1:
k = available_lengths[-1]
else:
k = available_lengths[i + self.rng.choice([-1, 0])]