-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathdata_loader.py
57 lines (46 loc) · 1.72 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import numpy as np
import six.moves.cPickle as pickle
import torch
import torch.nn as nn
import torch.utils.data as data
class PolyphonicDataset(data.Dataset):
def __init__(self, filepath):
# 1. Initialize file path or list of file names.
"""read training sequences(list of int array) from a pickle file"""
print("loading data...")
f= open(filepath, "rb")
data = pickle.load(f)
self.seqs = data['sequences']
self.seqlens = data['seq_lens']
self.data_len = len(self.seqs)
print("{} entries".format(self.data_len))
def __getitem__(self, offset):
seq=self.seqs[offset].astype('float32')
rev_seq= seq.copy()
rev_seq[0:len(seq), :] = seq[(len(seq)-1)::-1, :]
seq_len=self.seqlens[offset].astype('int64')
return seq, rev_seq, seq_len
def __len__(self):
return self.data_len
class SyntheticDataset(data.Dataset):
def __init__(self, filepath):
# 1. Initialize file path or list of file names.
"""read training sequences(list of int array) from a pickle file"""
print("loading data...")
f= open(filepath, "rb")
data = pickle.load(f)
self.seqs = data['sequences']
self.seqlens = data['seq_lens']
self.z = data['z']
self.data_len = len(self.seqs)
print("{} entries".format(self.data_len))
def __getitem__(self, offset):
seq=self.seqs[offset].astype('float32')
rev_seq= seq.copy()
rev_seq[0:len(seq), :] = seq[(len(seq)-1)::-1, :]
seq_len=self.seqlens[offset].astype('int64')
z = self.z[offset]
return seq, rev_seq, seq_len, z
def __len__(self):
return self.data_len