-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
69 lines (50 loc) · 1.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from flask import Flask
from flask import jsonify
from flask import request
from flask_cors import CORS, cross_origin
from transformers import pipeline
from config import config
def create_app(arg_environment):
local_app = Flask(__name__)
local_app.config.from_object(arg_environment)
return local_app
qa_pipeline_v1 = pipeline(
"question-answering",
model="mrm8488/bert-multi-cased-finetuned-xquadv1",
tokenizer="mrm8488/bert-multi-cased-finetuned-xquadv1"
)
qa_pipeline_v2 = pipeline(
"question-answering",
model="LeoAngel/bert-finetuned-crossxquadv1_25sbl",
tokenizer="LeoAngel/bert-finetuned-crossxquadv1_25sbl"
)
environment = config['development']
app = create_app(environment)
CORS(app, support_credentials=True)
@app.route('/ping', methods=['GET'])
def get_ping():
return 'pong'
@cross_origin(supports_credentials=True)
@app.route('/v1/predict', methods=['POST'])
def post_v1_predict():
data = request.json
predictions = []
for question in data['questions']:
answer = qa_pipeline_v1({
'context': data['text'],
'question': question})
predictions.append({'question': question, 'answer': answer})
return jsonify({'predictions': predictions})
@cross_origin(supports_credentials=True)
@app.route('/v2/predict', methods=['POST'])
def post_v2_predict():
data = request.json
predictions = []
for question in data['questions']:
answer = qa_pipeline_v2({
'context': data['text'],
'question': question})
predictions.append({'question': question, 'answer': answer})
return jsonify({'predictions': predictions})
if __name__ == '__main__':
app.run(host='0.0.0.0', debug=False)