Skip to content

Commit

Permalink
add: 物体検出を行うための処理
Browse files Browse the repository at this point in the history
  • Loading branch information
takahashitom committed Sep 8, 2024
1 parent ba2ee31 commit ccc30fd
Show file tree
Hide file tree
Showing 19 changed files with 4,239 additions and 12 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
**/__pycache__/
.coverage
coverage*
coverage*

# 推論
*.jpeg
fig_image
*.pt
*.flag
9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ help:
@echo " $$ make check_style"
@echo "カバレッジレポートの表示"
@echo " $$ make coverage"
@echo "サーバの立ち上げ"
@echo " $$ make server"
@echo フラグ管理用のファイルを全て削除する
@echo " $$ make flag-delete"

run:
poetry run python src
Expand All @@ -27,3 +31,8 @@ coverage:
poetry run coverage run -m pytest
poetry run coverage report

server: flag-delete
poetry run python -m src.server.flask_server

flag-delete:
find ./ -type f -name "*.flag" -exec rm {} +
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ opencv-python = "^4.8.0.74"
requests = "^2.31.0"
flask = "^3.0.3"
pillow = "^10.4.0"
torch = ">=2.0.0, !=2.0.1, !=2.1.0"
torchvision = ">=0.15.3"
torchaudio = "^2.0.2"
ultralytics = "^8.0.163"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand Down
235 changes: 235 additions & 0 deletions src/detect_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""物体検出を行うモジュール.
ベストショット画像を選択するための物体検出を行う。
NOTE:
重みファイル'exp17_best.pt'(現状)は以下にあります。
https://drive.google.com/drive/folders/1FZPZu1xNMaarVyKfaLlm3HZPFgYp5RCZ
etrobocon2024-camera-system/yolo/の中にダウンロードしてください
参考コード:
https://github.com/ultralytics/yolov5
@author: takahashitom
"""

import torch
from pathlib import Path
import os
import numpy as np
import sys
from ultralytics.utils.plotting import Annotator, colors

script_dir = os.path.dirname(os.path.abspath(__file__)) # noqa
YOLO_PATH = os.path.join(script_dir, "..", "yolo") # noqa
sys.path.append(YOLO_PATH) # noqa
YOLO_PATH = Path(YOLO_PATH) # noqa
from models.common import DetectMultiBackend
from utils.general import (
check_img_size, cv2, non_max_suppression, scale_boxes)
from utils.torch_utils import select_device
from utils.augmentations import letterbox

PROJECT_DIR_PATH = os.path.dirname(script_dir)
IMAGE_DIR_PATH = Path(os.path.join(PROJECT_DIR_PATH, "fig_image"))


class DetectObject():
"""yolov5(物体検出)をロボコン用に編集したクラス."""

__DEVICE = 'cpu'
__IMG_SIZE = (640, 480)

def __init__(self,
weights=YOLO_PATH/'exp17_best.pt',
label_data=YOLO_PATH/'label_data.yaml',
conf_thres=0.6,
iou_thres=0.45,
max_det=1,
line_thickness=1,
stride=32):
"""コンストラクタ.
Args:
weights (str): 重みファイルパス
label_data (str): ラベルを記述したファイルパス
conf_thres (float): 信頼度閾値
iou_thres (float): NMS IOU 閾値
max_det (int): 最大検出数
line_thickness (int): バウンディングボックスの太さ
stride (int): ストライド
"""
self.check_exist(weights)
self.check_exist(label_data)
self.weights = str(weights)
self.label_data = str(label_data)
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.max_det = max_det
self.line_thickness = line_thickness
self.stride = stride

@staticmethod
def check_exist(path: str) -> None:
"""ファイル, ディレクトリが存在するかの確認.
Args:
path (str): ファイルまたはディレクトリのパス
raise:
FileNotFoundError: ファイルがない場合に発生
"""
try:
if not os.path.exists(path):
raise FileNotFoundError(f"'{path}' is not found")

except FileNotFoundError as e:
print("Error:", e)

def detect_object(self,
img_path=IMAGE_DIR_PATH/'test_image.png',
save_path=None) -> list:
"""物体の検出を行う関数.
Args:
img_path(str): 物体検出を行う画像パス
save_path(str): 検出結果の画像保存パス
Noneの場合、保存しない
Returns:
list: 検出したオブジェクト
"""
self.check_exist(img_path)

# cpuを指定
device = select_device(self.__DEVICE)

# モデルの読み込み
model = DetectMultiBackend(self.weights,
device=device,
dnn=False,
data=self.label_data,
fp16=False)

stride, labels, pt = model.stride, model.names, model.pt

# 画像のサイズを指定されたストライド(ステップ)の倍数に合わせるための関数
img_size = check_img_size(
self.__IMG_SIZE, s=stride) # >> [640, 640]

# モデルの初期化
batch_size = 1
model.warmup(
imgsz=(1 if pt or model.triton else batch_size, 3, *img_size))

# 画像の読み込み
original_img = cv2.imread(img_path) # BGR

# パディング処理
img = letterbox(original_img,
self.__IMG_SIZE,
stride=self.stride,
auto=True)[0]
img = img.transpose((2, 0, 1))[::-1] # BGR -> RGB
img = np.ascontiguousarray(img) # 連続したメモリ領域に変換
img = torch.from_numpy(img).to(model.device) # PyTorchのテンソルに変換
img = img.half() if model.fp16 else img.float() # uint8 to fp16/32

# スケーリング
img /= 255 # 0 - 255 to 0.0 - 1.0

# torch.Size([3, 640, 640]) >> torch.Size([1, 3, 640, 640])
if len(img.shape) == 3:
img = img[None]

# 検出
pred = model(img, augment=False, visualize=False)

# 非最大値抑制 (NMS) により重複検出を拒否
pred = non_max_suppression(pred,
self.conf_thres, # 信頼度の閾値
self.iou_thres, # IoUの閾値
max_det=self.max_det, # 保持する最大検出数
classes=None, # 検出するクラスのリスト
agnostic=False # Trueの場合、クラスを無視してNMSを実行
)

# 検出結果を画像に描画
objects = pred[0]
print(Path(img_path).name, " 検出数", len(objects))

save_img = original_img.copy()

# 画像にバウンディングボックスやラベルなどのアノテーションを追加
annotator = Annotator(save_img,
line_width=self.line_thickness,
example=str(labels))

if len(objects):
# バウンディングボックスをimgサイズからsave_imgサイズに再スケールします
objects[:, :4] = scale_boxes(
img.shape[2:], objects[:, :4], save_img.shape).round()

if save_path is not None:
# xyxy: バウンディングボックスの座標([x_min, y_min, x_max, y_max] 形式)
# conf: 信頼度
# cls: クラスID
for *xyxy, conf, cls in reversed(objects):
c = int(cls)
label = f'{labels[int(cls)]} {conf:.2f}'
# 画像にバウンディングボックスとラベルを追加
annotator.box_label(xyxy, label, color=colors(c, True))

# 検出結果を含む画像を保存
save_img = annotator.result()
cv2.imwrite(save_path, save_img)

"""
NOTE:
objectsの型について
行数:検出数
列数:6列([x_min, y_min, x_max, y_max, conf, cls])
"""
return objects.tolist()


if __name__ == '__main__':
"""作業用.
$ poetry run python ./src/detect_object.py
--img_path fig_image/test_image.png
"""
import argparse
save_path = os.path.join(str(IMAGE_DIR_PATH), "detect_test_image.png")

parser = argparse.ArgumentParser(description="リアカメラに関するプログラム")

parser.add_argument("-wpath", "--weights", type=str,
default=YOLO_PATH/'exp17_best.pt',
help='重みファイルパス')
parser.add_argument("-label", "--label_data", type=str,
default=YOLO_PATH/'label_data.yaml',
help='ラベルを記述したファイルパス')
parser.add_argument("-conf", "--conf_thres", type=int,
default=0.6, help='信頼度閾値')
parser.add_argument("-iou", "--iou_thres", type=int,
default=0.45, help='IOU 閾値')
parser.add_argument("--max_det", type=int, default=10, help='最大検出数')
parser.add_argument("--line_thickness", type=int,
default=1, help='バウンディングボックスの太さ')
parser.add_argument("--stride", type=int, default=32, help='ストライド')
parser.add_argument("-img", "--img_path", type=str,
default=IMAGE_DIR_PATH/'FigA_1.png', help='入力画像')
parser.add_argument("-spath", "--save_path", type=str,
default=save_path, help='検出画像の保存先. Noneの場合保存しない')
args = parser.parse_args()

d = DetectObject(args.weights,
args.label_data,
args.conf_thres,
args.iou_thres,
args.max_det,
args.line_thickness,
args.stride)

objects = d.detect_object(args.img_path, args.save_path)
print("objects\n", objects)

print("完了")
78 changes: 67 additions & 11 deletions src/server/flask_server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
走行体から画像ファイルを受信するWebサーバー.
@author Keiya121 CHIHAYATAKU
@author Keiya121 CHIHAYATAKU takahashitom
"""

