Skip to content

Commit

Permalink
new visual
Browse files Browse the repository at this point in the history
  • Loading branch information
Koldim2001 committed Sep 23, 2024
1 parent ae796c6 commit 53a455f
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions patched_yolo_infer/functions_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ def visualize_results_usual_yolo_inference(

class_names = model.names

# Map class IDs to indices in the color list
all_classes = set(cls for pred in predictions for cls in pred.boxes.cls.cpu().int().tolist())
class_to_color_index = {cls_id: idx for idx, cls_id in enumerate(all_classes)}

# Process each prediction
for pred in predictions:

Expand Down Expand Up @@ -120,7 +116,7 @@ def visualize_results_usual_yolo_inference(
random.seed(int(classes[i] + delta_colors))
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
else:
color = list_of_class_colors[class_to_color_index[class_index]]
color = list_of_class_colors[class_index]

box = boxes[i]
x_min, y_min, x_max, y_max = box
Expand Down Expand Up @@ -148,7 +144,7 @@ def visualize_results_usual_yolo_inference(
label = str(class_name)
(text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
background_color = (
color_class_background[class_to_color_index[class_index]]
color_class_background[class_index]
if isinstance(color_class_background, list)
else color_class_background
)
Expand Down Expand Up @@ -339,10 +335,6 @@ def visualize_results(
if random_object_colors:
random.seed(int(delta_colors))

# Map class IDs to indices in the color list
unique_classes = set(classes_ids)
class_to_color_index = {cls_id: idx for idx, cls_id in enumerate(unique_classes)}

# Process each prediction
for i in range(len(classes_ids)):
# Get the class for the current detection
Expand All @@ -361,7 +353,7 @@ def visualize_results(
random.seed(int(classes_ids[i] + delta_colors))
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
else:
color = list_of_class_colors[class_to_color_index[classes_ids[i]]]
color = list_of_class_colors[classes_ids[i]]

box = boxes[i]
x_min, y_min, x_max, y_max = box
Expand Down Expand Up @@ -409,7 +401,7 @@ def visualize_results(
label = str(class_name)
(text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
background_color = (
color_class_background[class_to_color_index[classes_ids[i]]]
color_class_background[classes_ids[i]]
if isinstance(color_class_background, list)
else color_class_background
)
Expand Down

0 comments on commit 53a455f

Please sign in to comment.