-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
28 lines (24 loc) · 871 Bytes
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import generate_captch
import tensorflow as tf
class CaptchData():
def __init__(self, config):
self.config = config
self.generator = generate_captch.CaptchaGenerator(
char_set=self.config.char_set,
lengths=self.config.text_lengths,
shape=self.config.image_shape)
def train_input_fn(self):
dataset = tf.data.Dataset.from_generator(
self.generator,output_types=(tf.float32,tf.int32,tf.int32))
dataset = dataset.batch(self.config.batch_size)
train_iterator = dataset.make_one_shot_iterator()
images,labels,seq_lens = train_iterator.get_next()
images = images*(2./255)-1
batch = {
'image': images,
'label':labels,
'seq_len': seq_lens,
}
return batch
def test_input_fn(self):
pass