-
Notifications
You must be signed in to change notification settings - Fork 0
/
web.py
73 lines (56 loc) · 2.15 KB
/
web.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import json
import glob
import numpy as np
from PIL import Image
import tensorflow as tf
from flask import Flask, request, jsonify, send_file
from model import efficientnetv2_s as create_model
from flask_cors import CORS
app = Flask(__name__)
CORS(app) # 将CORS中间件添加到Flask应用程序中
num_classes = 17 # 分类总数
img_size = {"s": 384, "m": 480, "l": 480}
num_model = "s" # 使用哪个模型
im_height = im_width = img_size[num_model] # 图片的宽和高
# 读取类别字典
json_path = './class_indices.json'
assert os.path.exists(json_path), "文件: '{}' 不存在。".format(json_path) # 判断文件是否存在
with open(json_path, "r") as f:
class_indict = json.load(f)
# 创建模型
model = create_model(num_classes=num_classes)
weights_path = './save_weights/efficientnetv2.ckpt'
assert len(glob.glob(weights_path + "*")), "找不到 {}".format(weights_path) # 判断文件是否存在
model.load_weights(weights_path) # 加载预训练模型权重
def preprocess_image(image):
image = image.resize((im_width, im_height))
image = np.array(image).astype(np.float32)
if len(image.shape) == 2: # 处理灰度图像
image = np.stack([image] * 3, axis=-1)
elif image.shape[2] == 4: # 处理带有Alpha通道的图像
image = image[:, :, :3]
image = (image / 255. - 0.5) / 0.5
image = np.expand_dims(image, 0)
return image
@app.route('/')
def index():
return send_file('index.html')
@app.route('/predict', methods=['POST'])
def predict():
if 'image' not in request.files:
return jsonify({'error': '请求中没有找到图像'})
image_file = request.files['image']
image = Image.open(image_file)
image_data = preprocess_image(image)
result = np.squeeze(model.predict(image_data))
result = tf.keras.layers.Softmax()(result)
predict_class = np.argmax(result)
print(predict_class)
prediction = {
'class': class_indict[str(predict_class)],
'probability': float(result[predict_class])
}
return jsonify(prediction)
if __name__ == '__main__':
app.run()