From fc975422af657657a8832ea846a9d7cc2111ca22 Mon Sep 17 00:00:00 2001 From: Giorgio Pintaudi Date: Wed, 25 Dec 2024 10:06:30 +0900 Subject: [PATCH] Let the user inject callables after starting subscribers Signed-off-by: Giorgio Pintaudi --- .../launch_testing_ros/wait_for_topics.py | 39 +++++++++- launch_testing_ros/test/examples/repeater.py | 53 +++++++++++++ .../wait_for_topic_inject_callback_test.py | 75 +++++++++++++++++++ .../examples/wait_for_topic_launch_test.py | 19 +++++ 4 files changed, 183 insertions(+), 3 deletions(-) create mode 100644 launch_testing_ros/test/examples/repeater.py create mode 100644 launch_testing_ros/test/examples/wait_for_topic_inject_callback_test.py diff --git a/launch_testing_ros/launch_testing_ros/wait_for_topics.py b/launch_testing_ros/launch_testing_ros/wait_for_topics.py index 8e99423d..8b8c69ba 100644 --- a/launch_testing_ros/launch_testing_ros/wait_for_topics.py +++ b/launch_testing_ros/launch_testing_ros/wait_for_topics.py @@ -19,6 +19,8 @@ from threading import Thread import rclpy +from rclpy.event_handler import QoSSubscriptionMatchedInfo +from rclpy.event_handler import SubscriptionEventCallbacks from rclpy.executors import SingleThreadedExecutor from rclpy.node import Node @@ -50,12 +52,29 @@ def method_2(): print(wait_for_topics.topics_received()) # Should be {'topic_1', 'topic_2'} print(wait_for_topics.messages_received('topic_1')) # Should be [message_1, ...] wait_for_topics.shutdown() + + # Method3, calling a callback function before the wait. The callback function takes + # the WaitForTopics object as the first argument. Any additional arguments has + # to be passed to the wait(*args, **kwargs) method directly. + def callback_function(node, arg=""): + node.get_logger().info('Callback function called with argument: ' + arg) + + def method_3(): + topic_list = [('topic_1', String), ('topic_2', String)] + wait_for_topics = WaitForTopics(topic_list, timeout=5.0) + assert wait_for_topics.wait("Hello World!") + print('Given topics are receiving messages !') + wait_for_topics.shutdown() """ - def __init__(self, topic_tuples, timeout=5.0, messages_received_buffer_length=10): + def __init__(self, topic_tuples, timeout=5.0, messages_received_buffer_length=10, + callback=None): self.topic_tuples = topic_tuples self.timeout = timeout self.messages_received_buffer_length = messages_received_buffer_length + self.callback = callback + if self.callback is not None and not callable(self.callback): + raise TypeError('The passed callback is not callable') self.__ros_context = rclpy.Context() rclpy.init(context=self.__ros_context) self.__ros_executor = SingleThreadedExecutor(context=self.__ros_context) @@ -83,8 +102,11 @@ def _prepare_ros_node(self): ) self.__ros_executor.add_node(self.__ros_node) - def wait(self): + def wait(self, *args, **kwargs): self.__ros_node.start_subscribers(self.topic_tuples) + if self.callback: + self.callback(self.__ros_node, *args, **kwargs) + self.__ros_node._any_publisher_connected.wait() return self.__ros_node.msg_event_object.wait(self.timeout) def shutdown(self): @@ -131,6 +153,13 @@ def __init__( self.expected_topics = set() self.received_topics = set() self.received_messages_buffer = {} + self._any_publisher_connected = Event() + + def _sub_matched_event_callback(self, info: QoSSubscriptionMatchedInfo): + if info.current_count != 0: + self._any_publisher_connected.set() + else: + self._any_publisher_connected.clear() def _reset(self): self.msg_event_object.clear() @@ -149,12 +178,16 @@ def start_subscribers(self, topic_tuples): maxlen=self.messages_received_buffer_length ) # Create a subscriber + sub_event_callback = SubscriptionEventCallbacks( + matched=self._sub_matched_event_callback + ) self.subscriber_list.append( self.create_subscription( topic_type, topic_name, self.callback_template(topic_name), - 10 + 10, + event_callbacks=sub_event_callback, ) ) diff --git a/launch_testing_ros/test/examples/repeater.py b/launch_testing_ros/test/examples/repeater.py new file mode 100644 index 00000000..10009b1e --- /dev/null +++ b/launch_testing_ros/test/examples/repeater.py @@ -0,0 +1,53 @@ +# Copyright 2019 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import rclpy +from rclpy.node import Node + +from std_msgs.msg import String + + +class Repeater(Node): + + def __init__(self): + super().__init__('repeater') + self.count = 0 + self.subscription = self.create_subscription( + String, 'input', self.callback, 10 + ) + self.publisher = self.create_publisher(String, 'output', 10) + + def callback(self, input_msg): + self.get_logger().info('I heard: [%s]' % input_msg.data) + output_msg_data = input_msg.data + self.get_logger().info('Publishing: "{0}"'.format(output_msg_data)) + self.publisher.publish(String(data=output_msg_data)) + + +def main(args=None): + rclpy.init(args=args) + + node = Repeater() + + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/launch_testing_ros/test/examples/wait_for_topic_inject_callback_test.py b/launch_testing_ros/test/examples/wait_for_topic_inject_callback_test.py new file mode 100644 index 00000000..e778bc37 --- /dev/null +++ b/launch_testing_ros/test/examples/wait_for_topic_inject_callback_test.py @@ -0,0 +1,75 @@ +# Copyright 2021 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import unittest + +import launch +import launch.actions +import launch_ros.actions +import launch_testing.actions +import launch_testing.markers +from launch_testing_ros import WaitForTopics +import pytest +from std_msgs.msg import String + + +def generate_node(): + """Return node and remap the topic based on the index provided.""" + path_to_test = os.path.dirname(__file__) + return launch_ros.actions.Node( + executable=sys.executable, + arguments=[os.path.join(path_to_test, 'repeater.py')], + name='demo_node', + additional_env={'PYTHONUNBUFFERED': '1'}, + ) + + +def trigger_callback(node): + if not hasattr(node, 'my_publisher'): + node.my_publisher = node.create_publisher(String, 'input', 10) + while node.my_publisher.get_subscription_count() == 0: + time.sleep(0.1) + msg = String() + msg.data = 'Hello World' + node.my_publisher.publish(msg) + print('Published message') + + +@pytest.mark.launch_test +@launch_testing.markers.keep_alive +def generate_test_description(): + description = [generate_node(), launch_testing.actions.ReadyToTest()] + return launch.LaunchDescription(description) + + +# TODO: Test cases fail on Windows debug builds +# https://github.com/ros2/launch_ros/issues/292 +if os.name != 'nt': + + class TestFixture(unittest.TestCase): + + def test_topics_successful(self): + """All the supplied topics should be read successfully.""" + topic_list = [('output', String)] + expected_topics = {'output'} + + # Method 1 : Using the magic methods and 'with' keyword + with WaitForTopics( + topic_list, timeout=10.0, callback=trigger_callback + ) as wait_for_node_object_1: + assert wait_for_node_object_1.topics_received() == expected_topics + assert wait_for_node_object_1.topics_not_received() == set() diff --git a/launch_testing_ros/test/examples/wait_for_topic_launch_test.py b/launch_testing_ros/test/examples/wait_for_topic_launch_test.py index c32ca361..246c6fff 100644 --- a/launch_testing_ros/test/examples/wait_for_topic_launch_test.py +++ b/launch_testing_ros/test/examples/wait_for_topic_launch_test.py @@ -104,3 +104,22 @@ def test_topics_unsuccessful(self, count: int): assert wait_for_node_object.topics_received() == expected_topics assert wait_for_node_object.topics_not_received() == {'invalid_topic'} wait_for_node_object.shutdown() + + def test_callback(self, count): + topic_list = [('chatter_' + str(i), String) for i in range(count)] + expected_topics = {'chatter_' + str(i) for i in range(count)} + + # Method 3 : Using a callback function + + # Using a list to store the callback function's argument as it is mutable + is_callback_called = [False] + + def callback(node, arg): + node.get_logger().info(f'Callback function called with argument: {arg[0]}') + arg[0] = True + + wait_for_node_object = WaitForTopics(topic_list, timeout=2.0, callback=callback) + assert wait_for_node_object.wait(is_callback_called) + assert wait_for_node_object.topics_received() == expected_topics + assert wait_for_node_object.topics_not_received() == set() + assert is_callback_called[0]