From b1eff0c4cbb3c41f7289554f1b6789d0a67f0757 Mon Sep 17 00:00:00 2001 From: guk98 Date: Sun, 30 Jul 2023 07:09:09 +0000 Subject: [PATCH] =?UTF-8?q?refactor:=20collect=5Fgallery=5Fdata=20?= =?UTF-8?q?=ED=95=A8=EC=88=98=EB=A5=BC=20ReId=20=EB=82=B4=EB=B6=80=20?= =?UTF-8?q?=EB=A9=94=EC=86=8C=EB=93=9C=EB=A1=9C=20=EC=9D=B4=EB=8F=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- re_id/module/reid.py | 26 ++++++++++++++++++++++---- utils/visualize.py | 27 +-------------------------- 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/re_id/module/reid.py b/re_id/module/reid.py index 475e733..de88f93 100644 --- a/re_id/module/reid.py +++ b/re_id/module/reid.py @@ -2,6 +2,7 @@ import numpy as np import faiss import os +import cv2 from torch.nn.functional import cosine_similarity from torchvision.transforms import functional as F @@ -62,8 +63,7 @@ def __init__(self, model, checkpoint, gallery_path=None, person_thr=0.6, cosine_ self.player_dict_init(10) self.gallery_dataset = GalleryDataset(self.gallery_path, self.tf) self.faiss_index_init() - - + def shot_re_id_inference(self, frame, results): person_img_lst = self.shot_person_query_lst(frame, results) detected_query = person_img_lst @@ -171,8 +171,6 @@ def shot_person_query_lst(self, frame, results): person_img_lst.append(tf_person_img) return person_img_lst - - def init_gallery(self, frames): """_summary_ @@ -236,6 +234,26 @@ def faiss_index_init(self): self.faiss_index.add_with_ids(vector, np.array(label)) print('Done') + + def collect_gallery_data(self, detect_model, video_path): + print('Collecting Gallery Data...') + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_num = cap.get(cv2.CAP_PROP_POS_FRAMES) + + gallery_img_lst = [] + + for index, f in enumerate(tqdm(range(total_frames_num))): + interval = 15 + if index%interval==0: + ret, frame = cap.read() + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + results = detect_model.predict(img) + person_idx_lst, person_img_lst = self.person_query_lst(img, results, 0.9) + gallery_img_lst.append(person_img_lst) + + return gallery_img_lst class GalleryDataset(Dataset): diff --git a/utils/visualize.py b/utils/visualize.py index 9755057..4452129 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -20,28 +20,6 @@ (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176), ] - -def collect_gallery_data(detect_model, re_id, video_path): - print('Collecting Gallery Data...') - cap = cv2.VideoCapture(video_path) - fps = cap.get(cv2.CAP_PROP_FPS) - total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - frame_num = cap.get(cv2.CAP_PROP_POS_FRAMES) - - gallery_img_lst = [] - - for index, f in enumerate(tqdm(range(total_frames_num))): - interval = 15 - if index%interval==0: - ret, frame = cap.read() - img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - results = detect_model.predict(img) - person_idx_lst, person_img_lst = re_id.person_query_lst(img, results, 0.9) - gallery_img_lst.append(person_img_lst) - - return gallery_img_lst - - def make_predicted_video(detect_model, re_id, video_path, save_path): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): @@ -60,7 +38,7 @@ def make_predicted_video(detect_model, re_id, video_path, save_path): first_frame = True re_id.shot_id = -1 - gallery_samples = collect_gallery_data(detect_model, re_id,video_path) + gallery_samples = re_id.collect_gallery_data(detect_model, video_path) re_id.init_gallery(gallery_samples) print("Inferencing...") @@ -162,10 +140,8 @@ def draw_id(img, id_dict, thr=0.5): cv2.rectangle(img, p1, p2, id_color[id], -1, cv2.LINE_AA) cv2.putText(img, f'ID_{str(id)}', (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), 0, thick / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA) - return img - def draw_scoreboard(draw_img, re_id, side=False): s_w, s_h = (50,65) ply_num = 0 @@ -197,7 +173,6 @@ def draw_scoreboard(draw_img, re_id, side=False): cv2.putText(draw_img, f'Shoot_Count: {tracker.shot_count}', (10,830), 0, 1, (255,255,255), thickness=2, lineType=cv2.LINE_AA) cv2.putText(draw_img, f'Made_Count: {tracker.made_count}', (10,860), 0, 1, (255,255,255), thickness=2, lineType=cv2.LINE_AA) - return draw_img def side_results(results):