Skip to content
This repository has been archived by the owner on Mar 23, 2024. It is now read-only.

Commit

Permalink
Added logic to handle ros parameter to choose planners (#80)
Browse files Browse the repository at this point in the history
* Added logic to handle ros parameter to choose planners

* Updated documentation and fixed test errors

* Changed logic according to pr comments and changed tests accordingly

* Added some comments

* Error log if planner not implemented

* Made changes according to comments

* changed location of get_params

* Cleanup

---------

Co-authored-by: Patrick Creighton <[email protected]>
  • Loading branch information
Krithik1 and patrick-5546 authored Feb 14, 2024
1 parent 7d6f2c9 commit b7ad6c7
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 16 deletions.
14 changes: 12 additions & 2 deletions local_pathfinding/local_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LocalPathState:
navigating along.
`wind_speed` (float): Wind speed.
`wind_direction` (int): Wind direction.
`planner` (str): Planner to use for the OMPL query.
"""

def __init__(
Expand All @@ -30,6 +31,7 @@ def __init__(
ais_ships: AISShips,
global_path: Path,
filtered_wind_sensor: WindSensor,
planner: str,
):
"""Initializes the state from ROS msgs."""
if gps: # TODO: remove when mock can be run
Expand All @@ -50,6 +52,8 @@ def __init__(
self.wind_speed = filtered_wind_sensor.speed.speed
self.wind_direction = filtered_wind_sensor.direction

self.planner = planner


class LocalPath:
"""Sets and updates the OMPL path and the local waypoints
Expand All @@ -62,6 +66,7 @@ class LocalPath:
"""

def __init__(self, parent_logger: RcutilsLogger):
"""Initializes the LocalPath class."""
self._logger = parent_logger.get_child(name="local_path")
self._ompl_path: Optional[OMPLPath] = None
self.waypoints: Optional[List[Tuple[float, float]]] = None
Expand All @@ -72,6 +77,7 @@ def update_if_needed(
ais_ships: AISShips,
global_path: Path,
filtered_wind_sensor: WindSensor,
planner: str,
):
"""Updates the OMPL path and waypoints. The path is updated if a new path is found.
Expand All @@ -81,8 +87,12 @@ def update_if_needed(
`global_path` (Path): Path to the destination.
`filtered_wind_sensor` (WindSensor): Wind data.
"""
state = LocalPathState(gps, ais_ships, global_path, filtered_wind_sensor)
ompl_path = OMPLPath(parent_logger=self._logger, max_runtime=1.0, local_path_state=state)
state = LocalPathState(gps, ais_ships, global_path, filtered_wind_sensor, planner)
ompl_path = OMPLPath(
parent_logger=self._logger,
max_runtime=1.0,
local_path_state=state,
)
if ompl_path.solved:
self._logger.info("Updating local path")
self._update(ompl_path)
Expand Down
39 changes: 34 additions & 5 deletions local_pathfinding/node_navigate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Sailbot(Node):
lpath_data_pub (Publisher): Publish the local path in a `LPathData` msg.
Publisher timers:
pub_period_sec (float): The period of the publisher timers.
desired_heading_timer (Timer): Call the desired heading callback function.
lpath_data_timer (Timer): Call the local path callback function.
Expand All @@ -42,6 +43,7 @@ class Sailbot(Node):
Attributes:
local_path (LocalPath): The path that `Sailbot` is following.
planner (str): The path planner that `Sailbot` is using.
"""

def __init__(self):
Expand All @@ -51,6 +53,7 @@ def __init__(self):
namespace="",
parameters=[
("pub_period_sec", rclpy.Parameter.Type.DOUBLE),
("path_planner", rclpy.Parameter.Type.STRING),
],
)

Expand Down Expand Up @@ -86,13 +89,15 @@ def __init__(self):
)

# publisher timers
pub_period_sec = self.get_parameter("pub_period_sec").get_parameter_value().double_value
self.get_logger().debug(f"Got parameter: {pub_period_sec=}")
self.pub_period_sec = (
self.get_parameter("pub_period_sec").get_parameter_value().double_value
)
self.get_logger().debug(f"Got parameter: {self.pub_period_sec=}")
self.desired_heading_timer = self.create_timer(
timer_period_sec=pub_period_sec, callback=self.desired_heading_callback
timer_period_sec=self.pub_period_sec, callback=self.desired_heading_callback
)
self.lpath_data_timer = self.create_timer(
timer_period_sec=pub_period_sec, callback=self.lpath_data_callback
timer_period_sec=self.pub_period_sec, callback=self.lpath_data_callback
)

