-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathapi.py
executable file
·49 lines (41 loc) · 1.58 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/usr/local/bin/python3
import numpy as np
import tensorflow as tf
from functools import wraps
from flask import Flask, request, jsonify
"""
Load a tensorflow model and make it available as a REST service
"""
app = Flask(__name__)
def parse_postget(f):
@wraps(f)
def wrapper(*args, **kw):
try:
d = dict((key, request.values.getlist(key) if len(request.values.getlist(
key)) > 1 else request.values.getlist(key)[0]) for key in request.values.keys())
except BadRequest as e:
raise Exception("Payload must be a valid json. {}".format(e))
return f(d)
return wrapper
@app.route('/model', methods=['GET', 'POST'])
@parse_postget
def apply_model(d):
tf.reset_default_graph()
with tf.Session() as session:
n = 1
x = tf.placeholder(tf.float32, [n], name='x')
y = tf.placeholder(tf.float32, [n], name='y')
m = tf.Variable([1.0], name='m')
b = tf.Variable([1.0], name='b')
y = tf.add(tf.mul(m, x), b) # fit y_i = m * x_i + b
y_act = tf.placeholder(tf.float32, [n], name='y_')
# minimize sum of squared error between trained and actual.
error = tf.sqrt((y - y_act) * (y - y_act))
train_step = tf.train.AdamOptimizer(0.05).minimize(error)
feed_dict = {x: np.array([float(d['x_in'])]), y_act: np.array([float(d['y_star'])])}
saver = tf.train.Saver()
saver.restore(session, 'linear.chk')
y_i, _, _ = session.run([y, m, b], feed_dict)
return jsonify(output=float(y_i))
if __name__ == '__main__':
app.run(debug=True)