import os
import socket
import platform
from ..detect_object import DetectObject

from flask import Flask, request, jsonify
from flask import Flask, request, jsonify, send_file

app = Flask(__name__)

Expand All @@ -18,8 +19,8 @@
# '/upload'へのPOSTリクエストに対する操作


@app.route('/upload', methods=['POST'])
def getImageFile() -> jsonify:
@app.route('/images', methods=['POST'])
def get_image() -> jsonify:
"""走行体から、画像ファイルを取得するための関数."""
# curlコマンドのエラーハンドリング
if 'file' not in request.files:
Expand All @@ -30,12 +31,66 @@ def getImageFile() -> jsonify:
if file.filename == '':
return jsonify({"error": "No selected file"}), 400

fileName = file.filename
# src/server/datafilesに、受信したファイルを保存する。
filePath = os.path.join(UPLOAD_FOLDER, fileName)
file.save(filePath)
return jsonify({"message": "File uploaded successfully",
"filePath": filePath}), 200
file_name = file.filename

upload_folder = os.path.join(os.path.dirname(__file__), 'image_data')
os.makedirs(upload_folder, exist_ok=True)

# src/server/image_dataに、受信したファイルを保存する。
file_path = os.path.join(upload_folder, file_name)
file.save(file_path)

return jsonify({"message": "File uploaded successfully"}), 200

