From 70add528cf1700737d0912c299cacf4532fd09a6 Mon Sep 17 00:00:00 2001 From: vangeliq Date: Tue, 3 Dec 2024 18:50:24 -0800 Subject: [PATCH] initial tracking implementation --- ros2_ws/src/custom_interfaces/CMakeLists.txt | 2 +- .../srv/GetOriginalImage.srv | 4 -- .../srv/GetRowPlantCount.srv | 3 + .../python_workspace/extermination_node.py | 17 ++++++ .../python_workspace/scripts/tracker.py | 57 +++++++++++++++++++ 5 files changed, 78 insertions(+), 5 deletions(-) delete mode 100644 ros2_ws/src/custom_interfaces/srv/GetOriginalImage.srv create mode 100644 ros2_ws/src/custom_interfaces/srv/GetRowPlantCount.srv create mode 100644 ros2_ws/src/python_workspace/python_workspace/scripts/tracker.py diff --git a/ros2_ws/src/custom_interfaces/CMakeLists.txt b/ros2_ws/src/custom_interfaces/CMakeLists.txt index f75ddef..270335a 100644 --- a/ros2_ws/src/custom_interfaces/CMakeLists.txt +++ b/ros2_ws/src/custom_interfaces/CMakeLists.txt @@ -17,7 +17,7 @@ find_package(sensor_msgs REQUIRED) rosidl_generate_interfaces(${PROJECT_NAME} "msg/ImageInput.msg" "msg/InferenceOutput.msg" - "srv/GetOriginalImage.srv" + "srv/GetRowPlantCount.srv" DEPENDENCIES sensor_msgs ) diff --git a/ros2_ws/src/custom_interfaces/srv/GetOriginalImage.srv b/ros2_ws/src/custom_interfaces/srv/GetOriginalImage.srv deleted file mode 100644 index 1703e66..0000000 --- a/ros2_ws/src/custom_interfaces/srv/GetOriginalImage.srv +++ /dev/null @@ -1,4 +0,0 @@ -string image_id # to make sure the image is correctly received ---- -sensor_msgs/Image original_image -float32 velocity \ No newline at end of file diff --git a/ros2_ws/src/custom_interfaces/srv/GetRowPlantCount.srv b/ros2_ws/src/custom_interfaces/srv/GetRowPlantCount.srv new file mode 100644 index 0000000..c6b6877 --- /dev/null +++ b/ros2_ws/src/custom_interfaces/srv/GetRowPlantCount.srv @@ -0,0 +1,3 @@ + # to make sure the image is correctly received +--- +int16 plant_count \ No newline at end of file diff --git a/ros2_ws/src/python_workspace/python_workspace/extermination_node.py b/ros2_ws/src/python_workspace/python_workspace/extermination_node.py index 6fb982c..9b04b24 100644 --- a/ros2_ws/src/python_workspace/python_workspace/extermination_node.py +++ b/ros2_ws/src/python_workspace/python_workspace/extermination_node.py @@ -9,8 +9,10 @@ from sensor_msgs.msg import Image from std_msgs.msg import Header, Int8, String from custom_interfaces.msg import InferenceOutput +from custom_interfaces.srv import GetRowPlantCount from .scripts.utils import ModelInference +from .scripts.tracker import EuclideanDistTracker class ExterminationNode(Node): def __init__(self): @@ -32,6 +34,7 @@ def __init__(self): self.minimum_confidence = 0.5 self.boxes_present = 0 self.model = ModelInference() + self.tracker = EuclideanDistTracker() self.bridge = CvBridge() self.boxes_msg = Int8() self.boxes_msg.data = 0 @@ -40,6 +43,18 @@ def __init__(self): self.box_publisher = self.create_publisher(Int8, f'{self.camera_side}_extermination_output', 10) self.timer = self.create_timer(self.publishing_rate, self.timer_callback) + self.get_tracker_row_count_service = self.create_service(GetRowPlantCount, 'reset_tracker', self.get_tracker_row_count_callback) + + + def get_tracker_row_count_callback(self,request,response): + """ + When navigation requests this service, reset the tracker count so that it knows to start a new row. + return the tracker's current row count + """ + row_count = self.tracker.reset() + response.row_count = row_count + return response + def inference_callback(self, msg): self.get_logger().info("Received Bounding Boxes") @@ -48,6 +63,8 @@ def inference_callback(self, msg): bounding_boxes = self.model.postprocess(msg.confidences.data,msg.bounding_boxes.data, raw_image,msg.velocity) final_image = self.model.draw_boxes(raw_image,bounding_boxes,velocity=msg.velocity) + self.tracker.update(bounding_boxes) + if self.use_display_node: cv2.imshow(self.window, final_image) cv2.waitKey(10) diff --git a/ros2_ws/src/python_workspace/python_workspace/scripts/tracker.py b/ros2_ws/src/python_workspace/python_workspace/scripts/tracker.py new file mode 100644 index 0000000..df23f80 --- /dev/null +++ b/ros2_ws/src/python_workspace/python_workspace/scripts/tracker.py @@ -0,0 +1,57 @@ +import math + +class EuclideanDistTracker: + def __init__(self): + # Store the center positions of the objects + self.center_points = {} + # Keep the count of the IDs + # each time a new object id detected, the count will increase by one + self.id_count = 0 + + def reset(self): + """ resets the count and also returns val of curr row + """ + self.center_points = {} + row_count = self.id_count + self.id_count = 0 + return row_count + + + def update(self, objects_rect): + # Objects boxes and ids + objects_bbs_ids = [] + + # Get center point of new object + for rect in objects_rect: + x1, y1, x2, y2 = rect + cx = (x1 + x2) // 2 + cy = (y1 + y2) // 2 + + # Find out if that object was detected already + same_object_detected = False + for id, pt in self.center_points.items(): + dist = math.hypot(cx - pt[0], cy - pt[1]) + + if dist < 25: + self.center_points[id] = (cx, cy) + # print(self.center_points) + objects_bbs_ids.append([x1, y1, x2, y2, id]) + same_object_detected = True + break + + # New object is detected we assign the ID to that object + if same_object_detected is False: + self.center_points[self.id_count] = (cx, cy) + objects_bbs_ids.append([x1, y1, x2, y2, id]) + self.id_count += 1 + + # Clean the dictionary by center points to remove IDS not used anymore + new_center_points = {} + for obj_bb_id in objects_bbs_ids: + _, _, _, _, object_id = obj_bb_id + center = self.center_points[object_id] + new_center_points[object_id] = center + + # Update dictionary with IDs not used removed + self.center_points = new_center_points.copy() + return objects_bbs_ids \ No newline at end of file