Skip to content

Commit

Permalink
initial tracking implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
vangeliq committed Dec 4, 2024
1 parent 8f7f418 commit 70add52
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ros2_ws/src/custom_interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ find_package(sensor_msgs REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/ImageInput.msg"
"msg/InferenceOutput.msg"
"srv/GetOriginalImage.srv"
"srv/GetRowPlantCount.srv"
DEPENDENCIES sensor_msgs

)
Expand Down
4 changes: 0 additions & 4 deletions ros2_ws/src/custom_interfaces/srv/GetOriginalImage.srv

This file was deleted.

3 changes: 3 additions & 0 deletions ros2_ws/src/custom_interfaces/srv/GetRowPlantCount.srv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# to make sure the image is correctly received
---
int16 plant_count
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from sensor_msgs.msg import Image
from std_msgs.msg import Header, Int8, String
from custom_interfaces.msg import InferenceOutput
from custom_interfaces.srv import GetRowPlantCount

from .scripts.utils import ModelInference
from .scripts.tracker import EuclideanDistTracker

class ExterminationNode(Node):
def __init__(self):
Expand All @@ -32,6 +34,7 @@ def __init__(self):
self.minimum_confidence = 0.5
self.boxes_present = 0
self.model = ModelInference()
self.tracker = EuclideanDistTracker()
self.bridge = CvBridge()
self.boxes_msg = Int8()
self.boxes_msg.data = 0
Expand All @@ -40,6 +43,18 @@ def __init__(self):
self.box_publisher = self.create_publisher(Int8, f'{self.camera_side}_extermination_output', 10)
self.timer = self.create_timer(self.publishing_rate, self.timer_callback)

self.get_tracker_row_count_service = self.create_service(GetRowPlantCount, 'reset_tracker', self.get_tracker_row_count_callback)


def get_tracker_row_count_callback(self,request,response):
"""
When navigation requests this service, reset the tracker count so that it knows to start a new row.
return the tracker's current row count
"""
row_count = self.tracker.reset()
response.row_count = row_count
return response

def inference_callback(self, msg):
self.get_logger().info("Received Bounding Boxes")

Expand All @@ -48,6 +63,8 @@ def inference_callback(self, msg):
bounding_boxes = self.model.postprocess(msg.confidences.data,msg.bounding_boxes.data, raw_image,msg.velocity)
final_image = self.model.draw_boxes(raw_image,bounding_boxes,velocity=msg.velocity)

self.tracker.update(bounding_boxes)

if self.use_display_node:
cv2.imshow(self.window, final_image)
cv2.waitKey(10)
Expand Down
57 changes: 57 additions & 0 deletions ros2_ws/src/python_workspace/python_workspace/scripts/tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import math

class EuclideanDistTracker:
def __init__(self):
# Store the center positions of the objects
self.center_points = {}
# Keep the count of the IDs
# each time a new object id detected, the count will increase by one
self.id_count = 0

def reset(self):
""" resets the count and also returns val of curr row
"""
self.center_points = {}
row_count = self.id_count
self.id_count = 0
return row_count


def update(self, objects_rect):
# Objects boxes and ids
objects_bbs_ids = []

# Get center point of new object
for rect in objects_rect:
x1, y1, x2, y2 = rect
cx = (x1 + x2) // 2
cy = (y1 + y2) // 2

# Find out if that object was detected already
same_object_detected = False
for id, pt in self.center_points.items():
dist = math.hypot(cx - pt[0], cy - pt[1])

if dist < 25:
self.center_points[id] = (cx, cy)
# print(self.center_points)
objects_bbs_ids.append([x1, y1, x2, y2, id])
same_object_detected = True
break

# New object is detected we assign the ID to that object
if same_object_detected is False:
self.center_points[self.id_count] = (cx, cy)
objects_bbs_ids.append([x1, y1, x2, y2, id])
self.id_count += 1

# Clean the dictionary by center points to remove IDS not used anymore
new_center_points = {}
for obj_bb_id in objects_bbs_ids:
_, _, _, _, object_id = obj_bb_id
center = self.center_points[object_id]
new_center_points[object_id] = center

# Update dictionary with IDs not used removed
self.center_points = new_center_points.copy()
return objects_bbs_ids

0 comments on commit 70add52

Please sign in to comment.