Skip to content

Commit

Permalink
Visualization bug fixed.
Browse files Browse the repository at this point in the history
  • Loading branch information
kadirnar committed Jan 26, 2023
1 parent 02357ea commit 293d4e5
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 28 deletions.
2 changes: 1 addition & 1 deletion torchyolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from torchyolo.predict import YoloHub

__version__ = "1.0.2"
__version__ = "1.1.0"
11 changes: 3 additions & 8 deletions torchyolo/modelhub/yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,14 @@ def predict(
label = f"Id:{track_id} {category_name} {float(score):.2f}"

if self.save or self.show:
frame = video_vis(
img_src = video_vis(
bbox=bbox,
label=label,
frame=img_src,
object_id=int(category_id),
)
if self.save:
video_writer.write(frame)

if self.show:
cv2.imshow("frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
if self.save:
video_writer.write(img_src)

else:
for pred in prediction.cpu().detach().numpy():
Expand Down
4 changes: 2 additions & 2 deletions torchyolo/modelhub/yolov6.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def predict(
frame=img_src,
object_id=int(category_id),
)
if self.save:
video_writer.write(frame)
if self.save:
video_writer.write(frame)
else:
for *xyxy, conf, cls in det:
label = f"{COCO_CLASSES[int(cls)]} {float(conf):.2f}"
Expand Down
9 changes: 2 additions & 7 deletions torchyolo/modelhub/yolov7.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,8 @@ def predict(
frame=img_src,
object_id=int(category_id),
)
if self.save:
video_writer.write(frame)

if self.show:
cv2.imshow("frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
if self.save:
video_writer.write(frame)

else:
for pred in prediction.cpu().detach().numpy():
Expand Down
10 changes: 3 additions & 7 deletions torchyolo/modelhub/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def predict(
score = prediction[:].boxes.conf
category_id = prediction[:].boxes.cls
dets = torch.cat((boxes, score.unsqueeze(1), category_id.unsqueeze(1)), dim=1)
tracker_outputs[image_id] = tracker_module.update(dets, img_src)
tracker_outputs[image_id] = tracker_module.update(dets.cpu(), img_src)
for output in tracker_outputs[image_id]:
bbox, track_id, category_id, score = (
output[:4],
Expand All @@ -104,13 +104,9 @@ def predict(
frame=img_src,
object_id=int(category_id),
)
if self.save:
video_writer.write(frame)
if self.save:
video_writer.write(frame)

if self.show:
cv2.imshow("frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:

for image_id, prediction in enumerate(results[0].boxes.cpu().numpy()):
Expand Down
9 changes: 6 additions & 3 deletions torchyolo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ def predict(
if __name__ == "__main__":
model = YoloHub(
config_path="torchyolo/configs/default_config.yaml",
model_type="yolov6",
model_path="yolov6l.pt",
model_type="yolov8",
model_path="yolov8s.pt",
)
result = model.predict(
source="1.mp4", tracker_type="NORFAIR", tracker_config_path="torchyolo/configs/tracker/norfair_track.yaml"
source="../test.mp4",
tracker_type="SORT",
tracker_config_path="torchyolo/configs/tracker/sort_track.yaml",
tracker_weight_path="osnet_x1_0_imagenet.pt",
)

0 comments on commit 293d4e5

Please sign in to comment.