-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathutils.py
54 lines (43 loc) · 1.81 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import sys
import h5py
import yaml
from fuel.datasets import H5PYDataset
from fuel.streams import DataStream
from fuel.schemes import SequentialScheme, ShuffledScheme
from fuel.transformers import Mapping
from blocks.extensions import saveload, predicates
from blocks.extensions.training import TrackTheBest
from blocks import main_loop
from fuel.utils import do_not_pickle_attributes
#Define this class to skip serialization of extensions
@do_not_pickle_attributes('extensions')
class MainLoop(main_loop.MainLoop):
def __init__(self, **kwargs):
super(MainLoop, self).__init__(**kwargs)
def load(self):
self.extensions = []
def transpose_stream(data):
return (data[0].T, data[1].T)
def track_best(channel, save_path):
tracker = TrackTheBest(channel, choose_best=min)
checkpoint = saveload.Checkpoint(
save_path, after_training=False, use_cpickle=True)
checkpoint.add_condition(["after_epoch"],
predicate=predicates.OnLogRecord('{0}_best_so_far'.format(channel)))
return [tracker, checkpoint]
def get_metadata(hdf5_file):
with h5py.File(hdf5_file) as f:
ix_to_char = yaml.load(f['targets'].attrs['ix_to_char'])
char_to_ix = yaml.load(f['targets'].attrs['char_to_ix'])
vocab_size = len(ix_to_char)
return ix_to_char, char_to_ix, vocab_size
def get_stream(hdf5_file, which_set, batch_size=None):
dataset = H5PYDataset(
hdf5_file, which_sets=(which_set,), load_in_memory=True)
if batch_size == None:
batch_size = dataset.num_examples
stream = DataStream(dataset=dataset, iteration_scheme=ShuffledScheme(
examples=dataset.num_examples, batch_size=batch_size))
# Required because Recurrent bricks receive as input [sequence, batch,
# features]
return Mapping(stream, transpose_stream)