-
Notifications
You must be signed in to change notification settings - Fork 2
/
api_onnx.py
102 lines (89 loc) · 3.7 KB
/
api_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastapi import FastAPI, File, UploadFile, HTTPException
import logging
import sys
from order import OrderPolygons
from paddleocr import PaddleOCR
from PIL import Image
import io
import numpy as np
import cv2
# For logging options see
# https://docs.python.org/3/library/logging.html
logging.basicConfig(filename='api_log_onnx.log', filemode='w', format='%(asctime)s %(message)s', datefmt='%d/%m/%Y %H:%M:%S', level=logging.INFO)
#initialize ordering
order = OrderPolygons()
#path to folder where the inference model is located.
onnx_rec_model = './onnx_model/rec_onnx/model.onnx'
onnx_det_model = './onnx_model/det_onnx/model.onnx'
onnx_cls_model = './onnx_model/cls_onnx/model.onnx'
try:
# Initialize API Server
app = FastAPI()
except Exception as e:
logging.error('Failed to start the API server: %s' % e)
sys.exit(1)
# Function is run (only) before the application starts
@app.on_event("startup")
async def load_model():
"""
Load the pretrained model on startup.
"""
try:
#load model
model = PaddleOCR(lang='latin', show_log=False, det=True,
use_angle_cls=True,
rec_model_dir=onnx_rec_model,
det_model_dir=onnx_det_model,
cls_model_dir=onnx_cls_model,
use_gpu=False,
use_onnx=True)
# Add model to app state
app.package = {"model": model}
except Exception as e:
logging.error('Failed to load the model file: %s' % e)
raise HTTPException(status_code=500, detail='Failed to load the model file: %s' % e)
def predict(path, use_angle_cls, reorder_texts):
"""
Perform prediction on input image.
"""
# Get model from app state
model = app.package["model"]
res = model.ocr(path, cls=use_angle_cls)
if reorder_texts:
boxes = [i[0] for i in res[0]]
new_boxes = [[box[1][0], box[0][0], box[2][1], box[0][1]] for box in boxes]
new_order = order.order(new_boxes)
res[0] = [res[0][i] for i in new_order]
return res[0]
# Endpoint for GET requests: input image path is received with the http request
@app.get("/predict_path")
async def predict_path(path: str, use_angle_cls: bool, reorder_texts: bool):
# Get predicted class and confidence
try:
predictions = predict(path,
use_angle_cls=use_angle_cls,
reorder_texts=reorder_texts)
except Exception as e:
logging.error('Failed to analyze the input image file: %s' % e)
raise HTTPException(status_code=500, detail='Failed to analyze the input image file: %s' % e)
return predictions
# Endpoint for POST requests: input image is received with the http request
@app.post("/predict_image")
async def predict_image(use_angle_cls: bool, reorder_texts: bool, file: UploadFile = File(...)):
try:
# Loads the image sent with the POST request
req_content = await file.read()
image = np.array(Image.open(io.BytesIO(req_content)).convert('RGB'))
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
except Exception as e:
logging.error('Failed to load the input image file: %s' % e)
raise HTTPException(status_code=400, detail='Failed to load the input image file: %s' % e)
# Get predicted class and confidence
try:
predictions = predict(image,
use_angle_cls=use_angle_cls,
reorder_texts=reorder_texts)
except Exception as e:
logging.error('Failed to analyze the input image file: %s' % e)
raise HTTPException(status_code=500, detail='Failed to analyze the input image file: %s' % e)
return predictions