diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d66a801c..efdba50a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,7 +18,7 @@ jobs: test-definitions: # sailbot_workspace: use locally-defined file # other repositories: set to UBCSailbot/sailbot_workspace/.github/workflows/test_definitions.yml@ - uses: UBCSailbot/sailbot_workspace/.github/workflows/test_definitions.yml@v1.6.1 + uses: UBCSailbot/sailbot_workspace/.github/workflows/test_definitions.yml@v1.7.0 # see https://github.com/UBCSailbot/sailbot_workspace/blob/main/.github/workflows/test_definitions.yml # for documentation on the inputs and secrets below with: diff --git a/.gitignore b/.gitignore index 983bc442..5ad2f1cf 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,3 @@ __pycache__/ # global paths with exceptions /global_paths/*.csv !/global_paths/mock_global_path.csv -!/global_paths/path.csv diff --git a/README.md b/README.md index e33e1d60..b92bfcdc 100644 --- a/README.md +++ b/README.md @@ -17,3 +17,8 @@ Launch arguments are added to the run command in the format `:=`. | `log_level` | Logging level | A [severity level][severity level] (case insensitive) | [severity level]: + +### Server Files + +The server files: `get_server.py` and `post_server.py` are basic http server files which are used for testing the +global_path module's GET and POST methods. diff --git a/global_paths/path_builder/path_builder.py b/global_paths/path_builder/path_builder.py index 778581ad..d712bee7 100644 --- a/global_paths/path_builder/path_builder.py +++ b/global_paths/path_builder/path_builder.py @@ -23,9 +23,9 @@ from custom_interfaces.msg import HelperLatLon, Path from flask import Flask, jsonify, render_template, request -from local_pathfinding.node_mock_global_path import ( +from local_pathfinding.global_path import ( + _interpolate_path, calculate_interval_spacing, - interpolate_path, write_to_file, ) @@ -65,7 +65,7 @@ def main(): path_spacing = calculate_interval_spacing(pos=pos, waypoints=waypoints) path = Path(waypoints=waypoints) - path = interpolate_path( + path = _interpolate_path( global_path=path, interval_spacing=args.interpolate, pos=pos, @@ -134,7 +134,7 @@ def _delete_paths(): @app.route("/interpolate_path", methods=["POST"]) -def _interpolate_path(): +def _interpolate_path_(): data = request.json result = _handle_interpolate(data) return jsonify(result) @@ -229,7 +229,7 @@ def _handle_interpolate(data): try: path_spacing = calculate_interval_spacing(pos=point1, waypoints=path.waypoints) - path = interpolate_path( + path = _interpolate_path( global_path=path, interval_spacing=interval_spacing, pos=point1, diff --git a/global_paths/path_builder/requirements.txt b/global_paths/path_builder/requirements.txt deleted file mode 100644 index 08fcee0c..00000000 --- a/global_paths/path_builder/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -Flask 3.0.0 -Werkzeug 3.0.1 diff --git a/local_pathfinding/global_path.py b/local_pathfinding/global_path.py new file mode 100644 index 00000000..1128a76c --- /dev/null +++ b/local_pathfinding/global_path.py @@ -0,0 +1,440 @@ +"""The Global Path Module, which retrieves the global path from a specified http source and +sends it to NET via POST request. + +The main function accepts two CLI arguments: + file_path (str): The path to the global path csv file. + --interval (float, Optional): The desired path interval length in km. +""" + +import argparse +import csv +import json +import os +import time +from datetime import datetime +from urllib.error import HTTPError, URLError +from urllib.request import urlopen + +import numpy as np +from custom_interfaces.msg import HelperLatLon, Path + +from local_pathfinding.coord_systems import GEODESIC, meters_to_km + +GPS_URL = "http://localhost:3005/api/gps" +PATH_URL = "http://localhost:8081/global-path" +GLOBAL_PATHS_FILE_PATH = "/workspaces/sailbot_workspace/src/local_pathfinding/global_paths" +PERIOD = 5 # seconds + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("file_path", help="The path to the global path csv file.") + parser.add_argument("--interval", help="Desired path interval length.", type=float) + args = parser.parse_args() + + file_path = args.file_path + path_mod_tmstmp = None + pos = None + + try: + path = get_path(file_path) + print(f"retrieved path from {file_path}", path_to_dict(path)) + except FileNotFoundError: + print(f"{file_path} not found. Please enter a valid file path.") + exit(1) + + # Main service loop + while True: + time.sleep(PERIOD) + timestamp = time.ctime(os.path.getmtime(file_path)) + + # We should try to retrieve the position on every loop + pos = get_pos() + + if pos is None: + print(f"Failed to retrieve position from {GPS_URL}") + continue + + position_delta = meters_to_km( + GEODESIC.inv( + lats1=pos.latitude, + lons1=pos.longitude, + lats2=path.waypoints[0].latitude, + lons2=path.waypoints[0].longitude, + )[2] + ) + + # exit loop if the path has not been modified or interval lengths are fine + if (timestamp == path_mod_tmstmp) and ( + (args.interval is None) or position_delta <= args.interval + ): + continue + + if args.interval is not None: + # interpolate path will interpolate new path and save it to a new csv file + path = interpolate_path( + path=path, + pos=pos, + interval_spacing=args.interval, + file_path=file_path, + ) + + if post_path(path): + print("Global path successfully updated.") + print(f"position was {pos}") + file_path = get_most_recent_file(GLOBAL_PATHS_FILE_PATH) + timestamp = time.ctime(os.path.getmtime(file_path)) + else: + # if the post was unsuccessful, we should try again + # so don't update the timestamp + continue + + path_mod_tmstmp = timestamp + + +def get_most_recent_file(directory_path: str) -> str: + """ + Returns the most recently modified file in the specified directory. + + Args: + directory_path (str): The path to the directory containing the files. + + Returns: + str: The path to the most recently modified file. + """ + all_files = os.listdir(directory_path) + + # Filter out directories and get the full file paths + files = [ + os.path.join(directory_path, file) + for file in all_files + if os.path.isfile(os.path.join(directory_path, file)) + ] + + # Sort the files based on their last modification time + files.sort(key=lambda x: os.path.getmtime(x), reverse=True) + + if files: + return files[0] + else: + return "" + + +def get_path(file_path: str) -> Path: + """Returns the global path from the specified file path. + + Args: + file_path (str): The path to the global path csv file. + + Returns: + (Path): The global path retrieved from the csv file. + """ + path = Path() + + with open(file_path, "r") as file: + reader = csv.reader(file) + # skip header + reader.__next__() + for row in reader: + path.waypoints.append(HelperLatLon(latitude=float(row[0]), longitude=float(row[1]))) + return path + + +def post_path(path: Path) -> bool: + """Sends the global path to NET via POST request. + + Args: + path (Path): The global path. + + Returns: + bool: Whether or not the global path was successfully posted. + """ + waypoints = [ + {"latitude": float(item.latitude), "longitude": float(item.longitude)} + for item in path.waypoints + ] + + # the timestamp format will be -- :: + timestamp = datetime.now().strftime("%y-%m-%d %H:%M:%S") + + data = {"waypoints": waypoints, "timestamp": timestamp} + + json_data = json.dumps(data).encode("utf-8") + try: + urlopen(PATH_URL, json_data) + return True + except HTTPError as http_error: + print(f"HTTP Error: {http_error.code}") + except URLError as url_error: + print(f"URL Error: {url_error.reason}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + return False + + +def get_pos() -> HelperLatLon: + """Returns the current position of sailbot, retrieved from the an http GET request. + + Returns: + HelperLatLon: The current position of sailbot + OR + None: If the position could not be retrieved. + """ + try: + position = json.loads(urlopen(GPS_URL).read()) + except HTTPError as http_error: + print(f"HTTP Error: {http_error.code}") + return None + except URLError as url_error: + print(f"URL Error: {url_error.reason}") + return None + except ConnectionResetError as connect_error: + print(f"Connection Reset Error: {connect_error}") + return None + except Exception as e: + print(f"An unexpected error occurred: {e}") + return None + + if len(position["data"]) == 0: + print(f"Connection to {GPS_URL} successful. No position data available.") + return None + + latitude = position["data"][-1]["latitude"] + longitude = position["data"][-1]["longitude"] + pos = HelperLatLon(latitude=latitude, longitude=longitude) + + return pos + + +def generate_path( + dest: HelperLatLon, + interval_spacing: float, + pos: HelperLatLon, + write: bool = False, + file_path: str = "", +) -> Path: + """Returns a path from the current GPS location to the destination point. + Waypoints are evenly spaced along the path according to the interval_spacing parameter. + Path does not include pos, but does include dest as the final element. + + If write is True, the path is written to a new csv file in the same directory as file_path, + with the name of the original file, appended with a timestamp. + + Args: + dest (HelperLatLon): The destination point + interval_spacing (float): The desired distance between waypoints on the path + pos (HelperLatLon): The current GPS location + write (bool, optional): Whether to write the path to a new csv file, default False + file_path (str, optional): The filepath to the global path csv file, default empty + + Returns: + Path: The generated path + """ + global_path = Path() + + lat1 = pos.latitude + lon1 = pos.longitude + + lat2 = dest.latitude + lon2 = dest.longitude + + distance = meters_to_km(GEODESIC.inv(lats1=lat1, lons1=lon1, lats2=lat2, lons2=lon2)[2]) + + # minimum number of waypoints to not exceed interval_spacing + n = np.floor(distance / interval_spacing) + n = max(1, n) + + # npts returns a path with neither pos nor dest included + global_path_tuples = GEODESIC.npts(lon1=lon1, lat1=lat1, lon2=lon2, lat2=lat2, npts=n) + + # npts returns (lon,lat) tuples, its backwards for some reason + for lon, lat in global_path_tuples: + global_path.waypoints.append(HelperLatLon(latitude=lat, longitude=lon)) + + # append the destination point + global_path.waypoints.append(HelperLatLon(latitude=lat2, longitude=lon2)) + + if write: + write_to_file(file_path=file_path, global_path=global_path) + + return global_path + + +def _interpolate_path( + global_path: Path, + interval_spacing: float, + pos: HelperLatLon, + path_spacing: list[float], + write: bool = False, + file_path: str = "", +) -> Path: + """Interpolates and inserts subpaths between any waypoints which are spaced too far apart. + + Args: + global_path (Path): The path to interpolate between + interval_spacing (float): The desired spacing between waypoints + pos (HelperLatLon): The current GPS location + path_spacing (list[float]): The distances between pairs of points in global_path + write (bool, optional): Whether to write the path to a new csv file, default False + file_path (str, optional): The filepath to the global path csv file, default empty + + Returns: + Path: The interpolated path + """ + + waypoints = [pos] + global_path.waypoints + + i, j = 0, 0 + while i < len(path_spacing): + if path_spacing[i] > interval_spacing: + # interpolate a new sub path between the two waypoints + pos = waypoints[j] + dest = waypoints[j + 1] + + sub_path = generate_path( + dest=dest, + interval_spacing=interval_spacing, + pos=pos, + ) + # insert sub path into path + waypoints[j + 1 : j + 1] = sub_path.waypoints[:-1] + # shift indices to account for path insertion + j += len(sub_path.waypoints) - 1 + + i += 1 + j += 1 + # remove pos from waypoints again + waypoints.pop(0) + + global_path.waypoints = waypoints + + if write: + write_to_file(file_path=file_path, global_path=global_path) + + return global_path + + +def interpolate_path( + path: Path, + pos: HelperLatLon, + interval_spacing: float, + file_path: str, + write=True, +) -> Path: + """Interpolates path to ensure the interval lengths are less than or equal to the specified + interval spacing. + + Args: + path (Path): The global path. + pos (HelperLatLon): The current position of the vehicle. + interval_spacing (float): The desired interval spacing. + file_path (str): The path to the global path csv file. + write (bool, optional): Whether or not to write the new path to a csv file. Default True. + + Returns: + Path: The interpolated path. + """ + + # obtain the actual distances between every waypoint in the path + path_spacing = calculate_interval_spacing(pos, path.waypoints) + + # check if global path is just a destination point + if len(path.waypoints) < 2: + path = generate_path( + dest=path.waypoints[0], + interval_spacing=interval_spacing, + pos=pos, + write=write, + file_path=file_path, + ) + # Check if any waypoints are too far apart + elif max(path_spacing) > interval_spacing: + path = _interpolate_path( + global_path=path, + interval_spacing=interval_spacing, + pos=pos, + path_spacing=path_spacing, + write=write, + file_path=file_path, + ) + + return path + + +def calculate_interval_spacing(pos: HelperLatLon, waypoints: list[HelperLatLon]) -> list[float]: + """Returns the distances between pairs of points in a list of latitudes and longitudes, + including pos as the first point. + + Args: + pos (HelperLatLon): The gps position of the boat + waypoints (list[HelperLatLon]): The list of waypoints + + Returns: + list[float]: The distances between pairs of points in waypoints [km] + """ + all_coords = [(pos.latitude, pos.longitude)] + [ + (waypoint.latitude, waypoint.longitude) for waypoint in waypoints + ] + + coords_array = np.array(all_coords) + + lats1, lons1 = coords_array[:-1].T + lats2, lons2 = coords_array[1:].T + + distances = GEODESIC.inv(lats1=lats1, lons1=lons1, lats2=lats2, lons2=lons2)[2] + + distances = [meters_to_km(distance) for distance in distances] + + return distances + + +def write_to_file(file_path: str, global_path: Path, tmstmp: bool = True) -> Path: + """Writes the global path to a new, timestamped csv file. + + Args + file_path (str): The filepath to the global path csv file + global_path (Path): The global path to write to file + tmstmp (bool, optional): Whether to append a timestamp to the file name, default True + + Raises: + ValueError: If file_path is not to an existing `global_paths` directory + """ + + # check if file_path is a valid file path + if not os.path.isdir(os.path.dirname(file_path)) or not str( + os.path.dirname(file_path) + ).endswith("global_paths"): + raise ValueError(f"Invalid file path: {file_path}") + + if tmstmp: + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + dst_file_path = file_path.removesuffix(".csv") + f"_{timestamp}.csv" + else: + dst_file_path = file_path + + with open(dst_file_path, "w") as file: + writer = csv.writer(file) + writer.writerow(["latitude", "longitude"]) + for waypoint in global_path.waypoints: + writer.writerow([waypoint.latitude, waypoint.longitude]) + + +def path_to_dict(path: Path, num_decimals: int = 4) -> dict[int, str]: + """Converts a Path msg to a dictionary suitable for printing. + + Args: + path (Path): The Path msg to be converted. + num_decimals (int, optional): The number of decimal places to round to, default 4. + + Returns: + dict[int, str]: Keys are the indices of the formatted latlon waypoints. + """ + return { + i: f"({waypoint.latitude:.{num_decimals}f}, {waypoint.longitude:.{num_decimals}f})" + for i, waypoint in enumerate(path.waypoints) + } + + +if __name__ == "__main__": + main() diff --git a/local_pathfinding/local_path.py b/local_pathfinding/local_path.py index a4b39a34..75551d21 100644 --- a/local_pathfinding/local_path.py +++ b/local_pathfinding/local_path.py @@ -22,6 +22,7 @@ class LocalPathState: navigating along. `wind_speed` (float): Wind speed. `wind_direction` (int): Wind direction. + `planner` (str): Planner to use for the OMPL query. """ def __init__( @@ -30,6 +31,7 @@ def __init__( ais_ships: AISShips, global_path: Path, filtered_wind_sensor: WindSensor, + planner: str, ): """Initializes the state from ROS msgs.""" if gps: # TODO: remove when mock can be run @@ -50,6 +52,8 @@ def __init__( self.wind_speed = filtered_wind_sensor.speed.speed self.wind_direction = filtered_wind_sensor.direction + self.planner = planner + class LocalPath: """Sets and updates the OMPL path and the local waypoints @@ -62,6 +66,7 @@ class LocalPath: """ def __init__(self, parent_logger: RcutilsLogger): + """Initializes the LocalPath class.""" self._logger = parent_logger.get_child(name="local_path") self._ompl_path: Optional[OMPLPath] = None self.waypoints: Optional[List[Tuple[float, float]]] = None @@ -72,6 +77,7 @@ def update_if_needed( ais_ships: AISShips, global_path: Path, filtered_wind_sensor: WindSensor, + planner: str, ): """Updates the OMPL path and waypoints. The path is updated if a new path is found. @@ -81,8 +87,12 @@ def update_if_needed( `global_path` (Path): Path to the destination. `filtered_wind_sensor` (WindSensor): Wind data. """ - state = LocalPathState(gps, ais_ships, global_path, filtered_wind_sensor) - ompl_path = OMPLPath(parent_logger=self._logger, max_runtime=1.0, local_path_state=state) + state = LocalPathState(gps, ais_ships, global_path, filtered_wind_sensor, planner) + ompl_path = OMPLPath( + parent_logger=self._logger, + max_runtime=1.0, + local_path_state=state, + ) if ompl_path.solved: self._logger.info("Updating local path") self._update(ompl_path) diff --git a/local_pathfinding/node_mock_global_path.py b/local_pathfinding/node_mock_global_path.py index 2fa79280..abb8a869 100644 --- a/local_pathfinding/node_mock_global_path.py +++ b/local_pathfinding/node_mock_global_path.py @@ -1,16 +1,26 @@ -"""Node that publishes the mock global path, represented by the `MockGlobalPath` class.""" +"""Node loads in Sailbot's position via GET request, loads a global path from a csv file, +and posts the mock global path via a POST request. +The node is represented by the `MockGlobalPath` class.""" -import csv import os import time -from datetime import datetime -import numpy as np import rclpy -from custom_interfaces.msg import GPS, HelperLatLon, Path +from custom_interfaces.msg import GPS, HelperLatLon from rclpy.node import Node from local_pathfinding.coord_systems import GEODESIC, meters_to_km +from local_pathfinding.global_path import ( + GPS_URL, + PATH_URL, + _interpolate_path, + calculate_interval_spacing, + generate_path, + get_path, + get_pos, + path_to_dict, + post_path, +) # Mock gps data to get things running until we have a running gps node # TODO Remove when NET publishes GPS @@ -27,188 +37,6 @@ def main(args=None): rclpy.shutdown() -def generate_path( - dest: HelperLatLon, - interval_spacing: float, - pos: HelperLatLon, - write: bool = False, - file_path: str = "", -) -> Path: - """Returns a path from the current GPS location to the destination point. - Waypoints are evenly spaced along the path according to the interval_spacing parameter. - Path does not include pos, but does include dest as the final element. - - If write is True, the path is written to a new csv file in the same directory as file_path, - with the name of the original file, appended with a timestamp. - - Args: - dest (Union[HelperLatLon, list[HelperLatLon]]): The destination point or partial path - interval_spacing (float): The desired distance between waypoints on the path - pos (HelperLatLon): The current GPS location - write (bool, optional): Whether to write the path to a new csv file, default False - file_path (str, optional): The filepath to the global path csv file, default empty - - Returns: - Path: The generated path - """ - global_path = Path() - - lat1 = pos.latitude - lon1 = pos.longitude - - lat2 = dest.latitude - lon2 = dest.longitude - - distance = meters_to_km(GEODESIC.inv(lats1=lat1, lons1=lon1, lats2=lat2, lons2=lon2)[2]) - - # minimum number of waypoints to not exceed interval_spacing - n = np.floor(distance / interval_spacing) - n = max(1, n) - - # npts returns a path with neither pos nor dest included - global_path_tuples = GEODESIC.npts(lon1=lon1, lat1=lat1, lon2=lon2, lat2=lat2, npts=n) - - # npts returns (lon,lat) tuples, its backwards for some reason - for lon, lat in global_path_tuples: - global_path.waypoints.append(HelperLatLon(latitude=lat, longitude=lon)) - - # append the destination point - global_path.waypoints.append(HelperLatLon(latitude=lat2, longitude=lon2)) - - if write: - write_to_file(file_path=file_path, global_path=global_path) - - return global_path - - -def interpolate_path( - global_path: Path, - interval_spacing: float, - pos: HelperLatLon, - path_spacing: list[float], - write: bool = False, - file_path: str = "", -) -> Path: - """Interpolates and inserts subpaths between any waypoints which are spaced too far apart. - - Args: - global_path (Path): The path to interpolate between - interval_spacing (float): The desired spacing between waypoints - pos (HelperLatLon): The current GPS location - path_spacing (list[float]): The distances between pairs of points in global_path - write (bool, optional): Whether to write the path to a new csv file, default False - file_path (str, optional): The filepath to the global path csv file, default empty - - Returns: - Path: The interpolated path - """ - - waypoints = [pos] + global_path.waypoints - - i, j = 0, 0 - while i < len(path_spacing): - if path_spacing[i] > interval_spacing: - # interpolate a new sub path between the two waypoints - pos = waypoints[j] - dest = waypoints[j + 1] - - sub_path = generate_path( - dest=dest, - interval_spacing=interval_spacing, - pos=pos, - ) - # insert sub path into path - waypoints[j + 1 : j + 1] = sub_path.waypoints[:-1] - # shift indices to account for path insertion - j += len(sub_path.waypoints) - 1 - - i += 1 - j += 1 - # remove pos from waypoints again - waypoints.pop(0) - - global_path.waypoints = waypoints - - if write: - write_to_file(file_path=file_path, global_path=global_path) - - return global_path - - -def calculate_interval_spacing(pos: HelperLatLon, waypoints: list[HelperLatLon]) -> list[float]: - """Returns the distances between pairs of points in a list of latitudes and longitudes, - including pos as the first point. - - Args: - pos (HelperLatLon): The gps position of the boat - waypoints (list[HelperLatLon]): The list of waypoints - - Returns: - list[float]: The distances between pairs of points in waypoints [km] - """ - all_coords = [(pos.latitude, pos.longitude)] + [ - (waypoint.latitude, waypoint.longitude) for waypoint in waypoints - ] - - coords_array = np.array(all_coords) - - lats1, lons1 = coords_array[:-1].T - lats2, lons2 = coords_array[1:].T - - distances = GEODESIC.inv(lats1=lats1, lons1=lons1, lats2=lats2, lons2=lons2)[2] - - distances = [meters_to_km(distance) for distance in distances] - - return distances - - -def write_to_file(file_path: str, global_path: Path, tmstmp: bool = True) -> Path: - """Writes the global path to a new, timestamped csv file. - - Args - file_path (str): The filepath to the global path csv file - global_path (Path): The global path to write to file - tmstmp (bool, optional): Whether to append a timestamp to the file name, default True - - Raises: - ValueError: If file_path is not to an existing `global_paths` directory - """ - - # check if file_path is a valid file path - if not os.path.isdir(os.path.dirname(file_path)) or not str( - os.path.dirname(file_path) - ).endswith("global_paths"): - raise ValueError(f"Invalid file path: {file_path}") - - if tmstmp: - timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - dst_file_path = file_path.removesuffix(".csv") + f"_{timestamp}.csv" - else: - dst_file_path = file_path - - with open(dst_file_path, "w") as file: - writer = csv.writer(file) - writer.writerow(["latitude", "longitude"]) - for waypoint in global_path.waypoints: - writer.writerow([waypoint.latitude, waypoint.longitude]) - - -def path_to_dict(path: Path, num_decimals: int = 4) -> dict[int, str]: - """Converts a Path msg to a dictionary suitable for printing. - - Args: - path (Path): The Path msg to be converted. - num_decimals (int, optional): The number of decimal places to round to, default 4. - - Returns: - dict[int, str]: Keys are the indices of the formatted latlon waypoints. - """ - return { - i: f"({waypoint.latitude:.{num_decimals}f}, {waypoint.longitude:.{num_decimals}f})" - for i, waypoint in enumerate(path.waypoints) - } - - class MockGlobalPath(Node): """Stores and publishes the mock global path to the global_path topic. @@ -245,67 +73,63 @@ def __init__(self): ("force", rclpy.Parameter.Type.BOOL), ], ) - - # Subscribers - self.gps_sub = self.create_subscription( - msg_type=GPS, topic="gps", callback=self.gps_callback, qos_profile=10 - ) - - # Publishers - self.global_path_pub = self.create_publisher( - msg_type=Path, topic="global_path", qos_profile=10 - ) - - # Path callback timer + # get the publishing period parameter to use for callbacks pub_period_sec = self.get_parameter("pub_period_sec").get_parameter_value().double_value self.get_logger().debug(f"Got parameter: {pub_period_sec=}") + # mock global path callback runs repeatedly on a timer self.global_path_timer = self.create_timer( timer_period_sec=pub_period_sec, callback=self.global_path_callback, ) - # Attributes - self.gps = MOCK_GPS # TODO Remove when NET publishes GPS + self.pos = MOCK_GPS.lat_lon self.path_mod_tmstmp = None self.file_path = None + self.period = pub_period_sec - # Subscriber callbacks - def gps_callback(self, msg: GPS): - """Store the gps data and check if the global path needs to be updated. + def check_pos(self): + """Get the gps data and check if the global path needs to be updated. If the position has changed by more than gps_threshold * interval_spacing since last step, - the global_path_callback is run with the force parameter set to true, bypassing any checks. + the force parameter set to true, bypassing any checks in the global_path_callback. """ - self.get_logger().debug(f"Received data from {self.gps_sub.topic}: {msg}") + self.get_logger().info( + f"Retreiving current position from {GPS_URL}", throttle_duration_sec=1 + ) + + pos = get_pos() + if pos is None: + return # error is logged in calling function position_delta = meters_to_km( GEODESIC.inv( - lats1=self.gps.lat_lon.latitude, - lons1=self.gps.lat_lon.longitude, - lats2=msg.lat_lon.latitude, - lons2=msg.lat_lon.longitude, + lats1=self.pos.latitude, + lons1=self.pos.longitude, + lats2=pos.latitude, + lons2=pos.longitude, )[2] ) gps_threshold = self.get_parameter("gps_threshold")._value interval_spacing = self.get_parameter("interval_spacing")._value + if position_delta > gps_threshold * interval_spacing: self.get_logger().info( - f"GPS data changed by more than {gps_threshold*interval_spacing} km. Running ", - "global path callback", + f"GPS data changed by more than {gps_threshold*interval_spacing} km. Running " + "global path callback" ) self.set_parameters([rclpy.Parameter("force", rclpy.Parameter.Type.BOOL, True)]) - self.global_path_callback() - self.gps = msg + self.pos = pos # Timer callbacks def global_path_callback(self): """Check if the global path csv file has changed. If it has, the new path is published. - This function is also called by the gps callback if the gps data has changed by more than - gps_threshold. + This function also checks if the gps data has changed by more than + gps_threshold. If it has, the force parameter is set to true, bypassing any checks and + updating the path. Depending on the boolean value of the write parameter, each generated path may be written to a new csv file in the same directory as the source csv file. @@ -314,14 +138,18 @@ def global_path_callback(self): global_path_filepath parameter. """ - if not self._all_subs_active(): - self._log_inactive_subs_warning() file_path = self.get_parameter("global_path_filepath")._value # check when global path was changed last path_mod_tmstmp = time.ctime(os.path.getmtime(file_path)) + self.check_pos() + + if self.pos is None: + self.log_no_pos() + return + # check if the global path has been forced to update by a parameter change force = self.get_parameter("force")._value @@ -329,87 +157,77 @@ def global_path_callback(self): if path_mod_tmstmp == self.path_mod_tmstmp and self.file_path == file_path and not force: return - else: + self.get_logger().info( + f"Global path file is: {os.path.basename(file_path)}\n Reading path" + ) + global_path = get_path(file_path=file_path) + + pos = self.pos + + # obtain the actual distances between every waypoint in the global path + path_spacing = calculate_interval_spacing(pos, global_path.waypoints) + + # obtain desired interval spacing + interval_spacing = self.get_parameter("interval_spacing")._value + + # check if global path is just a destination point + if len(global_path.waypoints) < 2: self.get_logger().info( - f"Global path file is: {os.path.basename(file_path)}\n Reading path" + f"Generating new path from {pos.latitude:.4f}, {pos.longitude:.4f} to " + f"{global_path.waypoints[0].latitude:.4f}, " + f"{global_path.waypoints[0].longitude:.4f}" ) - global_path = Path() - - with open(file_path, "r") as file: - reader = csv.reader(file) - # skip header - reader.__next__() - for row in reader: - global_path.waypoints.append( - HelperLatLon(latitude=float(row[0]), longitude=float(row[1])) - ) - - pos = self.gps.lat_lon - - # obtain the actual distances between every waypoint in the global path - path_spacing = calculate_interval_spacing(pos, global_path.waypoints) - - # obtain desired interval spacing - interval_spacing = self.get_parameter("interval_spacing")._value - - # check if global path is just a destination point - if len(global_path.waypoints) < 2: - self.get_logger().info( - f"Generating new path from {pos.latitude:.4f}, {pos.longitude:.4f} to " - f"{global_path.waypoints[0].latitude:.4f}, " - f"{global_path.waypoints[0].longitude:.4f}" - ) - - write = self.get_parameter("write")._value - if write: - self.get_logger().info("Writing generated path to new file") - - msg = generate_path( - dest=global_path.waypoints[0], - interval_spacing=interval_spacing, - pos=pos, - write=write, - file_path=file_path, - ) - # Check if any waypoints are too far apart - elif max(path_spacing) > interval_spacing: - self.get_logger().info( - f"Some waypoints in the global path exceed the maximum interval spacing of" - f" {interval_spacing} km. Interpolating between waypoints and generating path" - ) - - write = self.get_parameter("write")._value - if write: - self.get_logger().info("Writing generated path to new file") - - msg = interpolate_path( - global_path=global_path, - interval_spacing=interval_spacing, - pos=pos, - path_spacing=path_spacing, - write=write, - file_path=file_path, - ) - - else: - msg = global_path - - # publish global path - self.global_path_pub.publish(msg) + write = self.get_parameter("write")._value + if write: + self.get_logger().info("Writing generated path to new file") + + msg = generate_path( + dest=global_path.waypoints[0], + interval_spacing=interval_spacing, + pos=pos, + write=write, + file_path=file_path, + ) + # Check if any waypoints are too far apart + elif max(path_spacing) > interval_spacing: self.get_logger().info( - f"Publishing to {self.global_path_pub.topic}: {path_to_dict(msg)}" + f"Some waypoints in the global path exceed the maximum interval spacing of" + f" {interval_spacing} km. Interpolating between waypoints and generating path" + ) + + write = self.get_parameter("write")._value + if write: + self.get_logger().info("Writing generated path to new file") + + msg = _interpolate_path( + global_path=global_path, + interval_spacing=interval_spacing, + pos=pos, + path_spacing=path_spacing, + write=write, + file_path=file_path, ) + else: + msg = global_path + + # post global path + if post_path(msg): + self.get_logger().info(f"Posting path to {PATH_URL}: {path_to_dict(msg)}") self.set_parameters([rclpy.Parameter("force", rclpy.Parameter.Type.BOOL, False)]) self.path_mod_tmstmp = path_mod_tmstmp self.file_path = file_path + else: + self.log_failed_post() - def _all_subs_active(self) -> bool: - return self.gps is not None + def log_no_pos(self): + self.get_logger().warn( + f"Failed to get position from {GPS_URL} will retry in {self.period} seconds." + ) - def _log_inactive_subs_warning(self): - self.get_logger().warning("Waiting for GPS to be published") + def log_failed_post(self): + self.get_logger().warn(f"Failed to post path to {PATH_URL}") if __name__ == "__main__": diff --git a/local_pathfinding/node_navigate.py b/local_pathfinding/node_navigate.py index 42489371..bbf6d749 100644 --- a/local_pathfinding/node_navigate.py +++ b/local_pathfinding/node_navigate.py @@ -1,9 +1,9 @@ """The main node of the local_pathfinding package, represented by the `Sailbot` class.""" -import custom_interfaces.msg as ci import rclpy from rclpy.node import Node +import custom_interfaces.msg as ci from local_pathfinding.local_path import LocalPath @@ -31,6 +31,7 @@ class Sailbot(Node): lpath_data_pub (Publisher): Publish the local path in a `LPathData` msg. Publisher timers: + pub_period_sec (float): The period of the publisher timers. desired_heading_timer (Timer): Call the desired heading callback function. lpath_data_timer (Timer): Call the local path callback function. @@ -42,6 +43,7 @@ class Sailbot(Node): Attributes: local_path (LocalPath): The path that `Sailbot` is following. + planner (str): The path planner that `Sailbot` is using. """ def __init__(self): @@ -51,6 +53,7 @@ def __init__(self): namespace="", parameters=[ ("pub_period_sec", rclpy.Parameter.Type.DOUBLE), + ("path_planner", rclpy.Parameter.Type.STRING), ], ) @@ -86,13 +89,15 @@ def __init__(self): ) # publisher timers - pub_period_sec = self.get_parameter("pub_period_sec").get_parameter_value().double_value - self.get_logger().debug(f"Got parameter: {pub_period_sec=}") + self.pub_period_sec = ( + self.get_parameter("pub_period_sec").get_parameter_value().double_value + ) + self.get_logger().debug(f"Got parameter: {self.pub_period_sec=}") self.desired_heading_timer = self.create_timer( - timer_period_sec=pub_period_sec, callback=self.desired_heading_callback + timer_period_sec=self.pub_period_sec, callback=self.desired_heading_callback ) self.lpath_data_timer = self.create_timer( - timer_period_sec=pub_period_sec, callback=self.lpath_data_callback + timer_period_sec=self.pub_period_sec, callback=self.lpath_data_callback ) # attributes from subscribers @@ -103,6 +108,8 @@ def __init__(self): # attributes self.local_path = LocalPath(parent_logger=self.get_logger()) + self.planner = self.get_parameter("path_planner").get_parameter_value().string_value + self.get_logger().debug(f"Got parameter: {self.planner=}") # subscriber callbacks @@ -129,6 +136,8 @@ def desired_heading_callback(self): Warn if not following the heading conventions in custom_interfaces/msg/HelperHeading.msg. """ + self.update_params() + desired_heading = self.get_desired_heading() if desired_heading < 0 or 360 <= desired_heading: self.get_logger().warning(f"Heading {desired_heading} not in [0, 360)") @@ -163,19 +172,52 @@ def get_desired_heading(self) -> float: return -1.0 self.local_path.update_if_needed( - self.gps, self.ais_ships, self.global_path, self.filtered_wind_sensor + self.gps, self.ais_ships, self.global_path, self.filtered_wind_sensor, self.planner ) # TODO: create function to compute the heading from current position to next local waypoint return 0.0 + def update_params(self): + """Update instance variables that depend on parameters if they have changed.""" + pub_period_sec = self.get_parameter("pub_period_sec").get_parameter_value().double_value + if pub_period_sec != self.pub_period_sec: + self.get_logger().debug( + f"Updating pub period and timers from {self.pub_period_sec} to {pub_period_sec}" + ) + self.pub_period_sec = pub_period_sec + self.desired_heading_timer = self.create_timer( + timer_period_sec=self.pub_period_sec, callback=self.desired_heading_callback + ) + self.lpath_data_timer = self.create_timer( + timer_period_sec=self.pub_period_sec, callback=self.lpath_data_callback + ) + + planner = self.get_parameter("path_planner").get_parameter_value().string_value + if planner != self.planner: + self.get_logger().debug(f"Updating planner from {self.planner} to {planner}") + self.planner = planner + def _all_subs_active(self) -> bool: return True # TODO: this line is a placeholder, delete when mocks can be run return self.ais_ships and self.gps and self.global_path and self.filtered_wind_sensor def _log_inactive_subs_warning(self): - # TODO: log which subscribers are inactive - self.get_logger().warning("There are inactive subscribers") + """ + Logs a warning message for each inactive subscriber. + """ + inactive_subs = [] + if self.ais_ships_sub is None: + inactive_subs.append("ais_ships") + if self.gps_sub is None: + inactive_subs.append("gps") + if self.global_path_sub is None: + inactive_subs.append("global_path") + if self.filtered_wind_sensor_sub is None: + inactive_subs.append("filtered_wind_sensor") + if len(inactive_subs) == 0: + return + self._logger.warning("Inactive Subscribers: " + ", ".join(inactive_subs)) if __name__ == "__main__": diff --git a/local_pathfinding/ompl_path.py b/local_pathfinding/ompl_path.py index b47f62f3..37d52a62 100644 --- a/local_pathfinding/ompl_path.py +++ b/local_pathfinding/ompl_path.py @@ -7,7 +7,7 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Tuple, Type from custom_interfaces.msg import HelperLatLon from ompl import base as ob @@ -26,7 +26,7 @@ class OMPLPathState: - def __init__(self, local_path_state: LocalPathState): + def __init__(self, local_path_state: LocalPathState, logger: RcutilsLogger): # TODO: derive OMPLPathState attributes from local_path_state self.heading_direction = 45.0 self.wind_direction = 10.0 @@ -43,6 +43,15 @@ def __init__(self, local_path_state: LocalPathState): else HelperLatLon(latitude=0.0, longitude=0.0) ) + if local_path_state: + planner = local_path_state.planner + supported_planner, _ = get_planner_class(planner) + if planner != supported_planner: + logger.error( + f"Planner {planner} is not implemented, defaulting to {supported_planner}" + ) + self.planner = supported_planner + class OMPLPath: """Represents the general OMPL Path. @@ -67,7 +76,7 @@ def __init__( local_path_state (LocalPathState): State of Sailbot. """ self._logger = parent_logger.get_child(name="ompl_path") - self.state = OMPLPathState(local_path_state) + self.state = OMPLPathState(local_path_state, self._logger) self._simple_setup = self._init_simple_setup() self.solved = self._simple_setup.solve(time=max_runtime) # time is in seconds @@ -177,8 +186,8 @@ def _init_simple_setup(self) -> og.SimpleSetup: simple_setup.setOptimizationObjective(objective) # set the planner of the simple setup object - # TODO: implement and add planner here - planner = og.RRTstar(space_information) + _, planner_class = get_planner_class(self.state.planner) + planner = planner_class(space_information) simple_setup.setPlanner(planner) return simple_setup @@ -196,3 +205,44 @@ def is_state_valid(state: ob.SE2StateSpace) -> bool: # TODO: implement obstacle avoidance here # note: `state` is of type `SE2StateInternal`, so we don't need to use the `()` operator. return state.getX() < 0.6 + + +def get_planner_class(planner: str) -> Tuple[str, Type[ob.Planner]]: + """Choose the planner to use for the OMPL query. + + Args: + planner (str): Name of the planner to use. + + Returns: + Tuple[str, Type[ob.Planner]]: The name and class of the planner to use for the OMPL query, + defaults to RRT* if `planner` is not implemented in this function. + """ + match planner.lower(): + case "bitstar": + return planner, og.BITstar + case "bfmtstar": + return planner, og.BFMT + case "fmtstar": + return planner, og.FMT + case "informedrrtstar": + return planner, og.InformedRRTstar + case "lazylbtrrt": + return planner, og.LazyLBTRRT + case "lazyprmstar": + return planner, og.LazyPRMstar + case "lbtrrt": + return planner, og.LBTRRT + case "prmstar": + return planner, og.PRMstar + case "rrtconnect": + return planner, og.RRTConnect + case "rrtsharp": + return planner, og.RRTsharp + case "rrtstar": + return planner, og.RRTstar + case "rrtxstatic": + return planner, og.RRTXstatic + case "sorrtstar": + return planner, og.SORRTstar + case _: + return "rrtstar", og.RRTstar diff --git a/package.xml b/package.xml index 57744635..955fcff3 100644 --- a/package.xml +++ b/package.xml @@ -7,9 +7,19 @@ Patrick Creighton MIT - rclpy + custom_interfaces + rclpy + + + python3-numpy + python3-pyproj + python3-shapely + + ament_copyright + ament_flake8 + ament_pep257 python3-pytest diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..a4bbb4e8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +# packages that aren't required on the main computer in production +# install them with pip3 install -r requirements.txt + +# global_paths/path_builder/path_builder.py +flask + +# test/test_obstacles.py +plotly diff --git a/test/post_server.py b/test/post_server.py new file mode 100644 index 00000000..7681b63e --- /dev/null +++ b/test/post_server.py @@ -0,0 +1,55 @@ +""" +This is a basic http server to handle POST requests from the global path module until the NET +endpoint is implemented. + +It receives a JSON payload with a list of waypoints and prints them to the console. +""" +import json +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer + + +class CustomRequestHandler(BaseHTTPRequestHandler): + def _set_response(self, status_code=200, content_type="application/json"): + self.send_response(status_code) + self.send_header("Content-type", content_type) + self.end_headers() + + def do_POST(self): + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length) + data = json.loads(post_data.decode("utf-8")) + + # Process the data as needed + waypoints = data.get("waypoints", []) + + # For now, just print the waypoints + print("Received waypoints:", waypoints) + + self._set_response(200) + self.wfile.write( + json.dumps({"message": "Global path received successfully"}).encode("utf-8") + ) + + +def run_server(port=8081) -> HTTPServer: + server_address = ("localhost", port) + httpd = HTTPServer(server_address, CustomRequestHandler) + + def run(): + print(f"Server running on http://localhost:{port}") + httpd.serve_forever() + + # Start the server in a separate thread + server_thread = threading.Thread(target=run) + server_thread.start() + + return httpd + + +def shutdown_server(httpd: HTTPServer): + httpd.shutdown() + + +if __name__ == "__main__": + run_server() diff --git a/test/test_node_mock_global_path.py b/test/test_global_path.py similarity index 59% rename from test/test_node_mock_global_path.py rename to test/test_global_path.py index 4e78aca7..ceca31e7 100644 --- a/test/test_node_mock_global_path.py +++ b/test/test_global_path.py @@ -1,31 +1,25 @@ +import os + +import post_server import pytest from custom_interfaces.msg import HelperLatLon, Path from local_pathfinding.coord_systems import GEODESIC, meters_to_km -from local_pathfinding.node_mock_global_path import ( +from local_pathfinding.global_path import ( + _interpolate_path, calculate_interval_spacing, generate_path, + get_most_recent_file, + get_path, + get_pos, interpolate_path, path_to_dict, + post_path, write_to_file, ) -# ------------------------- TEST WRITE_TO_FILE ------------------------------ -@pytest.mark.parametrize( - "file_path", - [ - ("/workspaces/sailbot_workspace/src/local_pathfinding/anywhere_else/mock_global_path.csv"), - (""), - ("/workspaces/sailbot_workspace/src/local_pathfinding/ global_paths/mock_global_path.csv"), - ], -) -def test_write_to_file(file_path: str): - with pytest.raises(ValueError): - write_to_file(file_path=file_path, global_path=None) - - -# ------------------------- TEST INTERPOLATE_PATH ------------------------- +# ------------------------- TEST _INTERPOLATE_PATH ------------------------- @pytest.mark.parametrize( "pos,global_path,interval_spacing", [ @@ -42,12 +36,12 @@ def test_write_to_file(file_path: str): ) ], ) -def test_interpolate_path( +def test__interpolate_path( pos: HelperLatLon, global_path: Path, interval_spacing: float, ): - """Test the interpolate_path method of MockGlobalPath. + """Test the _interpolate_path method of MockGlobalPath. Args: global_path (HelperLatLon): The global path. @@ -57,7 +51,7 @@ def test_interpolate_path( path_spacing = calculate_interval_spacing(pos, global_path.waypoints) - interpolated_path = interpolate_path( + interpolated_path = _interpolate_path( global_path=global_path, interval_spacing=interval_spacing, pos=pos, @@ -201,6 +195,107 @@ def test_generate_path( assert dist <= interval_spacing, "Interval spacing is not correct" +# ------------------------- TEST GET_MOST_RECENT_FILE ------------------------- +@pytest.mark.parametrize( + "file_path,global_path,tmstmp", + [ + ( + "/workspaces/sailbot_workspace/src/local_pathfinding/global_paths/test_file.csv", + Path(), + False, + ) + ], +) +def test_get_most_recent_file(file_path: str, global_path: Path, tmstmp: bool): + # create a file in the directory + write_to_file(file_path=file_path, global_path=global_path, tmstmp=tmstmp) + + assert ( + get_most_recent_file(directory_path=file_path[: -len(file_path.split("/")[-1])]) + == file_path + ), "Did not get most recent file" + + os.remove(file_path) + + +# ------------------------- TEST GET_PATH ------------------------- +@pytest.mark.parametrize( + "file_path", + [("/workspaces/sailbot_workspace/src/local_pathfinding/global_paths/mock_global_path.csv")], +) +def test_get_path(file_path: str): + """ " + Args: + file_path (str): The path to the global path csv file. + """ + global_path = get_path(file_path) + + assert isinstance(global_path, Path) + + # Check that the path is formatted correctly + for waypoint in global_path.waypoints: + assert isinstance(waypoint, HelperLatLon), "Waypoint is not a HelperLatLon" + assert isinstance(waypoint.latitude, float), "Waypoint latitude is not a float" + assert isinstance(waypoint.longitude, float), "Waypoint longitude is not a float" + + +# ------------------------- TEST GET_POS ------------------------- +@pytest.mark.parametrize( + "pos", [HelperLatLon(latitude=49.34175775635472, longitude=-123.35453636335373)] +) +def test_get_pos(pos: HelperLatLon): + """ + Args: + pos (HelperLatLon): The position of the Sailbot. + """ + + pos = get_pos() + assert pos is not None, "No position data received" + assert pos.latitude is not None, "No latitude" + assert pos.longitude is not None, "No longitude" + + +# ------------------------- TEST INTERPOLATE_PATH ------------------------- +@pytest.mark.parametrize( + "path,pos,interval_spacing", + [ + ( + Path( + waypoints=[ + HelperLatLon(latitude=48.95, longitude=123.56), + HelperLatLon(latitude=38.95, longitude=133.36), + HelperLatLon(latitude=28.95, longitude=143.36), + ] + ), + HelperLatLon(latitude=58.95, longitude=113.56), + 50.0, + ) + ], +) +def test_interpolate_path(path: Path, pos: HelperLatLon, interval_spacing: float): + """ + Args: + path (Path): The global path. + pos (HelperLatLon): The position of the Sailbot. + interval_spacing (float): The spacing between each waypoint. + """ + formatted_path = interpolate_path( + path=path, pos=pos, interval_spacing=interval_spacing, file_path="", write=False + ) + + assert isinstance(formatted_path, Path), "Formatted path is not a Path" + + # Check that the path is formatted correctly + for waypoint in formatted_path.waypoints: + assert isinstance(waypoint, HelperLatLon), "Waypoint is not a HelperLatLon" + assert isinstance(waypoint.latitude, float), "Waypoint latitude is not a float" + assert isinstance(waypoint.longitude, float), "Waypoint longitude is not a float" + + path_spacing = calculate_interval_spacing(pos, formatted_path.waypoints) + assert max(path_spacing) <= interval_spacing, "Path spacing is too large" + assert max(path_spacing) <= interval_spacing, "Path spacing is too large" + + # ------------------------- TEST PATH_TO_DICT ------------------------- @pytest.mark.parametrize( "path,expected", @@ -219,3 +314,46 @@ def test_generate_path( def test_path_to_dict(path: Path, expected: dict[int, str]): path_dict = path_to_dict(path) assert path_dict == expected, "Did not correctly convert path to dictionary" + + +# ------------------------- TEST POST_PATH ------------------------- +@pytest.mark.parametrize( + "global_path", + [ + ( + Path( + waypoints=[ + HelperLatLon(latitude=48.95, longitude=123.56), + HelperLatLon(latitude=38.95, longitude=133.36), + HelperLatLon(latitude=28.95, longitude=143.36), + ] + ) + ) + ], +) +def test_post_path(global_path: Path): + """ + Args: + global_path (Path): The global path to post. + """ + + # Launch http server + server = post_server.run_server() + + assert post_path(global_path), "Failed to post global path" + + post_server.shutdown_server(httpd=server) + + +# ------------------------- TEST WRITE_TO_FILE ------------------------------ +@pytest.mark.parametrize( + "file_path", + [ + ("/workspaces/sailbot_workspace/src/local_pathfinding/anywhere_else/mock_global_path.csv"), + (""), + ("/workspaces/sailbot_workspace/src/local_pathfinding/ global_paths/mock_global_path.csv"), + ], +) +def test_write_to_file(file_path: str): + with pytest.raises(ValueError): + write_to_file(file_path=file_path, global_path=None) diff --git a/test/test_local_path.py b/test/test_local_path.py index 3cd6eb7c..001c45f3 100644 --- a/test/test_local_path.py +++ b/test/test_local_path.py @@ -12,6 +12,7 @@ def test_LocalPath_update_if_needed(): ais_ships=AISShips(), global_path=Path(), filtered_wind_sensor=WindSensor(), + planner="bitstar", ) assert PATH.waypoints is not None, "waypoints is not initialized" assert len(PATH.waypoints) > 1, "waypoints length <= 1" diff --git a/test/test_objectives.py b/test/test_objectives.py index 7f942801..f9e928d0 100644 --- a/test/test_objectives.py +++ b/test/test_objectives.py @@ -2,12 +2,13 @@ import math import pytest -from custom_interfaces.msg import HelperLatLon +from custom_interfaces.msg import GPS, AISShips, HelperLatLon, Path, WindSensor from rclpy.impl.rcutils_logger import RcutilsLogger import local_pathfinding.coord_systems as coord_systems import local_pathfinding.objectives as objectives import local_pathfinding.ompl_path as ompl_path +from local_pathfinding.local_path import LocalPathState # Upwind downwind cost multipliers UPWIND_MULTIPLIER = 3000.0 @@ -17,7 +18,13 @@ PATH = ompl_path.OMPLPath( parent_logger=RcutilsLogger(), max_runtime=1, - local_path_state=None, # type: ignore[arg-type] # None is placeholder + local_path_state=LocalPathState( + gps=GPS(), + ais_ships=AISShips(), + global_path=Path(), + filtered_wind_sensor=WindSensor(), + planner="bitstar", + ), ) diff --git a/test/test_ompl_path.py b/test/test_ompl_path.py index ac7809dc..e7662566 100644 --- a/test/test_ompl_path.py +++ b/test/test_ompl_path.py @@ -1,19 +1,27 @@ import pytest +from custom_interfaces.msg import GPS, AISShips, Path, WindSensor from ompl import base as ob from rclpy.impl.rcutils_logger import RcutilsLogger import local_pathfinding.coord_systems as cs import local_pathfinding.ompl_path as ompl_path +from local_pathfinding.local_path import LocalPathState PATH = ompl_path.OMPLPath( parent_logger=RcutilsLogger(), max_runtime=1, - local_path_state=None, # type: ignore[arg-type] # None is placeholder + local_path_state=LocalPathState( + gps=GPS(), + ais_ships=AISShips(), + global_path=Path(), + filtered_wind_sensor=WindSensor(), + planner="bitstar", + ), ) def test_OMPLPathState(): - state = ompl_path.OMPLPathState(local_path_state=None) + state = ompl_path.OMPLPathState(local_path_state=None, logger=RcutilsLogger()) assert state.state_domain == (-1, 1), "incorrect value for attribute state_domain" assert state.state_range == (-1, 1), "incorrect value for attribute start_state" assert state.start_state == pytest.approx(