From b7ad6c71e12c1131e610b76ae531a8df12fc7b0b Mon Sep 17 00:00:00 2001 From: Krithik Sehgal <62566536+Krithik1@users.noreply.github.com> Date: Tue, 13 Feb 2024 19:10:31 -0800 Subject: [PATCH] Added logic to handle ros parameter to choose planners (#80) * Added logic to handle ros parameter to choose planners * Updated documentation and fixed test errors * Changed logic according to pr comments and changed tests accordingly * Added some comments * Error log if planner not implemented * Made changes according to comments * changed location of get_params * Cleanup --------- Co-authored-by: Patrick Creighton --- local_pathfinding/local_path.py | 14 ++++++- local_pathfinding/node_navigate.py | 39 ++++++++++++++++--- local_pathfinding/ompl_path.py | 60 +++++++++++++++++++++++++++--- test/test_local_path.py | 1 + test/test_objectives.py | 11 +++++- test/test_ompl_path.py | 12 +++++- 6 files changed, 121 insertions(+), 16 deletions(-) diff --git a/local_pathfinding/local_path.py b/local_pathfinding/local_path.py index a4b39a3..75551d2 100644 --- a/local_pathfinding/local_path.py +++ b/local_pathfinding/local_path.py @@ -22,6 +22,7 @@ class LocalPathState: navigating along. `wind_speed` (float): Wind speed. `wind_direction` (int): Wind direction. + `planner` (str): Planner to use for the OMPL query. """ def __init__( @@ -30,6 +31,7 @@ def __init__( ais_ships: AISShips, global_path: Path, filtered_wind_sensor: WindSensor, + planner: str, ): """Initializes the state from ROS msgs.""" if gps: # TODO: remove when mock can be run @@ -50,6 +52,8 @@ def __init__( self.wind_speed = filtered_wind_sensor.speed.speed self.wind_direction = filtered_wind_sensor.direction + self.planner = planner + class LocalPath: """Sets and updates the OMPL path and the local waypoints @@ -62,6 +66,7 @@ class LocalPath: """ def __init__(self, parent_logger: RcutilsLogger): + """Initializes the LocalPath class.""" self._logger = parent_logger.get_child(name="local_path") self._ompl_path: Optional[OMPLPath] = None self.waypoints: Optional[List[Tuple[float, float]]] = None @@ -72,6 +77,7 @@ def update_if_needed( ais_ships: AISShips, global_path: Path, filtered_wind_sensor: WindSensor, + planner: str, ): """Updates the OMPL path and waypoints. The path is updated if a new path is found. @@ -81,8 +87,12 @@ def update_if_needed( `global_path` (Path): Path to the destination. `filtered_wind_sensor` (WindSensor): Wind data. """ - state = LocalPathState(gps, ais_ships, global_path, filtered_wind_sensor) - ompl_path = OMPLPath(parent_logger=self._logger, max_runtime=1.0, local_path_state=state) + state = LocalPathState(gps, ais_ships, global_path, filtered_wind_sensor, planner) + ompl_path = OMPLPath( + parent_logger=self._logger, + max_runtime=1.0, + local_path_state=state, + ) if ompl_path.solved: self._logger.info("Updating local path") self._update(ompl_path) diff --git a/local_pathfinding/node_navigate.py b/local_pathfinding/node_navigate.py index 4248937..9c00d59 100644 --- a/local_pathfinding/node_navigate.py +++ b/local_pathfinding/node_navigate.py @@ -31,6 +31,7 @@ class Sailbot(Node): lpath_data_pub (Publisher): Publish the local path in a `LPathData` msg. Publisher timers: + pub_period_sec (float): The period of the publisher timers. desired_heading_timer (Timer): Call the desired heading callback function. lpath_data_timer (Timer): Call the local path callback function. @@ -42,6 +43,7 @@ class Sailbot(Node): Attributes: local_path (LocalPath): The path that `Sailbot` is following. + planner (str): The path planner that `Sailbot` is using. """ def __init__(self): @@ -51,6 +53,7 @@ def __init__(self): namespace="", parameters=[ ("pub_period_sec", rclpy.Parameter.Type.DOUBLE), + ("path_planner", rclpy.Parameter.Type.STRING), ], ) @@ -86,13 +89,15 @@ def __init__(self): ) # publisher timers - pub_period_sec = self.get_parameter("pub_period_sec").get_parameter_value().double_value - self.get_logger().debug(f"Got parameter: {pub_period_sec=}") + self.pub_period_sec = ( + self.get_parameter("pub_period_sec").get_parameter_value().double_value + ) + self.get_logger().debug(f"Got parameter: {self.pub_period_sec=}") self.desired_heading_timer = self.create_timer( - timer_period_sec=pub_period_sec, callback=self.desired_heading_callback + timer_period_sec=self.pub_period_sec, callback=self.desired_heading_callback ) self.lpath_data_timer = self.create_timer( - timer_period_sec=pub_period_sec, callback=self.lpath_data_callback + timer_period_sec=self.pub_period_sec, callback=self.lpath_data_callback ) # attributes from subscribers @@ -103,6 +108,8 @@ def __init__(self): # attributes self.local_path = LocalPath(parent_logger=self.get_logger()) + self.planner = self.get_parameter("path_planner").get_parameter_value().string_value + self.get_logger().debug(f"Got parameter: {self.planner=}") # subscriber callbacks @@ -129,6 +136,8 @@ def desired_heading_callback(self): Warn if not following the heading conventions in custom_interfaces/msg/HelperHeading.msg. """ + self.update_params() + desired_heading = self.get_desired_heading() if desired_heading < 0 or 360 <= desired_heading: self.get_logger().warning(f"Heading {desired_heading} not in [0, 360)") @@ -163,12 +172,32 @@ def get_desired_heading(self) -> float: return -1.0 self.local_path.update_if_needed( - self.gps, self.ais_ships, self.global_path, self.filtered_wind_sensor + self.gps, self.ais_ships, self.global_path, self.filtered_wind_sensor, self.planner ) # TODO: create function to compute the heading from current position to next local waypoint return 0.0 + def update_params(self): + """Update instance variables that depend on parameters if they have changed.""" + pub_period_sec = self.get_parameter("pub_period_sec").get_parameter_value().double_value + if pub_period_sec != self.pub_period_sec: + self.get_logger().debug( + f"Updating pub period and timers from {self.pub_period_sec} to {pub_period_sec}" + ) + self.pub_period_sec = pub_period_sec + self.desired_heading_timer = self.create_timer( + timer_period_sec=self.pub_period_sec, callback=self.desired_heading_callback + ) + self.lpath_data_timer = self.create_timer( + timer_period_sec=self.pub_period_sec, callback=self.lpath_data_callback + ) + + planner = self.get_parameter("path_planner").get_parameter_value().string_value + if planner != self.planner: + self.get_logger().debug(f"Updating planner from {self.planner} to {planner}") + self.planner = planner + def _all_subs_active(self) -> bool: return True # TODO: this line is a placeholder, delete when mocks can be run return self.ais_ships and self.gps and self.global_path and self.filtered_wind_sensor diff --git a/local_pathfinding/ompl_path.py b/local_pathfinding/ompl_path.py index b47f62f..37d52a6 100644 --- a/local_pathfinding/ompl_path.py +++ b/local_pathfinding/ompl_path.py @@ -7,7 +7,7 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Tuple, Type from custom_interfaces.msg import HelperLatLon from ompl import base as ob @@ -26,7 +26,7 @@ class OMPLPathState: - def __init__(self, local_path_state: LocalPathState): + def __init__(self, local_path_state: LocalPathState, logger: RcutilsLogger): # TODO: derive OMPLPathState attributes from local_path_state self.heading_direction = 45.0 self.wind_direction = 10.0 @@ -43,6 +43,15 @@ def __init__(self, local_path_state: LocalPathState): else HelperLatLon(latitude=0.0, longitude=0.0) ) + if local_path_state: + planner = local_path_state.planner + supported_planner, _ = get_planner_class(planner) + if planner != supported_planner: + logger.error( + f"Planner {planner} is not implemented, defaulting to {supported_planner}" + ) + self.planner = supported_planner + class OMPLPath: """Represents the general OMPL Path. @@ -67,7 +76,7 @@ def __init__( local_path_state (LocalPathState): State of Sailbot. """ self._logger = parent_logger.get_child(name="ompl_path") - self.state = OMPLPathState(local_path_state) + self.state = OMPLPathState(local_path_state, self._logger) self._simple_setup = self._init_simple_setup() self.solved = self._simple_setup.solve(time=max_runtime) # time is in seconds @@ -177,8 +186,8 @@ def _init_simple_setup(self) -> og.SimpleSetup: simple_setup.setOptimizationObjective(objective) # set the planner of the simple setup object - # TODO: implement and add planner here - planner = og.RRTstar(space_information) + _, planner_class = get_planner_class(self.state.planner) + planner = planner_class(space_information) simple_setup.setPlanner(planner) return simple_setup @@ -196,3 +205,44 @@ def is_state_valid(state: ob.SE2StateSpace) -> bool: # TODO: implement obstacle avoidance here # note: `state` is of type `SE2StateInternal`, so we don't need to use the `()` operator. return state.getX() < 0.6 + + +def get_planner_class(planner: str) -> Tuple[str, Type[ob.Planner]]: + """Choose the planner to use for the OMPL query. + + Args: + planner (str): Name of the planner to use. + + Returns: + Tuple[str, Type[ob.Planner]]: The name and class of the planner to use for the OMPL query, + defaults to RRT* if `planner` is not implemented in this function. + """ + match planner.lower(): + case "bitstar": + return planner, og.BITstar + case "bfmtstar": + return planner, og.BFMT + case "fmtstar": + return planner, og.FMT + case "informedrrtstar": + return planner, og.InformedRRTstar + case "lazylbtrrt": + return planner, og.LazyLBTRRT + case "lazyprmstar": + return planner, og.LazyPRMstar + case "lbtrrt": + return planner, og.LBTRRT + case "prmstar": + return planner, og.PRMstar + case "rrtconnect": + return planner, og.RRTConnect + case "rrtsharp": + return planner, og.RRTsharp + case "rrtstar": + return planner, og.RRTstar + case "rrtxstatic": + return planner, og.RRTXstatic + case "sorrtstar": + return planner, og.SORRTstar + case _: + return "rrtstar", og.RRTstar diff --git a/test/test_local_path.py b/test/test_local_path.py index 3cd6eb7..001c45f 100644 --- a/test/test_local_path.py +++ b/test/test_local_path.py @@ -12,6 +12,7 @@ def test_LocalPath_update_if_needed(): ais_ships=AISShips(), global_path=Path(), filtered_wind_sensor=WindSensor(), + planner="bitstar", ) assert PATH.waypoints is not None, "waypoints is not initialized" assert len(PATH.waypoints) > 1, "waypoints length <= 1" diff --git a/test/test_objectives.py b/test/test_objectives.py index b8d5c04..9289b65 100644 --- a/test/test_objectives.py +++ b/test/test_objectives.py @@ -1,12 +1,13 @@ import math import pytest -from custom_interfaces.msg import HelperLatLon +from custom_interfaces.msg import GPS, AISShips, HelperLatLon, Path, WindSensor from rclpy.impl.rcutils_logger import RcutilsLogger import local_pathfinding.coord_systems as coord_systems import local_pathfinding.objectives as objectives import local_pathfinding.ompl_path as ompl_path +from local_pathfinding.local_path import LocalPathState # Upwind downwind cost multipliers UPWIND_MULTIPLIER = 3000.0 @@ -16,7 +17,13 @@ PATH = ompl_path.OMPLPath( parent_logger=RcutilsLogger(), max_runtime=1, - local_path_state=None, # type: ignore[arg-type] # None is placeholder + local_path_state=LocalPathState( + gps=GPS(), + ais_ships=AISShips(), + global_path=Path(), + filtered_wind_sensor=WindSensor(), + planner="bitstar", + ), ) diff --git a/test/test_ompl_path.py b/test/test_ompl_path.py index ac7809d..e766256 100644 --- a/test/test_ompl_path.py +++ b/test/test_ompl_path.py @@ -1,19 +1,27 @@ import pytest +from custom_interfaces.msg import GPS, AISShips, Path, WindSensor from ompl import base as ob from rclpy.impl.rcutils_logger import RcutilsLogger import local_pathfinding.coord_systems as cs import local_pathfinding.ompl_path as ompl_path +from local_pathfinding.local_path import LocalPathState PATH = ompl_path.OMPLPath( parent_logger=RcutilsLogger(), max_runtime=1, - local_path_state=None, # type: ignore[arg-type] # None is placeholder + local_path_state=LocalPathState( + gps=GPS(), + ais_ships=AISShips(), + global_path=Path(), + filtered_wind_sensor=WindSensor(), + planner="bitstar", + ), ) def test_OMPLPathState(): - state = ompl_path.OMPLPathState(local_path_state=None) + state = ompl_path.OMPLPathState(local_path_state=None, logger=RcutilsLogger()) assert state.state_domain == (-1, 1), "incorrect value for attribute state_domain" assert state.state_range == (-1, 1), "incorrect value for attribute start_state" assert state.start_state == pytest.approx(