# '/detect'へのPOSTリクエストに対する操作


@app.route('/detect', methods=['POST'])
def get_detection_image() -> jsonify:
"""走行体から画像ファイルを取得し、物体検出した結果を送信するための関数."""
# curlコマンドのエラーハンドリング
if 'file' not in request.files:
return jsonify({"error": "No file part"}), 400

file = request.files['file']

if file.filename == '':
return jsonify({"error": "No selected file"}), 400

file_name = file.filename

upload_folder = os.path.join(os.path.dirname(__file__), 'image_data')
os.makedirs(upload_folder, exist_ok=True)

# src/server/image_dataに、受信したファイルを保存する
file_path = os.path.join(upload_folder, file_name)
file.save(file_path)

# 取得した画像に対し物体検出を行う
d = DetectObject()
detected_img_path = os.path.join(upload_folder, "detected_"+file_name)

try:
objects = d.detect_object(img_path=file_path,
save_path=detected_img_path)
print(objects)

cls = int(objects[0][5])
empty_file = os.path.abspath(f"{cls}_skip_camera_action.flag")

# 空のフラグ管理用ファイルを作成
with open(empty_file, 'w') as file:
pass

return send_file(empty_file,
as_attachment=True,
download_name=empty_file,
mimetype='text/plain'), 200
except Exception:
print("Error: detect failed")
objects = []
return jsonify({"message": "File uploaded successfully",
"detect_results": "detect failed"}), 200


# ポート番号の設定
Expand All @@ -47,7 +102,8 @@ def getImageFile() -> jsonify:
else:
host = os.uname()[1]

if host == "KatLabLaptop":
# if host == "KatLabLaptop":
if host == "LAPTOP-UNI0BH6G":
# ソケットを作成し、GoogleのDNSサーバ("8.8.8.8:80")
# に接続することで、IPアドレスを取得する。
# 参考: https://qiita.com/suzu12/items/b5c3d16aae55effb67c0
Expand Down
Loading

0 comments on commit ccc30fd

Please sign in to comment.