forked from kootenpv/neural_complete
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserve.py
49 lines (33 loc) · 1.27 KB
/
serve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
from cors import crossdomain
from flask import Flask, jsonify, request
from neural_complete import neural_complete
from neural_complete import get_model
def read_models(base_path="models/"):
return set([x.split(".")[0] for x in os.listdir(base_path)])
app = Flask(__name__)
models = {x: get_model(x) for x in read_models()}
def get_args(req):
if request.method == 'POST':
args = request.json
elif request.method == "GET":
args = request.args
return args
@app.route("/predict", methods=["GET", "POST", "OPTIONS"])
@crossdomain(origin='*', headers="Content-Type")
def predict():
args = get_args(request)
sentence = args.get("keyword", "from ")
model_name = args.get("model", "char")
if model_name not in models:
models[model_name] = get_model(model_name)
suggestions = neural_complete(models[model_name], sentence, [0.2, 0.5, 1])
return jsonify({"data": {"results": [x.strip() for x in suggestions]}})
@app.route("/get_models", methods=["GET", "POST", "OPTIONS"])
@crossdomain(origin='*', headers="Content-Type")
def get_models():
return jsonify({"data": {"results": list(models)}})
def main(host="127.0.0.1", port=9078):
app.run(host=host, port=port, debug=True)
if __name__ == "__main__":
main()