forked from aws-samples/aws-research-workshops
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pipemode.py
46 lines (37 loc) · 1.4 KB
/
pipemode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import os
import tensorflow as tf
from sagemaker_tensorflow import PipeModeDataset
from tensorflow.contrib.data import map_and_batch
PREFETCH_SIZE = 10
BATCH_SIZE = 64
NUM_PARALLEL_BATCHES = 2
DIMENSION = 1024
EPOCHS = 1
def estimator_fn(run_config, params):
column = tf.feature_column.numeric_column('data', shape=(DIMENSION, ))
return tf.estimator.LinearClassifier(feature_columns=[column], config=run_config)
def train_input_fn(training_dir, params):
"""Returns input function that would feed the model during training"""
return _input_fn('train')
def eval_input_fn(training_dir, params):
"""Returns input function that would feed the model during evaluation"""
return _input_fn('eval')
def _input_fn(channel):
"""Returns a Dataset for reading from a SageMaker PipeMode channel."""
features = {
'data': tf.FixedLenFeature([], tf.string),
'labels': tf.FixedLenFeature([], tf.int64),
}
def parse(record):
parsed = tf.parse_single_example(record, features)
return ({
'data': tf.decode_raw(parsed['data'], tf.float64)
}, parsed['labels'])
ds = PipeModeDataset(channel)
if EPOCHS > 1:
ds = ds.repeat(EPOCHS)
ds = ds.prefetch(PREFETCH_SIZE)
ds = ds.apply(map_and_batch(parse, batch_size=BATCH_SIZE,
num_parallel_batches=NUM_PARALLEL_BATCHES))
return ds