# attributes from subscribers
Expand All @@ -103,6 +108,8 @@ def __init__(self):

# attributes
self.local_path = LocalPath(parent_logger=self.get_logger())
self.planner = self.get_parameter("path_planner").get_parameter_value().string_value
self.get_logger().debug(f"Got parameter: {self.planner=}")

# subscriber callbacks

Expand All @@ -129,6 +136,8 @@ def desired_heading_callback(self):
Warn if not following the heading conventions in custom_interfaces/msg/HelperHeading.msg.
"""
self.update_params()

desired_heading = self.get_desired_heading()
if desired_heading < 0 or 360 <= desired_heading:
self.get_logger().warning(f"Heading {desired_heading} not in [0, 360)")
Expand Down Expand Up @@ -163,12 +172,32 @@ def get_desired_heading(self) -> float:
return -1.0

self.local_path.update_if_needed(
self.gps, self.ais_ships, self.global_path, self.filtered_wind_sensor
self.gps, self.ais_ships, self.global_path, self.filtered_wind_sensor, self.planner
)

# TODO: create function to compute the heading from current position to next local waypoint
return 0.0

def update_params(self):
"""Update instance variables that depend on parameters if they have changed."""
pub_period_sec = self.get_parameter("pub_period_sec").get_parameter_value().double_value
if pub_period_sec != self.pub_period_sec:
self.get_logger().debug(
f"Updating pub period and timers from {self.pub_period_sec} to {pub_period_sec}"
)
self.pub_period_sec = pub_period_sec
self.desired_heading_timer = self.create_timer(
timer_period_sec=self.pub_period_sec, callback=self.desired_heading_callback
)
self.lpath_data_timer = self.create_timer(
timer_period_sec=self.pub_period_sec, callback=self.lpath_data_callback
)

planner = self.get_parameter("path_planner").get_parameter_value().string_value
if planner != self.planner:
self.get_logger().debug(f"Updating planner from {self.planner} to {planner}")
self.planner = planner

def _all_subs_active(self) -> bool:
return True # TODO: this line is a placeholder, delete when mocks can be run
return self.ais_ships and self.gps and self.global_path and self.filtered_wind_sensor
Expand Down
60 changes: 55 additions & 5 deletions local_pathfinding/ompl_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Tuple, Type

from custom_interfaces.msg import HelperLatLon
from ompl import base as ob
Expand All @@ -26,7 +26,7 @@


class OMPLPathState:
def __init__(self, local_path_state: LocalPathState):
def __init__(self, local_path_state: LocalPathState, logger: RcutilsLogger):
# TODO: derive OMPLPathState attributes from local_path_state
self.heading_direction = 45.0
self.wind_direction = 10.0
Expand All @@ -43,6 +43,15 @@ def __init__(self, local_path_state: LocalPathState):
else HelperLatLon(latitude=0.0, longitude=0.0)
)

if local_path_state:
planner = local_path_state.planner
supported_planner, _ = get_planner_class(planner)
if planner != supported_planner:
logger.error(
f"Planner {planner} is not implemented, defaulting to {supported_planner}"
)
self.planner = supported_planner


class OMPLPath:
"""Represents the general OMPL Path.
Expand All @@ -67,7 +76,7 @@ def __init__(
local_path_state (LocalPathState): State of Sailbot.
"""
self._logger = parent_logger.get_child(name="ompl_path")
self.state = OMPLPathState(local_path_state)
self.state = OMPLPathState(local_path_state, self._logger)
self._simple_setup = self._init_simple_setup()
self.solved = self._simple_setup.solve(time=max_runtime) # time is in seconds

Expand Down Expand Up @@ -177,8 +186,8 @@ def _init_simple_setup(self) -> og.SimpleSetup:
simple_setup.setOptimizationObjective(objective)

# set the planner of the simple setup object
# TODO: implement and add planner here
planner = og.RRTstar(space_information)
_, planner_class = get_planner_class(self.state.planner)
planner = planner_class(space_information)
simple_setup.setPlanner(planner)

return simple_setup
Expand All @@ -196,3 +205,44 @@ def is_state_valid(state: ob.SE2StateSpace) -> bool:
# TODO: implement obstacle avoidance here
# note: `state` is of type `SE2StateInternal`, so we don't need to use the `()` operator.
return state.getX() < 0.6


def get_planner_class(planner: str) -> Tuple[str, Type[ob.Planner]]:
"""Choose the planner to use for the OMPL query.
Args:
planner (str): Name of the planner to use.
Returns:
Tuple[str, Type[ob.Planner]]: The name and class of the planner to use for the OMPL query,
defaults to RRT* if `planner` is not implemented in this function.
"""
match planner.lower():
case "bitstar":
return planner, og.BITstar
case "bfmtstar":
return planner, og.BFMT
case "fmtstar":
return planner, og.FMT
case "informedrrtstar":
return planner, og.InformedRRTstar
case "lazylbtrrt":
return planner, og.LazyLBTRRT
case "lazyprmstar":
return planner, og.LazyPRMstar
case "lbtrrt":
return planner, og.LBTRRT
case "prmstar":
return planner, og.PRMstar
case "rrtconnect":
return planner, og.RRTConnect
case "rrtsharp":
return planner, og.RRTsharp
case "rrtstar":
return planner, og.RRTstar
case "rrtxstatic":
return planner, og.RRTXstatic
case "sorrtstar":
return planner, og.SORRTstar
case _:
return "rrtstar", og.RRTstar
1 change: 1 addition & 0 deletions test/test_local_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_LocalPath_update_if_needed():
ais_ships=AISShips(),
global_path=Path(),
filtered_wind_sensor=WindSensor(),
planner="bitstar",
)
assert PATH.waypoints is not None, "waypoints is not initialized"
assert len(PATH.waypoints) > 1, "waypoints length <= 1"
11 changes: 9 additions & 2 deletions test/test_objectives.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import math

import pytest
from custom_interfaces.msg import HelperLatLon
from custom_interfaces.msg import GPS, AISShips, HelperLatLon, Path, WindSensor
from rclpy.impl.rcutils_logger import RcutilsLogger

import local_pathfinding.coord_systems as coord_systems
import local_pathfinding.objectives as objectives
import local_pathfinding.ompl_path as ompl_path
from local_pathfinding.local_path import LocalPathState

# Upwind downwind cost multipliers
UPWIND_MULTIPLIER = 3000.0
Expand All @@ -16,7 +17,13 @@
PATH = ompl_path.OMPLPath(
parent_logger=RcutilsLogger(),
max_runtime=1,
local_path_state=None, # type: ignore[arg-type] # None is placeholder
local_path_state=LocalPathState(
gps=GPS(),
ais_ships=AISShips(),
global_path=Path(),
filtered_wind_sensor=WindSensor(),
planner="bitstar",
),
)


Expand Down
12 changes: 10 additions & 2 deletions test/test_ompl_path.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import pytest
from custom_interfaces.msg import GPS, AISShips, Path, WindSensor
from ompl import base as ob
from rclpy.impl.rcutils_logger import RcutilsLogger

import local_pathfinding.coord_systems as cs
import local_pathfinding.ompl_path as ompl_path
from local_pathfinding.local_path import LocalPathState

PATH = ompl_path.OMPLPath(
parent_logger=RcutilsLogger(),
max_runtime=1,
local_path_state=None, # type: ignore[arg-type] # None is placeholder
local_path_state=LocalPathState(
gps=GPS(),
ais_ships=AISShips(),
global_path=Path(),
filtered_wind_sensor=WindSensor(),
planner="bitstar",
),
)


def test_OMPLPathState():
state = ompl_path.OMPLPathState(local_path_state=None)
state = ompl_path.OMPLPathState(local_path_state=None, logger=RcutilsLogger())
assert state.state_domain == (-1, 1), "incorrect value for attribute state_domain"
assert state.state_range == (-1, 1), "incorrect value for attribute start_state"
assert state.start_state == pytest.approx(
Expand Down

0 comments on commit b7ad6c7

Please sign in to comment.