diff --git a/supervision/detection/line_counter.py b/supervision/detection/line_counter.py index ae662d116..283c901dc 100644 --- a/supervision/detection/line_counter.py +++ b/supervision/detection/line_counter.py @@ -51,8 +51,10 @@ def __init__( self.vector = Vector(start=start, end=end) self.limits = self.calculate_region_of_interest_limits(vector=self.vector) self.tracker_state: Dict[str, bool] = {} - self.in_count: int = 0 - self.out_count: int = 0 + # self.in_count: int = 0 + # self.out_count: int = 0 + self.class_in_count: Dict[int, int] = {} # Per-class in count + self.class_out_count: Dict[int, int] = {} # Per-class out count self.triggering_anchors = triggering_anchors @staticmethod @@ -124,6 +126,7 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]: if tracker_id is None: continue + class_label = detections.class_labels[i] # To get class label box_anchors = [Point(x=x, y=y) for x, y in all_anchors[:, i, :]] in_limits = all( @@ -156,10 +159,21 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]: if tracker_state: self.in_count += 1 crossed_in[i] = True + + # Update per-class in count + if class_label not in self.class_in_count: + self.class_in_count[class_label] = 0 + self.class_in_count[class_label] += 1 + else: self.out_count += 1 crossed_out[i] = True + # Update per-class out count + if class_label not in self.class_out_count: + self.class_out_count[class_label] = 0 + self.class_out_count[class_label] += 1 + return crossed_in, crossed_out @@ -284,28 +298,31 @@ def annotate(self, frame: np.ndarray, line_counter: LineZone) -> np.ndarray: ) if self.display_in_count: - in_text = ( - f"{self.custom_in_text}: {line_counter.in_count}" - if self.custom_in_text is not None - else f"in: {line_counter.in_count}" - ) - self._annotate_count( - frame=frame, - center_text_anchor=text_anchor.center, - text=in_text, - is_in_count=True, - ) + for class_label, count in line_counter.class_in_count.items(): + in_text = ( + f"{self.custom_in_text}: {count} - Class {class_label}" + if self.custom_in_text is not None + else f"in: {count} - Class {class_label}" + ) + self._annotate_count( + frame=frame, + center_text_anchor=text_anchor.center, + text=in_text, + is_in_count=True, + ) if self.display_out_count: - out_text = ( - f"{self.custom_out_text}: {line_counter.out_count}" - if self.custom_out_text is not None - else f"out: {line_counter.out_count}" - ) - self._annotate_count( - frame=frame, - center_text_anchor=text_anchor.center, - text=out_text, - is_in_count=False, - ) + for class_label, count in line_counter.class_out_count.items(): + out_text = ( + f"{self.custom_out_text}: {count} - Class {class_label}" + if self.custom_out_text is not None + else f"out: {count} - Class {class_label}" + ) + self._annotate_count( + frame=frame, + center_text_anchor=text_anchor.center, + text=out_text, + is_in_count=False, + ) + return frame