Skip to content

Commit

Permalink
fix: Async traffic light detection & ultralytics update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxJa4 committed Feb 12, 2024
1 parent 13c225c commit 9d9e4d2
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
70 changes: 45 additions & 25 deletions code/perception/src/vision_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

from time import perf_counter
from ros_compatibility.node import CompatibleNode
import ros_compatibility as roscomp
import torch
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(self, name, **kwargs):
self.device = torch.device("cuda"
if torch.cuda.is_available() else "cpu")
self.depth_images = []
self.DEBUG = False

# publish / subscribe setup
self.setup_camera_subscriptions()
Expand Down Expand Up @@ -211,34 +213,55 @@ def predict_torch(self, image):
return vision_result

def predict_ultralytics(self, image):
start = perf_counter()

cv_image = self.bridge.imgmsg_to_cv2(img_msg=image,
desired_encoding='passthrough')
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)

output = self.model(cv_image, half=True, verbose=False)
output = self.model.predict(cv_image, half=True, verbose=False)[0]

s1 = perf_counter()

if 9 in output.boxes.cls:
self.process_traffic_lights(output, cv_image, image.header)

s2 = perf_counter()

box_img = self.calc_draw_distance(output.boxes, cv_image)

now = perf_counter()
if self.DEBUG:
self.loginfo("S1: {}, S2: {}, S3: {}, Total: {}".format(
round(s1 - start, 4), round(s2 - s1, 4), round(now - s2, 4),
round(now - start, 4)))

return box_img

def calc_draw_distance(self, boxes, cv_image):
distance_output = []
c_boxes = []
c_labels = []
for r in output:
boxes = r.boxes
for box in boxes:

for box in boxes:
if len(self.depth_images) > 0:
cls = box.cls.item()
pixels = box.xyxy[0]
if len(self.depth_images) > 0:
distances = np.asarray(
[self.depth_images[i][int(pixels[1]):int(pixels[3]):1,
int(pixels[0]):int(pixels[2]):1]
for i in range(len(self.depth_images))])
non_zero_filter = distances[distances != 0]

if len(non_zero_filter) > 0:
obj_dist = np.min(non_zero_filter)
else:
obj_dist = np.inf

c_boxes.append(torch.tensor(pixels))
c_labels.append(f"Class: {cls}, Meters: {obj_dist}")
distance_output.append([cls, obj_dist])

distances = np.asarray(
[self.depth_images[i][int(pixels[1]):int(pixels[3]):1,
int(pixels[0]):int(pixels[2]):1]
for i in range(len(self.depth_images))])
non_zero_filter = distances[distances != 0]

if len(non_zero_filter) > 0:
obj_dist = np.min(non_zero_filter)
else:
obj_dist = np.inf

c_boxes.append(torch.tensor(pixels))
c_labels.append(f"Class: {cls}, Meters: {obj_dist}")
distance_output.append([cls, obj_dist])

self.distance_publisher.publish(
Float32MultiArray(data=distance_output))
Expand All @@ -247,9 +270,6 @@ def predict_ultralytics(self, image):
image_np_with_detections = torch.tensor(transposed_image,
dtype=torch.uint8)

if 9 in output[0].boxes.cls:
self.process_traffic_lights(output[0], cv_image, image.header)

c_boxes = torch.stack(c_boxes)
print(image_np_with_detections.shape, c_boxes.shape, c_labels)
box = draw_bounding_boxes(image_np_with_detections,
Expand All @@ -259,12 +279,12 @@ def predict_ultralytics(self, image):
width=3,
font_size=12)

np_box_img = np.transpose(box.detach().numpy(),
(1, 2, 0))
np_box_img = np.transpose(box.detach().numpy(), (1, 2, 0))
box_img = cv2.cvtColor(np_box_img, cv2.COLOR_BGR2RGB)

return box_img

def process_traffic_lights(self, prediction, cv_image, image_header):
async def process_traffic_lights(self, prediction, cv_image, image_header):
indices = (prediction.boxes.cls == 9).nonzero().squeeze().cpu().numpy()
indices = np.asarray([indices]) if indices.size == 1 else indices

Expand Down
2 changes: 1 addition & 1 deletion code/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ scipy==1.10.1
xmltodict==0.13.0
py-trees==2.2.3
numpy==1.23.5
ultralytics==8.1.9
ultralytics==8.1.11
scikit-learn>=0.18
pandas==2.0.3

0 comments on commit 9d9e4d2

Please sign in to comment.