diff --git a/f1tenth_gym/envs/reset/__init__.py b/f1tenth_gym/envs/reset/__init__.py index bae55b2f..c1a80191 100644 --- a/f1tenth_gym/envs/reset/__init__.py +++ b/f1tenth_gym/envs/reset/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations from .masked_reset import GridResetFn, AllTrackResetFn +from .map_reset import AllMapResetFn from .reset_fn import ResetFn from ..track import Track @@ -10,6 +11,12 @@ def make_reset_fn(type: str | None, track: Track, num_agents: int, **kwargs) -> try: refline_token, reset_token, shuffle_token = type.split("_") + if refline_token == "map": + reset_fn = {"random": AllMapResetFn}[reset_token] + shuffle = {"static": False, "random": True}[shuffle_token] + return reset_fn(track=track, num_agents=num_agents, shuffle=shuffle, **kwargs) + + # "cl" or "rl" refline = {"cl": track.centerline, "rl": track.raceline}[refline_token] reset_fn = {"grid": GridResetFn, "random": AllTrackResetFn}[reset_token] shuffle = {"static": False, "random": True}[shuffle_token] diff --git a/f1tenth_gym/envs/reset/map_reset.py b/f1tenth_gym/envs/reset/map_reset.py new file mode 100644 index 00000000..f92ab4de --- /dev/null +++ b/f1tenth_gym/envs/reset/map_reset.py @@ -0,0 +1,85 @@ +from abc import abstractmethod + +import cv2 +import numpy as np + +from .reset_fn import ResetFn +from .utils import sample_around_pose +from ..track import Track + + +class MapResetFn(ResetFn): + @abstractmethod + def get_mask(self) -> np.ndarray: + pass + + def __init__( + self, + track: Track, + num_agents: int, + move_laterally: bool, + min_dist: float, + max_dist: float, + ): + self.track = track + self.n_agents = num_agents + self.min_dist = min_dist + self.max_dist = max_dist + self.move_laterally = move_laterally + # Mask is a 2D array of booleans of where the agents can be placed + # Should acount for max_dist from obstacles + self.mask = self.get_mask() + + + def sample(self) -> np.ndarray: + # Random ample an x-y position from the mask + valid_x, valid_y = np.where(self.mask) + idx = np.random.choice(len(valid_x)) + pose_x = valid_x[idx] * self.track.spec.resolution + self.track.spec.origin[0] + pose_y = valid_y[idx] * self.track.spec.resolution + self.track.spec.origin[1] + pose_theta = np.random.uniform(-np.pi, np.pi) + pose = np.array([pose_x, pose_y, pose_theta]) + + poses = sample_around_pose( + pose=pose, + n_agents=self.n_agents, + min_dist=self.min_dist, + max_dist=self.max_dist, + ) + return poses + +class AllMapResetFn(MapResetFn): + def __init__( + self, + track: Track, + num_agents: int, + move_laterally: bool = True, + shuffle: bool = True, + min_dist: float = 0.5, + max_dist: float = 1.0, + ): + super().__init__( + track=track, + num_agents=num_agents, + move_laterally=move_laterally, + min_dist=min_dist, + max_dist=max_dist, + ) + self.shuffle = shuffle + + def get_mask(self) -> np.ndarray: + # Create mask from occupancy grid enlarged by max_dist + dilation_size = int(self.max_dist / self.track.spec.resolution) + kernel = np.ones((dilation_size, dilation_size), np.uint8) + inverted_occ_map = (255 - self.track.occupancy_map) + dilated = cv2.dilate(inverted_occ_map, kernel, iterations=1) + dilated_inverted = (255 - dilated) + return dilated_inverted == 255 + + def sample(self) -> np.ndarray: + poses = super().sample() + + if self.shuffle: + np.random.shuffle(poses) + + return poses diff --git a/f1tenth_gym/envs/reset/utils.py b/f1tenth_gym/envs/reset/utils.py index 74846cf8..9571d195 100644 --- a/f1tenth_gym/envs/reset/utils.py +++ b/f1tenth_gym/envs/reset/utils.py @@ -83,3 +83,36 @@ def sample_around_waypoint( ) return np.array(poses) + +def sample_around_pose( + pose: np.ndarray, + n_agents: int, + min_dist: float, + max_dist: float, +) -> np.ndarray: + """ + Compute n poses around a given pose. + It iteratively samples the next agent within a distance range from the previous one. + Note: no guarantee that the agents are on the track nor that they are not colliding with the environment. + + Args: + - pose: the initial pose + - n_agents: the number of agents + - min_dist: the minimum distance between two consecutive agents + - max_dist: the maximum distance between two consecutive agents + """ + current_pose = pose + + poses = [] + for i in range(n_agents): + x, y, theta = current_pose + pose = np.array([x, y, theta]) + poses.append(pose) + # sample next pose + dist = np.random.uniform(min_dist, max_dist) + theta = np.random.uniform(-np.pi, np.pi) + x += dist * np.cos(theta) + y += dist * np.sin(theta) + current_pose = np.array([x, y, theta]) + + return np.array(poses) \ No newline at end of file diff --git a/f1tenth_gym/envs/track/track.py b/f1tenth_gym/envs/track/track.py index f95bda3d..52d17b7a 100644 --- a/f1tenth_gym/envs/track/track.py +++ b/f1tenth_gym/envs/track/track.py @@ -164,6 +164,66 @@ def from_track_name(track: str, track_scale: float = 1.0) -> Track: print(ex) raise FileNotFoundError(f"It could not load track {track}") from ex + @staticmethod + def from_track_path(path: pathlib.Path): + """ + Load track from track path. + + Parameters + ---------- + path : pathlib.Path + path to the track yaml file + + Returns + ------- + Track + track object + + Raises + ------ + FileNotFoundError + if the track cannot be loaded + """ + try: + if type(path) is str: + path = pathlib.Path(path) + + track_spec = Track.load_spec( + track=path.stem, filespec=path + ) + + # load occupancy grid + # Image path is from path + image name from track_spec + image_path = path.parent / track_spec.image + image = Image.open(image_path).transpose(Transpose.FLIP_TOP_BOTTOM) + occupancy_map = np.array(image).astype(np.float32) + occupancy_map[occupancy_map <= 128] = 0.0 + occupancy_map[occupancy_map > 128] = 255.0 + + # if exists, load centerline + if (path / f"{path.stem}_centerline.csv").exists(): + centerline = Raceline.from_centerline_file(path / f"{path.stem}_centerline.csv") + else: + centerline = None + + # if exists, load raceline + if (path / f"{path.stem}_raceline.csv").exists(): + raceline = Raceline.from_raceline_file(path / f"{path.stem}_raceline.csv") + else: + raceline = centerline + + return Track( + spec=track_spec, + filepath=str(path.absolute()), + ext=image_path.suffix, + occupancy_map=occupancy_map, + centerline=centerline, + raceline=raceline, + ) + except Exception as ex: + print(ex) + raise FileNotFoundError(f"It could not load track {path}") from ex + @staticmethod def from_refline(x: np.ndarray, y: np.ndarray, velx: np.ndarray): """