diff --git a/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_inference.py b/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_inference.py index ada59c5f..cb3d9e5c 100644 --- a/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_inference.py +++ b/code/perception/src/traffic_light_detection/src/traffic_light_detection/traffic_light_inference.py @@ -64,7 +64,7 @@ def __call__(self, img): else: out = self.model(img) _, prediction = torch.max(out.data, 1) - return prediction.item() + return (prediction.item(), out.data.cpu().numpy()) # main function for testing purposes diff --git a/code/perception/src/traffic_light_node.py b/code/perception/src/traffic_light_node.py index 6f67b5b1..00034bab 100755 --- a/code/perception/src/traffic_light_node.py +++ b/code/perception/src/traffic_light_node.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +from datetime import datetime +import threading +from time import sleep from ros_compatibility.node import CompatibleNode import ros_compatibility as roscomp from rospy.numpy_msg import numpy_msg @@ -18,6 +21,9 @@ def __init__(self, name, **kwargs): self.role_name = self.get_param("role_name", "hero") self.side = self.get_param("side", "Center") self.classifier = TrafficLightInference(self.get_param("model", "")) + self.last_info_time: datetime = None + self.last_state = None + threading.Thread(target=self.auto_invalidate_state).start() # publish / subscribe setup self.setup_camera_subscriptions() @@ -38,14 +44,38 @@ def setup_traffic_light_publishers(self): qos_profile=1 ) + def auto_invalidate_state(self): + while True: + sleep(1) + + if self.last_info_time is None: + continue + + if (datetime.now() - self.last_info_time).total_seconds() >= 2: + msg = TrafficLightState() + msg.state = 0 + self.traffic_light_publisher.publish(msg) + self.last_info_time = None + def handle_camera_image(self, image): - result = self.classifier(self.bridge.imgmsg_to_cv2(image)) + result, data = self.classifier(self.bridge.imgmsg_to_cv2(image)) + + if data[0][0] > 1e-15 and data[0][3] > 1e-15 or \ + data[0][0] > 1e-10 or data[0][3] > 1e-10: + return # too uncertain, may not be a traffic light - # 1: Green, 2: Red, 4: Yellow, 0: Unknown - msg = TrafficLightState() - msg.state = result if result in [1, 2, 4] else 0 + state = result if result in [1, 2, 4] else 0 + if self.last_state == state: + # 1: Green, 2: Red, 4: Yellow, 0: Unknown + msg = TrafficLightState() + msg.state = state + self.traffic_light_publisher.publish(msg) + else: + self.last_state = state - self.traffic_light_publisher.publish(msg) + # Automatically invalidates state (state=0) in auto_invalidate_state() + if state != 0: + self.last_info_time = datetime.now() def run(self): self.spin() diff --git a/code/perception/src/vision_node.py b/code/perception/src/vision_node.py index e9681a4c..938ee669 100755 --- a/code/perception/src/vision_node.py +++ b/code/perception/src/vision_node.py @@ -191,14 +191,19 @@ def process_traffic_lights(self, prediction, cv_image, image_header): indices = (prediction.boxes.cls == 9).nonzero().squeeze().cpu().numpy() indices = np.asarray([indices]) if indices.size == 1 else indices - min_x = 550 - max_x = 700 - min_prob = 0.35 + max_y = 360 # middle of image + min_prob = 0.30 for index in indices: box = prediction.boxes.cpu().data.numpy()[index] - if box[0] < min_x or box[2] > max_x or box[4] < min_prob: + if box[4] < min_prob: + continue + + if (box[2] - box[0]) * 1.5 > box[3] - box[1]: + continue # ignore horizontal boxes + + if box[1] > max_y: continue box = box[0:4].astype(int) diff --git a/doc/06_perception/13_traffic_light_detection.md b/doc/06_perception/13_traffic_light_detection.md index c627e3cf..68a061c6 100644 --- a/doc/06_perception/13_traffic_light_detection.md +++ b/doc/06_perception/13_traffic_light_detection.md @@ -35,3 +35,22 @@ This method sets up a publisher for the traffic light state. It publishes to the This method is called whenever a new image message is received. It performs traffic light detection by using `traffic_light_inference.py` on the image and publishes the result. The result is a `TrafficLightState` message where the state is set to the detected traffic light state (1 for green, 2 for red, 4 for yellow, 0 for unknown). + +## Filtering of images + +### Vision Node + +Objects, which are detected as traffic light by the RTDETR-X model (or others), must fulfill the following criterias to be published: + +- At least a 30% (0.30) certainty/probablity of the classification model +- More than 1.5x as tall (height) as it is wide (width) +- Above 360px (upper half of the 1280x720 image) + +### Traffic Light Node + +Objects, which are published by the Vision Node, are further filtered by the following criterias: + +- Classification probabilities of "Unknown" and "Side" are either both below 1e-10 or one of both are below 1e-15 +- "Side" is treated as "Unknown" +- Valid states (Red, Green, Yellow) must be present at least twice in a row to be actually published +- A state decays (state=0; "Unknown") after 2 seconds if there is no new info in the meantime