diff --git a/obelisk/python/obelisk_py/core/control.py b/obelisk/python/obelisk_py/core/control.py index 3c36221c..7d04db5a 100644 --- a/obelisk/python/obelisk_py/core/control.py +++ b/obelisk/python/obelisk_py/core/control.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod +from typing import Type from obelisk_py.core.node import ObeliskNode -from obelisk_py.core.obelisk_typing import ObeliskControlMsg, ObeliskEstimatorMsg class ObeliskController(ABC, ObeliskNode): @@ -17,7 +17,7 @@ class ObeliskController(ABC, ObeliskNode): the control message should be of type ObeliskControlMsg to be compatible with the Obelisk ecosystem. """ - def __init__(self, node_name: str) -> None: + def __init__(self, node_name: str, ctrl_msg_type: Type, est_msg_type: Type) -> None: """Initialize the Obelisk controller.""" super().__init__(node_name) self.register_obk_timer( @@ -27,18 +27,18 @@ def __init__(self, node_name: str) -> None: ) self.register_obk_publisher( "pub_ctrl_setting", + msg_type=ctrl_msg_type, key="pub_ctrl", - msg_type=None, # generic, specified in config file ) self.register_obk_subscription( "sub_est_setting", self.update_x_hat, + msg_type=est_msg_type, key="sub_est", - msg_type=None, # generic, specified in config file ) @abstractmethod - def update_x_hat(self, x_hat_msg: ObeliskEstimatorMsg) -> None: + def update_x_hat(self, x_hat_msg: Type) -> None: """Update the state estimate. Parameters: @@ -46,7 +46,7 @@ def update_x_hat(self, x_hat_msg: ObeliskEstimatorMsg) -> None: """ @abstractmethod - def compute_control(self) -> ObeliskControlMsg: + def compute_control(self) -> Type: """Compute the control signal. This is the control timer callback and is expected to call 'publisher_ctrl' internally. Note that the control diff --git a/obelisk/python/obelisk_py/core/estimation.py b/obelisk/python/obelisk_py/core/estimation.py index 101331ab..9d34f4ec 100644 --- a/obelisk/python/obelisk_py/core/estimation.py +++ b/obelisk/python/obelisk_py/core/estimation.py @@ -1,12 +1,7 @@ from abc import ABC, abstractmethod -from typing import Union - -import obelisk_sensor_msgs.msg as osm -from rclpy.lifecycle.node import LifecycleState, TransitionCallbackReturn +from typing import Type from obelisk_py.core.node import ObeliskNode -from obelisk_py.core.obelisk_typing import ObeliskEstimatorMsg, ObeliskSensorMsg -from obelisk_py.core.utils.internal import get_classes_in_module class ObeliskEstimator(ABC, ObeliskNode): @@ -32,13 +27,12 @@ def update_X(self, X_msg: ObeliskSensorMsg) -> None: ``` """ - def __init__(self, node_name: str) -> None: + def __init__(self, node_name: str, est_msg_type: Type) -> None: """Initialize the Obelisk estimator. [NOTE] In derived classes, you should declare settings for sensor subscribers. """ super().__init__(node_name) - self._has_sensor_subscriber = False self.register_obk_timer( "timer_est_setting", self.compute_state_estimate, @@ -46,26 +40,12 @@ def __init__(self, node_name: str) -> None: ) self.register_obk_publisher( "pub_est_setting", + msg_type=est_msg_type, key="pub_est", - msg_type=None, # generic, specified in config file ) - def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn: - """Configure the estimator.""" - super().on_configure(state) - - # ensure there is at least one sensor subscriber - for sub_dict in self._obk_sub_settings: - msg_type = sub_dict["msg_type"] - if msg_type in get_classes_in_module(osm): - self._has_sensor_subscriber = True - break - assert self._has_sensor_subscriber, "At least one sensor subscriber is required in an ObeliskEstimator!" - - return TransitionCallbackReturn.SUCCESS - @abstractmethod - def compute_state_estimate(self) -> Union[ObeliskEstimatorMsg, ObeliskSensorMsg]: + def compute_state_estimate(self) -> Type: """Compute the state estimate. This is the state estimate timer callback and is expected to call 'publisher_est' internally. Note that the diff --git a/obelisk/python/obelisk_py/core/node.py b/obelisk/python/obelisk_py/core/node.py index db095afe..dfa76f33 100644 --- a/obelisk/python/obelisk_py/core/node.py +++ b/obelisk/python/obelisk_py/core/node.py @@ -1,25 +1,15 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union import rclpy -from rclpy._rclpy_pybind11 import RCLError from rclpy.callback_groups import CallbackGroup, MutuallyExclusiveCallbackGroup, ReentrantCallbackGroup from rclpy.lifecycle import LifecycleNode from rclpy.lifecycle.node import LifecycleState, TransitionCallbackReturn -from rclpy.publisher import Publisher -from rclpy.qos import QoSProfile -from rclpy.qos_event import PublisherEventCallbacks, SubscriptionEventCallbacks -from rclpy.qos_overriding_options import QoSOverridingOptions -from rclpy.subscription import Subscription - -from obelisk_py.core.exceptions import ObeliskMsgError -from obelisk_py.core.obelisk_typing import ObeliskAllowedMsg, ObeliskMsg, is_in_bound -from obelisk_py.core.utils.internal import check_and_get_obelisk_msg_type MsgType = TypeVar("MsgType") # hack to denote any message type class ObeliskNode(LifecycleNode): - """A lifecycle node whose publishers and subscribers can only publish and subscribe to Obelisk messages. + """A lifecycle node that automatically performs component registration according to Obelisk standards. By convention, the initialization function should only declare ROS parameters and define stateful quantities. Some guidelines for the on_configure, on_activate, and on_deactivate callbacks are provided below. @@ -90,17 +80,17 @@ def t(self) -> float: def register_obk_publisher( self, ros_parameter: str, + msg_type: Type, key: Optional[str] = None, - msg_type: Optional[Type] = None, default_config_str: Optional[str] = None, ) -> None: """Register a publisher using a configuration string. Parameters: ros_parameter: The ROS parameter that contains the configuration string. - key: The key for the publisher. If None, we look for the key as a field in the config string later. msg_type: The message type. If passed, we just use this directly. Otherwise, you can configure this at runtime by passing a msg_type field in the config string that corresponds to an Obelisk message type. + key: The key for the publisher. If None, we look for the key as a field in the config string later. default_config_str: The default configuration string. If None, the parameter must be initialized. Raises: @@ -115,9 +105,9 @@ def register_obk_publisher( def register_obk_subscription( self, ros_parameter: str, - callback: Callable[[Union[ObeliskAllowedMsg, MsgType]], None], + callback: Callable[[MsgType], None], + msg_type: Type, key: Optional[str] = None, - msg_type: Optional[Type] = None, default_config_str: Optional[str] = None, ) -> None: """Register a subscription using a configuration string. @@ -125,9 +115,9 @@ def register_obk_subscription( Parameters: ros_parameter: The ROS parameter that contains the configuration string. callback: The callback function. - key: The key for the subscription. If None, we look for the key as a field in the config string later. msg_type: The message type. If passed, we just use this directly. Otherwise, you can configure this at runtime by passing a msg_type field in the config string that corresponds to an Obelisk message type. + key: The key for the subscription. If None, we look for the key as a field in the config string later. default_config_str: The default configuration string. If None, the parameter must be initialized. Raises: @@ -258,21 +248,6 @@ def _get_key_from_config_dict(config_dict: Dict) -> str: assert isinstance(config_dict["key"], str), "The 'key' field must be a string!" return config_dict["key"] - @staticmethod - def _get_msg_type_from_config_dict(config_dict: Dict) -> Type: - """Get the message type from a configuration dictionary.""" - assert config_dict.get("msg_type") is not None, "No message type supplied!" - assert isinstance(config_dict["msg_type"], str), "The 'msg_type' field must be a string!" - - if "non_obelisk" in config_dict: - assert isinstance(config_dict["non_obelisk"], str), "The 'non_obelisk' field must be a string!" - assert config_dict["non_obelisk"].lower() != "true", "non_obelisk=True but no message type supplied!" - msg_type = check_and_get_obelisk_msg_type(config_dict["msg_type"], ObeliskMsg) - else: - msg_type = check_and_get_obelisk_msg_type(config_dict["msg_type"], ObeliskAllowedMsg) - - return msg_type - @staticmethod def _create_callback_groups_from_config_str(config_str: str) -> Dict[str, CallbackGroup]: """Create callback groups from a configuration string. @@ -335,9 +310,9 @@ def _get_callback_group_from_config_dict(self, config_dict: Dict) -> Optional[Ca def _create_publisher_from_config_str( self, config_str: str, + msg_type: Type, key: Optional[str] = None, - msg_type: Optional[Type] = None, - ) -> Tuple[str, Type]: + ) -> str: """Create a publisher from a configuration string and adds it to the publisher dictionary. Also creates a key attribute for the publisher. For example, if the key is "pub_ctrl", then the publisher can be @@ -349,13 +324,12 @@ def _create_publisher_from_config_str( Parameters: config_str: The configuration string. - key: The key for the publisher. If None, we look for the key as a field in the config string. msg_type: The message type. If passed, we just use this directly. Otherwise, you can configure this at runtime by passing a msg_type field in the config string that corresponds to an Obelisk message type. + key: The key for the publisher. If None, we look for the key as a field in the config string. Returns: key: The key for the publisher. - msg_type: The message type for the publisher. Raises: AssertionError: If the configuration string is invalid. @@ -364,7 +338,7 @@ def _create_publisher_from_config_str( # parse and check the configuration string field_names, value_names = ObeliskNode._parse_config_str(config_str) required_field_names = ["topic"] - optional_field_names = ["key", "msg_type", "history_depth", "callback_group", "non_obelisk"] + optional_field_names = ["key", "msg_type", "history_depth", "callback_group"] ObeliskNode._check_fields(field_names, required_field_names, optional_field_names) config_dict = dict(zip(field_names, value_names)) @@ -377,43 +351,30 @@ def _create_publisher_from_config_str( f"string. Using the value 'key'={key}, as hardcoded specifications take precedence!" ) - # parse the message type - if msg_type is None: - msg_type = ObeliskNode._get_msg_type_from_config_dict(config_dict) - elif "msg_type" in field_names: - self.get_logger().warn( - f"'msg_type'={msg_type} registered for this publisher, and 'msg_type'={config_dict['msg_type']} " - f"specified in the config string. Using the value 'msg_type'={msg_type}, as hardcoded specifications " - "take precedence!" - ) - # set the callback group callback_group = self._get_callback_group_from_config_dict(config_dict) # run type assertions and create the publisher history_depth = config_dict.get("history_depth", 10) - non_obelisk_field = config_dict.get("non_obelisk", "False") assert isinstance(config_dict["topic"], str), "The 'topic' field must be a string!" assert isinstance(history_depth, int), "The 'history_depth' field must be an int!" - assert isinstance(non_obelisk_field, str), "The 'non_obelisk' field must be a str!" self.obk_publishers[key] = self.create_publisher( msg_type=msg_type, topic=config_dict["topic"], qos_profile=history_depth, callback_group=callback_group, - non_obelisk=non_obelisk_field.lower() == "true", ) assert not hasattr(self, key), f"Attribute {key} already exists in the node!" setattr(self, key + "_key", self.obk_publishers[key]) # create key attribute for publisher - return key, msg_type + return key def _create_subscription_from_config_str( self, config_str: str, - callback: Callable[[Union[ObeliskAllowedMsg, MsgType]], None], + msg_type: Type, + callback: Callable[[MsgType], None], key: Optional[str] = None, - msg_type: Optional[Type] = None, - ) -> Tuple[str, Type]: + ) -> str: """Create a subscription from a configuration string and adds it to the subscription dictionary. Also creates a key attribute for the subscription. For example, if the key is "sub_ctrl", then the subscription @@ -425,14 +386,13 @@ def _create_subscription_from_config_str( Parameters: config_str: The configuration string. - callback: The callback function. - key: The key for the subscription. If None, we look for the key as a field in the config string. msg_type: The message type. If passed, we just use this directly. Otherwise, you can configure this at runtime by passing a msg_type field in the config string that corresponds to an Obelisk message type. + callback: The callback function. + key: The key for the subscription. If None, we look for the key as a field in the config string. Returns: key: The key for the subscription. - msg_type: The message type for the subscription. Raises: AssertionError: If the configuration string is invalid. @@ -441,7 +401,7 @@ def _create_subscription_from_config_str( # parse and check the configuration string field_names, value_names = ObeliskNode._parse_config_str(config_str) required_field_names = ["topic"] - optional_field_names = ["key", "msg_type", "history_depth", "callback_group", "non_obelisk"] + optional_field_names = ["key", "msg_type", "history_depth", "callback_group"] ObeliskNode._check_fields(field_names, required_field_names, optional_field_names) config_dict = dict(zip(field_names, value_names)) @@ -454,25 +414,13 @@ def _create_subscription_from_config_str( f"string. Using the value 'key'={key}, as hardcoded specifications take precedence!" ) - # parse the message type - if msg_type is None: - msg_type = ObeliskNode._get_msg_type_from_config_dict(config_dict) - elif "msg_type" in field_names: - self.get_logger().warn( - f"'msg_type'={msg_type} registered for this subscription, and 'msg_type'={config_dict['msg_type']} " - f"specified in the config string. Using the value 'msg_type'={msg_type}, as hardcoded specifications " - "take precedence!" - ) - # set the callback group callback_group = self._get_callback_group_from_config_dict(config_dict) # run type assertions and return the subscription history_depth = config_dict.get("history_depth", 10) - non_obelisk_field = config_dict.get("non_obelisk", "False") assert isinstance(config_dict["topic"], str), "The 'topic' field must be a string!" assert isinstance(history_depth, int), "The 'history_depth' field must be an int!" - assert isinstance(non_obelisk_field, str), "The 'non_obelisk' field must be a str!" self.obk_subscriptions[key] = self.create_subscription( msg_type=msg_type, @@ -480,11 +428,10 @@ def _create_subscription_from_config_str( callback=callback, # type: ignore qos_profile=history_depth, callback_group=callback_group, - non_obelisk=non_obelisk_field.lower() == "true", ) assert not hasattr(self, key), f"Attribute {key} already exists in the node!" setattr(self, key + "_key", self.obk_subscriptions[key]) # create key attribute for subscription - return key, msg_type + return key def _create_timer_from_config_str( self, @@ -544,110 +491,6 @@ def _create_timer_from_config_str( setattr(self, key + "_key", self.obk_timers[key]) # create key attribute for timer return key - # ################ # - # PUB/SUB CREATION # - # ################ # - - def create_publisher( - self, - msg_type: ObeliskAllowedMsg, - topic: str, - qos_profile: Union[QoSProfile, int], - *, - callback_group: Optional[CallbackGroup] = None, - event_callbacks: Optional[PublisherEventCallbacks] = None, - qos_overriding_options: Optional[QoSOverridingOptions] = None, - publisher_class: Type[Publisher] = Publisher, - non_obelisk: bool = False, - ) -> Publisher: - """Create a new publisher that can only publish Obelisk messages. - - See: github.com/ros2/rclpy/blob/e4042398d6f0403df2fafdadbdfc90b6f6678d13/rclpy/rclpy/node.py#L1242 - - Parameters: - non_obelisk: If True, the publisher can publish non-Obelisk messages. Default is False. - - Raises: - ObeliskMsgError: If the message type is not an Obelisk message. - """ - if not non_obelisk and not is_in_bound(msg_type, ObeliskAllowedMsg): - if get_origin(ObeliskAllowedMsg.__bound__) is Union: - valid_msg_types = [a.__name__ for a in get_args(ObeliskAllowedMsg.__bound__)] - else: - valid_msg_types = [ObeliskAllowedMsg.__name__] - raise ObeliskMsgError( - f"msg_type must be one of {valid_msg_types}. " - "Got {msg_type.__name__}. If you are sure that the message type is correct, " - "set non_obelisk=True. Note that this may cause certain API incompatibilies." - ) - - try: - return super().create_publisher( - msg_type=msg_type, - topic=topic, - qos_profile=qos_profile, - callback_group=callback_group, - event_callbacks=event_callbacks, - qos_overriding_options=qos_overriding_options, - publisher_class=publisher_class, - ) - except RCLError as e: - self.get_logger().error( - "Failed to create publisher: verify that you haven't declared the same topic twice!" - ) - raise e - - def create_subscription( - self, - msg_type: ObeliskAllowedMsg, - topic: str, - callback: Callable[[ObeliskAllowedMsg], None], - qos_profile: Union[QoSProfile, int], - *, - callback_group: Optional[CallbackGroup] = None, - event_callbacks: Optional[SubscriptionEventCallbacks] = None, - qos_overriding_options: Optional[QoSOverridingOptions] = None, - raw: bool = False, - non_obelisk: bool = False, - ) -> Subscription: - """Create a new subscription that can only subscribe to Obelisk messages. - - See: github.com/ros2/rclpy/blob/e4042398d6f0403df2fafdadbdfc90b6f6678d13/rclpy/rclpy/node.py#L1316 - - Parameters: - non_obelisk: If True, the subscriber can receive non-Obelisk messages. Default is False. - - Raises: - ObeliskMsgError: If the message type is not an Obelisk message. - """ - if not non_obelisk and not is_in_bound(msg_type, ObeliskAllowedMsg): - if get_origin(ObeliskAllowedMsg.__bound__) is Union: - valid_msg_types = [a.__name__ for a in get_args(ObeliskAllowedMsg.__bound__)] - else: - valid_msg_types = [ObeliskAllowedMsg.__name__] - raise ObeliskMsgError( - f"msg_type must be one of {valid_msg_types}. " - "Got {msg_type.__name__}. If you are sure that the message type is correct, " - "set non_obelisk=True. Note that this may cause certain API incompatibilies." - ) - - try: - return super().create_subscription( - msg_type=msg_type, - topic=topic, - callback=callback, - qos_profile=qos_profile, - callback_group=callback_group, - event_callbacks=event_callbacks, - qos_overriding_options=qos_overriding_options, - raw=raw, - ) - except RCLError as e: - self.get_logger().error( - "Failed to create subscription: verify that you haven't declared the same topic twice!" - ) - raise e - # ################### # # LIFECYCLE CALLBACKS # # ################### # @@ -678,11 +521,8 @@ def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn: if pub_config_str == "": self.get_logger().warn(f"Publisher {key} has no configuration string!") continue - final_key, final_msg_type = self._create_publisher_from_config_str( - pub_config_str, key=key, msg_type=msg_type - ) + final_key = self._create_publisher_from_config_str(pub_config_str, key=key, msg_type=msg_type) pub_dict["key"] = final_key # if no key passed, use value from config file - pub_dict["msg_type"] = final_msg_type # if no msg_type passed, use value from config file for sub_dict in self._obk_sub_settings: key = sub_dict["key"] @@ -694,11 +534,10 @@ def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn: if sub_config_str == "": self.get_logger().warn(f"Subscription {key} has no configuration string!") continue - final_key, final_msg_type = self._create_subscription_from_config_str( + final_key = self._create_subscription_from_config_str( sub_config_str, callback=callback, key=key, msg_type=msg_type ) sub_dict["key"] = final_key # if no key passed, use value from config file - sub_dict["msg_type"] = final_msg_type # if no msg_type passed, use value from config file for timer_dict in self._obk_timer_settings: key = timer_dict["key"] diff --git a/obelisk/python/obelisk_py/core/robot.py b/obelisk/python/obelisk_py/core/robot.py index 71c7422d..80b55b42 100644 --- a/obelisk/python/obelisk_py/core/robot.py +++ b/obelisk/python/obelisk_py/core/robot.py @@ -1,12 +1,11 @@ import multiprocessing from abc import ABC, abstractmethod -from typing import List +from typing import List, Type import obelisk_sensor_msgs.msg as osm from rclpy.lifecycle.node import LifecycleState, TransitionCallbackReturn from obelisk_py.core.node import ObeliskNode -from obelisk_py.core.obelisk_typing import ObeliskControlMsg class ObeliskRobot(ABC, ObeliskNode): @@ -17,18 +16,18 @@ class ObeliskRobot(ABC, ObeliskNode): of a real system. """ - def __init__(self, node_name: str) -> None: + def __init__(self, node_name: str, ctrl_msg_type: Type) -> None: """Initialize the Obelisk robot.""" super().__init__(node_name) self.register_obk_subscription( "sub_ctrl_setting", self.apply_control, + ctrl_msg_type, key="sub_ctrl", - msg_type=None, # generic, specified in config file ) @abstractmethod - def apply_control(self, control_msg: ObeliskControlMsg) -> None: + def apply_control(self, control_msg: Type) -> None: """Apply the control message to the robot. Code interfacing with the hardware should be implemented here. @@ -49,9 +48,9 @@ class ObeliskSimRobot(ObeliskRobot): preventing the end user from implementing their own simulator of choice or us from implementing other simulators. """ - def __init__(self, node_name: str) -> None: + def __init__(self, node_name: str, ctrl_msg_type: Type) -> None: """Initialize the Obelisk sim robot.""" - super().__init__(node_name) + super().__init__(node_name, ctrl_msg_type) self.register_obk_timer( "timer_true_sim_state_setting", self.publish_true_sim_state, @@ -60,8 +59,8 @@ def __init__(self, node_name: str) -> None: ) self.register_obk_publisher( "pub_true_sim_state_setting", + osm.TrueSimState, key="pub_true_sim_state", - msg_type=osm.TrueSimState, default_config_str="", ) self.shared_ctrl = None diff --git a/obelisk/python/obelisk_py/core/sim/mujoco.py b/obelisk/python/obelisk_py/core/sim/mujoco.py index 0849f778..497faaa5 100644 --- a/obelisk/python/obelisk_py/core/sim/mujoco.py +++ b/obelisk/python/obelisk_py/core/sim/mujoco.py @@ -10,24 +10,25 @@ import rclpy from ament_index_python.packages import get_package_share_directory from mujoco import MjData, MjModel, mj_forward, mj_name2id, mj_step, mju_copy # type: ignore +from obelisk_control_msgs.msg import PositionSetpoint from rclpy.callback_groups import ReentrantCallbackGroup from rclpy.lifecycle import LifecycleState, TransitionCallbackReturn from rclpy.publisher import Publisher -from obelisk_py.core.obelisk_typing import ObeliskControlMsg, ObeliskSensorMsg, is_in_bound +from obelisk_py.core.obelisk_typing import ObeliskSensorMsg, is_in_bound from obelisk_py.core.robot import ObeliskSimRobot class ObeliskMujocoRobot(ObeliskSimRobot): """Simulator that runs Mujoco.""" - def __init__(self, node_name: str = "obelisk_mujoco_robot") -> None: + def __init__(self, node_name: str = "obelisk_mujoco_robot", ctrl_msg_type: Type = PositionSetpoint) -> None: """Initialize the mujoco simulator.""" - super().__init__(node_name) + super().__init__(node_name, ctrl_msg_type) self.declare_parameter("mujoco_setting", rclpy.Parameter.Type.STRING) self.declare_parameter("ic_keyframe", "ic") - def _get_msg_type_from_string(self, msg_type_str: str) -> Type[ObeliskSensorMsg]: + def _get_msg_type_from_string(self, msg_type_str: str) -> Type: """Get the message type from a string. Parameters: @@ -61,7 +62,7 @@ def _get_time_from_sim(self) -> Tuple[float, float]: def _create_timer_callback_from_msg_type( self, - msg_type: ObeliskSensorMsg, + msg_type: Type, mj_sensor_names: List[str], obk_sensor_fields: List[str], pub_sensor: Publisher, @@ -408,7 +409,7 @@ def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn: # no return TransitionCallbackReturn.SUCCESS - def apply_control(self, control_msg: ObeliskControlMsg) -> None: + def apply_control(self, control_msg: Type) -> None: """Apply the control message. We assume that the control message is a vector of control inputs and is fully compatible with the data.ctrl diff --git a/obelisk/python/obelisk_py/core/utils/ros.py b/obelisk/python/obelisk_py/core/utils/ros.py index f109baef..a115f4cb 100644 --- a/obelisk/python/obelisk_py/core/utils/ros.py +++ b/obelisk/python/obelisk_py/core/utils/ros.py @@ -10,6 +10,7 @@ def spin_obelisk( args: Optional[List], node_type: Type[ObeliskNode], executor_type: Union[Type[SingleThreadedExecutor], Type[MultiThreadedExecutor]], + node_kwargs: Optional[dict] = None, ) -> None: """Spin an Obelisk node. @@ -17,9 +18,10 @@ def spin_obelisk( args: Command-line arguments. node_type: Obelisk node type to spin. executor_type: Executor type to use. + node_kwargs: Keyword arguments to pass to the node """ rclpy.init(args=args) - node = node_type(node_name="obelisk_node") + node = node_type(node_name="obelisk_node", **(node_kwargs or {})) executor = executor_type() executor.add_node(node) try: diff --git a/obelisk/python/obelisk_py/zoo/control/example/example_position_setpoint_controller.py b/obelisk/python/obelisk_py/zoo/control/example/example_position_setpoint_controller.py index 2943ec91..365b3e3e 100644 --- a/obelisk/python/obelisk_py/zoo/control/example/example_position_setpoint_controller.py +++ b/obelisk/python/obelisk_py/zoo/control/example/example_position_setpoint_controller.py @@ -1,9 +1,12 @@ +from typing import Type + import numpy as np from obelisk_control_msgs.msg import PositionSetpoint +from obelisk_estimator_msgs.msg import EstimatedState from rclpy.lifecycle import LifecycleState, TransitionCallbackReturn from obelisk_py.core.control import ObeliskController -from obelisk_py.core.obelisk_typing import ObeliskControlMsg, ObeliskEstimatorMsg, is_in_bound +from obelisk_py.core.obelisk_typing import ObeliskControlMsg, is_in_bound class ExamplePositionSetpointController(ObeliskController): @@ -11,7 +14,7 @@ class ExamplePositionSetpointController(ObeliskController): def __init__(self, node_name: str = "example_position_setpoint_controller") -> None: """Initialize the example position setpoint controller.""" - super().__init__(node_name) + super().__init__(node_name, PositionSetpoint, EstimatedState) self.declare_parameter("test_param", "default_value") self.get_logger().info(f"test_param: {self.get_parameter('test_param').get_parameter_value().string_value}") @@ -21,7 +24,7 @@ def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn: self.joint_pos = None return TransitionCallbackReturn.SUCCESS - def update_x_hat(self, x_hat_msg: ObeliskEstimatorMsg) -> None: + def update_x_hat(self, x_hat_msg: Type) -> None: """Update the state estimate. Parameters: @@ -29,7 +32,7 @@ def update_x_hat(self, x_hat_msg: ObeliskEstimatorMsg) -> None: """ pass # do nothing - def compute_control(self) -> ObeliskControlMsg: + def compute_control(self) -> Type: """Compute the control signal for the dummy 2-link robot. Returns: diff --git a/obelisk/python/obelisk_py/zoo/control/example/leap_example_pos_setpoint_controller.py b/obelisk/python/obelisk_py/zoo/control/example/leap_example_pos_setpoint_controller.py index 460fa2c4..ba9bb6d8 100644 --- a/obelisk/python/obelisk_py/zoo/control/example/leap_example_pos_setpoint_controller.py +++ b/obelisk/python/obelisk_py/zoo/control/example/leap_example_pos_setpoint_controller.py @@ -1,9 +1,12 @@ +from typing import Type + import numpy as np from obelisk_control_msgs.msg import PositionSetpoint +from obelisk_estimator_msgs.msg import EstimatedState from rclpy.lifecycle import LifecycleState, TransitionCallbackReturn from obelisk_py.core.control import ObeliskController -from obelisk_py.core.obelisk_typing import ObeliskControlMsg, ObeliskEstimatorMsg, is_in_bound +from obelisk_py.core.obelisk_typing import ObeliskControlMsg, is_in_bound class LeapExamplePositionSetpointController(ObeliskController): @@ -11,7 +14,7 @@ class LeapExamplePositionSetpointController(ObeliskController): def __init__(self, node_name: str = "leap_example_position_setpoint_controller") -> None: """Initialize the example position setpoint controller.""" - super().__init__(node_name) + super().__init__(node_name, PositionSetpoint, EstimatedState) def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn: """Configure the controller.""" @@ -19,7 +22,7 @@ def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn: self.joint_pos = None return TransitionCallbackReturn.SUCCESS - def update_x_hat(self, x_hat_msg: ObeliskEstimatorMsg) -> None: + def update_x_hat(self, x_hat_msg: Type) -> None: """Update the state estimate. Parameters: @@ -27,7 +30,7 @@ def update_x_hat(self, x_hat_msg: ObeliskEstimatorMsg) -> None: """ pass # do nothing - def compute_control(self) -> ObeliskControlMsg: + def compute_control(self) -> Type: """Compute the control signal for the LEAP hand. Returns: diff --git a/obelisk/python/obelisk_py/zoo/estimation/jointencoders_passthrough_estimator.py b/obelisk/python/obelisk_py/zoo/estimation/jointencoders_passthrough_estimator.py index 26df3ec4..ebc11705 100644 --- a/obelisk/python/obelisk_py/zoo/estimation/jointencoders_passthrough_estimator.py +++ b/obelisk/python/obelisk_py/zoo/estimation/jointencoders_passthrough_estimator.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any from obelisk_estimator_msgs.msg import EstimatedState from obelisk_sensor_msgs.msg import ObkJointEncoders @@ -12,12 +12,12 @@ class JointEncodersPassthroughEstimator(ObeliskEstimator): def __init__(self, node_name: str = "joint_encoders_passthrough_estimator") -> None: """Initialize the joint encoders passthrough estimator.""" - super().__init__(node_name) + super().__init__(node_name, EstimatedState) self.register_obk_subscription( "sub_sensor_setting", self.joint_encoder_callback, # type: ignore + ObkJointEncoders, key="sub_sensor", # key can be specified here or in the config file - msg_type=ObkJointEncoders, ) def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn: @@ -30,7 +30,7 @@ def joint_encoder_callback(self, msg: ObkJointEncoders) -> None: """Callback for joint encoder messages.""" self.joint_encoder_values = msg.joint_pos - def compute_state_estimate(self) -> Union[EstimatedState, None]: + def compute_state_estimate(self) -> Any: """Compute the state estimate.""" estimated_state_msg = EstimatedState() if self.joint_encoder_values is not None: diff --git a/obelisk/python/obelisk_py/zoo/hardware/robots/leap/leap_hand_interface.py b/obelisk/python/obelisk_py/zoo/hardware/robots/leap/leap_hand_interface.py index 29bd12d9..502bcf2f 100644 --- a/obelisk/python/obelisk_py/zoo/hardware/robots/leap/leap_hand_interface.py +++ b/obelisk/python/obelisk_py/zoo/hardware/robots/leap/leap_hand_interface.py @@ -1,11 +1,12 @@ import math +from typing import Type from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD +from obelisk_control_msgs.msg import PositionSetpoint from obelisk_sensor_msgs.msg import ObkJointEncoders from rclpy.lifecycle.node import LifecycleState, TransitionCallbackReturn import obelisk_py.zoo.hardware.robots.leap.dxl_motor_helper as dxl -from obelisk_py.core.obelisk_typing import ObeliskControlMsg from obelisk_py.core.robot import ObeliskRobot @@ -16,11 +17,11 @@ class ObeliskLeapRobot(ObeliskRobot): def __init__(self, node_name: str) -> None: """Initialize the Obelisk Leap Hand robot.""" - super().__init__(node_name) + super().__init__(node_name, PositionSetpoint) self.register_obk_publisher( "pub_sensor_setting", + ObkJointEncoders, key="pub_sensor", - msg_type=ObkJointEncoders, ) self.register_obk_timer( "timer_sensor_setting", @@ -39,7 +40,7 @@ def _radians_to_dxl_pos(radians: float) -> int: def _dxl_pos_to_radians(dxl_pos: int) -> float: return (dxl_pos / (dxl.MAX_POS - dxl.MIN_POS)) * 2 * math.pi - math.pi - def apply_control(self, control_msg: ObeliskControlMsg) -> None: + def apply_control(self, control_msg: Type) -> None: """Apply the control message to the robot.""" for i in range(self.N_MOTORS): val = self._radians_to_dxl_pos(control_msg.q_des[i]) diff --git a/obelisk_ws/src/obelisk_ros/config/dummy_py.yaml b/obelisk_ws/src/obelisk_ros/config/dummy_py.yaml index 6104e96a..bf9a2018 100644 --- a/obelisk_ws/src/obelisk_ros/config/dummy_py.yaml +++ b/obelisk_ws/src/obelisk_ros/config/dummy_py.yaml @@ -9,18 +9,14 @@ onboard: publishers: - ros_parameter: pub_ctrl_setting topic: /obelisk/dummy/ctrl - msg_type: PositionSetpoint key: "asdf" history_depth: 10 callback_group: None - non_obelisk: False subscribers: - ros_parameter: sub_est_setting topic: /obelisk/dummy/est - msg_type: EstimatedState history_depth: 10 callback_group: None - non_obelisk: False timers: - ros_parameter: timer_ctrl_setting timer_period_sec: 0.001 @@ -32,18 +28,14 @@ onboard: publishers: - ros_parameter: pub_est_setting topic: /obelisk/dummy/est - msg_type: EstimatedState history_depth: 10 callback_group: None - non_obelisk: False subscribers: - ros_parameter: sub_sensor_setting # key: sub1 topic: /obelisk/dummy/joint_encoders - msg_type: ObkJointEncoders history_depth: 10 callback_group: None - non_obelisk: False timers: - ros_parameter: timer_est_setting timer_period_sec: 0.001 @@ -62,10 +54,8 @@ onboard: - ros_parameter: sub_ctrl_setting # key: sub1 topic: /obelisk/dummy/ctrl - msg_type: PositionSetpoint history_depth: 10 callback_group: None - non_obelisk: False sim: - ros_parameter: mujoco_setting n_u: 1 diff --git a/obelisk_ws/src/obelisk_ros/config/dummy_py_zed.yaml b/obelisk_ws/src/obelisk_ros/config/dummy_py_zed.yaml new file mode 100644 index 00000000..aa6c9e15 --- /dev/null +++ b/obelisk_ws/src/obelisk_ros/config/dummy_py_zed.yaml @@ -0,0 +1,91 @@ +config: dummy +onboard: + control: + - pkg: obelisk_control_py + executable: example_position_setpoint_controller + # callback_groups: + params: + test_param: value_configured_in_yaml + publishers: + - ros_parameter: pub_ctrl_setting + topic: /obelisk/dummy/ctrl + key: "asdf" + history_depth: 10 + callback_group: None + subscribers: + - ros_parameter: sub_est_setting + topic: /obelisk/dummy/est + history_depth: 10 + callback_group: None + timers: + - ros_parameter: timer_ctrl_setting + timer_period_sec: 0.001 + callback_group: None + estimation: + - pkg: obelisk_estimation_py + executable: jointencoders_passthrough_estimator + # callback_groups: + publishers: + - ros_parameter: pub_est_setting + topic: /obelisk/dummy/est + history_depth: 10 + callback_group: None + subscribers: + - ros_parameter: sub_sensor_setting + # key: sub1 + topic: /obelisk/dummy/joint_encoders + history_depth: 10 + callback_group: None + # TODO(ahl): make a new estimator object that has a zed camera for testing + # - ros_parameter: sub_zed_setting + # topic: /obelisk/dummy/zedmini/img + # msg_type: ObkImage + # history_depth: 1 + # callback_group: sub_cbg + timers: + - ros_parameter: timer_est_setting + timer_period_sec: 0.001 + callback_group: None + # sensing: + robot: + # `is_simulated` is critical for parsing correct package in the obelisk launch file + - is_simulated: True + pkg: obelisk_sim_py + executable: obelisk_mujoco_robot + params: + ic_keyframe: other + # callback_groups: + # publishers: + subscribers: + - ros_parameter: sub_ctrl_setting + # key: sub1 + topic: /obelisk/dummy/ctrl + history_depth: 10 + callback_group: None + sim: + - ros_parameter: mujoco_setting + n_u: 1 + time_step: 0.002 + num_steps_per_viz: 5 + robot_pkg: dummy_bot + model_xml_path: dummy.xml + sensor_settings: + - topic: /obelisk/dummy/joint_encoders + dt: 0.001 + msg_type: ObkJointEncoders + sensor_names: + joint_pos: jointpos + joint_vel: jointvel + - topic: /obelisk/dummy/imu + dt: 0.002 + msg_type: ObkImu + sensor_names: + tip_acc_sensor: accelerometer + tip_gyro_sensor: gyro + tip_frame_sensor: framequat + - topic: /obelisk/dummy/framepose + dt: 0.002 + msg_type: ObkFramePose + sensor_names: + tip_pos_sensor: framepos + tip_orientation_sensor: framequat diff --git a/obelisk_ws/src/obelisk_ros/config/leap_py.yaml b/obelisk_ws/src/obelisk_ros/config/leap_py.yaml index d3c3328e..9167bac9 100644 --- a/obelisk_ws/src/obelisk_ros/config/leap_py.yaml +++ b/obelisk_ws/src/obelisk_ros/config/leap_py.yaml @@ -7,17 +7,13 @@ onboard: publishers: - ros_parameter: pub_ctrl_setting topic: /obelisk/leap/ctrl - msg_type: PositionSetpoint history_depth: 10 callback_group: None - non_obelisk: False subscribers: - ros_parameter: sub_est_setting topic: /obelisk/leap/est - msg_type: EstimatedState history_depth: 10 callback_group: None - non_obelisk: False timers: - ros_parameter: timer_ctrl_setting timer_period_sec: 0.001 @@ -29,18 +25,14 @@ onboard: publishers: - ros_parameter: pub_est_setting topic: /obelisk/leap/est - msg_type: EstimatedState history_depth: 10 callback_group: None - non_obelisk: False subscribers: - ros_parameter: sub_sensor_setting # key: sub1 topic: /obelisk/leap/joint_encoders - msg_type: ObkJointEncoders history_depth: 10 callback_group: None - non_obelisk: False timers: - ros_parameter: timer_est_setting timer_period_sec: 0.001 @@ -57,21 +49,17 @@ onboard: # executable: obelisk_leap_robot # ================== # callback_groups: - publishers: + publishers: # for hardware only - ros_parameter: pub_sensor_setting topic: /obelisk/leap/joint_encoders - msg_type: ObkJointEncoders history_depth: 10 callback_group: None - non_obelisk: False subscribers: - ros_parameter: sub_ctrl_setting # key: sub1 topic: /obelisk/leap/ctrl - msg_type: PositionSetpoint history_depth: 10 callback_group: None - non_obelisk: False sim: - ros_parameter: mujoco_setting n_u: 16 diff --git a/obelisk_ws/src/python/obelisk_sim_py/obelisk_sim_py/obelisk_mujoco_robot.py b/obelisk_ws/src/python/obelisk_sim_py/obelisk_sim_py/obelisk_mujoco_robot.py index e54212fd..126cecf8 100644 --- a/obelisk_ws/src/python/obelisk_sim_py/obelisk_sim_py/obelisk_mujoco_robot.py +++ b/obelisk_ws/src/python/obelisk_sim_py/obelisk_sim_py/obelisk_mujoco_robot.py @@ -1,5 +1,6 @@ from typing import List, Optional +from obelisk_control_msgs.msg import PositionSetpoint from rclpy.executors import MultiThreadedExecutor from obelisk_py.core.sim.mujoco import ObeliskMujocoRobot @@ -8,7 +9,8 @@ def main(args: Optional[List] = None) -> None: """Main entrypoint.""" - spin_obelisk(args, ObeliskMujocoRobot, MultiThreadedExecutor) + ctrl_msg_type = PositionSetpoint + spin_obelisk(args, ObeliskMujocoRobot, MultiThreadedExecutor, {"ctrl_msg_type": ctrl_msg_type}) if __name__ == "__main__": diff --git a/tests/tests_python/tests_core/test_control.py b/tests/tests_python/tests_core/test_control.py index a7b881cc..0d0103d1 100644 --- a/tests/tests_python/tests_core/test_control.py +++ b/tests/tests_python/tests_core/test_control.py @@ -4,6 +4,7 @@ from rclpy.publisher import Publisher from rclpy.subscription import Subscription from rclpy.timer import Timer +from std_msgs.msg import String from obelisk_py.core.control import ObeliskController from obelisk_py.core.obelisk_typing import ObeliskControlMsg, ObeliskEstimatorMsg @@ -28,7 +29,9 @@ def compute_control(self) -> ObeliskControlMsg: @pytest.fixture def test_controller(ros_context: Any) -> Generator[TestController, None, None]: """Fixture for the TestController class.""" - controller = TestController("test_controller") + ctrl_msg_type = String + est_msg_type = String + controller = TestController("test_controller", ctrl_msg_type, est_msg_type) yield controller controller.destroy_node() @@ -65,18 +68,6 @@ def test_timer_registration(test_controller: TestController) -> None: assert timer_setting["callback"] == test_controller.compute_control -def test_publisher_registration(test_controller: TestController) -> None: - """Test the registration of the control publisher. - - This test verifies that the control publisher is properly registered with the correct key and message type. - - Parameters: - test_controller: An instance of TestController. - """ - pub_setting = next(s for s in test_controller._obk_pub_settings if s["key"] == "pub_ctrl") - assert pub_setting["msg_type"] is None # Should be specified in config file - - def test_subscription_registration(test_controller: TestController) -> None: """Test the registration of the estimator subscription. @@ -88,7 +79,6 @@ def test_subscription_registration(test_controller: TestController) -> None: """ sub_setting = next(s for s in test_controller._obk_sub_settings if s["key"] == "sub_est") assert sub_setting["callback"] == test_controller.update_x_hat - assert sub_setting["msg_type"] is None # Should be specified in config file def test_controller_configuration(test_controller: TestController, set_node_parameters: Callable) -> None: @@ -141,6 +131,6 @@ def update_x_hat(self, x_hat_msg: ObeliskEstimatorMsg) -> None: def compute_control(self) -> ObeliskControlMsg: return ObeliskControlMsg() - complete_controller = CompleteController("complete_controller") + complete_controller = CompleteController("complete_controller", String, String) assert hasattr(complete_controller, "update_x_hat") assert hasattr(complete_controller, "compute_control") diff --git a/tests/tests_python/tests_core/test_estimation.py b/tests/tests_python/tests_core/test_estimation.py index ca544448..12328b7f 100644 --- a/tests/tests_python/tests_core/test_estimation.py +++ b/tests/tests_python/tests_core/test_estimation.py @@ -1,10 +1,12 @@ -from typing import Any, Callable, Generator +from typing import Any, Callable, Generator, Type import pytest import rclpy +from obelisk_sensor_msgs.msg import ObkJointEncoders from rclpy.lifecycle.node import TransitionCallbackReturn from rclpy.publisher import Publisher from rclpy.timer import Timer +from std_msgs.msg import String from obelisk_py.core.estimation import ObeliskEstimator from obelisk_py.core.obelisk_typing import ObeliskEstimatorMsg, ObeliskSensorMsg @@ -17,14 +19,14 @@ class TestEstimator(ObeliskEstimator): """A concrete implementation of ObeliskEstimator for testing purposes.""" - def __init__(self, node_name: str) -> None: + def __init__(self, node_name: str, est_msg_type: Type) -> None: """Initialize the TestEstimator.""" - super().__init__(node_name) + super().__init__(node_name, est_msg_type) self.register_obk_subscription( "sub_sensor_setting", self.update_sensor, + ObkJointEncoders, key="sub_sensor", - msg_type=None, # generic, specified in config file ) def update_sensor(self, sensor_msg: ObeliskSensorMsg) -> None: @@ -39,7 +41,7 @@ def compute_state_estimate(self) -> ObeliskEstimatorMsg: @pytest.fixture def test_estimator(ros_context: Any) -> Generator[TestEstimator, None, None]: """Fixture for the TestEstimator class.""" - estimator = TestEstimator("test_estimator") + estimator = TestEstimator("test_estimator", String) yield estimator estimator.destroy_node() @@ -76,18 +78,6 @@ def test_timer_registration(test_estimator: TestEstimator) -> None: assert timer_setting["callback"] == test_estimator.compute_state_estimate -def test_publisher_registration(test_estimator: TestEstimator) -> None: - """Test the registration of the estimation publisher. - - This test verifies that the estimation publisher is properly registered with the correct key and message type. - - Parameters: - test_estimator: An instance of TestEstimator. - """ - pub_setting = next(s for s in test_estimator._obk_pub_settings if s["key"] == "pub_est") - assert pub_setting["msg_type"] is None # Should be specified in config file - - def test_subscription_registration(test_estimator: TestEstimator) -> None: """Test the registration of the sensor subscription. @@ -99,7 +89,6 @@ def test_subscription_registration(test_estimator: TestEstimator) -> None: """ sub_setting = next(s for s in test_estimator._obk_sub_settings if s["key"] == "sub_sensor") assert sub_setting["callback"] == test_estimator.update_sensor - assert sub_setting["msg_type"] is None # Should be specified in config file def test_estimator_configuration(test_estimator: TestEstimator, set_node_parameters: Callable) -> None: @@ -117,7 +106,7 @@ def test_estimator_configuration(test_estimator: TestEstimator, set_node_paramet { "timer_est_setting": "timer_period_sec:0.1", "pub_est_setting": "topic:/test_estimate,msg_type:EstimatedState", - "sub_sensor_setting": "topic:/test_sensor,msg_type:ObkJointEncoders", + "sub_sensor_setting": "topic:/test_sensor", }, ) result = test_estimator.on_configure(None) @@ -128,7 +117,6 @@ def test_estimator_configuration(test_estimator: TestEstimator, set_node_paramet assert "pub_est" in test_estimator.obk_publishers assert isinstance(test_estimator.obk_publishers["pub_est"], Publisher) assert "sub_sensor" in test_estimator.obk_subscriptions - assert test_estimator._has_sensor_subscriber def test_abstract_methods() -> None: @@ -148,7 +136,7 @@ class CompleteEstimator(ObeliskEstimator): def compute_state_estimate(self) -> ObeliskEstimatorMsg: return ObeliskEstimatorMsg() - complete_estimator = CompleteEstimator("complete_estimator") + complete_estimator = CompleteEstimator("complete_estimator", String) assert hasattr(complete_estimator, "compute_state_estimate") @@ -162,6 +150,6 @@ class NoSensorEstimator(ObeliskEstimator): def compute_state_estimate(self) -> ObeliskEstimatorMsg: return ObeliskEstimatorMsg() - estimator = NoSensorEstimator("no_sensor_estimator") + estimator = NoSensorEstimator("no_sensor_estimator", String) with pytest.raises(rclpy.exceptions.ParameterUninitializedException): estimator.on_configure(None) diff --git a/tests/tests_python/tests_core/test_node.py b/tests/tests_python/tests_core/test_node.py index 8d95461d..841ee4c2 100644 --- a/tests/tests_python/tests_core/test_node.py +++ b/tests/tests_python/tests_core/test_node.py @@ -9,7 +9,6 @@ from rclpy.timer import Timer from std_msgs.msg import String -from obelisk_py.core.exceptions import ObeliskMsgError from obelisk_py.core.node import ObeliskNode # ##### # @@ -52,7 +51,7 @@ def test_register_obk_publisher(test_node: ObeliskNode) -> None: Parameters: test_node: An instance of ObeliskNode. """ - test_node.register_obk_publisher("test_pub_param", key="test_pub", msg_type=String) + test_node.register_obk_publisher("test_pub_param", String, key="test_pub") assert test_node.has_parameter("test_pub_param") @@ -69,7 +68,7 @@ def test_register_obk_subscription(test_node: ObeliskNode) -> None: def callback(msg: String) -> None: pass - test_node.register_obk_subscription("test_sub_param", callback, key="test_sub", msg_type=String) + test_node.register_obk_subscription("test_sub_param", callback, String, key="test_sub") assert test_node.has_parameter("test_sub_param") @@ -101,9 +100,6 @@ def test_create_publisher(test_node: ObeliskNode) -> None: pub = test_node.create_publisher(ocm.PositionSetpoint, "test_topic", 10) assert isinstance(pub, Publisher) - with pytest.raises(ObeliskMsgError): - test_node.create_publisher(String, "test_topic", 10) - def test_create_subscription(test_node: ObeliskNode) -> None: """Test the creation of a subscription. @@ -121,9 +117,6 @@ def callback(msg: ocm.PositionSetpoint) -> None: sub = test_node.create_subscription(ocm.PositionSetpoint, "test_topic", callback, 10) assert isinstance(sub, Subscription) - with pytest.raises(ObeliskMsgError): - test_node.create_subscription(String, "test_topic", callback, 10) - def test_lifecycle_callbacks(test_node: ObeliskNode) -> None: """Test the lifecycle callbacks of the ObeliskNode. @@ -169,7 +162,7 @@ def test_publisher_creation_from_config(test_node: ObeliskNode, set_node_paramet test_node: An instance of ObeliskNode. set_node_parameters: A fixture to set node parameters. """ - test_node.register_obk_publisher("test_pub_param", key="test_pub", msg_type=ocm.PositionSetpoint) + test_node.register_obk_publisher("test_pub_param", ocm.PositionSetpoint, key="test_pub") set_node_parameters(test_node, {"test_pub_param": "topic:/test_topic,msg_type:PositionSetpoint,history_depth:10"}) test_node.on_configure(None) @@ -190,7 +183,7 @@ def test_subscription_creation_from_config(test_node: ObeliskNode, set_node_para def callback(msg: ocm.PositionSetpoint) -> None: pass - test_node.register_obk_subscription("test_sub_param", callback, key="test_sub", msg_type=ocm.PositionSetpoint) + test_node.register_obk_subscription("test_sub_param", callback, ocm.PositionSetpoint, key="test_sub") set_node_parameters(test_node, {"test_sub_param": "topic:/test_topic,msg_type:PositionSetpoint,history_depth:10"}) test_node.on_configure(None) diff --git a/tests/tests_python/tests_core/test_robot.py b/tests/tests_python/tests_core/test_robot.py index dad81e4f..b535c693 100644 --- a/tests/tests_python/tests_core/test_robot.py +++ b/tests/tests_python/tests_core/test_robot.py @@ -6,6 +6,7 @@ from rclpy.publisher import Publisher from rclpy.subscription import Subscription from rclpy.timer import Timer +from std_msgs.msg import String from obelisk_py.core.obelisk_typing import ObeliskControlMsg from obelisk_py.core.robot import ObeliskRobot, ObeliskSimRobot @@ -38,7 +39,7 @@ def run_simulator(self) -> None: @pytest.fixture def test_robot(ros_context: Any) -> Generator[TestRobot, None, None]: """Fixture for the TestRobot class.""" - robot = TestRobot("test_robot") + robot = TestRobot("test_robot", String) yield robot robot.destroy_node() @@ -46,7 +47,7 @@ def test_robot(ros_context: Any) -> Generator[TestRobot, None, None]: @pytest.fixture def test_sim_robot(ros_context: Any) -> Generator[TestSimRobot, None, None]: """Fixture for the TestSimRobot class.""" - sim_robot = TestSimRobot("test_sim_robot") + sim_robot = TestSimRobot("test_sim_robot", String) yield sim_robot sim_robot.destroy_node() @@ -80,7 +81,6 @@ def test_robot_subscription_registration(test_robot: TestRobot) -> None: """ sub_setting = next(s for s in test_robot._obk_sub_settings if s["key"] == "sub_ctrl") assert sub_setting["callback"] == test_robot.apply_control - assert sub_setting["msg_type"] is None # Should be specified in config file def test_robot_configuration(test_robot: TestRobot, set_node_parameters: Callable) -> None: @@ -190,7 +190,7 @@ class CompleteRobot(ObeliskRobot): def apply_control(self, control_msg: ObeliskControlMsg) -> None: pass - complete_robot = CompleteRobot("complete_robot") + complete_robot = CompleteRobot("complete_robot", String) assert hasattr(complete_robot, "apply_control") class CompleteSimRobot(ObeliskSimRobot): @@ -200,6 +200,6 @@ def apply_control(self, control_msg: ObeliskControlMsg) -> None: def run_simulator(self) -> None: pass - complete_sim_robot = CompleteSimRobot("complete_sim_robot") + complete_sim_robot = CompleteSimRobot("complete_sim_robot", String) assert hasattr(complete_sim_robot, "apply_control") assert hasattr(complete_sim_robot, "run_simulator") diff --git a/tests/tests_python/tests_core/test_sensing.py b/tests/tests_python/tests_core/test_sensing.py index ab05f762..f92ff0b5 100644 --- a/tests/tests_python/tests_core/test_sensing.py +++ b/tests/tests_python/tests_core/test_sensing.py @@ -48,7 +48,7 @@ def test_sensor_configuration_with_publisher(test_sensor: ObeliskSensor, set_nod test_sensor: An instance of ObeliskSensor. set_node_parameters: A fixture to set node parameters. """ - test_sensor.register_obk_publisher("test_pub_param", key="test_pub", msg_type=osm.ObkJointEncoders) + test_sensor.register_obk_publisher("test_pub_param", osm.ObkJointEncoders, key="test_pub") set_node_parameters(test_sensor, {"test_pub_param": "topic:/test_topic,msg_type:ObkJointEncoders,history_depth:10"}) result = test_sensor.on_configure(None)