-
Notifications
You must be signed in to change notification settings - Fork 3
/
cars.py
78 lines (60 loc) · 2.09 KB
/
cars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import sys
from io import BytesIO
from pathlib import Path
import aiohttp
import numpy as np
import scipy.io
import uvicorn
from fastai.vision import (
load_learner,
open_image,
)
from starlette.applications import Starlette
from starlette.responses import JSONResponse, HTMLResponse, RedirectResponse
async def get_bytes(url):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.read()
app = Starlette()
car_images_path = Path("/")
car_learner = load_learner(car_images_path, file='export-rn101_train_stage2-50e.pkl')
@app.route("/upload", methods=["POST"])
async def upload(request):
data = await request.form()
bytes = await (data["file"].read())
return predict_image_from_bytes(bytes)
@app.route("/classify-url", methods=["GET"])
async def classify_url(request):
bytes = await get_bytes(request.query_params["url"])
return predict_image_from_bytes(bytes)
def predict_image_from_bytes(bytes):
img = open_image(BytesIO(bytes))
pred_class, pred_idx, confidence = car_learner.predict(img)
cars_meta = scipy.io.loadmat('devkit/cars_meta')
class_names = cars_meta['class_names'] # shape=(1, 196)
class_names = np.transpose(class_names)
return JSONResponse({
"prediction_class": class_names[pred_idx.item()][0][0],
"confidence": confidence[pred_idx.item()].item()
})
@app.route("/")
def form(request):
return HTMLResponse(
"""
<form action="/upload" method="post" enctype="multipart/form-data">
Select image to upload:
<input type="file" name="file">
<input type="submit" value="Upload Image">
</form>
Or submit a URL:
<form action="/classify-url" method="get">
<input type="url" name="url">
<input type="submit" value="Fetch and analyze image">
</form>
""")
@app.route("/form")
def redirect_to_homepage(request):
return RedirectResponse("/")
if __name__ == "__main__":
if "serve" in sys.argv:
uvicorn.run(app, host="0.0.0.0", port=8008)