diff --git a/README.md b/README.md index 17c572d..53d4c77 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ Currently, the dataloader supports interfacing with the following datasets: | nuScenes Train/TrainVal/Val | `nusc_trainval` | `train`, `train_val`, `val` | `boston`, `singapore` | nuScenes prediction challenge training/validation/test splits (500/200/150 scenes) | 0.5s (2Hz) | :white_check_mark: | | nuScenes Test | `nusc_test` | `test` | `boston`, `singapore` | nuScenes' test split, no annotations (150 scenes) | 0.5s (2Hz) | :white_check_mark: | | nuScenes Mini | `nusc_mini` | `mini_train`, `mini_val` | `boston`, `singapore` | nuScenes mini training/validation splits (8/2 scenes) | 0.5s (2Hz) | :white_check_mark: | +| nuPlan Mini | `nuplan_mini` | `mini_train`, `mini_val`, `mini_test` | `boston`, `singapore`, `pittsburgh`, `las_vegas` | nuPlan mini training/validation/test splits (942/197/224 scenes) | 0.05s (20Hz) | :white_check_mark: | | Lyft Level 5 Train | `lyft_train` | `train` | `palo_alto` | Lyft Level 5 training data - part 1/2 (8.4 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Train Full | `lyft_train_full` | `train` | `palo_alto` | Lyft Level 5 training data - part 2/2 (70 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Validation | `lyft_val` | `val` | `palo_alto` | Lyft Level 5 validation data (8.2 GB) | 0.1s (10Hz) | :white_check_mark: | @@ -127,6 +128,19 @@ dataset = UnifiedDataset( **Note**: Be careful about loading multiple datasets without an associated `desired_dt` argument; many datasets do not share the same underlying data annotation frequency. To address this, we've implemented timestep interpolation to a common frequency which will ensure that all batched data shares the same dt. Interpolation can only be performed to integer multiples of the original data annotation frequency. For example, nuScenes' `dt=0.5` and the ETH BIWI dataset's `dt=0.4` can be interpolated to a common `desired_dt=0.1`. +## Map API +`trajdata` also provides an API to access the raw vector map information from datasets that provide it. + +```py +from pathlib import Path +from trajdata import MapAPI, VectorMap + +cache_path = Path("~/.unified_data_cache").expanduser() +map_api = MapAPI(cache_path) + +vector_map: VectorMap = map_api.get_map("nusc_mini:boston-seaport") +``` + ## Simulation Interface One additional feature of trajdata is that it can be used to initialize simulations from real data and track resulting agent motion, metrics, etc. @@ -159,7 +173,7 @@ sim_scene = SimulationScene( ) obs: AgentBatch = sim_scene.reset() -for t in range(1, sim_scene.scene_info.length_timesteps): +for t in range(1, sim_scene.scene.length_timesteps): new_xyh_dict: Dict[str, np.ndarray] = dict() # Everything inside the forloop just sets @@ -181,4 +195,3 @@ for t in range(1, sim_scene.scene_info.length_timesteps): ## TODO - Create a method like finalize() which writes all the batch information to a TFRecord/WebDataset/some other format which is (very) fast to read from for higher epoch training. - Add more examples to the README. - diff --git a/examples/batch_example.py b/examples/batch_example.py index fa8f88b..385861b 100644 --- a/examples/batch_example.py +++ b/examples/batch_example.py @@ -20,8 +20,12 @@ def main(): only_predict=[AgentType.VEHICLE], agent_interaction_distances=defaultdict(lambda: 30.0), incl_robot_future=False, - incl_map=True, - map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, augmentations=[noise_hists], num_workers=0, verbose=True, diff --git a/examples/cache_and_filter_example.py b/examples/cache_and_filter_example.py new file mode 100644 index 0000000..64c13d1 --- /dev/null +++ b/examples/cache_and_filter_example.py @@ -0,0 +1,97 @@ +import os +from collections import defaultdict + +from torch.utils.data import DataLoader +from tqdm import tqdm + +from trajdata import AgentBatch, AgentType, UnifiedDataset +from trajdata.augmentation import NoiseHistories +from trajdata.data_structures.batch_element import AgentBatchElement +from trajdata.visualization.vis import plot_agent_batch + + +def main(): + noise_hists = NoiseHistories() + + create_dataset = lambda: UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="agent", + desired_dt=0.5, + history_sec=(2.0, 2.0), + future_sec=(4.0, 4.0), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=False, + incl_raster_map=False, + # map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, + augmentations=[noise_hists], + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + dataset = create_dataset() + + print(f"# Data Samples: {len(dataset):,}") + + print( + "To demonstrate how to use caching we will first save the " + "entire dataset (all BatchElements) to a cache file and then load from " + "the cache file. Note that for large datasets and/or high time resolution " + "this will create a large file and will use a lot of RAM." + ) + cache_path = "./temp_cache_file.dill" + + print( + "We also use a custom filter function that only keeps elements with more " + "than 5 neighbors" + ) + + def my_filter(el: AgentBatchElement) -> bool: + return el.num_neighbors > 5 + + print( + f"In the first run we will iterate through the entire dataset and save all " + f"BatchElements to the cache file {cache_path}" + ) + print("This may take several minutes.") + dataset.load_or_create_cache( + cache_path=cache_path, num_workers=0, filter_fn=my_filter + ) + assert os.path.isfile(cache_path) + + print( + "To demonstrate a consecuitve run we create a new dataset and load elements " + "from the cache file." + ) + del dataset + dataset = create_dataset() + + dataset.load_or_create_cache( + cache_path=cache_path, num_workers=0, filter_fn=my_filter + ) + + # Remove the temp cache file, we dont need it anymore. + os.remove(cache_path) + + print( + "We can iterate through the dataset the same way as normally, but this " + "time it will be much faster because all BatchElements are in memory." + ) + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=0, + ) + + batch: AgentBatch + for batch in tqdm(dataloader): + plot_agent_batch(batch, batch_idx=0) + + +if __name__ == "__main__": + main() diff --git a/examples/custom_batch_data.py b/examples/custom_batch_data.py index 24828eb..7559dbd 100644 --- a/examples/custom_batch_data.py +++ b/examples/custom_batch_data.py @@ -11,9 +11,7 @@ from tqdm import tqdm from trajdata import AgentBatch, AgentType, UnifiedDataset -from trajdata.augmentation import NoiseHistories from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement -from trajdata.visualization.vis import plot_agent_batch def custom_random_data( @@ -74,8 +72,12 @@ def main(): only_types=[AgentType.VEHICLE], agent_interaction_distances=defaultdict(lambda: 30.0), incl_robot_future=False, - incl_map=True, - map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, num_workers=0, verbose=True, data_dirs={ # Remember to change this to match your filesystem! diff --git a/examples/lane_query_example.py b/examples/lane_query_example.py new file mode 100644 index 0000000..2ecc56e --- /dev/null +++ b/examples/lane_query_example.py @@ -0,0 +1,124 @@ +""" +This is an example of how to extend a batch with lane information +""" + +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +from torch.utils.data import DataLoader +from tqdm import tqdm + +from trajdata import AgentBatch, AgentType, UnifiedDataset +from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement +from trajdata.maps import VectorMap +from trajdata.maps.vec_map_elements import RoadLane +from trajdata.utils.arr_utils import batch_nd_transform_points_np +from trajdata.visualization.vis import plot_agent_batch + + +def get_closest_lane_point(element: AgentBatchElement) -> np.ndarray: + """Closest lane for predicted agent.""" + + # Transform from agent coordinate frame to world coordinate frame. + vector_map: VectorMap = element.vec_map + world_from_agent_tf = np.linalg.inv(element.agent_from_world_tf) + agent_future_xy_world = batch_nd_transform_points_np( + element.agent_future_np[:, :2], world_from_agent_tf + ) + + # Use cached kdtree to find closest lane point + lane_points_world = [] + for xy_world in agent_future_xy_world: + point_4d = np.array([[xy_world[0], xy_world[1], 0.0, 0.0]]) + closest_lane: RoadLane = vector_map.get_closest_lane(point_4d.squeeze(axis=0)) + lane_points_world.append(closest_lane.center.project_onto(point_4d)) + + lane_points_world = np.concatenate(lane_points_world, axis=0) + + # Transform lane points to agent coordinate frame + lane_points = batch_nd_transform_points_np( + lane_points_world[:, :2], element.agent_from_world_tf + ) + + return lane_points + + +def main(): + dataset = UnifiedDataset( + desired_data=[ + "nusc_mini-mini_train", + "lyft_sample-mini_val", + ], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_types=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=False, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + incl_vector_map=True, + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", + }, + # A dictionary that contains functions that generate our custom data. + # Can be any function and has access to the batch element. + extras={ + "closest_lane_point": get_closest_lane_point, + }, + ) + + print(f"# Data Samples: {len(dataset):,}") + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=False, + collate_fn=dataset.get_collate_fn(), + num_workers=0, + ) + + # Visualize selected examples + num_plots = 3 + batch_idxs = [10876, 10227, 1284] + # batch_idxs = random.sample(range(len(dataset)), num_plots) + batch: AgentBatch = dataset.get_collate_fn(pad_format="right")( + [dataset[i] for i in batch_idxs] + ) + assert "closest_lane_point" in batch.extras + + for batch_i in range(num_plots): + ax = plot_agent_batch( + batch, batch_idx=batch_i, legend=False, show=False, close=False + ) + lane_points = batch.extras["closest_lane_point"][batch_i] + ax.plot( + lane_points[:, 0], + lane_points[:, 1], + "o-", + markersize=3, + label="Lane points", + ) + ax.legend(loc="best", frameon=True) + plt.show() + plt.close("all") + + # Scan through dataset + batch: AgentBatch + for idx, batch in enumerate(tqdm(dataloader)): + assert "closest_lane_point" in batch.extras + if idx > 50: + break + + +if __name__ == "__main__": + main() diff --git a/examples/map_api_example.py b/examples/map_api_example.py new file mode 100644 index 0000000..5851c42 --- /dev/null +++ b/examples/map_api_example.py @@ -0,0 +1,288 @@ +import time +from pathlib import Path +from typing import Dict, List, Optional + +import matplotlib.pyplot as plt +import numpy as np + +from trajdata import MapAPI, VectorMap +from trajdata.caching.df_cache import DataFrameCache +from trajdata.caching.env_cache import EnvCache +from trajdata.caching.scene_cache import SceneCache +from trajdata.data_structures.scene_metadata import Scene +from trajdata.maps.vec_map import Polyline, RoadLane +from trajdata.utils import map_utils + + +def load_random_scene(cache_path: Path, env_name: str, scene_dt: float) -> Scene: + env_cache = EnvCache(cache_path) + scenes_list = env_cache.load_env_scenes_list(env_name) + random_scene_name = scenes_list[np.random.randint(0, len(scenes_list))].name + + return env_cache.load_scene(env_name, random_scene_name, scene_dt) + + +def main(): + cache_path = Path("~/.unified_data_cache").expanduser() + map_api = MapAPI(cache_path) + + ### Loading random scene and initializing VectorMap. + env_name: str = np.random.choice(["nusc_mini", "lyft_sample", "nuplan_mini"]) + scene_cache: Optional[SceneCache] = None + if env_name == "nuplan_mini": + # Hardcoding scene_dt = 0.05s for now + # (using nuPlan as our traffic light data example). + random_scene: Scene = load_random_scene(cache_path, env_name, scene_dt=0.05) + scene_cache = DataFrameCache(cache_path, random_scene) + + vec_map: VectorMap = map_api.get_map( + f"{env_name}:{random_scene.location}", scene_cache=scene_cache + ) + else: + random_location: Dict[str, str] = { + "nusc_mini": np.random.choice(["boston-seaport", "singapore-onenorth"]), + "lyft_sample": "palo_alto", + } + + vec_map: VectorMap = map_api.get_map( + f"{env_name}:{random_location[env_name]}", scene_cache=scene_cache + ) + + print(f"Randomly chose {vec_map.env_name}, {vec_map.map_name} map.") + + ### Loading Lane used for the next few figures. + lane: RoadLane = vec_map.lanes[np.random.randint(0, len(vec_map.lanes))] + + ### Lane Interpolation (max_dist) + start = time.perf_counter() + interpolated: Polyline = lane.center.interpolate(max_dist=0.01) + end = time.perf_counter() + print(f"interpolate (max_dist) took {(end - start)*1000:.2f} ms") + + fig, ax = plt.subplots() + ax.scatter( + lane.center.points[:, 0], lane.center.points[:, 1], label="Original", s=80 + ) + ax.quiver( + lane.center.points[:, 0], + lane.center.points[:, 1], + np.cos(lane.center.points[:, -1]), + np.sin(lane.center.points[:, -1]), + ) + + ax.scatter( + interpolated.points[:, 0], interpolated.points[:, 1], label="Interpolated" + ) + ax.quiver( + interpolated.points[:, 0], + interpolated.points[:, 1], + np.cos(interpolated.points[:, -1]), + np.sin(interpolated.points[:, -1]), + ) + + ax.legend(loc="best") + ax.axis("equal") + + ### Lane Interpolation (num_pts) + start = time.perf_counter() + interpolated: Polyline = lane.center.interpolate(num_pts=10) + end = time.perf_counter() + print(f"interpolate (num_pts) took {(end - start)*1000:.2f} ms") + + fig, ax = plt.subplots() + ax.scatter( + lane.center.points[:, 0], lane.center.points[:, 1], label="Original", s=80 + ) + ax.quiver( + lane.center.points[:, 0], + lane.center.points[:, 1], + np.cos(lane.center.points[:, -1]), + np.sin(lane.center.points[:, -1]), + ) + + ax.scatter( + interpolated.points[:, 0], interpolated.points[:, 1], label="Interpolated" + ) + ax.quiver( + interpolated.points[:, 0], + interpolated.points[:, 1], + np.cos(interpolated.points[:, -1]), + np.sin(interpolated.points[:, -1]), + ) + + ax.legend(loc="best") + ax.axis("equal") + + ### Projection onto Lane + num_pts = 15 + orig_pts = lane.center.midpoint + np.concatenate( + [ + np.random.uniform(-3, 3, size=(num_pts, 2)), # x,y offsets + np.zeros(shape=(num_pts, 1)), # no z offsets + np.random.uniform(-np.pi, np.pi, size=(num_pts, 1)), # headings + ], + axis=-1, + ) + start = time.perf_counter() + proj_pts = lane.center.project_onto(orig_pts) + end = time.perf_counter() + print(f"project_onto ({num_pts} points) took {(end - start)*1000:.2f} ms") + + fig, ax = plt.subplots() + ax.plot(lane.center.points[:, 0], lane.center.points[:, 1], label="Lane") + ax.scatter(orig_pts[:, 0], orig_pts[:, 1], label="Original") + ax.quiver( + orig_pts[:, 0], + orig_pts[:, 1], + np.cos(orig_pts[:, -1]), + np.sin(orig_pts[:, -1]), + ) + + ax.scatter(proj_pts[:, 0], proj_pts[:, 1], label="Projected") + ax.quiver( + proj_pts[:, 0], + proj_pts[:, 1], + np.cos(proj_pts[:, -1]), + np.sin(proj_pts[:, -1]), + ) + + ax.legend(loc="best") + ax.axis("equal") + + ### Lane Graph Visualization (with rasterized map in background) + fig, ax = plt.subplots() + map_img, raster_from_world = vec_map.rasterize( + resolution=2, + return_tf_mat=True, + incl_centerlines=False, + area_color=(255, 255, 255), + edge_color=(0, 0, 0), + scene_ts=100, + ) + ax.imshow(map_img, alpha=0.5, origin="lower") + vec_map.visualize_lane_graph( + origin_lane=np.random.randint(0, len(vec_map.lanes)), + num_hops=10, + raster_from_world=raster_from_world, + ax=ax, + ) + ax.axis("equal") + ax.grid(None) + + ### Closest Lane Query (with rasterized map in background) + # vec_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + min_x, min_y, _, max_x, max_y, _ = vec_map.extent + + # Adding some heading to the query point. + mean_pt_heading: float = np.random.uniform(-np.pi, np.pi) + + mean_pt: np.ndarray = np.array( + [ + np.random.uniform(min_x, max_x), + np.random.uniform(min_y, max_y), + 0, + mean_pt_heading, + ] + ) + + start = time.perf_counter() + lane: RoadLane = vec_map.get_closest_lane(mean_pt) + end = time.perf_counter() + print(f"get_closest_lane took {(end - start)*1000:.2f} ms") + + fig, ax = plt.subplots() + map_img, raster_from_world = vec_map.rasterize( + resolution=2, + return_tf_mat=True, + incl_centerlines=False, + area_color=(255, 255, 255), + edge_color=(0, 0, 0), + ) + ax.imshow(map_img, alpha=0.5, origin="lower") + query_pt_map: np.ndarray = map_utils.transform_points( + mean_pt[None, :2], raster_from_world + )[0] + ax.scatter(*query_pt_map, label="Query Point") + ax.quiver( + [query_pt_map[0]], + [query_pt_map[1]], + [np.cos(mean_pt_heading)], + [np.sin(mean_pt_heading)], + ) + vec_map.visualize_lane_graph( + origin_lane=lane, num_hops=0, raster_from_world=raster_from_world, ax=ax + ) + ax.axis("equal") + ax.grid(None) + + ### Lanes Within Range Query (with rasterized map in background) + radius: float = 20.0 + + # vec_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + min_x, min_y, _, max_x, max_y, _ = vec_map.extent + + # Adding some heading to the query point. + mean_pt_heading: float = np.random.uniform(-np.pi, np.pi) + + mean_pt: np.ndarray = np.array( + [ + np.random.uniform(min_x, max_x), + np.random.uniform(min_y, max_y), + 0, + mean_pt_heading, + ] + ) + + start = time.perf_counter() + lanes: List[RoadLane] = vec_map.get_lanes_within(mean_pt, radius) + end = time.perf_counter() + print(f"get_lanes_within took {(end - start)*1000:.2f} ms") + + fig, ax = plt.subplots() + img_resolution: float = 2 + map_img, raster_from_world = vec_map.rasterize( + resolution=img_resolution, + return_tf_mat=True, + incl_centerlines=False, + area_color=(255, 255, 255), + edge_color=(0, 0, 0), + ) + ax.imshow(map_img, alpha=0.5, origin="lower") + + query_pt_map: np.ndarray = map_utils.transform_points( + mean_pt[None, :2], raster_from_world + )[0] + ax.scatter(*query_pt_map, label="Query Point") + ax.quiver( + [query_pt_map[0]], + [query_pt_map[1]], + [np.cos(mean_pt_heading)], + [np.sin(mean_pt_heading)], + ) + circle2 = plt.Circle( + (query_pt_map[0], query_pt_map[1]), + radius * img_resolution, + color="b", + fill=False, + ) + ax.add_patch(circle2) + + for l in lanes: + vec_map.visualize_lane_graph( + origin_lane=l, + num_hops=0, + raster_from_world=raster_from_world, + ax=ax, + legend=False, + ) + + ax.axis("equal") + ax.grid(None) + ax.legend(loc="best", frameon=True) + + plt.show() + plt.close("all") + + +if __name__ == "__main__": + main() diff --git a/examples/preprocess_data.py b/examples/preprocess_data.py index fc03ef5..e6faf03 100644 --- a/examples/preprocess_data.py +++ b/examples/preprocess_data.py @@ -5,7 +5,7 @@ def main(): dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "lyft_sample", "nuplan_mini"], rebuild_cache=True, rebuild_maps=True, num_workers=os.cpu_count(), @@ -13,6 +13,7 @@ def main(): data_dirs={ # Remember to change this to match your filesystem! "nusc_mini": "~/datasets/nuScenes", "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", }, ) print(f"Total Data Samples: {len(dataset):,}") diff --git a/examples/preprocess_maps.py b/examples/preprocess_maps.py index 21042a2..f5dcbf9 100644 --- a/examples/preprocess_maps.py +++ b/examples/preprocess_maps.py @@ -1,17 +1,17 @@ -import os - from trajdata import UnifiedDataset # @profile def main(): dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "lyft_sample", "nuplan_mini"], rebuild_maps=True, data_dirs={ # Remember to change this to match your filesystem! "nusc_mini": "~/datasets/nuScenes", "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", }, + verbose=True, ) print(f"Finished Caching Maps!") diff --git a/examples/scene_batch_example.py b/examples/scene_batch_example.py index 7320b00..9340e07 100644 --- a/examples/scene_batch_example.py +++ b/examples/scene_batch_example.py @@ -20,8 +20,12 @@ def main(): only_types=[AgentType.VEHICLE], agent_interaction_distances=defaultdict(lambda: 30.0), incl_robot_future=True, - incl_map=True, - map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, augmentations=[noise_hists], max_agent_num=20, num_workers=4, diff --git a/examples/simple_map_api_example.py b/examples/simple_map_api_example.py new file mode 100644 index 0000000..57c9d93 --- /dev/null +++ b/examples/simple_map_api_example.py @@ -0,0 +1,75 @@ +import time +from pathlib import Path +from typing import Dict + +import matplotlib.pyplot as plt +import numpy as np + +from trajdata import MapAPI, VectorMap + + +def main(): + cache_path = Path("~/.unified_data_cache").expanduser() + map_api = MapAPI(cache_path) + + ### Loading random scene and initializing VectorMap. + env_name: str = np.random.choice(["nusc_mini", "lyft_sample", "nuplan_mini"]) + random_location_dict: Dict[str, str] = { + "nuplan_mini": np.random.choice( + ["boston", "singapore", "pittsburgh", "las_vegas"] + ), + "nusc_mini": np.random.choice(["boston-seaport", "singapore-onenorth"]), + "lyft_sample": "palo_alto", + } + + start = time.perf_counter() + vec_map: VectorMap = map_api.get_map(f"{env_name}:{random_location_dict[env_name]}") + end = time.perf_counter() + print(f"Map loading took {(end - start)*1000:.2f} ms") + + start = time.perf_counter() + vec_map: VectorMap = map_api.get_map(f"{env_name}:{random_location_dict[env_name]}") + end = time.perf_counter() + print(f"Repeated (cached in memory) map loading took {(end - start)*1000:.2f} ms") + + print(f"Randomly chose {vec_map.env_name}, {vec_map.map_name} map.") + + ### Lane Graph Visualization (with rasterized map in background) + fig, ax = plt.subplots() + + print(f"Rasterizing Map...") + start = time.perf_counter() + map_img, raster_from_world = vec_map.rasterize( + resolution=2, + return_tf_mat=True, + incl_centerlines=False, + area_color=(255, 255, 255), + edge_color=(0, 0, 0), + scene_ts=100, + ) + end = time.perf_counter() + print(f"Map rasterization took {(end - start)*1000:.2f} ms") + + ax.imshow(map_img, alpha=0.5, origin="lower") + + lane_idx = np.random.randint(0, len(vec_map.lanes)) + print(f"Visualizing random lane index {lane_idx}...") + start = time.perf_counter() + vec_map.visualize_lane_graph( + origin_lane=lane_idx, + num_hops=10, + raster_from_world=raster_from_world, + ax=ax, + ) + end = time.perf_counter() + print(f"Lane visualization took {(end - start)*1000:.2f} ms") + + ax.axis("equal") + ax.grid(None) + + plt.show() + plt.close("all") + + +if __name__ == "__main__": + main() diff --git a/examples/speed_example.py b/examples/speed_example.py new file mode 100644 index 0000000..4710020 --- /dev/null +++ b/examples/speed_example.py @@ -0,0 +1,54 @@ +import os +from collections import defaultdict + +from torch.utils.data import DataLoader +from tqdm import tqdm + +from trajdata import AgentBatch, AgentType, UnifiedDataset +from trajdata.augmentation import NoiseHistories + + +def main(): + noise_hists = NoiseHistories() + + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_train"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + incl_vector_map=True, + augmentations=[noise_hists], + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + print(f"# Data Samples: {len(dataset):,}") + + dataloader = DataLoader( + dataset, + batch_size=64, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=os.cpu_count() // 2, + ) + + batch: AgentBatch + for batch in tqdm(dataloader): + pass + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 17d7c79..1fa112e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ tqdm matplotlib dill pandas +seaborn pyarrow torch zarr diff --git a/setup.cfg b/setup.cfg index 8ecbf21..6f31175 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = trajdata -version = 1.0.8 +version = 1.1.0 author = Boris Ivanovic author_email = bivanovic@nvidia.com description = A unified interface to many trajectory forecasting datasets. @@ -29,6 +29,7 @@ install_requires = torch>=1.10.2 zarr>=2.11.0 kornia>=0.6.4 + seaborn>=0.12 [options.packages.find] where = src diff --git a/src/trajdata/__init__.py b/src/trajdata/__init__.py index 2da9e8e..b639ce9 100644 --- a/src/trajdata/__init__.py +++ b/src/trajdata/__init__.py @@ -1,2 +1,3 @@ from .data_structures import AgentBatch, AgentType, SceneBatch from .dataset import UnifiedDataset +from .maps import MapAPI, VectorMap diff --git a/src/trajdata/caching/__init__.py b/src/trajdata/caching/__init__.py index 64ed55b..17e837a 100644 --- a/src/trajdata/caching/__init__.py +++ b/src/trajdata/caching/__init__.py @@ -1,3 +1,2 @@ -from .df_cache import DataFrameCache from .env_cache import EnvCache from .scene_cache import SceneCache diff --git a/src/trajdata/caching/df_cache.py b/src/trajdata/caching/df_cache.py index 8f9c9f9..1dd7da7 100644 --- a/src/trajdata/caching/df_cache.py +++ b/src/trajdata/caching/df_cache.py @@ -1,7 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trajdata.maps import ( + RasterizedMap, + RasterizedMapMetadata, + VectorMap, + ) + from trajdata.maps.map_kdtree import MapElementKDTree + import pickle from math import ceil, floor from pathlib import Path -from typing import Callable, Dict, Final, List, Optional, Tuple +from typing import Any, Dict, Final, List, Optional, Tuple import dill import kornia @@ -14,9 +26,8 @@ from trajdata.caching.scene_cache import SceneCache from trajdata.data_structures.agent import AgentMetadata, FixedExtent from trajdata.data_structures.scene_metadata import Scene -from trajdata.maps import RasterizedMap, RasterizedMapMetadata -from trajdata.proto.vectorized_map_pb2 import VectorizedMap -from trajdata.utils import arr_utils +from trajdata.maps.traffic_light_status import TrafficLightStatus +from trajdata.utils import arr_utils, df_utils, raster_utils STATE_COLS: Final[List[str]] = ["x", "y", "vx", "vy", "ax", "ay"] EXTENT_COLS: Final[List[str]] = ["length", "width", "height"] @@ -27,7 +38,6 @@ def __init__( self, cache_path: Path, scene: Scene, - scene_ts: Optional[int] = 0, augmentations: Optional[List[Augmentation]] = None, ) -> None: """ @@ -36,7 +46,7 @@ def __init__( and pickle for miscellaneous supporting objects. Maps are pre-rasterized and stored as Zarr arrays. """ - super().__init__(cache_path, scene, scene_ts, augmentations) + super().__init__(cache_path, scene, augmentations) agent_data_path: Path = self.scene_dir / DataFrameCache._agent_data_file( scene.dt @@ -53,6 +63,8 @@ def __init__( # Setting default data transformation parameters. self.reset_transforms() + self._kdtrees = None + if augmentations: dataset_augments: List[DatasetAugmentation] = [ augment @@ -143,7 +155,9 @@ def save_agent_data( cache_path: Path, scene: Scene, ) -> None: - scene_cache_dir: Path = cache_path / scene.env_name / scene.name + scene_cache_dir: Path = DataFrameCache.scene_cache_dir( + cache_path, scene.env_name, scene.name + ) scene_cache_dir.mkdir(parents=True, exist_ok=True) index_dict: Dict[Tuple[str, int], int] = { @@ -189,10 +203,17 @@ def get_value(self, agent_id: str, scene_ts: int, attribute: str) -> float: ) return transformed_pair[0, 0].item() + def get_raw_state(self, agent_id: str, scene_ts: int) -> np.ndarray: + return ( + self.scene_data_df.iloc[ + self.index_dict[(agent_id, scene_ts)], : self._state_dim + ] + .to_numpy() + .copy() + ) + def get_state(self, agent_id: str, scene_ts: int) -> np.ndarray: - state = self.scene_data_df.iloc[ - self.index_dict[(agent_id, scene_ts)], : self._state_dim - ].to_numpy() + state = self.get_raw_state(agent_id, scene_ts) return self._transform_state(state) @@ -223,7 +244,7 @@ def transform_data(self, **kwargs) -> None: if "sincos_heading" in kwargs: self._sincos_heading = True - self.obs_dim += 1 + self.obs_dim = self._state_dim + 1 def reset_transforms(self) -> None: self._transf_mean: Optional[np.ndarray] = None @@ -389,7 +410,7 @@ def get_agent_history( agent_extent_np: np.ndarray if isinstance(agent_info.extent, FixedExtent): agent_extent_np = agent_info.extent.get_extents( - self.scene_ts - agent_history_df.shape[0] + 1, self.scene_ts + scene_ts - agent_history_df.shape[0] + 1, scene_ts ) else: agent_extent_np = agent_history_df.iloc[:, self.extent_cols].to_numpy() @@ -432,7 +453,7 @@ def get_agent_future( agent_extent_np: np.ndarray if isinstance(agent_info.extent, FixedExtent): agent_extent_np: np.ndarray = agent_info.extent.get_extents( - self.scene_ts + 1, self.scene_ts + agent_future_df.shape[0] + scene_ts + 1, scene_ts + agent_future_df.shape[0] ) else: agent_extent_np = agent_future_df.iloc[:, self.extent_cols].to_numpy() @@ -451,7 +472,7 @@ def get_agents_history( history_sec: Tuple[Optional[float], Optional[float]], ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: first_timesteps = np.array( - [agent.first_timestep for agent in agents], dtype=np.long + [agent.first_timestep for agent in agents], dtype=int ) if history_sec[1] is not None: max_history: int = floor(history_sec[1] / self.dt) @@ -462,10 +483,10 @@ def get_agents_history( self.index_dict[(agent.name, first_timesteps[idx])] for idx, agent in enumerate(agents) ], - dtype=np.long, + dtype=int, ) last_index_incl: np.ndarray = np.array( - [self.index_dict[(agent.name, scene_ts)] for agent in agents], dtype=np.long + [self.index_dict[(agent.name, scene_ts)] for agent in agents], dtype=int ) concat_idxs = arr_utils.vrange(first_index_incl, last_index_incl + 1) @@ -491,8 +512,8 @@ def get_agents_history( else: neighbor_extents = [ agent.extent.get_extents( - self.scene_ts - neighbor_history_lens_np[idx].item() + 1, - self.scene_ts, + scene_ts - neighbor_history_lens_np[idx].item() + 1, + scene_ts, ) for idx, agent in enumerate(agents) ] @@ -509,9 +530,7 @@ def get_agents_future( agents: List[AgentMetadata], future_sec: Tuple[Optional[float], Optional[float]], ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: - last_timesteps = np.array( - [agent.last_timestep for agent in agents], dtype=np.long - ) + last_timesteps = np.array([agent.last_timestep for agent in agents], dtype=int) first_timesteps = np.minimum(scene_ts + 1, last_timesteps) @@ -524,14 +543,14 @@ def get_agents_future( self.index_dict[(agent.name, first_timesteps[idx])] for idx, agent in enumerate(agents) ], - dtype=np.long, + dtype=int, ) last_index_incl: np.ndarray = np.array( [ self.index_dict[(agent.name, last_timesteps[idx])] for idx, agent in enumerate(agents) ], - dtype=np.long, + dtype=int, ) concat_idxs = arr_utils.vrange(first_index_incl, last_index_incl + 1) @@ -557,8 +576,8 @@ def get_agents_future( else: neighbor_extents = [ agent.extent.get_extents( - self.scene_ts - neighbor_future_lens_np[idx].item() + 1, - self.scene_ts, + scene_ts - neighbor_future_lens_np[idx].item() + 1, + scene_ts, ) for idx, agent in enumerate(agents) ] @@ -569,6 +588,80 @@ def get_agents_future( neighbor_future_lens_np, ) + # TRAFFIC LIGHT INFO + @staticmethod + def _tls_data_file(scene_dt: float) -> str: + return f"tls_data_dt{scene_dt:.2f}.feather" + + @staticmethod + def save_traffic_light_data( + traffic_light_status_data: pd.DataFrame, + cache_path: Path, + scene: Scene, + dt: Optional[float] = None, + ) -> None: + """ + Assumes traffic_light_status_data is a MultiIndex dataframe with + lane_connector_id and scene_ts as the indices, and has a column "status" with integer + values for traffic status given by the TrafficLightStatus enum + """ + scene_cache_dir: Path = DataFrameCache.scene_cache_dir( + cache_path, scene.env_name, scene.name + ) + scene_cache_dir.mkdir(parents=True, exist_ok=True) + + if dt is None: + dt = scene.dt + + traffic_light_status_data.reset_index().to_feather( + scene_cache_dir / DataFrameCache._tls_data_file(dt) + ) + + def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bool: + desired_dt = self.dt if desired_dt is None else desired_dt + tls_data_path: Path = self.scene_dir / DataFrameCache._tls_data_file(desired_dt) + return tls_data_path.exists() + + def get_traffic_light_status_dict( + self, desired_dt: Optional[float] = None + ) -> Dict[Tuple[int, int], TrafficLightStatus]: + """ + Returns dict mapping Lane Id, scene_ts to traffic light status for the + particular scene. If data doesn't exist for the current dt, interpolates and + saves the interpolated data to disk for loading later. + """ + desired_dt = self.dt if desired_dt is None else desired_dt + + tls_data_path: Path = self.scene_dir / DataFrameCache._tls_data_file(desired_dt) + if not tls_data_path.exists(): + # Load the original dt traffic light data + tls_orig_dt_df: pd.DataFrame = pd.read_feather( + self.scene_dir + / DataFrameCache._tls_data_file(self.scene.env_metadata.dt), + use_threads=False, + ).set_index(["lane_connector_id", "scene_ts"]) + + # Interpolate it to the desired dt. + tls_data_df = df_utils.interpolate_multi_index_df( + tls_orig_dt_df, self.scene.env_metadata.dt, desired_dt, method="nearest" + ) + + # Save it for the future + DataFrameCache.save_traffic_light_data( + tls_data_df, self.path, self.scene, desired_dt + ) + else: + # Load the data with the desired dt. + tls_data_df: pd.DataFrame = pd.read_feather( + tls_data_path, + use_threads=False, + ).set_index(["lane_connector_id", "scene_ts"]) + + # Return data as dict + return { + idx: TrafficLightStatus(v["status"]) for idx, v in tls_data_df.iterrows() + } + # MAPS @staticmethod def get_maps_path(cache_path: Path, env_name: str) -> Path: @@ -581,14 +674,21 @@ def are_maps_cached(cache_path: Path, env_name: str) -> bool: @staticmethod def get_map_paths( cache_path: Path, env_name: str, map_name: str, resolution: float - ) -> Tuple[Path, Path, Path, Path]: + ) -> Tuple[Path, Path, Path, Path, Path]: maps_path: Path = DataFrameCache.get_maps_path(cache_path, env_name) vector_map_path: Path = maps_path / f"{map_name}.pb" + kdtrees_path: Path = maps_path / f"{map_name}_kdtrees.dill" raster_map_path: Path = maps_path / f"{map_name}_{resolution:.2f}px_m.zarr" raster_metadata_path: Path = maps_path / f"{map_name}_{resolution:.2f}px_m.dill" - return maps_path, vector_map_path, raster_map_path, raster_metadata_path + return ( + maps_path, + vector_map_path, + kdtrees_path, + raster_map_path, + raster_metadata_path, + ) @staticmethod def is_map_cached( @@ -597,75 +697,60 @@ def is_map_cached( ( maps_path, vector_map_path, + kdtrees_path, raster_map_path, raster_metadata_path, ) = DataFrameCache.get_map_paths(cache_path, env_name, map_name, resolution) return ( maps_path.exists() and vector_map_path.exists() + and kdtrees_path.exists() and raster_metadata_path.exists() and raster_map_path.exists() ) @staticmethod - def cache_map( - cache_path: Path, vec_map: VectorizedMap, map_obj: RasterizedMap, env_name: str + def finalize_and_cache_map( + cache_path: Path, + vector_map: VectorMap, + map_params: Dict[str, Any], ) -> None: + raster_resolution: float = map_params["px_per_m"] + ( maps_path, vector_map_path, + kdtrees_path, raster_map_path, raster_metadata_path, ) = DataFrameCache.get_map_paths( - cache_path, env_name, map_obj.metadata.name, map_obj.metadata.resolution + cache_path, vector_map.env_name, vector_map.map_name, raster_resolution ) - # Ensuring the maps directory exists. - maps_path.mkdir(parents=True, exist_ok=True) - - # Saving the vectorized map data. - with open(vector_map_path, "wb") as f: - f.write(vec_map.SerializeToString()) - - # Saving the rasterized map data. - zarr.save(raster_map_path, map_obj.data) - - # Saving the rasterized map metadata. - with open(raster_metadata_path, "wb") as f: - dill.dump(map_obj.metadata, f) - - @staticmethod - def cache_map_layers( - cache_path: Path, - vec_map: VectorizedMap, - map_info: RasterizedMapMetadata, - layer_fn: Callable[[str], np.ndarray], - env_name: str, - ) -> None: - ( - maps_path, - vector_map_path, - raster_map_path, - raster_metadata_path, - ) = DataFrameCache.get_map_paths( - cache_path, env_name, map_info.name, map_info.resolution + pbar_kwargs = {"position": 2, "leave": False} + rasterized_map: RasterizedMap = raster_utils.rasterize_map( + vector_map, raster_resolution, **pbar_kwargs ) + vector_map.compute_search_indices() + # Ensuring the maps directory exists. maps_path.mkdir(parents=True, exist_ok=True) # Saving the vectorized map data. with open(vector_map_path, "wb") as f: - f.write(vec_map.SerializeToString()) + f.write(vector_map.to_proto().SerializeToString()) + + # Saving precomputed map element kdtrees. + with open(kdtrees_path, "wb") as f: + dill.dump(vector_map.search_kdtrees, f) # Saving the rasterized map data. - disk_data = zarr.open_array(raster_map_path, mode="w", shape=map_info.shape) - for idx, layer_name in enumerate(map_info.layers): - disk_data[idx] = layer_fn(layer_name) + zarr.save(raster_map_path, rasterized_map.data) # Saving the rasterized map metadata. with open(raster_metadata_path, "wb") as f: - dill.dump(map_info, f) + dill.dump(rasterized_map.metadata, f) def pad_map_patch( self, @@ -698,6 +783,33 @@ def pad_map_patch( return np.pad(patch, [(0, 0), (pad_top, pad_bot), (pad_left, pad_right)]) + def load_kdtrees(self) -> Dict[str, MapElementKDTree]: + _, _, kdtrees_path, _, _ = DataFrameCache.get_map_paths( + self.path, self.scene.env_name, self.scene.location, 0.0 + ) + + with open(kdtrees_path, "rb") as f: + kdtrees: Dict[str, MapElementKDTree] = dill.load(f) + + return kdtrees + + def get_kdtrees(self, load_only_once: bool = True): + """Loads and returns the kdtrees dictionary from the cache file. + + Args: + load_only_once (bool): store the kdtree dictionary in self so that we + dont have to load it from the cache file more than once. + """ + if self._kdtrees is None: + kdtrees = self.load_kdtrees() + if load_only_once: + self._kdtrees = kdtrees + + return kdtrees + + else: + return self._kdtrees + def load_map_patch( self, world_x: float, @@ -713,6 +825,7 @@ def load_map_patch( ( maps_path, _, + _, raster_map_path, raster_metadata_path, ) = DataFrameCache.get_map_paths( diff --git a/src/trajdata/caching/scene_cache.py b/src/trajdata/caching/scene_cache.py index 8f8b1fc..7594d80 100644 --- a/src/trajdata/caching/scene_cache.py +++ b/src/trajdata/caching/scene_cache.py @@ -1,13 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trajdata.maps import TrafficLightStatus, VectorMap + + from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np from trajdata.augmentation.augmentation import Augmentation from trajdata.data_structures.agent import AgentMetadata from trajdata.data_structures.scene_metadata import Scene -from trajdata.maps import RasterizedMap, RasterizedMapMetadata -from trajdata.proto.vectorized_map_pb2 import VectorizedMap class SceneCache: @@ -15,7 +21,6 @@ def __init__( self, cache_path: Path, scene: Scene, - scene_ts: Optional[int] = 0, augmentations: Optional[List[Augmentation]] = None, ) -> None: """ @@ -24,13 +29,19 @@ def __init__( self.path = cache_path self.scene = scene self.dt = scene.dt - self.scene_ts = scene_ts self.augmentations = augmentations # Ensuring the scene cache folder exists - self.scene_dir: Path = self.path / self.scene.env_name / self.scene.name + self.scene_dir: Path = SceneCache.scene_cache_dir( + self.path, self.scene.env_name, self.scene.name + ) self.scene_dir.mkdir(parents=True, exist_ok=True) + @staticmethod + def scene_cache_dir(cache_path: Path, env_name: str, scene_name: str) -> Path: + """Standardized convention to compute scene cache folder path""" + return cache_path / env_name / scene_name + def write_cache_to_disk(self) -> None: """Saves agent data to disk for fast loading later (just like save_agent_data), but using the class attributes for the sources of data and file paths. @@ -53,6 +64,12 @@ def get_value(self, agent_id: str, scene_ts: int, attribute: str) -> float: """ raise NotImplementedError() + def get_raw_state(self, agent_id: str, scene_ts: int) -> np.ndarray: + """ + Get an agent's raw state (without transformations applied) + """ + raise NotImplementedError() + def get_state(self, agent_id: str, scene_ts: int) -> np.ndarray: """ Get an agent's state at a specific timestep. @@ -120,6 +137,24 @@ def get_agents_future( ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: raise NotImplementedError() + # TRAFFIC LIGHT INFO + @staticmethod + def save_traffic_light_data( + traffic_light_status_data: Any, cache_path: Path, scene: Scene + ) -> None: + """Saves traffic light status to disk for easy access later""" + raise NotImplementedError() + + def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bool: + raise NotImplementedError() + + def get_traffic_light_status_dict( + self, + ) -> Dict[Tuple[int, int], TrafficLightStatus]: + """Returns lookup table for traffic light status in the current scene + lane_id, scene_ts -> TrafficLightStatus""" + raise NotImplementedError() + # MAPS @staticmethod def are_maps_cached(cache_path: Path, env_name: str) -> bool: @@ -132,18 +167,10 @@ def is_map_cached( raise NotImplementedError() @staticmethod - def cache_map( - cache_path: Path, vec_map: VectorizedMap, map_obj: RasterizedMap, env_name: str - ) -> None: - raise NotImplementedError() - - @staticmethod - def cache_map_layers( + def finalize_and_cache_map( cache_path: Path, - vec_map: VectorizedMap, - map_info: RasterizedMapMetadata, - layer_fn: Callable[[str], np.ndarray], - env_name: str, + vector_map: VectorMap, + map_params: Dict[str, Any], ) -> None: raise NotImplementedError() diff --git a/src/trajdata/data_structures/batch.py b/src/trajdata/data_structures/batch.py index 0844ef1..26de656 100644 --- a/src/trajdata/data_structures/batch.py +++ b/src/trajdata/data_structures/batch.py @@ -7,12 +7,14 @@ from torch import Tensor from trajdata.data_structures.agent import AgentType +from trajdata.maps import VectorMap from trajdata.utils.arr_utils import PadDirection @dataclass class AgentBatch: data_idx: Tensor + scene_ts: Tensor dt: Tensor agent_name: List[str] agent_type: Tensor @@ -33,8 +35,10 @@ class AgentBatch: neigh_fut_len: Tensor robot_fut: Optional[Tensor] robot_fut_len: Optional[Tensor] + map_names: Optional[List[str]] maps: Optional[Tensor] maps_resolution: Optional[Tensor] + vector_maps: Optional[List[VectorMap]] rasters_from_world_tf: Optional[Tensor] agents_from_world_tf: Tensor scene_ids: Optional[List] @@ -53,6 +57,8 @@ def to(self, device) -> None: "neigh_types", "num_neigh", "robot_fut_len", + "map_names", + "vector_maps", "scene_ids", "history_pad_dir", "extras", @@ -64,7 +70,11 @@ def to(self, device) -> None: setattr(self, val, tensor_val.to(device, non_blocking=True)) for key, val in self.extras.items(): - self.extras[key] = val.to(device, non_blocking=True) + # Allow for custom .to() method for objects that define a __to__ function. + if hasattr(val, "__to__"): + self.extras[key] = val.__to__(device, non_blocking=True) + else: + self.extras[key] = val.to(device, non_blocking=True) def agent_types(self) -> List[AgentType]: unique_types: Tensor = torch.unique(self.agent_type) @@ -74,6 +84,7 @@ def for_agent_type(self, agent_type: AgentType) -> AgentBatch: match_type = self.agent_type == agent_type return AgentBatch( data_idx=self.data_idx[match_type], + scene_ts=self.scene_ts[match_type], dt=self.dt[match_type], agent_name=[ name for idx, name in enumerate(self.agent_name) if match_type[idx] @@ -100,10 +111,22 @@ def for_agent_type(self, agent_type: AgentType) -> AgentBatch: robot_fut_len=self.robot_fut_len[match_type] if self.robot_fut_len is not None else None, + map_names=[ + name for idx, name in enumerate(self.map_names) if match_type[idx] + ] + if self.map_names is not None + else None, maps=self.maps[match_type] if self.maps is not None else None, maps_resolution=self.maps_resolution[match_type] if self.maps_resolution is not None else None, + vector_maps=[ + vector_map + for idx, vector_map in enumerate(self.vector_maps) + if match_type[idx] + ] + if self.vector_maps is not None + else None, rasters_from_world_tf=self.rasters_from_world_tf[match_type] if self.rasters_from_world_tf is not None else None, @@ -121,10 +144,12 @@ def for_agent_type(self, agent_type: AgentType) -> AgentBatch: @dataclass class SceneBatch: data_idx: Tensor + scene_ts: Tensor dt: Tensor num_agents: Tensor agent_type: Tensor centered_agent_state: Tensor + agent_names: List[str] agent_hist: Tensor agent_hist_extent: Tensor agent_hist_len: Tensor @@ -133,8 +158,10 @@ class SceneBatch: agent_fut_len: Tensor robot_fut: Optional[Tensor] robot_fut_len: Optional[Tensor] + map_names: Optional[Tensor] maps: Optional[Tensor] maps_resolution: Optional[Tensor] + vector_maps: Optional[List[VectorMap]] rasters_from_world_tf: Optional[Tensor] centered_agent_from_world_tf: Tensor centered_world_from_agent_tf: Tensor @@ -144,6 +171,9 @@ class SceneBatch: def to(self, device) -> None: excl_vals = { + "agent_names", + "map_names", + "vector_maps", "history_pad_dir", "extras", } @@ -164,10 +194,16 @@ def for_agent_type(self, agent_type: AgentType) -> SceneBatch: match_type = self.agent_type == agent_type return SceneBatch( data_idx=self.data_idx[match_type], + scene_ts=self.scene_ts[match_type], dt=self.dt[match_type], num_agents=self.num_agents[match_type], agent_type=self.agent_type[match_type], centered_agent_state=self.centered_agent_state[match_type], + agent_names=[ + agent_name + for idx, agent_name in enumerate(self.agent_names) + if match_type[idx] + ], agent_hist=self.agent_hist[match_type], agent_hist_extent=self.agent_hist_extent[match_type], agent_hist_len=self.agent_hist_len[match_type], @@ -180,10 +216,22 @@ def for_agent_type(self, agent_type: AgentType) -> SceneBatch: robot_fut_len=self.robot_fut_len[match_type] if self.robot_fut_len is not None else None, + map_names=[ + name for idx, name in enumerate(self.map_names) if match_type[idx] + ] + if self.map_names is not None + else None, maps=self.maps[match_type] if self.maps is not None else None, maps_resolution=self.maps_resolution[match_type] if self.maps_resolution is not None else None, + vector_maps=[ + vector_map + for idx, vector_map in enumerate(self.vector_maps) + if match_type[idx] + ] + if self.vector_maps is not None + else None, rasters_from_world_tf=self.rasters_from_world_tf[match_type] if self.rasters_from_world_tf is not None else None, diff --git a/src/trajdata/data_structures/batch_element.py b/src/trajdata/data_structures/batch_element.py index 2b177f9..d09f2a8 100644 --- a/src/trajdata/data_structures/batch_element.py +++ b/src/trajdata/data_structures/batch_element.py @@ -7,7 +7,7 @@ from trajdata.caching import SceneCache from trajdata.data_structures.agent import AgentMetadata, AgentType from trajdata.data_structures.scene import SceneTime, SceneTimeAgent -from trajdata.maps import RasterizedMapPatch +from trajdata.maps import MapAPI, RasterizedMapPatch, VectorMap class AgentBatchElement: @@ -25,8 +25,10 @@ def __init__( Tuple[AgentType, AgentType], float ] = defaultdict(lambda: np.inf), incl_robot_future: bool = False, - incl_map: bool = False, - map_params: Optional[Dict[str, Any]] = None, + incl_raster_map: bool = False, + raster_map_params: Optional[Dict[str, Any]] = None, + map_api: Optional[MapAPI] = None, + vector_map_params: Optional[Dict[str, Any]] = None, standardize_data: bool = False, standardize_derivatives: bool = False, max_neighbor_num: Optional[int] = None, @@ -43,7 +45,7 @@ def __init__( self.agent_type: AgentType = agent_info.type self.max_neighbor_num = max_neighbor_num - self.curr_agent_state_np: np.ndarray = cache.get_state( + self.curr_agent_state_np: np.ndarray = cache.get_raw_state( agent_info.name, self.scene_ts ) @@ -62,7 +64,7 @@ def __init__( ) self.agent_from_world_tf: np.ndarray = np.linalg.inv(world_from_agent_tf) - offset = self.curr_agent_state_np + offset = self.curr_agent_state_np.copy() if not standardize_derivatives: offset[2:6] = 0.0 @@ -129,9 +131,23 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: self.robot_future_len: int = self.robot_future_np.shape[0] - 1 ### MAP ### + self.map_name: Optional[str] = None self.map_patch: Optional[RasterizedMapPatch] = None - if incl_map: - self.map_patch = self.get_agent_map_patch(map_params) + + map_name: str = ( + f"{scene_time_agent.scene.env_name}:{scene_time_agent.scene.location}" + ) + if incl_raster_map: + self.map_name = map_name + self.map_patch = self.get_agent_map_patch(raster_map_params) + + self.vec_map: Optional[VectorMap] = None + if map_api is not None: + self.vec_map = map_api.get_map( + map_name, + self.cache if self.cache.is_traffic_light_data_cached() else None, + **vector_map_params if vector_map_params is not None else None, + ) self.scene_id = scene_time_agent.scene.name @@ -315,8 +331,10 @@ def __init__( Tuple[AgentType, AgentType], float ] = defaultdict(lambda: np.inf), incl_robot_future: bool = False, - incl_map: bool = False, - map_params: Optional[Dict[str, Any]] = None, + incl_raster_map: bool = False, + raster_map_params: Optional[Dict[str, Any]] = None, + map_api: Optional[MapAPI] = None, + vector_map_params: Optional[Dict[str, Any]] = None, standardize_data: bool = False, standardize_derivatives: bool = False, max_agent_num: Optional[int] = None, @@ -385,6 +403,7 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: ) self.num_agents = len(nearby_agents) + self.agent_names = [agent.name for agent in nearby_agents] ( self.agent_histories, self.agent_history_extents, @@ -397,12 +416,26 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: ) = self.get_agents_future(future_sec, nearby_agents) ### MAP ### + self.map_name: Optional[str] = None self.map_patches: Optional[RasterizedMapPatch] = None - if incl_map: + + map_name: str = f"{scene_time.scene.env_name}:{scene_time.scene.location}" + if incl_raster_map: + self.map_name = map_name self.map_patches = self.get_agents_map_patch( - map_params, self.agent_histories + raster_map_params, self.agent_histories ) + + self.vec_map: Optional[VectorMap] = None + if map_api is not None: + self.vec_map = map_api.get_map( + map_name, + self.cache if self.cache.is_traffic_light_data_cached() else None, + **vector_map_params if vector_map_params is not None else None, + ) + self.scene_id = scene_time.scene.name + ### ROBOT DATA ### self.robot_future_np: Optional[np.ndarray] = None diff --git a/src/trajdata/data_structures/collation.py b/src/trajdata/data_structures/collation.py index dfea7a1..ee12452 100644 --- a/src/trajdata/data_structures/collation.py +++ b/src/trajdata/data_structures/collation.py @@ -1,5 +1,4 @@ from dataclasses import asdict -from enum import IntEnum from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -12,26 +11,45 @@ from trajdata.augmentation import BatchAugmentation from trajdata.data_structures.batch import AgentBatch, SceneBatch from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement +from trajdata.maps import VectorMap from trajdata.utils import arr_utils -def map_collate_fn_agent( +class CustomCollateData: + @staticmethod + def __collate__(elements: list) -> any: + raise NotImplementedError + + def __to__(self, device, non_blocking=False): + raise NotImplementedError + + +def _collate_data(elems): + if hasattr(elems[0], "__collate__"): + return elems[0].__collate__(elems) + else: + return torch.as_tensor(np.stack(elems)) + + +def raster_map_collate_fn_agent( batch_elems: List[AgentBatchElement], ): if batch_elems[0].map_patch is None: - return None, None, None + return None, None, None, None + + map_names = [batch_elem.map_name for batch_elem in batch_elems] # Ensuring that any empty map patches have the correct number of channels # prior to collation. has_data: np.ndarray = np.array( [batch_elem.map_patch.has_data for batch_elem in batch_elems], - dtype=np.bool, + dtype=bool, ) no_data: np.ndarray = ~has_data patch_channels: np.ndarray = np.array( [batch_elem.map_patch.data.shape[0] for batch_elem in batch_elems], - dtype=np.int, + dtype=int, ) desired_num_channels: int @@ -148,26 +166,28 @@ def map_collate_fn_agent( ) return ( + map_names, rot_crop_patches, resolution, rasters_from_world_tf, ) -def map_collate_fn_scene( +def raster_map_collate_fn_scene( batch_elems: List[SceneBatchElement], max_agent_num: Optional[int] = None, pad_value: Any = np.nan, ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: if batch_elems[0].map_patches is None: - return None, None, None + return None, None, None, None patch_size: int = batch_elems[0].map_patches[0].crop_size assert all( batch_elem.map_patches[0].crop_size == patch_size for batch_elem in batch_elems ) + map_names: List[str] = list() num_agents: List[int] = list() agents_rasters_from_world_tfs: List[np.ndarray] = list() agents_patches: List[np.ndarray] = list() @@ -175,6 +195,7 @@ def map_collate_fn_scene( agents_res_list: List[float] = list() for elem in batch_elems: + map_names.append(elem.map_name) num_agents.append(min(elem.num_agents, max_agent_num)) agents_rasters_from_world_tfs += [ x.raster_from_world_tf for x in elem.map_patches[:max_agent_num] @@ -253,7 +274,7 @@ def map_collate_fn_scene( agents_resolution, num_agents, pad_value=0, desired_size=max_agent_num ) - return rot_crop_patches, agents_resolution, agents_rasters_from_world_tf + return map_names, rot_crop_patches, agents_resolution, agents_rasters_from_world_tf def agent_collate_fn( @@ -270,6 +291,7 @@ def agent_collate_fn( ) data_index_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) + scene_ts_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) dt_t: Tensor = torch.zeros((batch_size,), dtype=torch.float) agent_type_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) agent_names: List[str] = list() @@ -329,6 +351,7 @@ def agent_collate_fn( elem: AgentBatchElement for idx, elem in enumerate(batch_elems): data_index_t[idx] = elem.data_index + scene_ts_t[idx] = elem.scene_ts dt_t[idx] = elem.dt agent_names.append(elem.agent_name) agent_type_t[idx] = elem.agent_type.value @@ -615,10 +638,15 @@ def agent_collate_fn( ) ( + map_names, map_patches, maps_resolution, rasters_from_world_tf, - ) = map_collate_fn_agent(batch_elems) + ) = raster_map_collate_fn_agent(batch_elems) + + vector_maps: Optional[List[VectorMap]] = None + if batch_elems[0].vec_map is not None: + vector_maps = [batch_elem.vec_map for batch_elem in batch_elems] agents_from_world_tf = torch.as_tensor( np.stack([batch_elem.agent_from_world_tf for batch_elem in batch_elems]), @@ -629,12 +657,13 @@ def agent_collate_fn( extras: Dict[str, Tensor] = {} for key in batch_elems[0].extras.keys(): - extras[key] = torch.as_tensor( - np.stack([batch_elem.extras[key] for batch_elem in batch_elems]) + extras[key] = _collate_data( + [batch_elem.extras[key] for batch_elem in batch_elems] ) batch = AgentBatch( data_idx=data_index_t, + scene_ts=scene_ts_t, dt=dt_t, agent_name=agent_names, agent_type=agent_type_t, @@ -655,8 +684,10 @@ def agent_collate_fn( neigh_fut_len=neighbor_future_lens_t, robot_fut=robot_future_t, robot_fut_len=robot_future_len, + map_names=map_names, maps=map_patches, maps_resolution=maps_resolution, + vector_maps=vector_maps, rasters_from_world_tf=rasters_from_world_tf, agents_from_world_tf=agents_from_world_tf, scene_ids=scene_ids, @@ -733,6 +764,7 @@ def scene_collate_fn( ) data_index_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) + scene_ts_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) dt_t: Tensor = torch.zeros((batch_size,), dtype=torch.float) max_agent_num: int = max(elem.num_agents for elem in batch_elems) @@ -762,6 +794,7 @@ def scene_collate_fn( for idx, elem in enumerate(batch_elems): data_index_t[idx] = elem.data_index + scene_ts_t[idx] = elem.scene_ts dt_t[idx] = elem.dt centered_agent_state.append(elem.centered_agent_state_np) agents_types.append(elem.agent_types_np) @@ -871,9 +904,17 @@ def scene_collate_fn( agents_types_t, num_agents, pad_value=-1, desired_size=max_agent_num ) - map_patches, maps_resolution, rasters_from_world_tf = map_collate_fn_scene( - batch_elems, max_agent_num - ) + ( + map_names, + map_patches, + maps_resolution, + rasters_from_world_tf, + ) = raster_map_collate_fn_scene(batch_elems, max_agent_num) + + vector_maps: Optional[List[VectorMap]] = None + if batch_elems[0].vec_map is not None: + vector_maps = [batch_elem.vec_map for batch_elem in batch_elems] + centered_agent_from_world_tf = torch.as_tensor( np.stack( [batch_elem.centered_agent_from_world_tf for batch_elem in batch_elems] @@ -893,6 +934,8 @@ def scene_collate_fn( else None ) + agent_names = [batch_elem.agent_names for batch_elem in batch_elems] + scene_ids = [batch_elem.scene_id for batch_elem in batch_elems] extras: Dict[str, Tensor] = {} @@ -903,10 +946,12 @@ def scene_collate_fn( batch = SceneBatch( data_idx=data_index_t, + scene_ts=scene_ts_t, dt=dt_t, num_agents=num_agents_t, agent_type=agents_types_t, centered_agent_state=centered_agent_state_t, + agent_names=agent_names, agent_hist=agents_histories_t, agent_hist_extent=agents_history_extents_t, agent_hist_len=agents_history_len, @@ -915,8 +960,10 @@ def scene_collate_fn( agent_fut_len=agents_future_len, robot_fut=robot_future_t, robot_fut_len=robot_future_len, + map_names=map_names, maps=map_patches, maps_resolution=maps_resolution, + vector_maps=vector_maps, rasters_from_world_tf=rasters_from_world_tf, centered_agent_from_world_tf=centered_agent_from_world_tf, centered_world_from_agent_tf=centered_world_from_agent_tf, diff --git a/src/trajdata/data_structures/environment.py b/src/trajdata/data_structures/environment.py index e167865..a33ef13 100644 --- a/src/trajdata/data_structures/environment.py +++ b/src/trajdata/data_structures/environment.py @@ -1,6 +1,6 @@ import itertools from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from trajdata.data_structures.scene_tag import SceneTag @@ -13,10 +13,12 @@ def __init__( dt: float, parts: List[Tuple[str]], scene_split_map: Dict[str, str], + map_locations: Optional[Tuple[str]] = None, ) -> None: self.name = name self.data_dir = Path(data_dir).expanduser().resolve() self.dt = dt + self.map_locations = map_locations self.parts = parts self.scene_tags: List[SceneTag] = [ SceneTag(tag_tuple) diff --git a/src/trajdata/data_structures/map.py b/src/trajdata/data_structures/map.py deleted file mode 100644 index 7ab1172..0000000 --- a/src/trajdata/data_structures/map.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy as np -import torch -from torch import Tensor - - -class MapMetadata: - def __init__( - self, - name: str, - shape: Tuple[int, int], - layers: List[str], - layer_rgb_groups: Tuple[List[int], List[int], List[int]], - resolution: float, # px/m - map_from_world: np.ndarray, # Transformation from world coordinates [m] to map coordinates [px] - ) -> None: - self.name: str = name - self.shape: Tuple[int, int] = shape - self.layers: List[str] = layers - self.layer_rgb_groups: Tuple[List[int], List[int], List[int]] = layer_rgb_groups - self.resolution: float = resolution - self.map_from_world: np.ndarray = map_from_world - - -class Map: - def __init__( - self, - metadata: MapMetadata, - data: np.ndarray, - ) -> None: - assert data.shape == metadata.shape - self.metadata: MapMetadata = metadata - self.data: np.ndarray = data - - @property - def shape(self) -> Tuple[int, ...]: - return self.data.shape - - @staticmethod - def to_img( - map_arr: Tensor, - idx_groups: Optional[Tuple[List[int], List[int], List[int]]] = None, - ) -> Tensor: - if idx_groups is None: - return map_arr.permute(1, 2, 0).numpy() - - return torch.stack( - [ - torch.amax(map_arr[idx_groups[0]], dim=0), - torch.amax(map_arr[idx_groups[1]], dim=0), - torch.amax(map_arr[idx_groups[2]], dim=0), - ], - dim=-1, - ).numpy() diff --git a/src/trajdata/data_structures/map_patch.py b/src/trajdata/data_structures/map_patch.py deleted file mode 100644 index 371ae1c..0000000 --- a/src/trajdata/data_structures/map_patch.py +++ /dev/null @@ -1,17 +0,0 @@ -import numpy as np - - -class MapPatch: - def __init__( - self, - data: np.ndarray, - rot_angle: float, - crop_size: int, - resolution: float, - raster_from_world_tf: np.ndarray, - ) -> None: - self.data = data - self.rot_angle = rot_angle - self.crop_size = crop_size - self.resolution = resolution - self.raster_from_world_tf = raster_from_world_tf diff --git a/src/trajdata/dataset.py b/src/trajdata/dataset.py index 5979b0b..7e30eb7 100644 --- a/src/trajdata/dataset.py +++ b/src/trajdata/dataset.py @@ -1,17 +1,32 @@ import gc +import time from collections import defaultdict from functools import partial from itertools import chain +from os.path import isfile from pathlib import Path -from typing import Any, Callable, Dict, Final, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Final, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) +import dill import numpy as np +from torch import distributed from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from trajdata import filtering from trajdata.augmentation.augmentation import Augmentation, BatchAugmentation -from trajdata.caching import DataFrameCache, EnvCache, SceneCache +from trajdata.caching import EnvCache, SceneCache, df_cache from trajdata.data_structures import ( AgentBatchElement, AgentDataIndex, @@ -29,12 +44,11 @@ scene_collate_fn, ) from trajdata.dataset_specific import RawDataset -from trajdata.parallel import ( - ParallelDatasetPreprocessor, - parallel_iapply, - scene_paths_collate_fn, -) +from trajdata.maps import VectorMap +from trajdata.maps.map_api import MapAPI +from trajdata.parallel import ParallelDatasetPreprocessor, scene_paths_collate_fn from trajdata.utils import agent_utils, env_utils, scene_utils, string_utils +from trajdata.utils.parallel_utils import parallel_iapply # TODO(bivanovic): Move this to a better place in the codebase. DEFAULT_PX_PER_M: Final[float] = 2.0 @@ -60,8 +74,11 @@ def __init__( Tuple[AgentType, AgentType], float ] = defaultdict(lambda: np.inf), incl_robot_future: bool = False, - incl_map: bool = False, - map_params: Optional[Dict[str, Any]] = None, + incl_raster_map: bool = False, + raster_map_params: Optional[Dict[str, Any]] = None, + incl_vector_map: bool = False, + vector_map_params: Optional[Dict[str, Any]] = None, + require_map_cache: bool = True, only_types: Optional[List[AgentType]] = None, only_predict: Optional[List[AgentType]] = None, no_types: Optional[List[AgentType]] = None, @@ -72,18 +89,19 @@ def __init__( max_neighbor_num: Optional[int] = None, ego_only: Optional[bool] = False, data_dirs: Dict[str, str] = { - # "nusc_trainval": "~/datasets/nuScenes", - # "nusc_test": "~/datasets/nuScenes", "eupeds_eth": "~/datasets/eth_ucy_peds", "eupeds_hotel": "~/datasets/eth_ucy_peds", "eupeds_univ": "~/datasets/eth_ucy_peds", "eupeds_zara1": "~/datasets/eth_ucy_peds", "eupeds_zara2": "~/datasets/eth_ucy_peds", "nusc_mini": "~/datasets/nuScenes", + # "nusc_trainval": "~/datasets/nuScenes", + # "nusc_test": "~/datasets/nuScenes", "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", # "lyft_train": "~/datasets/lyft/scenes/train.zarr", # "lyft_train_full": "~/datasets/lyft/scenes/train_full.zarr", # "lyft_val": "~/datasets/lyft/scenes/validate.zarr", + # "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", }, cache_type: str = "dataframe", cache_location: str = "~/.unified_data_cache", @@ -92,6 +110,10 @@ def __init__( num_workers: int = 0, verbose: bool = False, extras: Dict[str, Callable[..., np.ndarray]] = dict(), + transforms: Iterable[ + Callable[..., Union[AgentBatchElement, SceneBatchElement]] + ] = (), + rank: int = 0, ) -> None: """Instantiates a PyTorch Dataset object which aggregates data from multiple trajectory forecasting datasets. @@ -105,8 +127,11 @@ def __init__( future_sec (Tuple[Optional[float], Optional[float]], optional): A tuple containing (the minimum seconds of future data each batch element must contain, the maximum seconds of future data to return). Both inclusive. Defaults to ( None, None, ). agent_interaction_distances: (Dict[Tuple[AgentType, AgentType], float]): A dictionary mapping agent-agent interaction distances in meters (determines which agents are included as neighbors to the predicted agent). Defaults to infinity for all types. incl_robot_future (bool, optional): Include the ego agent's future trajectory in batches (accordingly, never predict the ego's future). Defaults to False. - incl_map (bool, optional): Include a local cropping of the rasterized map (if the dataset provides a map) per agent. Defaults to False. - map_params (Optional[Dict[str, Any]], optional): Local map cropping parameters, must be specified if incl_map is True. Must contain keys {"px_per_m", "map_size_px"} and can optionally contain {"offset_frac_xy"}. Defaults to None. + incl_raster_map (bool, optional): Include a local cropping of the rasterized map (if the dataset provides a map) per agent. Defaults to False. + raster_map_params (Optional[Dict[str, Any]], optional): Local map cropping parameters, must be specified if incl_map is True. Must contain keys {"px_per_m", "map_size_px"} and can optionally contain {"offset_frac_xy"}. Defaults to None. + incl_vector_map (bool, optional): Include information about the scene's vector map (e.g., for use in nearest lane queries as an `extras` batch element function), + vector_map_params (Optional[Dict[str, Any]], optional): Vector map loading parameters. Defaults to None (by default only road lanes will be loaded as part of the map). + require_map_cache (bool, optional): Cache map objects (if the dataset provides a map) regardless of the value of incl_map. Defaults to True. only_types (Optional[List[AgentType]], optional): Filter out all agents EXCEPT for those of the specified types. Defaults to None. only_predict (Optional[List[AgentType]], optional): Only predict the specified types of agents. Importantly, this keeps other agent types in the scene, e.g., as neighbors of the agent to be predicted. Defaults to None. no_types (Optional[List[AgentType]], optional): Filter out all agents with the specified types. Defaults to None. @@ -124,33 +149,53 @@ def __init__( num_workers (int, optional): Number of parallel workers to use for dataset preprocessing and loading. Defaults to 0. verbose (bool, optional): If True, print internal data loading information. Defaults to False. extras (Dict[str, Callable[..., np.ndarray]], optional): Adds extra data to each batch element. Each Callable must take as input a filled {Agent,Scene}BatchElement and return an ndarray which will subsequently be added to the batch element's `extra` dict. + transforms (Iterable[Callable], optional): Allows for custom modifications of batch elements. Each Callable must take in a filled {Agent,Scene}BatchElement and return a {Agent,Scene}BatchElement. + rank (int, optional): Proccess rank when using torch DistributedDataParallel for multi-GPU training. Only the rank 0 process will be used for caching. """ self.centric: str = centric self.desired_dt: float = desired_dt if cache_type == "dataframe": - self.cache_class = DataFrameCache + self.cache_class = df_cache.DataFrameCache self.rebuild_cache: bool = rebuild_cache self.cache_path: Path = Path(cache_location).expanduser().resolve() self.cache_path.mkdir(parents=True, exist_ok=True) self.env_cache: EnvCache = EnvCache(self.cache_path) - if incl_map: + if incl_raster_map: assert ( - map_params is not None + raster_map_params is not None ), r"Path size information, i.e., {'px_per_m': ..., 'map_size_px': ...}, must be provided if incl_map=True" assert ( - map_params["map_size_px"] % 2 == 0 + raster_map_params["map_size_px"] % 2 == 0 ), "Patch parameter 'map_size_px' must be divisible by 2" + require_map_cache = require_map_cache or incl_raster_map + self.history_sec = history_sec self.future_sec = future_sec self.agent_interaction_distances = agent_interaction_distances self.incl_robot_future = incl_robot_future - self.incl_map = incl_map - self.map_params = ( - map_params if map_params is not None else {"px_per_m": DEFAULT_PX_PER_M} + self.incl_raster_map = incl_raster_map + self.raster_map_params = ( + raster_map_params + if raster_map_params is not None + else {"px_per_m": DEFAULT_PX_PER_M} + ) + self.incl_vector_map = incl_vector_map + self.vector_map_params = ( + vector_map_params + if vector_map_params is not None + else { + "incl_road_lanes": True, + "incl_road_areas": False, + "incl_ped_crosswalks": False, + "incl_ped_walkways": False, + # Collation can be quite slow if vector maps are included, + # so we do not unless the user requests it. + "no_collate": True, + } ) self.only_types = None if only_types is None else set(only_types) self.only_predict = None if only_predict is None else set(only_predict) @@ -159,6 +204,7 @@ def __init__( self.standardize_derivatives = standardize_derivatives self.augmentations = augmentations self.extras = extras + self.transforms = transforms self.verbose = verbose self.max_agent_num = max_agent_num self.max_neighbor_num = max_neighbor_num @@ -179,12 +225,15 @@ def __init__( flush=True, ) + self._map_api: Optional[MapAPI] = None + if self.incl_vector_map: + self._map_api = MapAPI(self.cache_path) + all_scenes_list: Union[List[SceneMetadata], List[Scene]] = list() for env in self.envs: if any(env.name in dataset_tuple for dataset_tuple in matching_datasets): all_data_cached: bool = False - all_maps_cached: bool = not env.has_maps or not self.incl_map - + all_maps_cached: bool = not env.has_maps or not require_map_cache if self.env_cache.env_is_cached(env.name) and not self.rebuild_cache: scenes_list: List[Scene] = self.get_desired_scenes_from_env( matching_datasets, scene_description_contains, env @@ -199,13 +248,13 @@ def __init__( all_maps_cached: bool = ( not env.has_maps - or not self.incl_map + or not require_map_cache or all( self.cache_class.is_map_cached( self.cache_path, env.name, scene.location, - self.map_params["px_per_m"], + self.raster_map_params["px_per_m"], ) for scene in scenes_list ) @@ -228,16 +277,29 @@ def __init__( self.cache_path, env.name ) ): - env.cache_maps( - self.cache_path, - self.cache_class, - self.map_params, - ) + # Use only rank 0 process for caching when using multi-GPU torch training. + if rank == 0: + env.cache_maps( + self.cache_path, + self.cache_class, + self.raster_map_params, + ) + + # Wait for rank 0 process to be done with caching. + if ( + distributed.is_initialized() + and distributed.get_world_size() > 1 + ): + distributed.barrier() scenes_list: List[SceneMetadata] = self.get_desired_scenes_from_env( matching_datasets, scene_description_contains, env ) + if self.incl_vector_map: + for map_name in env.metadata.map_locations: + self._map_api.get_map(f"{env.name}:{map_name}") + all_scenes_list += scenes_list # List of cached scene paths. @@ -271,6 +333,125 @@ def __init__( ) self._data_len: int = len(self._data_index) + self._cached_batch_elements = None + + def load_or_create_cache( + self, cache_path: str, num_workers=0, filter_fn=None + ) -> None: + if isfile(cache_path): + print(f"Loading cache from {cache_path} ...", end="") + t = time.time() + with open(cache_path, "rb") as f: + self._cached_batch_elements, keep_mask = dill.load(f, encoding="latin1") + print(f" done in {time.time() - t:.1f}s.") + + else: + # Build cache + cached_batch_elements = [] + keep_mask = [] + + if num_workers <= 0: + cache_data_iterator = self + else: + # Use DataLoader as a generic multiprocessing framework. + # We set batchsize=1 and a custom collate function. + # In effect this will just call self.__getitem__ in parallel. + cache_data_iterator = DataLoader( + self, + batch_size=1, + num_workers=num_workers, + shuffle=False, + collate_fn=lambda xlist: xlist[0], + ) + + for element in tqdm( + cache_data_iterator, + desc=f"Caching batch elements ({num_workers} CPUs): ", + disable=False, + ): + if filter_fn is None or filter_fn(element): + cached_batch_elements.append(element) + keep_mask.append(True) + else: + keep_mask.append(False) + + # Just deletes the variable cache_data_iterator, + # not self (in case it is set to that)! + del cache_data_iterator + + print(f"Saving cache to {cache_path} ....", end="") + t = time.time() + with open(cache_path, "wb") as f: + dill.dump((cached_batch_elements, keep_mask), f) + print(f" done in {time.time() - t:.1f}s.") + + self._cached_batch_elements = cached_batch_elements + + # Verify + if len(keep_mask) != self._data_len: + raise ValueError("Current data and keep_mask lengths do not match!") + + # Remove unwanted elements + self.remove_elements(keep_mask=keep_mask) + + # Verify + if len(self._cached_batch_elements) != self._data_len: + raise ValueError("Current data and cached data lengths do not match!") + + def apply_filter( + self, + filter_fn: Callable[[Union[AgentBatchElement, SceneBatchElement]], bool], + num_workers: int = 0, + ) -> None: + keep_mask = [] + + if num_workers <= 0: + cache_data_iterator = self + else: + # Use DataLoader as a generic multiprocessing framework. + # We set batchsize=1 and a custom collate function. + # In effect this will just call self.__getitem__ in parallel. + cache_data_iterator = DataLoader( + self, + batch_size=1, + num_workers=num_workers, + shuffle=False, + collate_fn=lambda xlist: xlist[0], + ) + + for element in tqdm( + cache_data_iterator, + desc=f"Filtering dataset ({num_workers} CPUs): ", + disable=False, + ): + if filter_fn is None or filter_fn(element): + keep_mask.append(True) + else: + keep_mask.append(False) + + # Just deletes the variable cache_data_iterator, + # not self (in case it is set to that)! + del cache_data_iterator + + # Verify + if len(keep_mask) != self._data_len: + raise ValueError("Current data and keep_mask lengths do not match!") + + # Remove unwanted elements + self.remove_elements(keep_mask=keep_mask) + + def remove_elements(self, keep_mask: List[bool]): + assert len(keep_mask) == self._data_len + old_len = self._data_len + self._data_index = [ + self._data_index[i] for i in range(len(keep_mask)) if keep_mask[i] + ] + self._data_len = len(self._data_index) + + print( + f"Kept {self._data_len}/{old_len} elements, {self._data_len/old_len*100.0:.2f}%." + ) + def get_data_index( self, num_workers: int, scene_paths: List[Path] ) -> Union[ @@ -377,7 +558,7 @@ def _get_data_index_scene( (scene if ret_scene_info else None), scene_info_path, len(index_elems), - np.array(index_elems, dtype=np.int), + np.array(index_elems, dtype=int), ) @staticmethod @@ -422,7 +603,7 @@ def _get_data_index_agent( num_agent_ts: int = valid_ts[1] - valid_ts[0] + 1 if num_agent_ts > 0: index_elems_len += num_agent_ts - index_elems.append((agent_info.name, np.array(valid_ts, dtype=np.int))) + index_elems.append((agent_info.name, np.array(valid_ts, dtype=int))) return ( (scene if ret_scene_info else None), @@ -548,7 +729,7 @@ def preprocess_scene_data( scene_dt: float = ( self.desired_dt if self.desired_dt is not None else scene_info.dt ) - if self.env_cache.scene_is_cached( + if not self.rebuild_cache and self.env_cache.scene_is_cached( scene_info.env_name, scene_info.name, scene_dt ): # This is a fast path in case we don't need to @@ -639,7 +820,11 @@ def preprocess_scene_data( desc=f"Calculating Agent Data ({num_workers} CPUs)", disable=not self.verbose, ): - scene_paths += [Path(path_str) for path_str in processed_scene_paths] + scene_paths += [ + Path(path_str) + for path_str in processed_scene_paths + if path_str is not None + ] return scene_paths @@ -655,10 +840,17 @@ def scenes(self) -> Scene: for scene_idx in range(self.num_scenes()): yield self.get_scene(scene_idx) + def __iter__(self): + for i in range(len(self)): + yield self[i] + def __len__(self) -> int: return self._data_len def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: + if self._cached_batch_elements is not None: + return self._cached_batch_elements[idx] + if self.centric == "scene": scene_path, ts = self._data_index[idx] elif self.centric == "agent": @@ -667,7 +859,7 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: scene: Scene = EnvCache.load(scene_path) scene_utils.enforce_desired_dt(scene, self.desired_dt) scene_cache: SceneCache = self.cache_class( - self.cache_path, scene, ts, self.augmentations + self.cache_path, scene, self.augmentations ) if self.centric == "scene": @@ -687,8 +879,10 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: self.future_sec, self.agent_interaction_distances, self.incl_robot_future, - self.incl_map, - self.map_params, + self.incl_raster_map, + self.raster_map_params, + self._map_api, + self.vector_map_params, self.standardize_data, self.standardize_derivatives, self.max_agent_num, @@ -712,8 +906,10 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: self.future_sec, self.agent_interaction_distances, self.incl_robot_future, - self.incl_map, - self.map_params, + self.incl_raster_map, + self.raster_map_params, + self._map_api, + self.vector_map_params, self.standardize_data, self.standardize_derivatives, self.max_neighbor_num, @@ -722,4 +918,10 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: for key, extra_fn in self.extras.items(): batch_element.extras[key] = extra_fn(batch_element) + for transform_fn in self.transforms: + batch_element = transform_fn(batch_element) + + if self.vector_map_params.get("no_collate", True): + batch_element.vec_map = None + return batch_element diff --git a/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py b/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py index 2bca501..bf17979 100644 --- a/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py +++ b/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py @@ -278,8 +278,6 @@ def get_agent_info( last_frame: int = frames.iat[-1].item() if frames.shape[0] < last_frame - start_frame + 1: - # Fun fact: this is never hit which means Lyft has no missing - # timesteps (which could be caused by, e.g., occlusion). raise ValueError("ETH/UCY indeed can have missing frames :(") agent_metadata = AgentMetadata( diff --git a/src/trajdata/dataset_specific/lyft/lyft_dataset.py b/src/trajdata/dataset_specific/lyft/lyft_dataset.py index aeb8634..de61994 100644 --- a/src/trajdata/dataset_specific/lyft/lyft_dataset.py +++ b/src/trajdata/dataset_specific/lyft/lyft_dataset.py @@ -1,19 +1,14 @@ -import warnings from collections import defaultdict from functools import partial -from math import ceil from pathlib import Path from random import Random from typing import Any, Dict, List, Optional, Tuple, Type -import l5kit.data.proto.road_network_pb2 as l5_pb2 import numpy as np import pandas as pd from l5kit.configs.config import load_metadata from l5kit.data import ChunkedDataset, LocalDataManager -from l5kit.data.map_api import InterpolationMethod, MapAPI -from l5kit.rasterization import RenderContext -from tqdm import tqdm +from l5kit.data.map_api import MapAPI from trajdata.caching import EnvCache, SceneCache from trajdata.data_structures import ( @@ -25,16 +20,9 @@ ) from trajdata.data_structures.agent import Agent, AgentType, VariableExtent from trajdata.dataset_specific.lyft import lyft_utils -from trajdata.dataset_specific.lyft.rasterizer import MapSemanticRasterizer from trajdata.dataset_specific.raw_dataset import RawDataset from trajdata.dataset_specific.scene_records import LyftSceneRecord -from trajdata.maps import RasterizedMap, RasterizedMapMetadata, map_utils -from trajdata.proto.vectorized_map_pb2 import ( - MapElement, - PedCrosswalk, - RoadLane, - VectorizedMap, -) +from trajdata.maps import VectorMap from trajdata.utils import arr_utils @@ -82,6 +70,8 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: ] scene_split_map = defaultdict(partial(const_lambda, const_val="val")) + else: + raise ValueError(f"Unknown Lyft environment name: {env_name}") return EnvMetadata( name=env_name, @@ -89,6 +79,9 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: dt=lyft_utils.LYFT_DT, parts=dataset_parts, scene_split_map=scene_split_map, + # The location names should match the map names used in + # the unified data cache. + map_locations=("palo_alto",), ) def load_dataset_obj(self, verbose: bool = False) -> None: @@ -329,80 +322,6 @@ def get_agent_info( return agent_list, agent_presence - def extract_vectorized(self, mapAPI: MapAPI) -> VectorizedMap: - vec_map = VectorizedMap() - maximum_bound: np.ndarray = np.full((3,), np.nan) - minimum_bound: np.ndarray = np.full((3,), np.nan) - for l5_element in tqdm(mapAPI.elements, desc="Creating Vectorized Map"): - if mapAPI.is_lane(l5_element): - l5_element_id: str = mapAPI.id_as_str(l5_element.id) - l5_lane: l5_pb2.Lane = l5_element.element.lane - - lane_dict = mapAPI.get_lane_coords(l5_element_id) - left_pts = lane_dict["xyz_left"] - right_pts = lane_dict["xyz_right"] - - # Ensuring the left and right bounds have the same numbers of points. - if len(left_pts) < len(right_pts): - left_pts = mapAPI.interpolate( - left_pts, len(right_pts), InterpolationMethod.INTER_ENSURE_LEN - ) - elif len(right_pts) < len(left_pts): - right_pts = mapAPI.interpolate( - right_pts, len(left_pts), InterpolationMethod.INTER_ENSURE_LEN - ) - - midlane_pts: np.ndarray = (left_pts + right_pts) / 2 - - # Computing the maximum and minimum map coordinates. - maximum_bound = np.fmax(maximum_bound, left_pts.max(axis=0)) - minimum_bound = np.fmin(minimum_bound, left_pts.min(axis=0)) - - maximum_bound = np.fmax(maximum_bound, right_pts.max(axis=0)) - minimum_bound = np.fmin(minimum_bound, right_pts.min(axis=0)) - - maximum_bound = np.fmax(maximum_bound, midlane_pts.max(axis=0)) - minimum_bound = np.fmin(minimum_bound, midlane_pts.min(axis=0)) - - # Adding the element to the map. - new_element: MapElement = vec_map.elements.add() - new_element.id = l5_element.id.id - - new_lane: RoadLane = new_element.road_lane - map_utils.populate_lane_polylines( - new_lane, midlane_pts, left_pts, right_pts - ) - - new_lane.exit_lanes.extend([gid.id for gid in l5_lane.lanes_ahead]) - new_lane.adjacent_lanes_left.append( - l5_lane.adjacent_lane_change_left.id - ) - new_lane.adjacent_lanes_right.append( - l5_lane.adjacent_lane_change_right.id - ) - - if mapAPI.is_crosswalk(l5_element): - l5_element_id: str = mapAPI.id_as_str(l5_element.id) - crosswalk_pts: np.ndarray = mapAPI.get_crosswalk_coords(l5_element_id)[ - "xyz" - ] - - # Computing the maximum and minimum map coordinates. - maximum_bound = np.fmax(maximum_bound, crosswalk_pts.max(axis=0)) - minimum_bound = np.fmin(minimum_bound, crosswalk_pts.min(axis=0)) - - new_element: MapElement = vec_map.elements.add() - new_element.id = l5_element.id.id - - new_crosswalk: PedCrosswalk = new_element.ped_crosswalk - map_utils.populate_polygon(new_crosswalk.polygon, crosswalk_pts) - - # Setting the map bounds. - vec_map.max_pt.x, vec_map.max_pt.y, vec_map.max_pt.z = maximum_bound - vec_map.min_pt.x, vec_map.min_pt.y, vec_map.min_pt.z = minimum_bound - - return vec_map - def cache_maps( self, cache_path: Path, @@ -421,75 +340,7 @@ def cache_maps( world_to_ecef = np.array(dataset_meta["world_to_ecef"], dtype=np.float64) mapAPI = MapAPI(semantic_map_filepath, world_to_ecef) - if map_params.get("original_format", False): - warnings.warn( - "Using a dataset's original map format is deprecated, and will be removed in the next version of trajdata!", - FutureWarning, - ) - - mins = np.stack( - [ - map_elem["bounds"][:, 0].min(axis=0) - for map_elem in mapAPI.bounds_info.values() - ] - ).min(axis=0) - maxs = np.stack( - [ - map_elem["bounds"][:, 1].max(axis=0) - for map_elem in mapAPI.bounds_info.values() - ] - ).max(axis=0) - - world_right, world_top = maxs - world_left, world_bottom = mins - - world_center: np.ndarray = np.array( - [(world_left + world_right) / 2, (world_bottom + world_top) / 2] - ) - raster_size_px: np.ndarray = np.array( - [ - ceil((world_right - world_left) * resolution), - ceil((world_top - world_bottom) * resolution), - ] - ) - - render_context = RenderContext( - raster_size_px=raster_size_px, - pixel_size_m=np.array([1 / resolution, 1 / resolution]), - center_in_raster_ratio=np.array([0.5, 0.5]), - set_origin_to_bottom=False, - ) - - map_from_world: np.ndarray = render_context.raster_from_world( - world_center, 0.0 - ) - - rasterizer = MapSemanticRasterizer( - render_context, semantic_map_filepath, world_to_ecef - ) + vector_map = VectorMap(map_id=f"{self.name}:{map_name}") + lyft_utils.populate_vector_map(vector_map, mapAPI) - print("Rendering palo_alto Map...", flush=True, end=" ") - map_data: np.ndarray = rasterizer.render_semantic_map( - world_center, map_from_world - ) - print("done!", flush=True) - - vectorized_map = VectorizedMap() - else: - vectorized_map: VectorizedMap = self.extract_vectorized(mapAPI) - map_data, map_from_world = map_utils.rasterize_map( - vectorized_map, resolution - ) - - rasterized_map_info: RasterizedMapMetadata = RasterizedMapMetadata( - name=map_name, - shape=map_data.shape, - layers=["drivable_area", "lane_divider", "ped_area"], - layer_rgb_groups=([0], [1], [2]), - resolution=resolution, - map_from_world=map_from_world, - ) - rasterized_map_obj: RasterizedMap = RasterizedMap(rasterized_map_info, map_data) - map_cache_class.cache_map( - cache_path, vectorized_map, rasterized_map_obj, self.name - ) + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) diff --git a/src/trajdata/dataset_specific/lyft/lyft_utils.py b/src/trajdata/dataset_specific/lyft/lyft_utils.py index 0fc6528..25cecb9 100644 --- a/src/trajdata/dataset_specific/lyft/lyft_utils.py +++ b/src/trajdata/dataset_specific/lyft/lyft_utils.py @@ -1,9 +1,12 @@ -from typing import Final, List +from typing import Dict, Final, List +import l5kit.data.proto.road_network_pb2 as l5_pb2 import numpy as np import pandas as pd from l5kit.data import ChunkedDataset +from l5kit.data.map_api import InterpolationMethod, MapAPI from l5kit.geometry import rotation33_as_yaw +from tqdm import tqdm from trajdata.data_structures import ( Agent, @@ -13,6 +16,9 @@ Scene, VariableExtent, ) +from trajdata.maps.vec_map import VectorMap +from trajdata.maps.vec_map_elements import PedCrosswalk, Polyline, RoadLane +from trajdata.utils import map_utils LYFT_DT: Final[float] = 0.1 @@ -93,3 +99,82 @@ def lyft_type_to_unified_type(lyft_type: int) -> AgentType: return AgentType.MOTORCYCLE elif lyft_type == 14: return AgentType.PEDESTRIAN + + +def populate_vector_map(vector_map: VectorMap, mapAPI: MapAPI) -> None: + maximum_bound: np.ndarray = np.full((3,), np.nan) + minimum_bound: np.ndarray = np.full((3,), np.nan) + for l5_element in tqdm(mapAPI.elements, desc="Creating Vectorized Map"): + if mapAPI.is_lane(l5_element): + l5_element_id: str = mapAPI.id_as_str(l5_element.id) + l5_lane: l5_pb2.Lane = l5_element.element.lane + + lane_dict = mapAPI.get_lane_coords(l5_element_id) + left_pts = lane_dict["xyz_left"] + right_pts = lane_dict["xyz_right"] + + # Ensuring the left and right bounds have the same numbers of points. + if len(left_pts) < len(right_pts): + left_pts = mapAPI.interpolate( + left_pts, len(right_pts), InterpolationMethod.INTER_ENSURE_LEN + ) + elif len(right_pts) < len(left_pts): + right_pts = mapAPI.interpolate( + right_pts, len(left_pts), InterpolationMethod.INTER_ENSURE_LEN + ) + + midlane_pts: np.ndarray = (left_pts + right_pts) / 2 + + # Computing the maximum and minimum map coordinates. + maximum_bound = np.fmax(maximum_bound, left_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, left_pts.min(axis=0)) + + maximum_bound = np.fmax(maximum_bound, right_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, right_pts.min(axis=0)) + + maximum_bound = np.fmax(maximum_bound, midlane_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, midlane_pts.min(axis=0)) + + # Adding the element to the map. + new_lane = RoadLane( + id=l5_element_id, + center=Polyline(midlane_pts), + left_edge=Polyline(left_pts), + right_edge=Polyline(right_pts), + ) + + new_lane.next_lanes.update( + [mapAPI.id_as_str(gid) for gid in l5_lane.lanes_ahead] + ) + + left_lane_change_id: str = mapAPI.id_as_str( + l5_lane.adjacent_lane_change_left + ) + if left_lane_change_id: + new_lane.adj_lanes_left.add(left_lane_change_id) + + right_lane_change_id: str = mapAPI.id_as_str( + l5_lane.adjacent_lane_change_right + ) + if right_lane_change_id: + new_lane.adj_lanes_right.add(right_lane_change_id) + + vector_map.add_map_element(new_lane) + + if mapAPI.is_crosswalk(l5_element): + l5_element_id: str = mapAPI.id_as_str(l5_element.id) + crosswalk_pts: np.ndarray = mapAPI.get_crosswalk_coords(l5_element_id)[ + "xyz" + ] + + # Computing the maximum and minimum map coordinates. + maximum_bound = np.fmax(maximum_bound, crosswalk_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, crosswalk_pts.min(axis=0)) + + vector_map.add_map_element( + PedCrosswalk(id=l5_element_id, polygon=Polyline(crosswalk_pts)) + ) + + # Setting the map bounds. + # vector_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + vector_map.extent = np.concatenate((minimum_bound, maximum_bound)) diff --git a/src/trajdata/dataset_specific/lyft/rasterizer.py b/src/trajdata/dataset_specific/lyft/rasterizer.py deleted file mode 100644 index c3bbe53..0000000 --- a/src/trajdata/dataset_specific/lyft/rasterizer.py +++ /dev/null @@ -1,124 +0,0 @@ -from collections import defaultdict -from typing import Dict - -import cv2 -import numpy as np -from l5kit.data.map_api import InterpolationMethod -from l5kit.geometry import transform_points -from l5kit.rasterization.semantic_rasterizer import ( - CV2_SUB_VALUES, - INTERPOLATION_POINTS, - RasterEls, - SemanticRasterizer, - cv2_subpixel, -) - - -def indices_in_bounds( - center: np.ndarray, bounds: np.ndarray, half_extent: float -) -> np.ndarray: - """ - Get indices of elements for which the bounding box described by bounds intersects the one defined around - center (square with side 2*half_side) - - Args: - center (float): XY of the center - bounds (np.ndarray): array of shape Nx2x2 [[x_min,y_min],[x_max, y_max]] - half_extent (float): half the side of the bounding box centered around center - - Returns: - np.ndarray: indices of elements inside radius from center - """ - return np.arange(bounds.shape[0], dtype=np.long) - - -class MapSemanticRasterizer(SemanticRasterizer): - def render_semantic_map( - self, center_in_world: np.ndarray, raster_from_world: np.ndarray - ) -> np.ndarray: - """Renders the semantic map at given x,y coordinates. - - Args: - center_in_world (np.ndarray): XY of the image center in world ref system - raster_from_world (np.ndarray): - Returns: - np.ndarray: RGB raster - - """ - lane_area_img: np.ndarray = np.zeros( - shape=(self.raster_size[1], self.raster_size[0], 3), dtype=np.uint8 - ) - lane_line_img: np.ndarray = np.zeros( - shape=(self.raster_size[1], self.raster_size[0], 3), dtype=np.uint8 - ) - ped_area_img: np.ndarray = np.zeros( - shape=(self.raster_size[1], self.raster_size[0], 3), dtype=np.uint8 - ) - - # filter using half a radius from the center - raster_radius = float(np.linalg.norm(self.raster_size * self.pixel_size)) / 2 - - # get all lanes as interpolation so that we can transform them all together - lane_indices = indices_in_bounds( - center_in_world, self.mapAPI.bounds_info["lanes"]["bounds"], raster_radius - ) - lanes_mask: Dict[str, np.ndarray] = defaultdict( - lambda: np.zeros(len(lane_indices) * 2, dtype=np.bool) - ) - lanes_area = np.zeros((len(lane_indices) * 2, INTERPOLATION_POINTS, 2)) - - for idx, lane_idx in enumerate(lane_indices): - lane_idx = self.mapAPI.bounds_info["lanes"]["ids"][lane_idx] - - # interpolate over polyline to always have the same number of points - lane_coords = self.mapAPI.get_lane_as_interpolation( - lane_idx, INTERPOLATION_POINTS, InterpolationMethod.INTER_ENSURE_LEN - ) - lanes_area[idx * 2] = lane_coords["xyz_left"][:, :2] - lanes_area[idx * 2 + 1] = lane_coords["xyz_right"][::-1, :2] - - lanes_mask[RasterEls.LANE_NOTL.name][idx * 2 : idx * 2 + 2] = True - - if len(lanes_area): - lanes_area = cv2_subpixel( - transform_points(lanes_area.reshape((-1, 2)), raster_from_world) - ) - - for lane_area in lanes_area.reshape((-1, INTERPOLATION_POINTS * 2, 2)): - # need to for-loop otherwise some of them are empty - cv2.fillPoly(lane_area_img, [lane_area], (255, 0, 0), **CV2_SUB_VALUES) - - lanes_area = lanes_area.reshape((-1, INTERPOLATION_POINTS, 2)) - for ( - name, - mask, - ) in lanes_mask.items(): # draw each type of lane with its own color - cv2.polylines( - lane_line_img, - lanes_area[mask], - False, - (0, 255, 0), - **CV2_SUB_VALUES - ) - - # plot crosswalks - crosswalks = [] - for idx in indices_in_bounds( - center_in_world, - self.mapAPI.bounds_info["crosswalks"]["bounds"], - raster_radius, - ): - crosswalk = self.mapAPI.get_crosswalk_coords( - self.mapAPI.bounds_info["crosswalks"]["ids"][idx] - ) - xy_cross = cv2_subpixel( - transform_points(crosswalk["xyz"][:, :2], raster_from_world) - ) - crosswalks.append(xy_cross) - - cv2.fillPoly(ped_area_img, crosswalks, (0, 0, 255), **CV2_SUB_VALUES) - - map_img: np.ndarray = (lane_area_img + lane_line_img + ped_area_img).astype( - np.float32 - ) / 255 - return map_img.transpose(2, 0, 1) diff --git a/src/trajdata/dataset_specific/nuplan/__init__.py b/src/trajdata/dataset_specific/nuplan/__init__.py new file mode 100644 index 0000000..022a1a0 --- /dev/null +++ b/src/trajdata/dataset_specific/nuplan/__init__.py @@ -0,0 +1 @@ +from .nuplan_dataset import NuplanDataset diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py new file mode 100644 index 0000000..5504f6d --- /dev/null +++ b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py @@ -0,0 +1,402 @@ +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type + +import numpy as np +import pandas as pd +from nuplan.common.maps.nuplan_map import map_factory +from nuplan.common.maps.nuplan_map.nuplan_map import NuPlanMap +from tqdm import tqdm + +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures.agent import ( + Agent, + AgentMetadata, + AgentType, + FixedExtent, + VariableExtent, +) +from trajdata.data_structures.environment import EnvMetadata +from trajdata.data_structures.scene_metadata import Scene, SceneMetadata +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.nuplan import nuplan_utils +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import NuPlanSceneRecord +from trajdata.maps.vec_map import VectorMap +from trajdata.utils import arr_utils + + +class NuplanDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + all_log_splits: Dict[str, List[str]] = nuplan_utils.create_splits_logs() + + nup_log_splits: Dict[str, List[str]] + if env_name == "nuplan_mini": + nup_log_splits = { + k: all_log_splits[k[5:]] + for k in ["mini_train", "mini_val", "mini_test"] + } + + # nuScenes possibilities are the Cartesian product of these + dataset_parts: List[Tuple[str, ...]] = [ + ("mini_train", "mini_val", "mini_test"), + nuplan_utils.NUPLAN_LOCATIONS, + ] + elif env_name.startswith("nuplan"): + split_str = env_name.split("_")[-1] + nup_log_splits = {split_str: all_log_splits[split_str]} + + # nuScenes possibilities are the Cartesian product of these + dataset_parts: List[Tuple[str, ...]] = [ + (split_str,), + nuplan_utils.NUPLAN_LOCATIONS, + ] + else: + raise ValueError(f"Unknown nuPlan environment name: {env_name}") + + # Inverting the dict from above, associating every log with its data split. + nup_log_split_map: Dict[str, str] = { + v_elem: k for k, v in nup_log_splits.items() for v_elem in v + } + + return EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=nuplan_utils.NUPLAN_DT, + parts=dataset_parts, + scene_split_map=nup_log_split_map, + # The location names should match the map names used in + # the unified data cache. + map_locations=nuplan_utils.NUPLAN_LOCATIONS, + ) + + def load_dataset_obj(self, verbose: bool = False) -> None: + if verbose: + print(f"Loading {self.name} dataset...", flush=True) + + if self.name == "nuplan_mini": + subfolder = "mini" + elif self.name.startswith("nuplan"): + subfolder = "trainval" + + self.dataset_obj = nuplan_utils.NuPlanObject(self.metadata.data_dir, subfolder) + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_scenes_list: List[NuPlanSceneRecord] = list() + + default_split = "mini_train" if "mini" in self.metadata.name else "train" + + scenes_list: List[SceneMetadata] = list() + for idx, scene_record in enumerate(self.dataset_obj.scenes): + scene_name: str = scene_record["name"] + originating_log: str = scene_name.split("=")[0] + # scene_desc: str = scene_record["description"].lower() + scene_location: str = scene_record["location"] + scene_split: str = self.metadata.scene_split_map.get( + originating_log, default_split + ) + scene_length: int = scene_record["num_timesteps"] + + if scene_length == 1: + # nuPlan has scenes with only a single frame of data which we + # can't do much with in terms of prediction/planning/etc. As a + # result, we skip it. + # As an example, nuplan_mini scene e958b276c7a65197 + # from log 2021.06.14.19.22.11_veh-38_01480_01860. + continue + + # Saving all scene records for later caching. + all_scenes_list.append( + NuPlanSceneRecord( + scene_name, + scene_location, + scene_length, + scene_split, + # scene_desc, + idx, + ) + ) + + if ( + scene_location in scene_tag + and scene_split in scene_tag + and scene_desc_contains is None + ): + # if scene_desc_contains is not None and not any( + # desc_query in scene_desc for desc_query in scene_desc_contains + # ): + # continue + + scene_metadata = SceneMetadata( + env_name=self.metadata.name, + name=scene_name, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + scenes_list.append(scene_metadata) + + self.cache_all_scenes_list(env_cache, all_scenes_list) + return scenes_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_scenes_list: List[NuPlanSceneRecord] = env_cache.load_env_scenes_list( + self.name + ) + + scenes_list: List[SceneMetadata] = list() + for scene_record in all_scenes_list: + ( + scene_name, + scene_location, + scene_length, + scene_split, + # scene_desc, + data_idx, + ) = scene_record + + if ( + scene_location in scene_tag + and scene_split in scene_tag + and scene_desc_contains is None + ): + # if scene_desc_contains is not None and not any( + # desc_query in scene_desc for desc_query in scene_desc_contains + # ): + # continue + + scene_metadata = Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + None, # This isn't used if everything is already cached. + # scene_desc, + ) + scenes_list.append(scene_metadata) + + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + _, _, _, data_idx = scene_info + default_split = "mini_train" if "mini" in self.metadata.name else "train" + + scene_record: Dict[str, str] = self.dataset_obj.scenes[data_idx] + + scene_name: str = scene_record["name"] + originating_log: str = scene_name.split("=")[0] + # scene_desc: str = scene_record["description"].lower() + scene_location: str = scene_record["location"] + scene_split: str = self.metadata.scene_split_map.get( + originating_log, default_split + ) + scene_length: int = scene_record["num_timesteps"] + + return Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + scene_record, + # scene_desc, + ) + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + # instantiate VectorMap from map_api if necessary + self.dataset_obj.open_db(scene.name.split("=")[0] + ".db") + + ego_agent_info: AgentMetadata = AgentMetadata( + name="ego", + agent_type=AgentType.VEHICLE, + first_timestep=0, + last_timestep=scene.length_timesteps - 1, + # From https://github.com/motional/nuplan-devkit/blob/761cdbd52d699560629c79ba1b10b29c18ebc068/nuplan/common/actor_state/vehicle_parameters.py#L125 + extent=FixedExtent(length=4.049 + 1.127, width=1.1485 * 2.0, height=1.777), + ) + + agent_list: List[AgentMetadata] = [ego_agent_info] + agent_presence: List[List[AgentMetadata]] = [ + [ego_agent_info] for _ in range(scene.length_timesteps) + ] + + all_frames: pd.DataFrame = self.dataset_obj.get_scene_frames(scene) + + ego_df = ( + all_frames[["ego_x", "ego_y", "ego_vx", "ego_vy", "ego_ax", "ego_ay"]] + .rename(columns=lambda name: name[4:]) + .reset_index(drop=True) + ) + ego_df["heading"] = arr_utils.quaternion_to_yaw( + all_frames[["ego_qw", "ego_qx", "ego_qy", "ego_qz"]].values + ) + ego_df["scene_ts"] = np.arange(len(ego_df)) + ego_df["agent_id"] = "ego" + + lpc_tokens: List[bytearray] = all_frames.index.tolist() + agents_df: pd.DataFrame = self.dataset_obj.get_detected_agents(lpc_tokens) + tls_df: pd.DataFrame = self.dataset_obj.get_traffic_light_status(lpc_tokens) + + self.dataset_obj.close_db() + + agents_df["scene_ts"] = agents_df["lidar_pc_token"].map( + {lpc_token: scene_ts for scene_ts, lpc_token in enumerate(lpc_tokens)} + ) + agents_df["agent_id"] = agents_df["track_token"].apply(lambda x: x.hex()) + + # Recording agent metadata for later. + agent_metadata_dict: Dict[str, Dict[str, Any]] = dict() + for agent_id, agent_data in agents_df.groupby("agent_id").first().iterrows(): + if agent_id not in agent_metadata_dict: + agent_metadata_dict[agent_id] = { + "type": nuplan_utils.nuplan_type_to_unified_type( + agent_data["category_name"] + ), + "width": agent_data["width"], + "length": agent_data["length"], + "height": agent_data["height"], + } + + agents_df = agents_df.drop( + columns=[ + "lidar_pc_token", + "track_token", + "category_name", + "width", + "length", + "height", + ], + ).rename(columns={"yaw": "heading"}) + + # Sorting the agents' combined DataFrame here. + agents_df.set_index(["agent_id", "scene_ts"], inplace=True) + agents_df.sort_index(inplace=True) + agents_df.reset_index(level=1, inplace=True) + + one_detection_agents: List[str] = list() + for agent_id in agent_metadata_dict: + agent_metadata_entry = agent_metadata_dict[agent_id] + + agent_specific_df = agents_df.loc[agent_id] + if len(agent_specific_df.shape) <= 1 or agent_specific_df.shape[0] <= 1: + # Removing agents that are only observed once. + one_detection_agents.append(agent_id) + continue + + first_timestep: int = agent_specific_df.iat[0, 0].item() + last_timestep: int = agent_specific_df.iat[-1, 0].item() + agent_info: AgentMetadata = AgentMetadata( + name=agent_id, + agent_type=agent_metadata_entry["type"], + first_timestep=first_timestep, + last_timestep=last_timestep, + extent=FixedExtent( + length=agent_metadata_entry["length"], + width=agent_metadata_entry["width"], + height=agent_metadata_entry["height"], + ), + ) + + agent_list.append(agent_info) + for timestep in range( + agent_info.first_timestep, agent_info.last_timestep + 1 + ): + agent_presence[timestep].append(agent_info) + + # Removing agents with only one detection. + agents_df.drop(index=one_detection_agents, inplace=True) + + ### Calculating agent accelerations + agent_ids: np.ndarray = agents_df.index.get_level_values(0).to_numpy() + agents_df[["ax", "ay"]] = ( + arr_utils.agent_aware_diff(agents_df[["vx", "vy"]].to_numpy(), agent_ids) + / nuplan_utils.NUPLAN_DT + ) + + # for agent_id, frames in agents_df.groupby("agent_id")["scene_ts"]: + # if frames.shape[0] <= 1: + # raise ValueError("nuPlan can have one-detection agents :(") + + # start_frame: int = frames.iat[0].item() + # last_frame: int = frames.iat[-1].item() + + # if frames.shape[0] < last_frame - start_frame + 1: + # raise ValueError("nuPlan indeed can have missing frames :(") + + overall_agents_df = pd.concat([ego_df, agents_df.reset_index()]).set_index( + ["agent_id", "scene_ts"] + ) + cache_class.save_agent_data(overall_agents_df, cache_path, scene) + + # similar process to clean up and traffic light data + tls_df["scene_ts"] = tls_df["lidar_pc_token"].map( + {lpc_token: scene_ts for scene_ts, lpc_token in enumerate(lpc_tokens)} + ) + tls_df = tls_df.drop(columns=["lidar_pc_token"]).set_index( + ["lane_connector_id", "scene_ts"] + ) + + cache_class.save_traffic_light_data(tls_df, cache_path, scene) + + return agent_list, agent_presence + + def cache_map( + self, + map_name: str, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + nuplan_map: NuPlanMap = map_factory.get_maps_api( + map_root=str(self.metadata.data_dir.parent / "maps"), + map_version=nuplan_utils.NUPLAN_MAP_VERSION, + map_name=nuplan_utils.NUPLAN_FULL_MAP_NAME_DICT[map_name], + ) + + # Loading all layer geometries. + nuplan_map.initialize_all_layers() + + # This df has the normal lane_connectors with additional boundary information, + # which we want to use, however the default index is not the lane_connector_fid, + # although it is a 1:1 mapping so we instead create another index with the + # lane_connector_fids as the key and the resulting integer indices as the value. + lane_connector_fids: pd.Series = nuplan_map._vector_map[ + "gen_lane_connectors_scaled_width_polygons" + ]["lane_connector_fid"] + lane_connector_idxs: pd.Series = pd.Series( + index=lane_connector_fids, data=range(len(lane_connector_fids)) + ) + + vector_map = VectorMap(map_id=f"{self.name}:{map_name}") + nuplan_utils.populate_vector_map(vector_map, nuplan_map, lane_connector_idxs) + + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + """ + Stores rasterized maps to disk for later retrieval. + """ + for map_name in tqdm( + nuplan_utils.NUPLAN_LOCATIONS, + desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", + position=0, + ): + self.cache_map(map_name, cache_path, map_cache_class, map_params) diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_utils.py b/src/trajdata/dataset_specific/nuplan/nuplan_utils.py new file mode 100644 index 0000000..7f4726f --- /dev/null +++ b/src/trajdata/dataset_specific/nuplan/nuplan_utils.py @@ -0,0 +1,407 @@ +import glob +import sqlite3 +from collections import defaultdict +from pathlib import Path +from typing import Dict, Final, Generator, Iterable, List, Optional, Tuple + +import numpy as np +import nuplan.planning.script.config.common as common_cfg +import pandas as pd +import yaml +from nuplan.common.maps.nuplan_map.nuplan_map import NuPlanMap +from tqdm import tqdm + +from trajdata.data_structures.agent import AgentType +from trajdata.data_structures.scene_metadata import Scene +from trajdata.maps import TrafficLightStatus, VectorMap +from trajdata.maps.vec_map_elements import ( + MapElementType, + PedCrosswalk, + PedWalkway, + Polyline, + RoadArea, + RoadLane, +) +from trajdata.utils import map_utils + +NUPLAN_DT: Final[float] = 0.05 +NUPLAN_FULL_MAP_NAME_DICT: Final[Dict[str, str]] = { + "boston": "us-ma-boston", + "singapore": "sg-one-north", + "las_vegas": "us-nv-las-vegas-strip", + "pittsburgh": "us-pa-pittsburgh-hazelwood", +} +_NUPLAN_SQL_MAP_FRIENDLY_NAMES_DICT: Final[Dict[str, str]] = { + "us-ma-boston": "boston", + "sg-one-north": "singapore", + "las_vegas": "las_vegas", + "us-pa-pittsburgh-hazelwood": "pittsburgh", +} +NUPLAN_LOCATIONS: Final[Tuple[str, str, str, str]] = tuple( + NUPLAN_FULL_MAP_NAME_DICT.keys() +) +NUPLAN_MAP_VERSION: Final[str] = "nuplan-maps-v1.0" + +NUPLAN_TRAFFIC_STATUS_DICT: Final[Dict[str, TrafficLightStatus]] = { + "green": TrafficLightStatus.GREEN, + "red": TrafficLightStatus.RED, + "unknown": TrafficLightStatus.UNKNOWN, +} + + +class NuPlanObject: + def __init__(self, dataset_path: Path, subfolder: str) -> None: + self.base_path: Path = dataset_path / subfolder + + self.connection: sqlite3.Connection = None + self.cursor: sqlite3.Cursor = None + + self.scenes: List[Dict[str, str]] = self._load_scenes() + + def open_db(self, db_filename: str) -> None: + self.connection = sqlite3.connect(str(self.base_path / db_filename)) + self.connection.row_factory = sqlite3.Row + self.cursor = self.connection.cursor() + + def execute_query_one( + self, query_text: str, query_params: Optional[Iterable] = None + ) -> sqlite3.Row: + self.cursor.execute( + query_text, query_params if query_params is not None else [] + ) + return self.cursor.fetchone() + + def execute_query_all( + self, query_text: str, query_params: Optional[Iterable] = None + ) -> List[sqlite3.Row]: + self.cursor.execute( + query_text, query_params if query_params is not None else [] + ) + return self.cursor.fetchall() + + def execute_query_iter( + self, query_text: str, query_params: Optional[Iterable] = None + ) -> Generator[sqlite3.Row, None, None]: + self.cursor.execute( + query_text, query_params if query_params is not None else [] + ) + + for row in self.cursor: + yield row + + def _load_scenes(self) -> List[Dict[str, str]]: + scene_info_query = """ + SELECT sc.token AS scene_token, + log.location, + log.logfile, + ( + SELECT count(*) + FROM lidar_pc AS lpc + WHERE lpc.scene_token = sc.token + ) AS num_timesteps + FROM scene AS sc + LEFT JOIN log ON sc.log_token = log.token + """ + scenes: List[Dict[str, str]] = [] + + for log_filename in glob.glob(str(self.base_path / "*.db")): + self.open_db(log_filename) + + for row in self.execute_query_iter(scene_info_query): + scenes.append( + { + "name": f"{row['logfile']}={row['scene_token'].hex()}", + "location": _NUPLAN_SQL_MAP_FRIENDLY_NAMES_DICT[ + row["location"] + ], + "num_timesteps": row["num_timesteps"], + } + ) + + self.close_db() + + return scenes + + def get_scene_frames(self, scene: Scene) -> pd.DataFrame: + query = """ + SELECT lpc.token AS lpc_token, + ep.x AS ego_x, + ep.y AS ego_y, + ep.qw AS ego_qw, + ep.qx AS ego_qx, + ep.qy AS ego_qy, + ep.qz AS ego_qz, + ep.vx AS ego_vx, + ep.vy AS ego_vy, + ep.acceleration_x AS ego_ax, + ep.acceleration_y AS ego_ay + FROM lidar_pc AS lpc + LEFT JOIN ego_pose AS ep ON lpc.ego_pose_token = ep.token + WHERE scene_token = ? + ORDER BY lpc.timestamp ASC; + """ + log_filename, scene_token_str = scene.name.split("=") + scene_token = bytearray.fromhex(scene_token_str) + + return pd.read_sql_query( + query, self.connection, index_col="lpc_token", params=(scene_token,) + ) + + def get_detected_agents(self, binary_lpc_tokens: List[bytearray]) -> pd.DataFrame: + query = f""" + SELECT lb.lidar_pc_token, + lb.track_token, + (SELECT category.name FROM category WHERE category.token = tr.category_token) AS category_name, + tr.width, + tr.length, + tr.height, + lb.x, + lb.y, + lb.vx, + lb.vy, + lb.yaw + FROM lidar_box AS lb + LEFT JOIN track AS tr ON lb.track_token = tr.token + + WHERE lidar_pc_token IN ({('?,'*len(binary_lpc_tokens))[:-1]}) AND category_name IN ('vehicle', 'bicycle', 'pedestrian') + """ + return pd.read_sql_query(query, self.connection, params=binary_lpc_tokens) + + def get_traffic_light_status( + self, binary_lpc_tokens: List[bytearray] + ) -> pd.DataFrame: + query = f""" + SELECT tls.lidar_pc_token AS lidar_pc_token, + tls.lane_connector_id AS lane_connector_id, + tls.status AS raw_status + FROM traffic_light_status AS tls + WHERE lidar_pc_token IN ({('?,'*len(binary_lpc_tokens))[:-1]}); + """ + df = pd.read_sql_query(query, self.connection, params=binary_lpc_tokens) + df["status"] = df["raw_status"].map(NUPLAN_TRAFFIC_STATUS_DICT) + return df.drop(columns=["raw_status"]) + + def close_db(self) -> None: + self.cursor.close() + self.connection.close() + + +def nuplan_type_to_unified_type(nuplan_type: str) -> AgentType: + if nuplan_type == "pedestrian": + return AgentType.PEDESTRIAN + elif nuplan_type == "bicycle": + return AgentType.BICYCLE + elif nuplan_type == "vehicle": + return AgentType.VEHICLE + else: + return AgentType.UNKNOWN + + +def create_splits_logs() -> Dict[str, List[str]]: + yaml_filepath = Path(common_cfg.__path__[0]) / "splitter" / "nuplan.yaml" + with open(yaml_filepath, "r") as stream: + splits = yaml.safe_load(stream) + + return splits["log_splits"] + + +def extract_lane_and_edges( + nuplan_map: NuPlanMap, lane_record, lane_connector_idxs: pd.Series +) -> Tuple[str, np.ndarray, np.ndarray, np.ndarray, Tuple[str, str]]: + lane_midline = np.stack(lane_record["geometry"].xy, axis=-1) + + # Getting the bounding polygon vertices. + boundary_df = nuplan_map._vector_map["boundaries"] + if np.isfinite(lane_record["lane_fid"]): + fid = str(int(lane_record["lane_fid"])) + lane_info = nuplan_map._vector_map["lanes_polygons"].loc[fid] + elif np.isfinite(lane_record["lane_connector_fid"]): + fid = int(lane_record["lane_connector_fid"]) + lane_info = nuplan_map._vector_map[ + "gen_lane_connectors_scaled_width_polygons" + ].iloc[lane_connector_idxs[fid]] + else: + raise ValueError("Both lane_fid and lane_connector_fid are NaN!") + + lane_fid = str(fid) + boundary_info = ( + str(lane_info["left_boundary_fid"]), + str(lane_info["right_boundary_fid"]), + ) + + left_pts = np.stack(boundary_df.loc[boundary_info[0]]["geometry"].xy, axis=-1) + right_pts = np.stack(boundary_df.loc[boundary_info[1]]["geometry"].xy, axis=-1) + + # Final ordering check, ensuring that left_pts and right_pts can be combined + # into a polygon without the endpoints intersecting. + # Reversing the one lane edge that does not match the ordering of the midline. + if map_utils.endpoints_intersect(left_pts, right_pts): + if not map_utils.order_matches(left_pts, lane_midline): + left_pts = left_pts[::-1] + else: + right_pts = right_pts[::-1] + + # Ensuring that left and right have the same number of points. + # This is necessary, not for data storage but for later rasterization. + if left_pts.shape[0] < right_pts.shape[0]: + left_pts = map_utils.interpolate(left_pts, num_pts=right_pts.shape[0]) + elif right_pts.shape[0] < left_pts.shape[0]: + right_pts = map_utils.interpolate(right_pts, num_pts=left_pts.shape[0]) + + return (lane_fid, lane_midline, left_pts, right_pts, boundary_info) + + +def extract_area(nuplan_map: NuPlanMap, area_record) -> np.ndarray: + return np.stack(area_record["geometry"].exterior.xy, axis=-1) + + +def populate_vector_map( + vector_map: VectorMap, nuplan_map: NuPlanMap, lane_connector_idxs: pd.Series +) -> None: + # Setting the map bounds. + # NOTE: min_pt is especially important here since the world coordinates of nuPlan + # are quite large in magnitude. We make them relative to the bottom-left by + # subtracting all positions by min_pt and registering that offset as part of + # the map_from_world (and related) transforms later. + min_pt = np.min( + [ + layer_df["geometry"].total_bounds[:2] + for layer_df in nuplan_map._vector_map.values() + ], + axis=0, + ) + max_pt = np.max( + [ + layer_df["geometry"].total_bounds[2:] + for layer_df in nuplan_map._vector_map.values() + ], + axis=0, + ) + + # vector_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + vector_map.extent = np.array( + [ + min_pt[0], + min_pt[1], + 0.0, + max_pt[0], + max_pt[1], + 0.0, + ] + ) + + overall_pbar = tqdm( + total=len(nuplan_map._vector_map["baseline_paths"]) + + len(nuplan_map._vector_map["drivable_area"]) + + len(nuplan_map._vector_map["crosswalks"]) + + len(nuplan_map._vector_map["walkways"]), + desc=f"Getting {nuplan_map.map_name} Elements", + position=1, + leave=False, + ) + + # This dict stores boundary IDs and which lanes are to the left and right of them. + boundary_connectivity_dict: Dict[str, Dict[str, List[str]]] = defaultdict( + lambda: defaultdict(list) + ) + + # This dict stores lanes' boundary IDs. + lane_boundary_dict: Dict[str, Tuple[str, str]] = dict() + for _, lane_info in nuplan_map._vector_map["baseline_paths"].iterrows(): + ( + lane_id, + center_pts, + left_pts, + right_pts, + boundary_info, + ) = extract_lane_and_edges(nuplan_map, lane_info, lane_connector_idxs) + + lane_boundary_dict[lane_id] = boundary_info + left_boundary_id, right_boundary_id = boundary_info + + # The left boundary of Lane A has Lane A to its right. + boundary_connectivity_dict[left_boundary_id]["right"].append(lane_id) + + # The right boundary of Lane A has Lane A to its left. + boundary_connectivity_dict[right_boundary_id]["left"].append(lane_id) + + # "partial" because we aren't adding lane connectivity until later. + partial_new_lane = RoadLane( + id=lane_id, + center=Polyline(center_pts), + left_edge=Polyline(left_pts), + right_edge=Polyline(right_pts), + ) + vector_map.add_map_element(partial_new_lane) + overall_pbar.update() + + for fid, polygon_info in nuplan_map._vector_map["drivable_area"].iterrows(): + polygon_pts = extract_area(nuplan_map, polygon_info) + + new_road_area = RoadArea(id=fid, exterior_polygon=Polyline(polygon_pts)) + for hole in polygon_info["geometry"].interiors: + hole_pts = extract_area(nuplan_map, hole) + new_road_area.interior_holes.append(Polyline(hole_pts)) + + vector_map.add_map_element(new_road_area) + overall_pbar.update() + + for fid, ped_area_record in nuplan_map._vector_map["crosswalks"].iterrows(): + polygon_pts = extract_area(nuplan_map, ped_area_record) + + new_ped_crosswalk = PedCrosswalk(id=fid, polygon=Polyline(polygon_pts)) + vector_map.add_map_element(new_ped_crosswalk) + overall_pbar.update() + + for fid, ped_area_record in nuplan_map._vector_map["walkways"].iterrows(): + polygon_pts = extract_area(nuplan_map, ped_area_record) + + new_ped_walkway = PedWalkway(id=fid, polygon=Polyline(polygon_pts)) + vector_map.add_map_element(new_ped_walkway) + overall_pbar.update() + + overall_pbar.close() + + # Lane connectivity + lane_connectivity_exit_dict = defaultdict(list) + lane_connectivity_entry_dict = defaultdict(list) + for lane_connector_fid, lane_connector in tqdm( + nuplan_map._vector_map["lane_connectors"].iterrows(), + desc="Getting Lane Connectivity", + total=len(nuplan_map._vector_map["lane_connectors"]), + position=1, + leave=False, + ): + lane_connectivity_exit_dict[str(lane_connector["exit_lane_fid"])].append( + lane_connector_fid + ) + lane_connectivity_entry_dict[lane_connector_fid].append( + str(lane_connector["exit_lane_fid"]) + ) + + lane_connectivity_exit_dict[lane_connector_fid].append( + str(lane_connector["entry_lane_fid"]) + ) + lane_connectivity_entry_dict[str(lane_connector["entry_lane_fid"])].append( + lane_connector_fid + ) + + map_elem: RoadLane + for map_elem in tqdm( + vector_map.elements[MapElementType.ROAD_LANE].values(), + desc="Storing Lane Connectivity", + position=1, + leave=False, + ): + map_elem.prev_lanes.update(lane_connectivity_entry_dict[map_elem.id]) + map_elem.next_lanes.update(lane_connectivity_exit_dict[map_elem.id]) + + lane_id: str = map_elem.id + left_boundary_id, right_boundary_id = lane_boundary_dict[lane_id] + + map_elem.adj_lanes_left.update( + boundary_connectivity_dict[left_boundary_id]["left"] + ) + map_elem.adj_lanes_right.update( + boundary_connectivity_dict[right_boundary_id]["right"] + ) diff --git a/src/trajdata/dataset_specific/nusc/nusc_dataset.py b/src/trajdata/dataset_specific/nusc/nusc_dataset.py index 401658a..209824e 100644 --- a/src/trajdata/dataset_specific/nusc/nusc_dataset.py +++ b/src/trajdata/dataset_specific/nusc/nusc_dataset.py @@ -3,14 +3,11 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Type, Union -import numpy as np import pandas as pd from nuscenes.eval.prediction.splits import NUM_IN_TRAIN_VAL -from nuscenes.map_expansion import arcline_path_utils from nuscenes.map_expansion.map_api import NuScenesMap, locations from nuscenes.nuscenes import NuScenes from nuscenes.utils.splits import create_splits_scenes -from scipy.spatial.distance import cdist from tqdm import tqdm from trajdata.caching import EnvCache, SceneCache @@ -27,16 +24,7 @@ from trajdata.dataset_specific.nusc import nusc_utils from trajdata.dataset_specific.raw_dataset import RawDataset from trajdata.dataset_specific.scene_records import NuscSceneRecord -from trajdata.maps import RasterizedMap, RasterizedMapMetadata, map_utils -from trajdata.proto.vectorized_map_pb2 import ( - MapElement, - PedCrosswalk, - PedWalkway, - Polyline, - RoadArea, - RoadLane, - VectorizedMap, -) +from trajdata.maps import VectorMap class NuscDataset(RawDataset): @@ -82,6 +70,8 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: ("mini_train", "mini_val"), ("boston", "singapore"), ] + else: + raise ValueError(f"Unknown nuScenes environment name: {env_name}") # Inverting the dict from above, associating every scene with its data split. nusc_scene_split_map: Dict[str, str] = { @@ -94,6 +84,9 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: dt=nusc_utils.NUSC_DT, parts=dataset_parts, scene_split_map=nusc_scene_split_map, + # The location names should match the map names used in + # the unified data cache. + map_locations=tuple(locations), ) def load_dataset_obj(self, verbose: bool = False) -> None: @@ -274,223 +267,6 @@ def get_agent_info( return agent_list, agent_presence - def extract_lane_and_edges( - self, nusc_map: NuScenesMap, lane_record - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - # Getting the bounding polygon vertices. - lane_polygon_obj = nusc_map.get("polygon", lane_record["polygon_token"]) - polygon_nodes = [ - nusc_map.get("node", node_token) - for node_token in lane_polygon_obj["exterior_node_tokens"] - ] - polygon_pts: np.ndarray = np.array( - [(node["x"], node["y"]) for node in polygon_nodes] - ) - - # Getting the lane center's points. - curr_lane = nusc_map.arcline_path_3.get(lane_record["token"], []) - lane_midline: np.ndarray = np.array( - arcline_path_utils.discretize_lane(curr_lane, resolution_meters=0.5) - )[:, :2] - - # For some reason, nuScenes duplicates a few entries - # (likely how they're building their arcline representation). - # We delete those duplicate entries here. - duplicate_check: np.ndarray = np.where( - np.linalg.norm(np.diff(lane_midline, axis=0, prepend=0), axis=1) < 1e-10 - )[0] - if duplicate_check.size > 0: - lane_midline = np.delete(lane_midline, duplicate_check, axis=0) - - # Computing the closest lane center point to each bounding polygon vertex. - closest_midlane_pt: np.ndarray = np.argmin( - cdist(polygon_pts, lane_midline), axis=1 - ) - # Computing the local direction of the lane at each lane center point. - direction_vectors: np.ndarray = np.diff( - lane_midline, - axis=0, - prepend=lane_midline[[0]] - (lane_midline[[1]] - lane_midline[[0]]), - ) - - # Selecting the direction vectors at the closest lane center point per polygon vertex. - local_dir_vecs: np.ndarray = direction_vectors[closest_midlane_pt] - # Calculating the vectors from the the closest lane center point per polygon vertex to the polygon vertex. - origin_to_polygon_vecs: np.ndarray = ( - polygon_pts - lane_midline[closest_midlane_pt] - ) - - # Computing the perpendicular dot product. - # See https://www.xarg.org/book/linear-algebra/2d-perp-product/ - # If perp_dot_product < 0, then the associated polygon vertex is - # on the right edge of the lane. - perp_dot_product: np.ndarray = ( - local_dir_vecs[:, 0] * origin_to_polygon_vecs[:, 1] - - local_dir_vecs[:, 1] * origin_to_polygon_vecs[:, 0] - ) - - # Determining which indices are on the right of the lane center. - on_right: np.ndarray = perp_dot_product < 0 - # Determining the boundary between the left/right polygon vertices - # (they will be together in blocks due to the ordering of the polygon vertices). - idx_changes: int = np.where(np.roll(on_right, 1) < on_right)[0].item() - - if idx_changes > 0: - # If the block of left/right points spreads across the bounds of the array, - # roll it until the boundary between left/right points is at index 0. - # This is important so that the following index selection orders points - # without jumps. - polygon_pts = np.roll(polygon_pts, shift=-idx_changes, axis=0) - on_right = np.roll(on_right, shift=-idx_changes) - - left_pts: np.ndarray = polygon_pts[~on_right] - right_pts: np.ndarray = polygon_pts[on_right] - - # Final ordering check, ensuring that the beginning of left_pts/right_pts - # matches the beginning of the lane. - left_order_correct: bool = np.linalg.norm( - left_pts[0] - lane_midline[0] - ) < np.linalg.norm(left_pts[0] - lane_midline[-1]) - right_order_correct: bool = np.linalg.norm( - right_pts[0] - lane_midline[0] - ) < np.linalg.norm(right_pts[0] - lane_midline[-1]) - - # Reversing left_pts/right_pts in case their first index is - # at the end of the lane. - if not left_order_correct: - left_pts = left_pts[::-1] - if not right_order_correct: - right_pts = right_pts[::-1] - - # Ensuring that left and right have the same number of points. - # This is necessary, not for data storage but for later rasterization. - if left_pts.shape[0] < right_pts.shape[0]: - left_pts = map_utils.interpolate(left_pts, right_pts.shape[0]) - elif right_pts.shape[0] < left_pts.shape[0]: - right_pts = map_utils.interpolate(right_pts, left_pts.shape[0]) - - return ( - lane_midline, - left_pts, - right_pts, - ) - - def extract_area(self, nusc_map: NuScenesMap, area_record) -> np.ndarray: - token_key: str - if "exterior_node_tokens" in area_record: - token_key = "exterior_node_tokens" - elif "node_tokens" in area_record: - token_key = "node_tokens" - - polygon_nodes = [ - nusc_map.get("node", node_token) for node_token in area_record[token_key] - ] - - return np.array([(node["x"], node["y"]) for node in polygon_nodes]) - - def extract_vectorized(self, nusc_map: NuScenesMap) -> VectorizedMap: - vec_map = VectorizedMap() - - # Setting the map bounds. - vec_map.max_pt.x, vec_map.max_pt.y, vec_map.max_pt.z = ( - nusc_map.explorer.canvas_max_x, - nusc_map.explorer.canvas_max_y, - 0.0, - ) - vec_map.min_pt.x, vec_map.min_pt.y, vec_map.min_pt.z = ( - nusc_map.explorer.canvas_min_x, - nusc_map.explorer.canvas_min_y, - 0.0, - ) - - overall_pbar = tqdm( - total=len(nusc_map.lane) - + len(nusc_map.drivable_area[0]["polygon_tokens"]) - + len(nusc_map.ped_crossing) - + len(nusc_map.walkway), - desc=f"Getting {nusc_map.map_name} Elements", - position=1, - leave=False, - ) - - for lane_record in nusc_map.lane: - center_pts, left_pts, right_pts = self.extract_lane_and_edges( - nusc_map, lane_record - ) - - lane_record_token: str = lane_record["token"] - - # Adding the element to the map. - new_element: MapElement = vec_map.elements.add() - new_element.id = lane_record_token.encode() - - new_lane: RoadLane = new_element.road_lane - map_utils.populate_lane_polylines(new_lane, center_pts, left_pts, right_pts) - - new_lane.entry_lanes.extend( - lane_id.encode() - for lane_id in nusc_map.get_incoming_lane_ids(lane_record_token) - ) - new_lane.exit_lanes.extend( - lane_id.encode() - for lane_id in nusc_map.get_outgoing_lane_ids(lane_record_token) - ) - - # new_lane.adjacent_lanes_left.append( - # l5_lane.adjacent_lane_change_left.id - # ) - # new_lane.adjacent_lanes_right.append( - # l5_lane.adjacent_lane_change_right.id - # ) - - overall_pbar.update() - - for polygon_token in nusc_map.drivable_area[0]["polygon_tokens"]: - polygon_record = nusc_map.get("polygon", polygon_token) - polygon_pts = self.extract_area(nusc_map, polygon_record) - - # Adding the element to the map. - new_element: MapElement = vec_map.elements.add() - new_element.id = lane_record["token"].encode() - - new_area: RoadArea = new_element.road_area - map_utils.populate_polygon(new_area.exterior_polygon, polygon_pts) - - for hole in polygon_record["holes"]: - polygon_pts = self.extract_area(nusc_map, hole) - new_hole: Polyline = new_area.interior_holes.add() - map_utils.populate_polygon(new_hole, polygon_pts) - - overall_pbar.update() - - for ped_area_record in nusc_map.ped_crossing: - polygon_pts = self.extract_area(nusc_map, ped_area_record) - - # Adding the element to the map. - new_element: MapElement = vec_map.elements.add() - new_element.id = ped_area_record["token"].encode() - - new_crosswalk: PedCrosswalk = new_element.ped_crosswalk - map_utils.populate_polygon(new_crosswalk.polygon, polygon_pts) - - overall_pbar.update() - - for ped_area_record in nusc_map.walkway: - polygon_pts = self.extract_area(nusc_map, ped_area_record) - - # Adding the element to the map. - new_element: MapElement = vec_map.elements.add() - new_element.id = ped_area_record["token"].encode() - - new_walkway: PedWalkway = new_element.ped_walkway - map_utils.populate_polygon(new_walkway.polygon, polygon_pts) - - overall_pbar.update() - - overall_pbar.close() - - return vec_map - def cache_map( self, map_name: str, @@ -498,81 +274,14 @@ def cache_map( map_cache_class: Type[SceneCache], map_params: Dict[str, Any], ) -> None: - resolution: float = map_params["px_per_m"] - nusc_map: NuScenesMap = NuScenesMap( dataroot=self.metadata.data_dir, map_name=map_name ) - if map_params.get("original_format", False): - warnings.warn( - "Using a dataset's original map format is deprecated, and will be removed in the next version of trajdata!", - FutureWarning, - ) - - width_m, height_m = nusc_map.canvas_edge - height_px, width_px = round(height_m * resolution), round( - width_m * resolution - ) + vector_map = VectorMap(map_id=f"{self.name}:{map_name}") + nusc_utils.populate_vector_map(vector_map, nusc_map) - def layer_fn(layer_name: str) -> np.ndarray: - # Getting rid of the channels dim by accessing index [0] - return nusc_map.get_map_mask( - patch_box=None, - patch_angle=0, - layer_names=[layer_name], - canvas_size=(height_px, width_px), - )[0].astype(np.bool) - - map_from_world: np.ndarray = np.array( - [[resolution, 0.0, 0.0], [0.0, resolution, 0.0], [0.0, 0.0, 1.0]] - ) - - layer_names: List[str] = [ - "lane", - "road_segment", - "drivable_area", - "road_divider", - "lane_divider", - "ped_crossing", - "walkway", - ] - map_info: RasterizedMapMetadata = RasterizedMapMetadata( - name=map_name, - shape=(len(layer_names), height_px, width_px), - layers=layer_names, - layer_rgb_groups=([0, 1, 2], [3, 4], [5, 6]), - resolution=resolution, - map_from_world=map_from_world, - ) - - map_cache_class.cache_map_layers( - cache_path, VectorizedMap(), map_info, layer_fn, self.name - ) - else: - vectorized_map: VectorizedMap = self.extract_vectorized(nusc_map) - - pbar_kwargs = {"position": 2, "leave": False} - map_data, map_from_world = map_utils.rasterize_map( - vectorized_map, resolution, **pbar_kwargs - ) - - rasterized_map_info: RasterizedMapMetadata = RasterizedMapMetadata( - name=map_name, - shape=map_data.shape, - layers=["drivable_area", "lane_divider", "ped_area"], - layer_rgb_groups=([0], [1], [2]), - resolution=resolution, - map_from_world=map_from_world, - ) - - rasterized_map_obj: RasterizedMap = RasterizedMap( - rasterized_map_info, map_data - ) - - map_cache_class.cache_map( - cache_path, vectorized_map, rasterized_map_obj, self.name - ) + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) def cache_maps( self, diff --git a/src/trajdata/dataset_specific/nusc/nusc_utils.py b/src/trajdata/dataset_specific/nusc/nusc_utils.py index 8cd4cb6..b8888b0 100644 --- a/src/trajdata/dataset_specific/nusc/nusc_utils.py +++ b/src/trajdata/dataset_specific/nusc/nusc_utils.py @@ -1,12 +1,25 @@ -from typing import Any, Dict, Final, List, Union +from typing import Any, Dict, Final, List, Tuple, Union import numpy as np import pandas as pd +from nuscenes.map_expansion import arcline_path_utils +from nuscenes.map_expansion.map_api import NuScenesMap from nuscenes.nuscenes import NuScenes from pyquaternion import Quaternion +from scipy.spatial.distance import cdist +from tqdm import tqdm from trajdata.data_structures import Agent, AgentMetadata, AgentType, FixedExtent, Scene -from trajdata.utils import arr_utils +from trajdata.maps import VectorMap +from trajdata.maps.vec_map_elements import ( + MapElementType, + PedCrosswalk, + PedWalkway, + Polyline, + RoadArea, + RoadLane, +) +from trajdata.utils import arr_utils, map_utils NUSC_DT: Final[float] = 0.5 @@ -233,3 +246,251 @@ def agg_ego_data(nusc_obj: NuScenes, scene: Scene) -> Agent: metadata=ego_metadata, data=ego_data_df, ) + + +def extract_lane_center(nusc_map: NuScenesMap, lane_record) -> np.ndarray: + # Getting the lane center's points. + curr_lane = nusc_map.arcline_path_3.get(lane_record["token"], []) + lane_midline: np.ndarray = np.array( + arcline_path_utils.discretize_lane(curr_lane, resolution_meters=0.5) + )[:, :2] + + # For some reason, nuScenes duplicates a few entries + # (likely how they're building their arcline representation). + # We delete those duplicate entries here. + duplicate_check: np.ndarray = np.where( + np.linalg.norm(np.diff(lane_midline, axis=0, prepend=0), axis=1) < 1e-10 + )[0] + if duplicate_check.size > 0: + lane_midline = np.delete(lane_midline, duplicate_check, axis=0) + + return lane_midline + + +def extract_lane_and_edges( + nusc_map: NuScenesMap, lane_record +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + # Getting the bounding polygon vertices. + lane_polygon_obj = nusc_map.get("polygon", lane_record["polygon_token"]) + polygon_nodes = [ + nusc_map.get("node", node_token) + for node_token in lane_polygon_obj["exterior_node_tokens"] + ] + polygon_pts: np.ndarray = np.array( + [(node["x"], node["y"]) for node in polygon_nodes] + ) + + # Getting the lane center's points. + lane_midline: np.ndarray = extract_lane_center(nusc_map, lane_record) + + # Computing the closest lane center point to each bounding polygon vertex. + closest_midlane_pt: np.ndarray = np.argmin(cdist(polygon_pts, lane_midline), axis=1) + # Computing the local direction of the lane at each lane center point. + direction_vectors: np.ndarray = np.diff( + lane_midline, + axis=0, + prepend=lane_midline[[0]] - (lane_midline[[1]] - lane_midline[[0]]), + ) + + # Selecting the direction vectors at the closest lane center point per polygon vertex. + local_dir_vecs: np.ndarray = direction_vectors[closest_midlane_pt] + # Calculating the vectors from the the closest lane center point per polygon vertex to the polygon vertex. + origin_to_polygon_vecs: np.ndarray = polygon_pts - lane_midline[closest_midlane_pt] + + # Computing the perpendicular dot product. + # See https://www.xarg.org/book/linear-algebra/2d-perp-product/ + # If perp_dot_product < 0, then the associated polygon vertex is + # on the right edge of the lane. + perp_dot_product: np.ndarray = ( + local_dir_vecs[:, 0] * origin_to_polygon_vecs[:, 1] + - local_dir_vecs[:, 1] * origin_to_polygon_vecs[:, 0] + ) + + # Determining which indices are on the right of the lane center. + on_right: np.ndarray = perp_dot_product < 0 + # Determining the boundary between the left/right polygon vertices + # (they will be together in blocks due to the ordering of the polygon vertices). + idx_changes: int = np.where(np.roll(on_right, 1) < on_right)[0].item() + + if idx_changes > 0: + # If the block of left/right points spreads across the bounds of the array, + # roll it until the boundary between left/right points is at index 0. + # This is important so that the following index selection orders points + # without jumps. + polygon_pts = np.roll(polygon_pts, shift=-idx_changes, axis=0) + on_right = np.roll(on_right, shift=-idx_changes) + + left_pts: np.ndarray = polygon_pts[~on_right] + right_pts: np.ndarray = polygon_pts[on_right] + + # Final ordering check, ensuring that left_pts and right_pts can be combined + # into a polygon without the endpoints intersecting. + # Reversing the one lane edge that does not match the ordering of the midline. + if map_utils.endpoints_intersect(left_pts, right_pts): + if not map_utils.order_matches(left_pts, lane_midline): + left_pts = left_pts[::-1] + else: + right_pts = right_pts[::-1] + + # Ensuring that left and right have the same number of points. + # This is necessary, not for data storage but for later rasterization. + if left_pts.shape[0] < right_pts.shape[0]: + left_pts = map_utils.interpolate(left_pts, num_pts=right_pts.shape[0]) + elif right_pts.shape[0] < left_pts.shape[0]: + right_pts = map_utils.interpolate(right_pts, num_pts=left_pts.shape[0]) + + return ( + lane_midline, + left_pts, + right_pts, + ) + + +def extract_area(nusc_map: NuScenesMap, area_record) -> np.ndarray: + token_key: str + if "exterior_node_tokens" in area_record: + token_key = "exterior_node_tokens" + elif "node_tokens" in area_record: + token_key = "node_tokens" + + polygon_nodes = [ + nusc_map.get("node", node_token) for node_token in area_record[token_key] + ] + + return np.array([(node["x"], node["y"]) for node in polygon_nodes]) + + +def populate_vector_map(vector_map: VectorMap, nusc_map: NuScenesMap) -> None: + # Setting the map bounds. + vector_map.extent = np.array( + [ + nusc_map.explorer.canvas_min_x, + nusc_map.explorer.canvas_min_y, + 0.0, + nusc_map.explorer.canvas_max_x, + nusc_map.explorer.canvas_max_y, + 0.0, + ] + ) + + overall_pbar = tqdm( + total=len(nusc_map.lane) + + len(nusc_map.lane_connector) + + len(nusc_map.drivable_area) + + len(nusc_map.ped_crossing) + + len(nusc_map.walkway), + desc=f"Getting {nusc_map.map_name} Elements", + position=1, + leave=False, + ) + + for lane_record in nusc_map.lane: + center_pts, left_pts, right_pts = extract_lane_and_edges(nusc_map, lane_record) + + lane_record_token: str = lane_record["token"] + + new_lane = RoadLane( + id=lane_record_token, + center=Polyline(center_pts), + left_edge=Polyline(left_pts), + right_edge=Polyline(right_pts), + ) + + for lane_id in nusc_map.get_incoming_lane_ids(lane_record_token): + # Need to do this because some incoming/outgoing lane_connector IDs + # do not exist as lane_connectors... + if lane_id in nusc_map._token2ind["lane_connector"]: + new_lane.prev_lanes.add(lane_id) + + for lane_id in nusc_map.get_outgoing_lane_ids(lane_record_token): + # Need to do this because some incoming/outgoing lane_connector IDs + # do not exist as lane_connectors... + if lane_id in nusc_map._token2ind["lane_connector"]: + new_lane.next_lanes.add(lane_id) + + # new_lane.adjacent_lanes_left.append( + # l5_lane.adjacent_lane_change_left.id + # ) + # new_lane.adjacent_lanes_right.append( + # l5_lane.adjacent_lane_change_right.id + # ) + + # Adding the element to the map. + vector_map.add_map_element(new_lane) + overall_pbar.update() + + for lane_record in nusc_map.lane_connector: + # Unfortunately lane connectors in nuScenes have very simple exterior + # polygons which make extracting their edges quite difficult, so we + # only extract the centerline. + center_pts = extract_lane_center(nusc_map, lane_record) + + lane_record_token: str = lane_record["token"] + + # Adding the element to the map. + new_lane = RoadLane( + id=lane_record_token, + center=Polyline(center_pts), + ) + + new_lane.prev_lanes.update(nusc_map.get_incoming_lane_ids(lane_record_token)) + new_lane.next_lanes.update(nusc_map.get_outgoing_lane_ids(lane_record_token)) + + # new_lane.adjacent_lanes_left.append( + # l5_lane.adjacent_lane_change_left.id + # ) + # new_lane.adjacent_lanes_right.append( + # l5_lane.adjacent_lane_change_right.id + # ) + + # Adding the element to the map. + vector_map.add_map_element(new_lane) + overall_pbar.update() + + for drivable_area in nusc_map.drivable_area: + for polygon_token in drivable_area["polygon_tokens"]: + if ( + polygon_token is None + and str(None) in vector_map.elements[MapElementType.ROAD_AREA] + ): + # See below, but essentially nuScenes has two None polygon_tokens + # back-to-back, so we don't need the second one. + continue + + polygon_record = nusc_map.get("polygon", polygon_token) + polygon_pts = extract_area(nusc_map, polygon_record) + + # NOTE: nuScenes has some polygon_tokens that are None, although that + # doesn't stop the above get(...) function call so it's fine, + # just have to be mindful of this when creating the id. + new_road_area = RoadArea( + id=str(polygon_token), exterior_polygon=Polyline(polygon_pts) + ) + + for hole in polygon_record["holes"]: + polygon_pts = extract_area(nusc_map, hole) + new_road_area.interior_holes.append(Polyline(polygon_pts)) + + # Adding the element to the map. + vector_map.add_map_element(new_road_area) + overall_pbar.update() + + for ped_area_record in nusc_map.ped_crossing: + polygon_pts = extract_area(nusc_map, ped_area_record) + + # Adding the element to the map. + vector_map.add_map_element( + PedCrosswalk(id=ped_area_record["token"], polygon=Polyline(polygon_pts)) + ) + overall_pbar.update() + + for ped_area_record in nusc_map.walkway: + polygon_pts = extract_area(nusc_map, ped_area_record) + + # Adding the element to the map. + vector_map.add_map_element( + PedWalkway(id=ped_area_record["token"], polygon=Polyline(polygon_pts)) + ) + overall_pbar.update() + + overall_pbar.close() diff --git a/src/trajdata/dataset_specific/raw_dataset.py b/src/trajdata/dataset_specific/raw_dataset.py index b26f392..57188d3 100644 --- a/src/trajdata/dataset_specific/raw_dataset.py +++ b/src/trajdata/dataset_specific/raw_dataset.py @@ -81,6 +81,13 @@ def get_scene(self, scene_info: SceneMetadata) -> Scene: def get_agent_info( self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + """ + Get frame-level information from source dataset, caching it + to cache_path. + + Always called after cache_maps, can load map if needed + to associate map information to positions. + """ raise NotImplementedError() def cache_maps( @@ -90,6 +97,9 @@ def cache_maps( map_params: Dict[str, Any], ) -> None: """ - resolution is in pixels per meter. + Get static, scene-level info from the source dataset, caching it + to cache_path. (Primarily this is info needed to construct VectorMap) + + Resolution is in pixels per meter. """ raise NotImplementedError() diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py index 60d3abb..72b7d84 100644 --- a/src/trajdata/dataset_specific/scene_records.py +++ b/src/trajdata/dataset_specific/scene_records.py @@ -21,3 +21,12 @@ class LyftSceneRecord(NamedTuple): name: str length: str data_idx: int + + +class NuPlanSceneRecord(NamedTuple): + name: str + location: str + length: str + split: str + # desc: str + data_idx: int diff --git a/src/trajdata/maps/__init__.py b/src/trajdata/maps/__init__.py index ea04c8d..2ef0bf6 100644 --- a/src/trajdata/maps/__init__.py +++ b/src/trajdata/maps/__init__.py @@ -1,2 +1,4 @@ -from .map import RasterizedMap, RasterizedMapMetadata -from .map_patch import RasterizedMapPatch +from .map_api import MapAPI +from .raster_map import RasterizedMap, RasterizedMapMetadata, RasterizedMapPatch +from .traffic_light_status import TrafficLightStatus +from .vec_map import VectorMap diff --git a/src/trajdata/maps/lane_route.py b/src/trajdata/maps/lane_route.py new file mode 100644 index 0000000..fe70388 --- /dev/null +++ b/src/trajdata/maps/lane_route.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass +from typing import Set + + +@dataclass +class LaneRoute: + lane_idxs: Set[int] diff --git a/src/trajdata/maps/map_api.py b/src/trajdata/maps/map_api.py new file mode 100644 index 0000000..c742770 --- /dev/null +++ b/src/trajdata/maps/map_api.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from trajdata.maps.map_kdtree import MapElementKDTree + from trajdata.caching.scene_cache import SceneCache + +from pathlib import Path +from typing import Dict + +from trajdata.maps.vec_map import VectorMap +from trajdata.proto.vectorized_map_pb2 import VectorizedMap +from trajdata.utils import map_utils + + +class MapAPI: + def __init__(self, unified_cache_path: Path) -> None: + self.unified_cache_path: Path = unified_cache_path + self.maps: Dict[str, VectorMap] = dict() + + def get_map( + self, map_id: str, scene_cache: Optional[SceneCache] = None, **kwargs + ) -> VectorMap: + if map_id not in self.maps: + env_name, map_name = map_id.split(":") + env_maps_path: Path = self.unified_cache_path / env_name / "maps" + stored_vec_map: VectorizedMap = map_utils.load_vector_map( + env_maps_path / f"{map_name}.pb" + ) + + vec_map: VectorMap = VectorMap.from_proto(stored_vec_map, **kwargs) + vec_map.search_kdtrees: Dict[ + str, MapElementKDTree + ] = map_utils.load_kdtrees(env_maps_path / f"{map_name}_kdtrees.dill") + + self.maps[map_id] = vec_map + + if scene_cache is not None: + self.maps[map_id].associate_scene_data( + scene_cache.get_traffic_light_status_dict() + ) + + return self.maps[map_id] diff --git a/src/trajdata/maps/map_kdtree.py b/src/trajdata/maps/map_kdtree.py new file mode 100644 index 0000000..1b200ef --- /dev/null +++ b/src/trajdata/maps/map_kdtree.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trajdata.maps.vec_map import VectorMap + +from typing import Optional + +import numpy as np +from scipy.spatial import KDTree +from tqdm import tqdm + +from trajdata.maps.vec_map_elements import MapElement, MapElementType, Polyline + + +class MapElementKDTree: + """ + Constructs a KDTree of MapElements and exposes fast lookup functions. + + Inheriting classes need to implement the _extra_points function that defines for a MapElement + the coordinates we want to store in the KDTree. + """ + + def __init__(self, vector_map: VectorMap) -> None: + # Build kd-tree + self.kdtree, self.polyline_inds = self._build_kdtree(vector_map) + + def _build_kdtree(self, vector_map: VectorMap): + polylines = [] + polyline_inds = [] + + map_elem: MapElement + for map_elem in tqdm( + vector_map.iter_elems(), + desc=f"Building K-D Trees", + leave=False, + total=len(vector_map), + ): + points = self._extract_points(map_elem) + if points is not None: + polyline_inds.extend([len(polylines)] * points.shape[0]) + + # Apply any map offsets to ensure we're in the same coordinate area as the + # original world map. + polylines.append(points) + + points = np.concatenate(polylines, axis=0) + polyline_inds = np.array(polyline_inds) + + kdtree = KDTree(points) + return kdtree, polyline_inds + + def _extract_points(self, map_element: MapElement) -> Optional[np.ndarray]: + """Defines the coordinates we want to store in the KDTree for a MapElement. + Args: + map_element (MapElement): the MapElement to store in the KDTree. + Returns: + Optional[np.ndarray]: coordinates based on which we can search the KDTree, or None. + If None, the MapElement will not be stored. + """ + raise NotImplementedError() + + def closest_point(self, query_points: np.ndarray) -> np.ndarray: + """Find the closest KDTree points to (a batch of) query points. + + Args: + query_points: np.ndarray of shape (..., data_dim). + + Return: + np.ndarray of shape (..., data_dim), the KDTree points closest to query_point. + """ + _, data_inds = self.kdtree.query(query_points, k=1) + pts = self.kdtree.data[data_inds] + return pts + + def closest_polyline_ind(self, query_points: np.ndarray) -> np.ndarray: + """Find the index of the closest polyline(s) in self.polylines.""" + _, data_ind = self.kdtree.query(query_points, k=1) + return self.polyline_inds[data_ind] + + def polyline_inds_in_range(self, point: np.ndarray, range: float) -> np.ndarray: + """Find the index of polylines in self.polylines within 'range' distance to 'point'.""" + data_inds = self.kdtree.query_ball_point(point, range) + return np.unique(self.polyline_inds[data_inds], axis=0) + + +class LaneCenterKDTree(MapElementKDTree): + """KDTree for lane center polylines.""" + + def __init__( + self, vector_map: VectorMap, max_segment_len: Optional[float] = None + ) -> None: + """ + Args: + vec_map: the VectorizedMap object to build the KDTree for + max_segment_len (float, optional): if specified, we will insert extra points into the KDTree + such that all polyline segments are shorter then max_segment_len. + """ + self.max_segment_len = max_segment_len + super().__init__(vector_map) + + def _extract_points(self, map_element: MapElement) -> Optional[np.ndarray]: + if map_element.elem_type == MapElementType.ROAD_LANE: + pts: Polyline = map_element.center + if self.max_segment_len is not None: + pts = pts.interpolate(max_dist=self.max_segment_len) + + return pts.points + else: + return None diff --git a/src/trajdata/maps/map_patch.py b/src/trajdata/maps/map_patch.py deleted file mode 100644 index 4416ae8..0000000 --- a/src/trajdata/maps/map_patch.py +++ /dev/null @@ -1,19 +0,0 @@ -import numpy as np - - -class RasterizedMapPatch: - def __init__( - self, - data: np.ndarray, - rot_angle: float, - crop_size: int, - resolution: float, - raster_from_world_tf: np.ndarray, - has_data: bool, - ) -> None: - self.data = data - self.rot_angle = rot_angle - self.crop_size = crop_size - self.resolution = resolution - self.raster_from_world_tf = raster_from_world_tf - self.has_data = has_data diff --git a/src/trajdata/maps/map_utils.py b/src/trajdata/maps/map_utils.py deleted file mode 100644 index b40e14d..0000000 --- a/src/trajdata/maps/map_utils.py +++ /dev/null @@ -1,291 +0,0 @@ -from math import ceil -from typing import Any, Dict, Final, List, Optional, Tuple - -import cv2 -import numpy as np -from tqdm import tqdm - -from trajdata.proto.vectorized_map_pb2 import ( - MapElement, - Polyline, - RoadLane, - VectorizedMap, -) - -# Sub-pixel drawing precision constants. -# See https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/rasterization/semantic_rasterizer.py#L16 -CV2_SUB_VALUES = {"shift": 9, "lineType": cv2.LINE_AA} -CV2_SHIFT_VALUE = 2 ** CV2_SUB_VALUES["shift"] - -MM_PER_M: Final[float] = 1000 - - -def cv2_subpixel(coords: np.ndarray) -> np.ndarray: - """ - Cast coordinates to numpy.int but keep fractional part by previously multiplying by 2**CV2_SHIFT - cv2 calls will use shift to restore original values with higher precision - - Args: - coords (np.ndarray): XY coords as float - - Returns: - np.ndarray: XY coords as int for cv2 shift draw - """ - return (coords * CV2_SHIFT_VALUE).astype(np.int) - - -def decompress_values(data: np.ndarray) -> np.ndarray: - # From https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 - # The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from - # the origin. For subsequent points, this field stores the difference between the point's - # coordinates and the previous point's coordinates. This is for representation efficiency. - return np.cumsum(data, axis=0, dtype=np.float) / MM_PER_M - - -def compress_values(data: np.ndarray) -> np.ndarray: - return (np.diff(data, axis=0, prepend=0.0) * MM_PER_M).astype(np.int32) - - -def populate_lane_polylines( - new_lane: RoadLane, - midlane_pts: np.ndarray, - left_pts: np.ndarray, - right_pts: np.ndarray, -) -> None: - """Fill a Lane object's polyline attributes. - All points should be in world coordinates. - - Args: - new_lane (Lane): _description_ - midlane_pts (np.ndarray): _description_ - left_pts (np.ndarray): _description_ - right_pts (np.ndarray): _description_ - """ - compressed_mid_pts: np.ndarray = compress_values(midlane_pts) - compressed_left_pts: np.ndarray = compress_values(left_pts) - compressed_right_pts: np.ndarray = compress_values(right_pts) - - new_lane.center.dx_mm.extend(compressed_mid_pts[:, 0].tolist()) - new_lane.center.dy_mm.extend(compressed_mid_pts[:, 1].tolist()) - - new_lane.left_boundary.dx_mm.extend(compressed_left_pts[:, 0].tolist()) - new_lane.left_boundary.dy_mm.extend(compressed_left_pts[:, 1].tolist()) - - new_lane.right_boundary.dx_mm.extend(compressed_right_pts[:, 0].tolist()) - new_lane.right_boundary.dy_mm.extend(compressed_right_pts[:, 1].tolist()) - - if compressed_mid_pts.shape[-1] == 3: - new_lane.center.dz_mm.extend(compressed_mid_pts[:, 2].tolist()) - new_lane.left_boundary.dz_mm.extend(compressed_left_pts[:, 2].tolist()) - new_lane.right_boundary.dz_mm.extend(compressed_right_pts[:, 2].tolist()) - - -def populate_polygon( - polygon: Polyline, - polygon_pts: np.ndarray, -) -> None: - """Fill a Crosswalk object's polygon attribute. - All points should be in world coordinates. - - Args: - new_crosswalk (Lane): _description_ - polygon_pts (np.ndarray): _description_ - """ - - compressed_pts: np.ndarray = compress_values(polygon_pts) - - polygon.dx_mm.extend(compressed_pts[:, 0].tolist()) - polygon.dy_mm.extend(compressed_pts[:, 1].tolist()) - - if compressed_pts.shape[-1] == 3: - polygon.dz_mm.extend(compressed_pts[:, 2].tolist()) - - -def proto_to_np(polyline: Polyline) -> np.ndarray: - dx: np.ndarray = np.asarray(polyline.dx_mm) - dy: np.ndarray = np.asarray(polyline.dy_mm) - - if len(polyline.dz_mm) > 0: - dz: np.ndarray = np.asarray(polyline.dz_mm) - pts: np.ndarray = np.stack([dx, dy, dz], axis=1) - else: - pts: np.ndarray = np.stack([dx, dy], axis=1) - - return decompress_values(pts) - - -def transform_points(points: np.ndarray, transf_mat: np.ndarray): - n_dim = points.shape[-1] - return points @ transf_mat[:n_dim, :n_dim] + transf_mat[:n_dim, -1] - - -def interpolate(pts: np.ndarray, num_pts: int) -> np.ndarray: - """ - Interpolate points based on cumulative distances from the first one. In particular, - interpolate using a variable step such that we always get step values. - - Args: - xyz (np.ndarray): XYZ coords. - num_pts (int): How many points to interpolate to. - - Returns: - np.ndarray: The new interpolated coordinates. - """ - cum_dist = np.cumsum(np.linalg.norm(np.diff(pts, axis=0), axis=-1)) - cum_dist = np.insert(cum_dist, 0, 0) - - assert num_pts > 1, f"num_pts must be at least 2, but got {num_pts}" - steps = np.linspace(cum_dist[0], cum_dist[-1], num_pts) - - xyz_inter = np.empty((len(steps), pts.shape[-1]), dtype=pts.dtype) - xyz_inter[:, 0] = np.interp(steps, xp=cum_dist, fp=pts[:, 0]) - xyz_inter[:, 1] = np.interp(steps, xp=cum_dist, fp=pts[:, 1]) - if pts.shape[-1] == 3: - xyz_inter[:, 2] = np.interp(steps, xp=cum_dist, fp=pts[:, 2]) - - return xyz_inter - - -def rasterize_map( - vec_map: VectorizedMap, resolution: float, **pbar_kwargs -) -> np.ndarray: - """Renders the semantic map at the given resolution. - - Args: - vec_map (VectorizedMap): _description_ - resolution (float): The rasterized image's resolution in pixels per meter. - - Returns: - np.ndarray: The rasterized RGB image. - """ - world_center_m: Tuple[float, float] = ( - (vec_map.max_pt.x + vec_map.min_pt.x) / 2, - (vec_map.max_pt.y + vec_map.min_pt.y) / 2, - ) - - raster_size_x: int = ceil((vec_map.max_pt.x - vec_map.min_pt.x) * resolution) - raster_size_y: int = ceil((vec_map.max_pt.y - vec_map.min_pt.y) * resolution) - - raster_from_local: np.ndarray = np.array( - [ - [resolution, 0, raster_size_x / 2], - [0, resolution, raster_size_y / 2], - [0, 0, 1], - ] - ) - - # Compute pose from its position and rotation - pose_from_world: np.ndarray = np.array( - [ - [1, 0, -world_center_m[0]], - [0, 1, -world_center_m[1]], - [0, 0, 1], - ] - ) - - raster_from_world: np.ndarray = raster_from_local @ pose_from_world - - lane_area_img: np.ndarray = np.zeros( - shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 - ) - lane_line_img: np.ndarray = np.zeros( - shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 - ) - ped_area_img: np.ndarray = np.zeros( - shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 - ) - - map_elem: MapElement - for map_elem in tqdm( - vec_map.elements, - desc=f"Rasterizing Map at {resolution:.2f} px/m", - **pbar_kwargs, - ): - if map_elem.HasField("road_lane"): - left_pts: np.ndarray = proto_to_np(map_elem.road_lane.left_boundary) - right_pts: np.ndarray = proto_to_np(map_elem.road_lane.right_boundary) - - lane_area: np.ndarray = cv2_subpixel( - transform_points( - np.concatenate([left_pts[:, :2], right_pts[::-1, :2]], axis=0), - raster_from_world, - ) - ) - - # Need to for-loop because doing it all at once can make holes. - cv2.fillPoly( - img=lane_area_img, - pts=[lane_area], - color=(255, 0, 0), - **CV2_SUB_VALUES, - ) - - # Drawing lane lines. - cv2.polylines( - img=lane_line_img, - pts=lane_area.reshape((2, -1, 2)), - isClosed=False, - color=(0, 255, 0), - **CV2_SUB_VALUES, - ) - - elif map_elem.HasField("road_area"): - xyz_pts: np.ndarray = proto_to_np(map_elem.road_area.exterior_polygon) - road_area: np.ndarray = cv2_subpixel( - transform_points(xyz_pts[:, :2], raster_from_world) - ) - - # Drawing general road areas. - cv2.fillPoly( - img=lane_area_img, - pts=[road_area], - color=(255, 0, 0), - **CV2_SUB_VALUES, - ) - - for interior_hole in map_elem.road_area.interior_holes: - xyz_pts: np.ndarray = proto_to_np(interior_hole) - road_area: np.ndarray = cv2_subpixel( - transform_points(xyz_pts[:, :2], raster_from_world) - ) - - # Removing holes. - cv2.fillPoly( - img=lane_area_img, - pts=[road_area], - color=(0, 0, 0), - **CV2_SUB_VALUES, - ) - - elif map_elem.HasField("ped_crosswalk"): - xyz_pts: np.ndarray = proto_to_np(map_elem.ped_crosswalk.polygon) - crosswalk_area: np.ndarray = cv2_subpixel( - transform_points(xyz_pts[:, :2], raster_from_world) - ) - - # Drawing crosswalks. - cv2.fillPoly( - img=ped_area_img, - pts=[crosswalk_area], - color=(0, 0, 255), - **CV2_SUB_VALUES, - ) - - elif map_elem.HasField("ped_walkway"): - xyz_pts: np.ndarray = proto_to_np(map_elem.ped_walkway.polygon) - walkway_area: np.ndarray = cv2_subpixel( - transform_points(xyz_pts[:, :2], raster_from_world) - ) - - # Drawing walkways. - cv2.fillPoly( - img=ped_area_img, - pts=[walkway_area], - color=(0, 0, 255), - **CV2_SUB_VALUES, - ) - - map_img: np.ndarray = (lane_area_img + lane_line_img + ped_area_img).astype( - np.float32 - ) / 255 - return map_img.transpose(2, 0, 1), raster_from_world diff --git a/src/trajdata/maps/map.py b/src/trajdata/maps/raster_map.py similarity index 71% rename from src/trajdata/maps/map.py rename to src/trajdata/maps/raster_map.py index ad0ee07..4309896 100644 --- a/src/trajdata/maps/map.py +++ b/src/trajdata/maps/raster_map.py @@ -9,14 +9,14 @@ class RasterizedMapMetadata: def __init__( self, name: str, - shape: Tuple[int, int], + shape: Tuple[int, int, int], layers: List[str], layer_rgb_groups: Tuple[List[int], List[int], List[int]], resolution: float, # px/m map_from_world: np.ndarray, # Transformation from world coordinates [m] to map coordinates [px] ) -> None: self.name: str = name - self.shape: Tuple[int, int] = shape + self.shape: Tuple[int, int, int] = shape self.layers: List[str] = layers self.layer_rgb_groups: Tuple[List[int], List[int], List[int]] = layer_rgb_groups self.resolution: float = resolution @@ -34,7 +34,7 @@ def __init__( self.data: np.ndarray = data @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> Tuple[int, int, int]: return self.data.shape @staticmethod @@ -53,3 +53,21 @@ def to_img( ], dim=-1, ).numpy() + + +class RasterizedMapPatch: + def __init__( + self, + data: np.ndarray, + rot_angle: float, + crop_size: int, + resolution: float, + raster_from_world_tf: np.ndarray, + has_data: bool, + ) -> None: + self.data = data + self.rot_angle = rot_angle + self.crop_size = crop_size + self.resolution = resolution + self.raster_from_world_tf = raster_from_world_tf + self.has_data = has_data diff --git a/src/trajdata/maps/traffic_light_status.py b/src/trajdata/maps/traffic_light_status.py new file mode 100644 index 0000000..260fe46 --- /dev/null +++ b/src/trajdata/maps/traffic_light_status.py @@ -0,0 +1,8 @@ +from enum import IntEnum + + +class TrafficLightStatus(IntEnum): + NO_DATA = -1 + UNKNOWN = 0 + GREEN = 1 + RED = 2 diff --git a/src/trajdata/maps/vec_map.py b/src/trajdata/maps/vec_map.py new file mode 100644 index 0000000..26fa08f --- /dev/null +++ b/src/trajdata/maps/vec_map.py @@ -0,0 +1,556 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trajdata.maps.map_kdtree import MapElementKDTree, LaneCenterKDTree + +from collections import defaultdict +from dataclasses import dataclass, field +from math import ceil +from typing import ( + DefaultDict, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + Union, + overload, +) + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.axes import Axes +from tqdm import tqdm + +import trajdata.proto.vectorized_map_pb2 as map_proto +from trajdata.maps.map_kdtree import LaneCenterKDTree +from trajdata.maps.traffic_light_status import TrafficLightStatus +from trajdata.maps.vec_map_elements import ( + MapElement, + MapElementType, + PedCrosswalk, + PedWalkway, + Polyline, + RoadArea, + RoadLane, +) +from trajdata.utils import map_utils, raster_utils + + +@dataclass(repr=False) +class VectorMap: + map_id: str + extent: Optional[ + np.ndarray + ] = None # extent is [min_x, min_y, min_z, max_x, max_y, max_z] + elements: DefaultDict[MapElementType, Dict[str, MapElement]] = field( + default_factory=lambda: defaultdict(dict) + ) + search_kdtrees: Optional[Dict[MapElementType, MapElementKDTree]] = None + traffic_light_status: Optional[Dict[Tuple[int, int], TrafficLightStatus]] = None + + def __post_init__(self) -> None: + self.env_name, self.map_name = self.map_id.split(":") + + self.lanes: Optional[List[RoadLane]] = None + if MapElementType.ROAD_LANE in self.elements: + self.lanes = list(self.elements[MapElementType.ROAD_LANE].values()) + + def add_map_element(self, map_elem: MapElement) -> None: + self.elements[map_elem.elem_type][map_elem.id] = map_elem + + def compute_search_indices(self) -> None: + self.search_kdtrees = {MapElementType.ROAD_LANE: LaneCenterKDTree(self)} + + def iter_elems(self) -> Iterator[MapElement]: + for elems_dict in self.elements.values(): + for elem in elems_dict.values(): + yield elem + + def get_road_lane(self, lane_id: str) -> RoadLane: + return self.elements[MapElementType.ROAD_LANE][lane_id] + + def __len__(self) -> int: + return sum(len(elems_dict) for elems_dict in self.elements.values()) + + def _write_road_lanes( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + road_lane: RoadLane + for elem_id, road_lane in self.elements[MapElementType.ROAD_LANE].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_lane: map_proto.RoadLane = new_element.road_lane + map_utils.populate_lane_polylines(new_lane, road_lane, shifted_origin) + + new_lane.entry_lanes.extend( + [lane_id.encode() for lane_id in road_lane.prev_lanes] + ) + new_lane.exit_lanes.extend( + [lane_id.encode() for lane_id in road_lane.next_lanes] + ) + + new_lane.adjacent_lanes_left.extend( + [lane_id.encode() for lane_id in road_lane.adj_lanes_left] + ) + new_lane.adjacent_lanes_right.extend( + [lane_id.encode() for lane_id in road_lane.adj_lanes_right] + ) + + def _write_road_areas( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + road_area: RoadArea + for elem_id, road_area in self.elements[MapElementType.ROAD_AREA].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_area: map_proto.RoadArea = new_element.road_area + map_utils.populate_polygon( + new_area.exterior_polygon, + road_area.exterior_polygon.xyz, + shifted_origin, + ) + + hole: Polyline + for hole in road_area.interior_holes: + new_hole: map_proto.Polyline = new_area.interior_holes.add() + map_utils.populate_polygon( + new_hole, + hole.xyz, + shifted_origin, + ) + + def _write_ped_crosswalks( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + ped_crosswalk: PedCrosswalk + for elem_id, ped_crosswalk in self.elements[ + MapElementType.PED_CROSSWALK + ].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_crosswalk: map_proto.PedCrosswalk = new_element.ped_crosswalk + map_utils.populate_polygon( + new_crosswalk.polygon, + ped_crosswalk.polygon.xyz, + shifted_origin, + ) + + def _write_ped_walkways( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + ped_walkway: PedWalkway + for elem_id, ped_walkway in self.elements[MapElementType.PED_WALKWAY].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_walkway: map_proto.PedWalkway = new_element.ped_walkway + map_utils.populate_polygon( + new_walkway.polygon, + ped_walkway.polygon.xyz, + shifted_origin, + ) + + def to_proto(self) -> map_proto.VectorizedMap: + output_map = map_proto.VectorizedMap() + output_map.name = self.map_id + + ( + output_map.min_pt.x, + output_map.min_pt.y, + output_map.min_pt.z, + output_map.max_pt.x, + output_map.max_pt.y, + output_map.max_pt.z, + ) = self.extent + + shifted_origin: np.ndarray = self.extent[:3] + ( + output_map.shifted_origin.x, + output_map.shifted_origin.y, + output_map.shifted_origin.z, + ) = shifted_origin + + # Populating the elements in the vectorized map protobuf. + self._write_road_lanes(output_map, shifted_origin) + self._write_road_areas(output_map, shifted_origin) + self._write_ped_crosswalks(output_map, shifted_origin) + self._write_ped_walkways(output_map, shifted_origin) + + return output_map + + @classmethod + def from_proto(cls, vec_map: map_proto.VectorizedMap, **kwargs): + # Options for which map elements to include. + incl_road_lanes: bool = kwargs.get("incl_road_lanes", True) + incl_road_areas: bool = kwargs.get("incl_road_areas", False) + incl_ped_crosswalks: bool = kwargs.get("incl_ped_crosswalks", False) + incl_ped_walkways: bool = kwargs.get("incl_ped_walkways", False) + + # Add any map offset in case the map origin was shifted for storage efficiency. + shifted_origin: np.ndarray = np.array( + [ + vec_map.shifted_origin.x, + vec_map.shifted_origin.y, + vec_map.shifted_origin.z, + 0.0, # Some polylines also have heading so we're adding + # this (zero) coordinate to account for that. + ] + ) + + map_elem_dict: Dict[str, Dict[str, MapElement]] = defaultdict(dict) + + map_elem: MapElement + for map_elem in vec_map.elements: + elem_id: str = map_elem.id.decode() + if incl_road_lanes and map_elem.HasField("road_lane"): + road_lane_obj: map_proto.RoadLane = map_elem.road_lane + + center_pl: Polyline = Polyline( + map_utils.proto_to_np(road_lane_obj.center) + shifted_origin + ) + + # We do not care for the heading of the left and right edges + # (only the center matters). + left_pl: Optional[Polyline] = None + if road_lane_obj.HasField("left_boundary"): + left_pl = Polyline( + map_utils.proto_to_np( + road_lane_obj.left_boundary, incl_heading=False + ) + + shifted_origin[:3] + ) + + right_pl: Optional[Polyline] = None + if road_lane_obj.HasField("right_boundary"): + right_pl = Polyline( + map_utils.proto_to_np( + road_lane_obj.right_boundary, incl_heading=False + ) + + shifted_origin[:3] + ) + + adj_lanes_left: Set[str] = set( + [iden.decode() for iden in road_lane_obj.adjacent_lanes_left] + ) + adj_lanes_right: Set[str] = set( + [iden.decode() for iden in road_lane_obj.adjacent_lanes_right] + ) + + next_lanes: Set[str] = set( + [iden.decode() for iden in road_lane_obj.exit_lanes] + ) + prev_lanes: Set[str] = set( + [iden.decode() for iden in road_lane_obj.entry_lanes] + ) + + # Double-using the connectivity attributes for lane IDs now (will + # replace them with Lane objects after all Lane objects have been created). + curr_lane = RoadLane( + elem_id, + center_pl, + left_pl, + right_pl, + adj_lanes_left, + adj_lanes_right, + next_lanes, + prev_lanes, + ) + map_elem_dict[MapElementType.ROAD_LANE][elem_id] = curr_lane + + elif incl_road_areas and map_elem.HasField("road_area"): + road_area_obj: map_proto.RoadArea = map_elem.road_area + + exterior: Polyline = Polyline( + map_utils.proto_to_np( + road_area_obj.exterior_polygon, incl_heading=False + ) + + shifted_origin[:3] + ) + + interior_holes: List[Polyline] = list() + interior_hole: map_proto.Polyline + for interior_hole in road_area_obj.interior_holes: + interior_holes.append( + Polyline( + map_utils.proto_to_np(interior_hole, incl_heading=False) + + shifted_origin[:3] + ) + ) + + curr_area = RoadArea(elem_id, exterior, interior_holes) + map_elem_dict[MapElementType.ROAD_AREA][elem_id] = curr_area + + elif incl_ped_crosswalks and map_elem.HasField("ped_crosswalk"): + ped_crosswalk_obj: map_proto.PedCrosswalk = map_elem.ped_crosswalk + + polygon_vertices: Polyline = Polyline( + map_utils.proto_to_np(ped_crosswalk_obj.polygon, incl_heading=False) + + shifted_origin[:3] + ) + + curr_area = PedCrosswalk(elem_id, polygon_vertices) + map_elem_dict[MapElementType.PED_CROSSWALK][elem_id] = curr_area + + elif incl_ped_walkways and map_elem.HasField("ped_walkway"): + ped_walkway_obj: map_proto.PedCrosswalk = map_elem.ped_walkway + + polygon_vertices: Polyline = Polyline( + map_utils.proto_to_np(ped_walkway_obj.polygon, incl_heading=False) + + shifted_origin[:3] + ) + + curr_area = PedWalkway(elem_id, polygon_vertices) + map_elem_dict[MapElementType.PED_WALKWAY][elem_id] = curr_area + + return cls( + map_id=vec_map.name, + extent=np.array( + [ + vec_map.min_pt.x, + vec_map.min_pt.y, + vec_map.min_pt.z, + vec_map.max_pt.x, + vec_map.max_pt.y, + vec_map.max_pt.z, + ] + ), + elements=map_elem_dict, + search_kdtrees=None, + traffic_light_status=None, + ) + + def associate_scene_data( + self, traffic_light_status_dict: Dict[Tuple[int, int], TrafficLightStatus] + ) -> None: + """Associates vector map with scene-specific data like traffic light information""" + self.traffic_light_status = traffic_light_status_dict + + def get_closest_lane(self, xyzh: np.ndarray) -> RoadLane: + lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] + return self.lanes[lane_kdtree.closest_polyline_ind(xyzh)] + + def get_lanes_within(self, xyzh: np.ndarray, dist: float) -> List[RoadLane]: + lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] + return [ + self.lanes[idx] for idx in lane_kdtree.polyline_inds_in_range(xyzh, dist) + ] + + def get_traffic_light_status( + self, lane_id: str, scene_ts: int + ) -> TrafficLightStatus: + return ( + self.traffic_light_status.get( + (int(lane_id), scene_ts), TrafficLightStatus.NO_DATA + ) + if self.traffic_light_status is not None + else TrafficLightStatus.NO_DATA + ) + + def rasterize( + self, resolution: float = 2, **kwargs + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """Renders this vector map at the specified resolution. + + Args: + resolution (float): The rasterized image's resolution in pixels per meter. + + Returns: + np.ndarray: The rasterized RGB image. + """ + return_tf_mat: bool = kwargs.get("return_tf_mat", False) + incl_centerlines: bool = kwargs.get("incl_centerlines", True) + incl_lane_edges: bool = kwargs.get("incl_lane_edges", True) + incl_lane_area: bool = kwargs.get("incl_lane_area", True) + + scene_ts: Optional[int] = kwargs.get("scene_ts", None) + + # (255, 102, 99) also looks nice. + center_color: Tuple[int, int, int] = kwargs.get("center_color", (129, 51, 255)) + # (86, 203, 249) also looks nice. + edge_color: Tuple[int, int, int] = kwargs.get("edge_color", (118, 185, 0)) + # (191, 215, 234) also looks nice. + area_color: Tuple[int, int, int] = kwargs.get("area_color", (214, 232, 181)) + + min_x, min_y, _, max_x, max_y, _ = self.extent + + world_center_m: Tuple[float, float] = ( + (max_x + min_x) / 2, + (max_y + min_y) / 2, + ) + + raster_size_x: int = ceil((max_x - min_x) * resolution) + raster_size_y: int = ceil((max_y - min_y) * resolution) + + raster_from_local: np.ndarray = np.array( + [ + [resolution, 0, raster_size_x / 2], + [0, resolution, raster_size_y / 2], + [0, 0, 1], + ] + ) + + # Compute pose from its position and rotation. + pose_from_world: np.ndarray = np.array( + [ + [1, 0, -world_center_m[0]], + [0, 1, -world_center_m[1]], + [0, 0, 1], + ] + ) + + raster_from_world: np.ndarray = raster_from_local @ pose_from_world + + map_img: np.ndarray = np.zeros( + shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 + ) + + lane_edges: List[np.ndarray] = list() + centerlines: List[np.ndarray] = list() + lane: RoadLane + for lane in tqdm( + self.elements[MapElementType.ROAD_LANE].values(), + desc=f"Rasterizing Map at {resolution:.2f} px/m", + leave=False, + ): + centerlines.append( + raster_utils.world_to_subpixel( + lane.center.points[:, :2], raster_from_world + ) + ) + if lane.left_edge is not None and lane.right_edge is not None: + left_pts: np.ndarray = lane.left_edge.points[:, :2] + right_pts: np.ndarray = lane.right_edge.points[:, :2] + + lane_edges += [ + raster_utils.world_to_subpixel(left_pts, raster_from_world), + raster_utils.world_to_subpixel(right_pts, raster_from_world), + ] + + lane_color = area_color + status = self.get_traffic_light_status(lane.id, scene_ts) + if status == TrafficLightStatus.GREEN: + lane_color = [0, 200, 0] + elif status == TrafficLightStatus.RED: + lane_color = [200, 0, 0] + elif status == TrafficLightStatus.UNKNOWN: + lane_color = [150, 150, 0] + + # Drawing lane areas. Need to do per loop because doing it all at once can + # create lots of wonky holes in the image. + # See https://stackoverflow.com/questions/69768620/cv2-fillpoly-failing-for-intersecting-polygons + if incl_lane_area: + lane_area: np.ndarray = np.concatenate( + [left_pts, right_pts[::-1]], axis=0 + ) + raster_utils.rasterize_world_polygon( + lane_area, + map_img, + raster_from_world, + color=lane_color, + ) + + # Drawing all lane edge lines at the same time. + if incl_lane_edges: + raster_utils.cv2_draw_polylines(lane_edges, map_img, color=edge_color) + + # Drawing centerlines last (on top of everything else). + if incl_centerlines: + raster_utils.cv2_draw_polylines(centerlines, map_img, color=center_color) + + if return_tf_mat: + return map_img.astype(float) / 255, raster_from_world + else: + return map_img.astype(float) / 255 + + @overload + def visualize_lane_graph( + self, + origin_lane: RoadLane, + num_hops: int, + **kwargs, + ) -> Axes: + ... + + @overload + def visualize_lane_graph(self, origin_lane: str, num_hops: int, **kwargs) -> Axes: + ... + + @overload + def visualize_lane_graph(self, origin_lane: int, num_hops: int, **kwargs) -> Axes: + ... + + def visualize_lane_graph( + self, origin_lane: Union[RoadLane, str, int], num_hops: int, **kwargs + ) -> Axes: + ax = kwargs.get("ax", None) + if ax is None: + fig, ax = plt.subplots() + + origin: str + if isinstance(origin_lane, RoadLane): + origin = origin_lane.id + elif isinstance(origin_lane, str): + origin = origin_lane + elif isinstance(origin_lane, int): + origin = self.lanes[origin_lane].id + + viridis = mpl.colormaps[kwargs.get("cmap", "rainbow")].resampled(num_hops + 1) + + already_seen: Set[str] = set() + lanes_to_plot: List[Tuple[str, int]] = [(origin, 0)] + + if kwargs.get("legend", True): + ax.scatter([], [], label=f"Lane Endpoints", color="k") + ax.plot([], [], label=f"Origin Lane ({origin})", color=viridis(0)) + for h in range(1, num_hops + 1): + ax.plot( + [], + [], + label=f"{h} Lane{'s' if h > 1 else ''} Away", + color=viridis(h), + ) + + raster_from_world = kwargs.get("raster_from_world", None) + while len(lanes_to_plot) > 0: + lane_id, curr_hops = lanes_to_plot.pop(0) + already_seen.add(lane_id) + lane: RoadLane = self.get_road_lane(lane_id) + + center: np.ndarray = lane.center.points[..., :2] + first_pt_heading: float = lane.center.points[0, -1] + mdpt: np.ndarray = lane.center.midpoint[..., :2] + + if raster_from_world is not None: + center = map_utils.transform_points(center, raster_from_world) + mdpt = map_utils.transform_points(mdpt[None, :], raster_from_world)[0] + + ax.plot(center[:, 0], center[:, 1], color=viridis(curr_hops)) + ax.scatter(center[[0, -1], 0], center[[0, -1], 1], color=viridis(curr_hops)) + ax.quiver( + [center[0, 0]], + [center[0, 1]], + [np.cos(first_pt_heading)], + [np.sin(first_pt_heading)], + color=viridis(curr_hops), + ) + ax.text(mdpt[0], mdpt[1], s=lane_id) + + if curr_hops < num_hops: + lanes_to_plot += [ + (l, curr_hops + 1) + for l in lane.reachable_lanes + if l not in already_seen + ] + + if kwargs.get("legend", True): + ax.legend(loc="best", frameon=True) + + return ax diff --git a/src/trajdata/maps/vec_map_elements.py b/src/trajdata/maps/vec_map_elements.py new file mode 100644 index 0000000..d42cc1e --- /dev/null +++ b/src/trajdata/maps/vec_map_elements.py @@ -0,0 +1,171 @@ +from dataclasses import dataclass, field +from enum import IntEnum +from typing import List, Optional, Set + +import numpy as np + +from trajdata.utils import map_utils + + +class MapElementType(IntEnum): + ROAD_LANE = 1 + ROAD_AREA = 2 + PED_CROSSWALK = 3 + PED_WALKWAY = 4 + + +@dataclass +class Polyline: + points: np.ndarray + + def __post_init__(self) -> None: + if self.points.shape[-1] < 2: + raise ValueError( + f"Polylines are expected to have 2 (xy), 3 (xyz), or 4 (xyzh) dimensions, but received {self.points.shape[-1]}." + ) + + if self.points.shape[-1] == 2: + # If only xy are passed in, then append zero to the end for z. + self.points = np.append( + self.points, np.zeros_like(self.points[:, [0]]), axis=-1 + ) + + @property + def midpoint(self) -> np.ndarray: + num_pts: int = self.points.shape[0] + return self.points[num_pts // 2] + + @property + def has_heading(self) -> bool: + return self.points.shape[-1] == 4 + + @property + def xy(self) -> np.ndarray: + return self.points[..., :2] + + @property + def xyz(self) -> np.ndarray: + return self.points[..., :3] + + @property + def xyzh(self) -> np.ndarray: + if self.has_heading: + return self.points[..., :4] + else: + raise ValueError( + f"This Polyline only has {self.points.shape[-1]} coordinates, expected 4." + ) + + @property + def h(self) -> np.ndarray: + return self.points[..., 3] + + def interpolate( + self, num_pts: Optional[int] = None, max_dist: Optional[float] = None + ) -> "Polyline": + return Polyline( + map_utils.interpolate(self.points, num_pts=num_pts, max_dist=max_dist) + ) + + def project_onto(self, xyzh: np.ndarray) -> np.ndarray: + """Project the given points onto this Polyline. + + Args: + xyzh (np.ndarray): Points to project, of shape (M, D) + + Returns: + np.ndarray: The projected points, of shape (M, D) + + Note: + D = 4 if this Polyline has headings, otherwise D = 3 + """ + # xyzh is now (M, 1, 3), we do not use heading for projection. + xyz = xyzh[:, np.newaxis, :3] + + # p0, p1 are (1, N, 3) + p0: np.ndarray = self.points[np.newaxis, :-1, :3] + p1: np.ndarray = self.points[np.newaxis, 1:, :3] + + # 1. Compute projections of each point to each line segment in a + # batched manner. + line_seg_diffs: np.ndarray = p1 - p0 + point_seg_diffs: np.ndarray = xyz - p0 + + dot_products: np.ndarray = (point_seg_diffs * line_seg_diffs).sum( + axis=-1, keepdims=True + ) + norms: np.ndarray = np.linalg.norm(line_seg_diffs, axis=-1, keepdims=True) ** 2 + + # Clip ensures that the projected point stays within the line segment boundaries. + projs: np.ndarray = ( + p0 + np.clip(dot_products / norms, a_min=0, a_max=1) * line_seg_diffs + ) + + # 2. Find the nearest projections to the original points. + closest_proj_idxs: int = np.linalg.norm(xyz - projs, axis=-1).argmin(axis=-1) + + if self.has_heading: + # Adding in the heading of the corresponding p0 point (which makes + # sense as p0 to p1 is a line => same heading along it). + return np.concatenate( + [ + projs[range(xyz.shape[0]), closest_proj_idxs], + np.expand_dims(self.points[closest_proj_idxs, -1], axis=-1), + ], + axis=-1, + ) + else: + return projs[range(xyz.shape[0]), closest_proj_idxs] + + +@dataclass +class MapElement: + id: str + + +@dataclass +class RoadLane(MapElement): + center: Polyline + left_edge: Optional[Polyline] = None + right_edge: Optional[Polyline] = None + adj_lanes_left: Set[str] = field(default_factory=lambda: set()) + adj_lanes_right: Set[str] = field(default_factory=lambda: set()) + next_lanes: Set[str] = field(default_factory=lambda: set()) + prev_lanes: Set[str] = field(default_factory=lambda: set()) + elem_type: MapElementType = MapElementType.ROAD_LANE + + def __post_init__(self) -> None: + if not self.center.has_heading: + self.center = Polyline( + np.append( + self.center.xyz, + map_utils.get_polyline_headings(self.center.xyz), + axis=-1, + ) + ) + + def __hash__(self) -> int: + return hash(self.id) + + @property + def reachable_lanes(self) -> Set[str]: + return self.adj_lanes_left | self.adj_lanes_right | self.next_lanes + + +@dataclass +class RoadArea(MapElement): + exterior_polygon: Polyline + interior_holes: List[Polyline] = field(default_factory=lambda: list()) + elem_type: MapElementType = MapElementType.ROAD_AREA + + +@dataclass +class PedCrosswalk(MapElement): + polygon: Polyline + elem_type: MapElementType = MapElementType.PED_CROSSWALK + + +@dataclass +class PedWalkway(MapElement): + polygon: Polyline + elem_type: MapElementType = MapElementType.PED_WALKWAY diff --git a/src/trajdata/parallel/__init__.py b/src/trajdata/parallel/__init__.py index 8d98449..7f88dbd 100644 --- a/src/trajdata/parallel/__init__.py +++ b/src/trajdata/parallel/__init__.py @@ -1,2 +1 @@ from .data_preprocessor import ParallelDatasetPreprocessor, scene_paths_collate_fn -from .parallel_utils import parallel_apply, parallel_iapply diff --git a/src/trajdata/parallel/data_preprocessor.py b/src/trajdata/parallel/data_preprocessor.py index 2a0909f..de209a8 100644 --- a/src/trajdata/parallel/data_preprocessor.py +++ b/src/trajdata/parallel/data_preprocessor.py @@ -6,8 +6,7 @@ from trajdata.caching import EnvCache, SceneCache from trajdata.data_structures import Scene, SceneMetadata -from trajdata.utils import agent_utils -from trajdata.utils.env_utils import get_raw_dataset +from trajdata.utils import agent_utils, env_utils def scene_paths_collate_fn(filled_scenes: List) -> List: @@ -61,7 +60,7 @@ def __getitem__(self, idx: int) -> str: scene_idx: int = self.scene_name_idxs[idx] env_name: str = str(self.env_names_arr[env_idx], encoding="utf-8") - raw_dataset = get_raw_dataset( + raw_dataset = env_utils.get_raw_dataset( env_name, str(self.data_dir_arr[env_idx], encoding="utf-8") ) @@ -84,6 +83,13 @@ def __getitem__(self, idx: int) -> str: ) raw_dataset.del_dataset_obj() + if scene is None: + # This provides an escape hatch in case there's a reason we + # don't want to add a scene to the list of scenes. As an example, + # nuPlan has a scene with only a single frame of data which we + # can't do much with in terms of prediction/planning/etc. + return None + scene_path: Path = EnvCache.scene_metadata_path( env_cache.path, scene.env_name, scene.name, scene.dt ) diff --git a/src/trajdata/parallel/temp_cache.py b/src/trajdata/parallel/temp_cache.py deleted file mode 100644 index 9495e7d..0000000 --- a/src/trajdata/parallel/temp_cache.py +++ /dev/null @@ -1,45 +0,0 @@ -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import List, Optional, Union - -import dill - -from trajdata.data_structures.scene_metadata import Scene - - -class TemporaryCache: - def __init__(self, temp_dir: Optional[str] = None) -> None: - self.temp_dir: Optional[TemporaryDirectory] = None - if temp_dir is None: - self.temp_dir: TemporaryDirectory = TemporaryDirectory() - self.path: Path = Path(self.temp_dir.name) - else: - self.path: Path = Path(temp_dir) - - def cache(self, scene: Scene, ret_str: bool = False) -> Union[Path, str]: - tmp_file_path: Path = self.path / TemporaryCache.get_file_path(scene) - with open(tmp_file_path, "wb") as f: - dill.dump(scene, f) - - if ret_str: - return str(tmp_file_path) - else: - return tmp_file_path - - def cache_scenes(self, scenes: List[Scene]) -> List[str]: - paths: List[str] = list() - for scene in scenes: - tmp_file_path: Path = self.path / TemporaryCache.get_file_path(scene) - with open(tmp_file_path, "wb") as f: - dill.dump(scene, f) - - paths.append(str(tmp_file_path)) - - return paths - - def cleanup(self) -> None: - self.temp_dir.cleanup() - - @staticmethod - def get_file_path(scene_info: Scene) -> Path: - return f"{scene_info.env_name}_{scene_info.name}.dill" diff --git a/src/trajdata/proto/vectorized_map.proto b/src/trajdata/proto/vectorized_map.proto index 2a9ca7c..e1cd502 100644 --- a/src/trajdata/proto/vectorized_map.proto +++ b/src/trajdata/proto/vectorized_map.proto @@ -3,13 +3,20 @@ syntax = "proto3"; package trajdata; message VectorizedMap { + // The name of this map in the format environment_name:map_name + string name = 1; + // The full set of map elements. - repeated MapElement elements = 1; + repeated MapElement elements = 2; // The coordinates of the cuboid (in m) // containing all elements in this map. - optional Point max_pt = 2; - optional Point min_pt = 3; + Point max_pt = 3; + Point min_pt = 4; + + // The original world coordinates (in m) of the bottom-left of the map + // (account for a change in the origin for storage efficiency). + Point shifted_origin = 5; } message MapElement { @@ -26,9 +33,9 @@ message MapElement { } message Point { - optional double x = 1; - optional double y = 2; - optional double z = 3; + double x = 1; + double y = 2; + double z = 3; } message Polyline { @@ -40,6 +47,7 @@ message Polyline { repeated sint32 dx_mm = 1; repeated sint32 dy_mm = 2; repeated sint32 dz_mm = 3; + repeated double h_rad = 4; } message RoadLane { @@ -47,11 +55,11 @@ message RoadLane { // segments defined between consecutive points. Polyline center = 1; - // The polyline data for the left boundary of this lane. - Polyline left_boundary = 2; + // The polyline data for the (optional) left boundary of this lane. + optional Polyline left_boundary = 2; - // The polyline data for the right boundary of this lane. - Polyline right_boundary = 3; + // The polyline data for the (optional) right boundary of this lane. + optional Polyline right_boundary = 3; // A list of IDs for lanes that this lane may be entered from. repeated bytes entry_lanes = 4; diff --git a/src/trajdata/proto/vectorized_map_pb2.py b/src/trajdata/proto/vectorized_map_pb2.py index 96bb4f1..7352109 100644 --- a/src/trajdata/proto/vectorized_map_pb2.py +++ b/src/trajdata/proto/vectorized_map_pb2.py @@ -14,7 +14,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x14vectorized_map.proto\x12\x08trajdata"\x99\x01\n\rVectorizedMap\x12&\n\x08\x65lements\x18\x01 \x03(\x0b\x32\x14.trajdata.MapElement\x12$\n\x06max_pt\x18\x02 \x01(\x0b\x32\x0f.trajdata.PointH\x00\x88\x01\x01\x12$\n\x06min_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.PointH\x01\x88\x01\x01\x42\t\n\x07_max_ptB\t\n\x07_min_pt"\xd8\x01\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x42\x0e\n\x0c\x65lement_data"I\n\x05Point\x12\x0e\n\x01x\x18\x01 \x01(\x01H\x00\x88\x01\x01\x12\x0e\n\x01y\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12\x0e\n\x01z\x18\x03 \x01(\x01H\x02\x88\x01\x01\x42\x04\n\x02_xB\x04\n\x02_yB\x04\n\x02_z"7\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x11\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x11\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x11"\xe9\x01\n\x08RoadLane\x12"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12)\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.Polyline\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polylineb\x06proto3' + b'\n\x14vectorized_map.proto\x12\x08trajdata"\xb0\x01\n\rVectorizedMap\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x08\x65lements\x18\x02 \x03(\x0b\x32\x14.trajdata.MapElement\x12\x1f\n\x06max_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.Point\x12\x1f\n\x06min_pt\x18\x04 \x01(\x0b\x32\x0f.trajdata.Point\x12\'\n\x0eshifted_origin\x18\x05 \x01(\x0b\x32\x0f.trajdata.Point"\xd8\x01\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x42\x0e\n\x0c\x65lement_data"(\n\x05Point\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01"F\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x11\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x11\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x11\x12\r\n\x05h_rad\x18\x04 \x03(\x01"\x98\x02\n\x08RoadLane\x12"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12.\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.PolylineH\x00\x88\x01\x01\x12/\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.PolylineH\x01\x88\x01\x01\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c\x42\x10\n\x0e_left_boundaryB\x11\n\x0f_right_boundary"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polylineb\x06proto3' ) @@ -118,19 +118,19 @@ DESCRIPTOR._options = None _VECTORIZEDMAP._serialized_start = 35 - _VECTORIZEDMAP._serialized_end = 188 - _MAPELEMENT._serialized_start = 191 - _MAPELEMENT._serialized_end = 407 - _POINT._serialized_start = 409 - _POINT._serialized_end = 482 - _POLYLINE._serialized_start = 484 - _POLYLINE._serialized_end = 539 - _ROADLANE._serialized_start = 542 - _ROADLANE._serialized_end = 775 - _ROADAREA._serialized_start = 777 - _ROADAREA._serialized_end = 877 - _PEDCROSSWALK._serialized_start = 879 - _PEDCROSSWALK._serialized_end = 930 - _PEDWALKWAY._serialized_start = 932 - _PEDWALKWAY._serialized_end = 981 + _VECTORIZEDMAP._serialized_end = 211 + _MAPELEMENT._serialized_start = 214 + _MAPELEMENT._serialized_end = 430 + _POINT._serialized_start = 432 + _POINT._serialized_end = 472 + _POLYLINE._serialized_start = 474 + _POLYLINE._serialized_end = 544 + _ROADLANE._serialized_start = 547 + _ROADLANE._serialized_end = 827 + _ROADAREA._serialized_start = 829 + _ROADAREA._serialized_end = 929 + _PEDCROSSWALK._serialized_start = 931 + _PEDCROSSWALK._serialized_end = 982 + _PEDWALKWAY._serialized_start = 984 + _PEDWALKWAY._serialized_end = 1033 # @@protoc_insertion_point(module_scope) diff --git a/src/trajdata/simulation/sim_df_cache.py b/src/trajdata/simulation/sim_df_cache.py index d6e571e..c757b26 100644 --- a/src/trajdata/simulation/sim_df_cache.py +++ b/src/trajdata/simulation/sim_df_cache.py @@ -64,9 +64,7 @@ def get_agents_future( agents: List[AgentMetadata], future_sec: Tuple[Optional[float], Optional[float]], ) -> Tuple[np.ndarray, np.ndarray]: - last_timesteps = np.array( - [agent.last_timestep for agent in agents], dtype=np.long - ) + last_timesteps = np.array([agent.last_timestep for agent in agents], dtype=int) if np.all(np.greater(scene_ts, last_timesteps)): return ( diff --git a/src/trajdata/simulation/sim_scene.py b/src/trajdata/simulation/sim_scene.py index 56da477..8cca6ac 100644 --- a/src/trajdata/simulation/sim_scene.py +++ b/src/trajdata/simulation/sim_scene.py @@ -133,8 +133,8 @@ def get_obs( future_sec=self.dataset.future_sec, agent_interaction_distances=self.dataset.agent_interaction_distances, incl_robot_future=False, - incl_map=get_map and self.dataset.incl_map, - map_params=self.dataset.map_params, + incl_raster_map=get_map and self.dataset.incl_raster_map, + raster_map_params=self.dataset.raster_map_params, standardize_data=self.dataset.standardize_data, standardize_derivatives=self.dataset.standardize_derivatives, max_neighbor_num=self.dataset.max_neighbor_num, diff --git a/src/trajdata/simulation/sim_stats.py b/src/trajdata/simulation/sim_stats.py index 74f7fb5..499a900 100644 --- a/src/trajdata/simulation/sim_stats.py +++ b/src/trajdata/simulation/sim_stats.py @@ -83,12 +83,22 @@ def calc_stats( """ velocity: Tensor = ( - torch.diff(positions, dim=1, prepend=positions[:, [1]] - positions[:, [0]]) / dt + torch.diff( + positions, + dim=1, + prepend=positions[:, [0]] - (positions[:, [1]] - positions[:, [0]]), + ) + / dt ) velocity_norm: Tensor = torch.linalg.vector_norm(velocity, dim=-1) accel: Tensor = ( - torch.diff(positions, dim=1, prepend=velocity[:, [1]] - velocity[:, [0]]) / dt + torch.diff( + velocity, + dim=1, + prepend=velocity[:, [0]] - (velocity[:, [1]] - velocity[:, [0]]), + ) + / dt ) accel_norm: Tensor = torch.linalg.vector_norm(accel, dim=-1) @@ -96,7 +106,11 @@ def calc_stats( lat_acc: Tensor = accel_norm * torch.sin(heading.squeeze(-1)) jerk: Tensor = ( - torch.diff(accel_norm, dim=1, prepend=accel_norm[:, [1]] - accel_norm[:, [0]]) + torch.diff( + accel_norm, + dim=1, + prepend=accel_norm[:, [0]] - (accel_norm[:, [1]] - accel_norm[:, [0]]), + ) / dt ) diff --git a/src/trajdata/utils/agent_utils.py b/src/trajdata/utils/agent_utils.py index c457d90..1052a96 100644 --- a/src/trajdata/utils/agent_utils.py +++ b/src/trajdata/utils/agent_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Type, Union +from typing import Optional, Type from trajdata.caching import EnvCache, SceneCache from trajdata.data_structures import Scene, SceneMetadata @@ -47,6 +47,8 @@ def get_agent_data( agent_list, agent_presence = raw_dataset.get_agent_info( scene, env_cache.path, cache_class ) + if agent_list is None and agent_presence is None: + raise ValueError(f"Scene {scene_info.name} contains no agents!") scene.update_agent_info(agent_list, agent_presence) env_cache.save_scene(scene) diff --git a/src/trajdata/utils/arr_utils.py b/src/trajdata/utils/arr_utils.py index ef381e3..0499651 100644 --- a/src/trajdata/utils/arr_utils.py +++ b/src/trajdata/utils/arr_utils.py @@ -147,7 +147,7 @@ def transform_matrices(angles: Tensor, translations: Tensor) -> Tensor: ) -def batch_nd_transform_points_np(points, Mat): +def batch_nd_transform_points_np(points: np.ndarray, Mat: np.ndarray) -> np.ndarray: ndim = Mat.shape[-1] - 1 batch = list(range(Mat.ndim - 2)) + [Mat.ndim - 1] + [Mat.ndim - 2] Mat = np.transpose(Mat, batch) @@ -164,6 +164,24 @@ def batch_nd_transform_points_np(points, Mat): raise Exception("wrong shape") +def batch_nd_transform_angles_np(angles: np.ndarray, Mat: np.ndarray) -> np.ndarray: + cos_vals, sin_vals = Mat[..., 0, 0], Mat[..., 1, 0] + rot_angle = np.arctan2(sin_vals, cos_vals) + angles = angles + rot_angle + angles = angle_wrap(angles) + return angles + + +def batch_nd_transform_points_angles_np( + points_angles: np.ndarray, Mat: np.ndarray +) -> np.ndarray: + assert points_angles.shape[-1] == 3 + points = batch_nd_transform_points_np(points_angles[..., :2], Mat) + angles = batch_nd_transform_angles_np(points_angles[..., 2:3], Mat) + points_angles = np.concatenate([points, angles], axis=-1) + return points_angles + + def agent_aware_diff(values: np.ndarray, agent_ids: np.ndarray) -> np.ndarray: values_diff: np.ndarray = np.diff( values, axis=0, prepend=values[[0]] - (values[[1]] - values[[0]]) @@ -214,6 +232,7 @@ def batch_proj(x, line): delta_y, torch.unsqueeze(delta_psi, dim=-1), ) + elif isinstance(x, np.ndarray): delta = line[..., 0:2] - np.repeat( x[..., np.newaxis, 0:2], line_length, axis=-2 @@ -236,3 +255,11 @@ def batch_proj(x, line): delta_y, np.expand_dims(delta_psi, axis=-1), ) + + +def quaternion_to_yaw(q: np.ndarray): + # From https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L1025 + return np.arctan2( + 2 * (q[..., 0] * q[..., 3] - q[..., 1] * q[..., 2]), + 1 - 2 * (q[..., 2] ** 2 + q[..., 3] ** 2), + ) diff --git a/src/trajdata/utils/batch_utils.py b/src/trajdata/utils/batch_utils.py new file mode 100644 index 0000000..da0c663 --- /dev/null +++ b/src/trajdata/utils/batch_utils.py @@ -0,0 +1,86 @@ +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from trajdata.data_structures import ( + AgentBatch, + AgentBatchElement, + AgentType, + SceneBatchElement, + SceneTimeAgent, +) +from trajdata.data_structures.collation import agent_collate_fn + + +def convert_to_agent_batch( + scene_batch_element: SceneBatchElement, + only_types: Optional[List[AgentType]] = None, + no_types: Optional[List[AgentType]] = None, + agent_interaction_distances: Dict[Tuple[AgentType, AgentType], float] = defaultdict( + lambda: np.inf + ), + incl_map: bool = False, + map_params: Optional[Dict[str, Any]] = None, + max_neighbor_num: Optional[int] = None, + standardize_data: bool = True, + standardize_derivatives: bool = False, + pad_format: str = "outside", +) -> AgentBatch: + """ + Converts a SceneBatchElement into a AgentBatch consisting of + AgentBatchElements for all agents present at the given scene at the given + time step. + + Args: + scene_batch_element (SceneBatchElement): element to process + only_types (Optional[List[AgentType]], optional): AgentsTypes to consider. Defaults to None. + no_types (Optional[List[AgentType]], optional): AgentTypes to ignore. Defaults to None. + agent_interaction_distances (_type_, optional): Distance threshold for interaction. Defaults to defaultdict(lambda: np.inf). + incl_map (bool, optional): Whether to include map info. Defaults to False. + map_params (Optional[Dict[str, Any]], optional): Map params. Defaults to None. + max_neighbor_num (Optional[int], optional): Max number of neighbors to allow. Defaults to None. + standardize_data (bool): Whether to return data relative to current agent state. Defaults to True. + standardize_derivatives: Whether to transform relative velocities and accelerations as well. Defaults to False. + pad_format (str, optional): Pad format when collating agent trajectories. Defaults to "outside". + + Returns: + AgentBatch: batch of AgentBatchElements corresponding to all agents in the SceneBatchElement + """ + data_idx = scene_batch_element.data_index + cache = scene_batch_element.cache + scene = cache.scene + dt = scene_batch_element.dt + ts = scene_batch_element.scene_ts + + batch_elems: List[AgentBatchElement] = [] + for j, agent_name in enumerate(scene_batch_element.agent_names): + history_sec = dt * (scene_batch_element.agent_histories[j].shape[0] - 1) + future_sec = dt * (scene_batch_element.agent_futures[j].shape[0]) + cache.reset_transforms() + scene_time_agent: SceneTimeAgent = SceneTimeAgent.from_cache( + scene, + ts, + agent_name, + cache, + only_types=only_types, + no_types=no_types, + ) + + batch_elems.append( + AgentBatchElement( + cache=cache, + data_index=data_idx, + scene_time_agent=scene_time_agent, + history_sec=(history_sec, history_sec), + future_sec=(future_sec, future_sec), + agent_interaction_distances=agent_interaction_distances, + incl_raster_map=incl_map, + raster_map_params=map_params, + standardize_data=standardize_data, + standardize_derivatives=standardize_derivatives, + max_neighbor_num=max_neighbor_num, + ) + ) + + return agent_collate_fn(batch_elems, return_dict=False, pad_format=pad_format) diff --git a/src/trajdata/utils/df_utils.py b/src/trajdata/utils/df_utils.py new file mode 100644 index 0000000..697369a --- /dev/null +++ b/src/trajdata/utils/df_utils.py @@ -0,0 +1,98 @@ +from typing import Callable, Optional + +import numpy as np +import pandas as pd + + +def downsample_multi_index_df( + df: pd.DataFrame, downsample_dt_factor: int +) -> pd.DataFrame: + """ + Downsamples MultiIndex dataframe, assuming level=1 of the index + corresponds to the scene timestep. + """ + subsampled_df = df.groupby(level=0).apply( + lambda g: g.reset_index(level=0, drop=True) + .iloc[::downsample_dt_factor] + .rename(index=lambda ts: ts // downsample_dt_factor) + ) + + return subsampled_df + + +def upsample_ts_index_df( + df: pd.DataFrame, + upsample_dt_factor: int, + method: str, + preprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, + postprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, +): + """ + Upsamples a time indexed dataframe, applying specified method. + Calls preprocess and postprocess before and after upsampling repsectively. + + If original data is at frames 2,3,4,5, and upsample_dt_factor is 3, then + the original data will live at frames 6,9,12,15, and new data will + be generated according to method for frames 7,8, 10,11, 13,14 (frames after the last frame are not generated) + """ + if preprocess: + df = preprocess(df) + + # first, we multiply ts index by upsample factor + df = df.rename(index=lambda ts: ts * upsample_dt_factor) + + # get the index by adding the number of frames needed per original index + new_index = pd.Index( + (df.index.to_numpy()[:, None] + np.arange(upsample_dt_factor)).flatten()[ + : -(upsample_dt_factor - 1) + ], + name=df.index.name, + ) + + # reindex and interpolate according to method + df = df.reindex(new_index).interpolate(method=method, limit_area="inside") + + if postprocess: + df = postprocess(df) + + return df + + +def upsample_multi_index_df( + df: pd.DataFrame, + upsample_dt_factor: int, + method: str, + preprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, + postprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, +) -> pd.DataFrame: + return df.groupby(level=[0]).apply( + lambda g: upsample_ts_index_df( + g.reset_index(level=[0], drop=True), + upsample_dt_factor, + method, + preprocess, + postprocess, + ) + ) + + +def interpolate_multi_index_df( + df: pd.DataFrame, data_dt: float, desired_dt: float, method: str = "linear" +) -> pd.DataFrame: + """ + Interpolates the given dataframe indexed with (elem_id, scene_ts) + where scene_ts corresponds to timesteps with increment data_dt to a new + desired_dt. + """ + upsample_dt_ratio: float = data_dt / desired_dt + downsample_dt_ratio: float = desired_dt / data_dt + if not upsample_dt_ratio.is_integer() and not downsample_dt_ratio.is_integer(): + raise ValueError( + f"Data's dt of {data_dt}s " + f"is not integer divisible by the desired dt {desired_dt}s." + ) + + if upsample_dt_ratio >= 1: + return upsample_multi_index_df(df, int(upsample_dt_ratio), method) + elif downsample_dt_ratio >= 1: + return downsample_multi_index_df(df, int(downsample_dt_ratio)) diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py index 61980a0..2fdcd60 100644 --- a/src/trajdata/utils/env_utils.py +++ b/src/trajdata/utils/env_utils.py @@ -1,7 +1,7 @@ from typing import Dict, List -from trajdata.dataset_specific import RawDataset from trajdata.dataset_specific.eth_ucy_peds import EUPedsDataset +from trajdata.dataset_specific.raw_dataset import RawDataset try: from trajdata.dataset_specific.lyft import LyftDataset @@ -18,6 +18,14 @@ pass +try: + from trajdata.dataset_specific.nuplan import NuplanDataset +except ModuleNotFoundError: + # This can happen if the user did not install trajdata + # with the "trajdata[nuplan]" option. + pass + + def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: if "nusc" in dataset_name: return NuscDataset(dataset_name, data_dir, parallelizable=False, has_maps=True) @@ -30,6 +38,9 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: dataset_name, data_dir, parallelizable=True, has_maps=False ) + if "nuplan" in dataset_name: + return NuplanDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + raise ValueError(f"Dataset with name '{dataset_name}' is not supported") diff --git a/src/trajdata/utils/map_utils.py b/src/trajdata/utils/map_utils.py new file mode 100644 index 0000000..b4fe8e1 --- /dev/null +++ b/src/trajdata/utils/map_utils.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from tqdm import tqdm + +if TYPE_CHECKING: + from trajdata.maps import map_kdtree, vec_map + +from pathlib import Path +from typing import Dict, Final, Optional + +import dill +import numpy as np +from scipy.stats import circmean + +import trajdata.maps.vec_map_elements as vec_map_elems +import trajdata.proto.vectorized_map_pb2 as map_proto +from trajdata.utils import arr_utils + +MM_PER_M: Final[float] = 1000 + + +def decompress_values(data: np.ndarray) -> np.ndarray: + # From https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 + # The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from + # the origin. For subsequent points, this field stores the difference between the point's + # coordinates and the previous point's coordinates. This is for representation efficiency. + return np.cumsum(data, axis=0, dtype=float) / MM_PER_M + + +def compress_values(data: np.ndarray) -> np.ndarray: + return (np.diff(data, axis=0, prepend=0.0) * MM_PER_M).astype(np.int32) + + +def get_polyline_headings(points: np.ndarray) -> np.ndarray: + """Get approximate heading angles for points in a polyline. + + Args: + points: XY points, np.ndarray of shape [N, 2] + + Returns: + np.ndarray: approximate heading angles in radians, shape [N, 1] + """ + if points.ndim < 2 and points.shape[-1] != 2 and points.shape[-2] <= 1: + raise ValueError("Unexpected shape") + + vectors = points[..., 1:, :] - points[..., :-1, :] + vec_headings = np.arctan2(vectors[..., 1], vectors[..., 0]) # -pi..pi + + # For internal points compute the mean heading of consecutive segments. + # Need to use circular mean to average directions. + # TODO(pkarkus) this would be more accurate if weighted with the distance to the neighbor + if vec_headings.shape[-1] <= 1: + # Handle special case because circmean unfortunately returns nan for such input. + mean_consec_headings = np.zeros( + list(vec_headings.shape[:-1]) + [0], dtype=vec_headings.dtype + ) + else: + mean_consec_headings = circmean( + np.stack([vec_headings[..., :-1], vec_headings[..., 1:]], axis=-1), + high=np.pi, + low=-np.pi, + axis=-1, + ) + + headings = np.concatenate( + [ + vec_headings[..., :1], # heading of first segment + mean_consec_headings, # mean heading of consecutive segments + vec_headings[..., -1:], # heading of last segment + ], + axis=-1, + ) + return headings[..., np.newaxis] + + +def populate_lane_polylines( + new_lane_proto: map_proto.RoadLane, + road_lane_py: vec_map.RoadLane, + origin: np.ndarray, +) -> None: + """Fill a Lane object's polyline attributes. + All points should be in world coordinates. + + Args: + new_lane (Lane): _description_ + midlane_pts (np.ndarray): _description_ + left_pts (np.ndarray): _description_ + right_pts (np.ndarray): _description_ + """ + compressed_mid_pts: np.ndarray = compress_values(road_lane_py.center.xyz - origin) + new_lane_proto.center.dx_mm.extend(compressed_mid_pts[:, 0].tolist()) + new_lane_proto.center.dy_mm.extend(compressed_mid_pts[:, 1].tolist()) + new_lane_proto.center.dz_mm.extend(compressed_mid_pts[:, 2].tolist()) + new_lane_proto.center.h_rad.extend(road_lane_py.center.h.tolist()) + + if road_lane_py.left_edge is not None: + compressed_left_pts: np.ndarray = compress_values( + road_lane_py.left_edge.xyz - origin + ) + new_lane_proto.left_boundary.dx_mm.extend(compressed_left_pts[:, 0].tolist()) + new_lane_proto.left_boundary.dy_mm.extend(compressed_left_pts[:, 1].tolist()) + new_lane_proto.left_boundary.dz_mm.extend(compressed_left_pts[:, 2].tolist()) + + if road_lane_py.right_edge is not None: + compressed_right_pts: np.ndarray = compress_values( + road_lane_py.right_edge.xyz - origin + ) + new_lane_proto.right_boundary.dx_mm.extend(compressed_right_pts[:, 0].tolist()) + new_lane_proto.right_boundary.dy_mm.extend(compressed_right_pts[:, 1].tolist()) + new_lane_proto.right_boundary.dz_mm.extend(compressed_right_pts[:, 2].tolist()) + + +def populate_polygon( + polygon_proto: map_proto.Polyline, + polygon_pts: np.ndarray, + origin: np.ndarray, +) -> None: + """Fill an object's polygon. + All points should be in world coordinates. + + Args: + polygon_proto (Polyline): _description_ + polygon_pts (np.ndarray): _description_ + """ + compressed_pts: np.ndarray = compress_values(polygon_pts - origin) + + polygon_proto.dx_mm.extend(compressed_pts[:, 0].tolist()) + polygon_proto.dy_mm.extend(compressed_pts[:, 1].tolist()) + polygon_proto.dz_mm.extend(compressed_pts[:, 2].tolist()) + + +def proto_to_np(polyline: map_proto.Polyline, incl_heading: bool = True) -> np.ndarray: + dx: np.ndarray = np.asarray(polyline.dx_mm) + dy: np.ndarray = np.asarray(polyline.dy_mm) + + if len(polyline.dz_mm) > 0: + dz: np.ndarray = np.asarray(polyline.dz_mm) + pts: np.ndarray = np.stack([dx, dy, dz], axis=1) + else: + # Default z is all zeros. + pts: np.ndarray = np.stack([dx, dy, np.zeros_like(dx)], axis=1) + + ret_pts: np.ndarray = decompress_values(pts) + + if incl_heading and len(polyline.h_rad) > 0: + headings: np.ndarray = np.asarray(polyline.h_rad) + ret_pts = np.concatenate((ret_pts, headings[:, np.newaxis]), axis=1) + elif incl_heading: + raise ValueError( + f"Polyline must have heading, but it does not (polyline.h_rad is empty)." + ) + + return ret_pts + + +def transform_points(points: np.ndarray, transf_mat: np.ndarray): + n_dim = points.shape[-1] + return points @ transf_mat[:n_dim, :n_dim] + transf_mat[:n_dim, -1] + + +def order_matches(pts: np.ndarray, ref: np.ndarray) -> bool: + """Evaluate whether `pts0` is ordered the same as `ref`, based on the distance from + `pts0`'s start and end points to `ref`'s start point. + + Args: + pts0 (np.ndarray): The first array of points, of shape (N, D). + pts1 (np.ndarray): The second array of points, of shape (M, D). + + Returns: + bool: True if `pts0`'s first point is closest to `ref`'s first point, + False if `pts0`'s endpoint is closer (e.g., they are flipped relative to each other). + """ + return np.linalg.norm(pts[0] - ref[0]) <= np.linalg.norm(pts[-1] - ref[0]) + + +def endpoints_intersect(left_edge: np.ndarray, right_edge: np.ndarray) -> bool: + def ccw(A, B, C): + return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0]) + + A, B = left_edge[-1], right_edge[-1] + C, D = right_edge[0], left_edge[0] + return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D) + + +def interpolate( + pts: np.ndarray, num_pts: Optional[int] = None, max_dist: Optional[float] = None +) -> np.ndarray: + """ + Interpolate points either based on cumulative distances from the first one (`num_pts`) + or by adding extra points until neighboring points are within `max_dist` of each other. + + In particular, `num_pts` will interpolate using a variable step such that we always get + the requested number of points. + + Args: + pts (np.ndarray): XYZ(H) coords. + num_pts (int, optional): Desired number of total points. + max_dist (float, optional): Maximum distance between points of the polyline. + + Note: + Only one of `num_pts` or `max_dist` can be specified. + + Returns: + np.ndarray: The new interpolated coordinates. + """ + if num_pts is not None and max_dist is not None: + raise ValueError("Only one of num_pts or max_dist can be used!") + + if pts.ndim != 2: + raise ValueError("pts is expected to be 2 dimensional") + + # 3 because XYZ (heading does not count as a positional distance). + pos_dim: int = min(pts.shape[-1], 3) + has_heading: bool = pts.shape[-1] == 4 + + if num_pts is not None: + assert num_pts > 1, f"num_pts must be at least 2, but got {num_pts}" + + if pts.shape[0] == num_pts: + return pts + + cum_dist: np.ndarray = np.cumsum( + np.linalg.norm(np.diff(pts[..., :pos_dim], axis=0), axis=-1) + ) + cum_dist = np.insert(cum_dist, 0, 0) + + steps: np.ndarray = np.linspace(cum_dist[0], cum_dist[-1], num_pts) + xyz_inter: np.ndarray = np.empty((num_pts, pts.shape[-1]), dtype=pts.dtype) + for i in range(pos_dim): + xyz_inter[:, i] = np.interp(steps, xp=cum_dist, fp=pts[:, i]) + + if has_heading: + # Heading, so make sure to unwrap, interpolate, and wrap it. + xyz_inter[:, 3] = arr_utils.angle_wrap( + np.interp(steps, xp=cum_dist, fp=np.unwrap(pts[:, 3])) + ) + + return xyz_inter + + elif max_dist is not None: + unwrapped_pts: np.ndarray = pts + if has_heading: + unwrapped_pts[..., 3] = np.unwrap(unwrapped_pts[..., 3]) + + segments = unwrapped_pts[..., 1:, :] - unwrapped_pts[..., :-1, :] + seg_lens = np.linalg.norm(segments[..., :pos_dim], axis=-1) + new_pts = [unwrapped_pts[..., 0:1, :]] + for i in range(segments.shape[-2]): + num_extra_points = seg_lens[..., i] // max_dist + if num_extra_points > 0: + step_vec = segments[..., i, :] / (num_extra_points + 1) + new_pts.append( + unwrapped_pts[..., i, np.newaxis, :] + + step_vec[..., np.newaxis, :] + * np.arange(1, num_extra_points + 1)[:, np.newaxis] + ) + + new_pts.append(unwrapped_pts[..., i + 1 : i + 2, :]) + + new_pts = np.concatenate(new_pts, axis=-2) + if has_heading: + new_pts[..., 3] = arr_utils.angle_wrap(new_pts[..., 3]) + + return new_pts + + +def load_vector_map(vector_map_path: Path) -> map_proto.VectorizedMap: + if not vector_map_path.exists(): + raise ValueError(f"{vector_map_path} does not exist!") + + vec_map = map_proto.VectorizedMap() + + # Saving the vectorized map data. + with open(vector_map_path, "rb") as f: + vec_map.ParseFromString(f.read()) + + return vec_map + + +def load_kdtrees(kdtrees_path: Path) -> Dict[str, map_kdtree.MapElementKDTree]: + if not kdtrees_path.exists(): + raise ValueError(f"{kdtrees_path} does not exist!") + + with open(kdtrees_path, "rb") as f: + kdtrees: Dict[str, map_kdtree.MapElementKDTree] = dill.load(f) + + return kdtrees diff --git a/src/trajdata/parallel/parallel_utils.py b/src/trajdata/utils/parallel_utils.py similarity index 100% rename from src/trajdata/parallel/parallel_utils.py rename to src/trajdata/utils/parallel_utils.py diff --git a/src/trajdata/utils/raster_utils.py b/src/trajdata/utils/raster_utils.py new file mode 100644 index 0000000..120981f --- /dev/null +++ b/src/trajdata/utils/raster_utils.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trajdata.maps import VectorMap + + +from math import ceil +from typing import List, Tuple + +import cv2 +import numpy as np +from tqdm import tqdm + +from trajdata.maps.raster_map import RasterizedMap, RasterizedMapMetadata +from trajdata.maps.vec_map import MapElement, MapElementType +from trajdata.utils import map_utils + +# Sub-pixel drawing precision constants. +# See https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/rasterization/semantic_rasterizer.py#L16 +CV2_SUB_VALUES = {"shift": 9, "lineType": cv2.LINE_AA} +CV2_SHIFT_VALUE = 2 ** CV2_SUB_VALUES["shift"] + + +def cv2_subpixel(coords: np.ndarray) -> np.ndarray: + """ + Cast coordinates to numpy.int but keep fractional part by previously multiplying by 2**CV2_SHIFT + cv2 calls will use shift to restore original values with higher precision + + Args: + coords (np.ndarray): XY coords as float + + Returns: + np.ndarray: XY coords as int for cv2 shift draw + """ + return (coords * CV2_SHIFT_VALUE).astype(int) + + +def world_to_subpixel(pts: np.ndarray, raster_from_world: np.ndarray): + return cv2_subpixel(map_utils.transform_points(pts, raster_from_world)) + + +def cv2_draw_polygons( + polygon_pts: List[np.ndarray], + onto_img: np.ndarray, + color: Tuple[int, int, int], +) -> None: + cv2.fillPoly( + img=onto_img, + pts=polygon_pts, + color=color, + **CV2_SUB_VALUES, + ) + + +def cv2_draw_polylines( + polyline_pts: List[np.ndarray], + onto_img: np.ndarray, + color: Tuple[int, int, int], +) -> None: + cv2.polylines( + img=onto_img, + pts=polyline_pts, + isClosed=False, + color=color, + **CV2_SUB_VALUES, + ) + + +def rasterize_world_polygon( + polygon_pts: np.ndarray, + onto_img: np.ndarray, + raster_from_world: np.ndarray, + color: Tuple[int, int, int], +) -> None: + subpixel_area: np.ndarray = world_to_subpixel( + polygon_pts[..., :2], raster_from_world + ) + + # Drawing general road areas. + cv2_draw_polygons(polygon_pts=[subpixel_area], onto_img=onto_img, color=color) + + +def rasterize_world_polylines( + polyline_pts: List[np.ndarray], + onto_img: np.ndarray, + raster_from_world: np.ndarray, + color: Tuple[int, int, int], +) -> None: + subpixel_pts: List[np.ndarray] = [ + world_to_subpixel(pts[..., :2], raster_from_world) for pts in polyline_pts + ] + + # Drawing line. + cv2_draw_polylines( + polyline_pts=subpixel_pts, + onto_img=onto_img, + color=color, + ) + + +def rasterize_lane( + left_edge: np.ndarray, + right_edge: np.ndarray, + onto_img_area: np.ndarray, + onto_img_line: np.ndarray, + raster_from_world: np.ndarray, + area_color: Tuple[int, int, int], + line_color: Tuple[int, int, int], +) -> None: + lane_edges: List[np.ndarray] = [left_edge[:, :2], right_edge[::-1, :2]] + + # Drawing lane area. + rasterize_world_polygon( + np.concatenate(lane_edges, axis=0), + onto_img_area, + raster_from_world, + color=area_color, + ) + + # Drawing lane lines. + rasterize_world_polylines(lane_edges, onto_img_line, raster_from_world, line_color) + + +def rasterize_map( + vec_map: VectorMap, resolution: float, **pbar_kwargs +) -> RasterizedMap: + """Renders the semantic map at the given resolution. + + Args: + vec_map (VectorMap): _description_ + resolution (float): The rasterized image's resolution in pixels per meter. + + Returns: + np.ndarray: The rasterized RGB image. + """ + # extents is [min_x, min_y, min_z, max_x, max_y, max_z] + min_x, min_y, _, max_x, max_y, _ = vec_map.extent + world_center_m: Tuple[float, float] = ( + (min_x + max_x) / 2, + (min_y + max_y) / 2, + ) + + raster_size_x: int = ceil((max_x - min_x) * resolution) + raster_size_y: int = ceil((max_y - min_y) * resolution) + + raster_from_local: np.ndarray = np.array( + [ + [resolution, 0, raster_size_x / 2], + [0, resolution, raster_size_y / 2], + [0, 0, 1], + ] + ) + + # Compute pose from its position and rotation. + pose_from_world: np.ndarray = np.array( + [ + [1, 0, -world_center_m[0]], + [0, 1, -world_center_m[1]], + [0, 0, 1], + ] + ) + + raster_from_world: np.ndarray = raster_from_local @ pose_from_world + + lane_area_img: np.ndarray = np.zeros( + shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 + ) + lane_line_img: np.ndarray = np.zeros( + shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 + ) + ped_area_img: np.ndarray = np.zeros( + shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 + ) + + map_elem: MapElement + for map_elem in tqdm( + vec_map.iter_elems(), + desc=f"Rasterizing Map at {resolution:.2f} px/m", + total=len(vec_map), + **pbar_kwargs, + ): + if map_elem.elem_type == MapElementType.ROAD_LANE: + if map_elem.left_edge is not None and map_elem.right_edge is not None: + # Heading doesn't matter for rasterization. + left_pts: np.ndarray = map_elem.left_edge.xyz + right_pts: np.ndarray = map_elem.right_edge.xyz + + # Need to for-loop because doing it all at once can make holes. + # Drawing lane. + rasterize_lane( + left_pts, + right_pts, + lane_area_img, + lane_line_img, + raster_from_world, + area_color=(255, 0, 0), + line_color=(0, 255, 0), + ) + + # # This code helps visualize centerlines to check if the inferred headings are correct. + # center_pts = cv2_subpixel( + # transform_points( + # proto_to_np(map_elem.road_lane.center, incl_heading=False), + # raster_from_world, + # ) + # )[..., :2] + + # # Drawing lane centerlines. + # cv2.polylines( + # img=lane_line_img, + # pts=center_pts[None, :, :], + # isClosed=False, + # color=(255, 0, 0), + # **CV2_SUB_VALUES, + # ) + + # headings = np.asarray(map_elem.road_lane.center.h_rad) + # delta = cv2_subpixel(30*np.array([np.cos(headings[0]), np.sin(headings[0])])) + # cv2.arrowedLine(img=lane_line_img, pt1=tuple(center_pts[0]), pt2=tuple(center_pts[0] + 10*(center_pts[1] - center_pts[0])), color=(255, 0, 0), shift=9, line_type=cv2.LINE_AA) + # cv2.arrowedLine(img=lane_line_img, pt1=tuple(center_pts[0]), pt2=tuple(center_pts[0] + delta), color=(0, 255, 0), shift=9, line_type=cv2.LINE_AA) + + elif map_elem.elem_type == MapElementType.ROAD_AREA: + # Drawing general road areas. + rasterize_world_polygon( + map_elem.exterior_polygon.xy, + lane_area_img, + raster_from_world, + color=(255, 0, 0), + ) + + for interior_hole in map_elem.interior_holes: + # Removing holes. + rasterize_world_polygon( + interior_hole.xy, lane_area_img, raster_from_world, color=(0, 0, 0) + ) + + elif map_elem.elem_type in { + MapElementType.PED_CROSSWALK, + MapElementType.PED_WALKWAY, + }: + # Drawing crosswalks and walkways. + rasterize_world_polygon( + map_elem.polygon.xy, ped_area_img, raster_from_world, color=(0, 0, 255) + ) + + map_data: np.ndarray = (lane_area_img + lane_line_img + ped_area_img).astype( + np.float32 + ).transpose(2, 0, 1) / 255 + + rasterized_map_info = RasterizedMapMetadata( + name=vec_map.map_name, + shape=map_data.shape, + layers=["drivable_area", "lane_divider", "ped_area"], + layer_rgb_groups=([0], [1], [2]), + resolution=resolution, + map_from_world=raster_from_world, + ) + + return RasterizedMap(rasterized_map_info, map_data) diff --git a/src/trajdata/visualization/vis.py b/src/trajdata/visualization/vis.py index 3976722..db6bad4 100644 --- a/src/trajdata/visualization/vis.py +++ b/src/trajdata/visualization/vis.py @@ -1,8 +1,12 @@ from typing import Optional import matplotlib.pyplot as plt +import matplotlib.transforms as mtransforms +import numpy as np +import seaborn as sns import torch from matplotlib.axes import Axes +from matplotlib.patches import Circle, FancyBboxPatch from torch import Tensor from trajdata.data_structures.agent import AgentType @@ -10,10 +14,189 @@ from trajdata.maps import RasterizedMap +def draw_agent( + ax: Axes, + agent_type: AgentType, + agent_state: Tensor, + agent_extent: Tensor, + agent_to_world_tf: Tensor, + **kwargs, +) -> None: + """Draws a path with the correct location, heading, and dimensions onto the given axes + + Args: + ax (Axes): _description_ + agent_type (AgentType): _description_ + agent_state (Tensor): _description_ + agent_extent (Tensor): _description_ + agent_to_world_tf (Tensor): _description_ + """ + + if torch.any(torch.isnan(agent_extent)): + if agent_type == AgentType.VEHICLE: + length = 4.3 + width = 1.8 + elif agent_type == AgentType.PEDESTRIAN: + length = 0.5 + width = 0.5 + elif agent_type == AgentType.BICYCLE: + length = 1.9 + width = 0.5 + else: + length = 1.0 + width = 1.0 + else: + length = agent_extent[0].item() + width = agent_extent[1].item() + + patch = FancyBboxPatch( + [-length / 2, -width / 2], length, width, boxstyle="rarrow", **kwargs + ) + transform = ( + mtransforms.Affine2D() + .rotate(np.arctan2(agent_state[-2].item(), agent_state[-1].item())) + .translate(agent_state[0], agent_state[1]) + + mtransforms.Affine2D(matrix=agent_to_world_tf.cpu().numpy()) + + ax.transData + ) + patch.set_transform(transform) + + center_patch = Circle([0, 0], radius=0.25, **kwargs) + center_patch.set_transform(transform) + + ax.add_patch(patch) + ax.add_patch(center_patch) + + +def draw_history( + ax, + agent_type, + agent_history, + agent_extent, + agent_to_world_tf, + start_alpha=0.2, + end_alpha=0.5, + **kwargs, +): + T = agent_history.shape[0] + alphas = np.linspace(start_alpha, end_alpha, T) + for t in range(T): + draw_agent( + ax, + agent_type, + agent_history[t], + agent_extent, + agent_to_world_tf, + alpha=alphas[t], + **kwargs, + ) + + +def draw_map(ax: Axes, map: Tensor, base_frame_from_map_tf: Tensor, **kwargs): + patch_size: int = map.shape[-1] + map_array = RasterizedMap.to_img(map.cpu()) + brightened_map_array = map_array * 0.2 + 0.8 + + im = ax.imshow( + brightened_map_array, extent=[0, patch_size, patch_size, 0], **kwargs + ) + transform = ( + mtransforms.Affine2D(matrix=base_frame_from_map_tf.cpu().numpy()) + ax.transData + ) + im.set_transform(transform) + + +def plot_agent_batch_all( + batch: AgentBatch, + ax: Optional[Axes] = None, + show: bool = True, + close: bool = True, +) -> None: + if ax is None: + _, ax = plt.subplots() + + # Use first agent as common reference frame + base_frame_from_world_tf = batch.agents_from_world_tf[0].cpu() + + # plot maps over each other with proper transformations: + for i in range(len(batch.agent_name)): + base_frame_from_map_tf = base_frame_from_world_tf @ torch.linalg.inv( + batch.rasters_from_world_tf[i].cpu() + ) + draw_map(ax, batch.maps[i], base_frame_from_map_tf, alpha=1.0) + + for i in range(len(batch.agent_name)): + agent_type = batch.agent_type[i] + agent_name = batch.agent_name[i] + agent_hist = batch.agent_hist[i, :, :].cpu() + agent_fut = batch.agent_fut[i, :, :].cpu() + agent_extent = batch.agent_hist_extent[i, -1, :].cpu() + base_frame_from_agent_tf = base_frame_from_world_tf @ torch.linalg.inv( + batch.agents_from_world_tf[i].cpu() + ) + + palette = sns.color_palette("husl", 4) + if agent_type == AgentType.VEHICLE: + color = palette[0] + elif agent_type == AgentType.PEDESTRIAN: + color = palette[1] + elif agent_type == AgentType.BICYCLE: + color = palette[2] + else: + color = palette[3] + + transform = ( + mtransforms.Affine2D(matrix=base_frame_from_agent_tf.numpy()) + ax.transData + ) + draw_history( + ax, + agent_type, + agent_hist[:-1, :], + agent_extent, + base_frame_from_agent_tf, + color=color, + linewidth=0, + ) + ax.plot( + agent_hist[:, 0], + agent_hist[:, 1], + linestyle="-", + color=color, + transform=transform, + ) + draw_agent( + ax, + agent_type, + agent_hist[-1, :], + agent_extent, + base_frame_from_agent_tf, + facecolor=color, + edgecolor="k", + ) + ax.plot( + agent_fut[:, 0], + agent_fut[:, 1], + linestyle="--", + color=color, + transform=transform, + ) + + ax.set_ylim(-30, 40) + ax.set_xlim(-30, 40) + ax.grid(False) + + if show: + plt.show() + + if close: + plt.close() + + def plot_agent_batch( batch: AgentBatch, batch_idx: int, ax: Optional[Axes] = None, + legend: bool = True, show: bool = True, close: bool = True, ) -> None: @@ -73,13 +256,14 @@ def plot_agent_batch( ls="--", label="Agent History", ) - ax.quiver( - history_xy[..., 0], - history_xy[..., 1], - history_xy[..., -1], - history_xy[..., -2], - color="k", - ) + # ax.quiver( + # history_xy[..., 0], + # history_xy[..., 1], + # history_xy[..., -1], + # history_xy[..., -2], + # color="k", + # ) + ax.plot(future_xy[..., 0], future_xy[..., 1], c="violet", label="Agent Future") ax.scatter(center_xy[0], center_xy[1], s=20, c="orangered", label="Agent Current") @@ -123,15 +307,22 @@ def plot_agent_batch( ax.set_ylabel("y (m)") ax.grid(False) - ax.legend(loc="best", frameon=True) ax.axis("equal") + # Doing this because the imshow above makes the map origin at the top. + ax.invert_yaxis() + + if legend: + ax.legend(loc="best", frameon=True) + if show: plt.show() if close: plt.close() + return ax + def plot_scene_batch( batch: SceneBatch, @@ -241,6 +432,9 @@ def plot_scene_batch( ax.legend(loc="best", frameon=True) ax.axis("equal") + # Doing this because the imshow above makes the map origin at the top. + ax.invert_yaxis() + if show: plt.show() diff --git a/tests/test_batch_conversion.py b/tests/test_batch_conversion.py new file mode 100644 index 0000000..3fc0b25 --- /dev/null +++ b/tests/test_batch_conversion.py @@ -0,0 +1,192 @@ +import unittest +from collections import defaultdict + +import torch + +from trajdata import AgentType, UnifiedDataset +from trajdata.caching.env_cache import EnvCache +from trajdata.utils.batch_utils import convert_to_agent_batch + + +class TestSceneToAgentBatchConversion(unittest.TestCase): + def __init__(self, methodName: str = "batchConversion") -> None: + super().__init__(methodName) + + data_source = "nusc_mini" + history_sec = 2.0 + prediction_sec = 6.0 + + attention_radius = defaultdict( + lambda: 20.0 + ) # Default range is 20m unless otherwise specified. + attention_radius[(AgentType.PEDESTRIAN, AgentType.PEDESTRIAN)] = 10.0 + attention_radius[(AgentType.PEDESTRIAN, AgentType.VEHICLE)] = 20.0 + attention_radius[(AgentType.VEHICLE, AgentType.PEDESTRIAN)] = 20.0 + attention_radius[(AgentType.VEHICLE, AgentType.VEHICLE)] = 30.0 + + map_params = {"px_per_m": 2, "map_size_px": 100, "offset_frac_xy": (-0.75, 0.0)} + + self._scene_dataset = UnifiedDataset( + centric="scene", + desired_data=[data_source], + history_sec=(history_sec, history_sec), + future_sec=(prediction_sec, prediction_sec), + agent_interaction_distances=attention_radius, + incl_robot_future=False, + incl_raster_map=True, + raster_map_params=map_params, + only_predict=[AgentType.VEHICLE, AgentType.PEDESTRIAN], + no_types=[AgentType.UNKNOWN], + num_workers=0, + standardize_data=True, + data_dirs={ + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + self._agent_dataset = UnifiedDataset( + centric="agent", + desired_data=[data_source], + history_sec=(history_sec, history_sec), + future_sec=(prediction_sec, prediction_sec), + agent_interaction_distances=attention_radius, + incl_robot_future=False, + incl_raster_map=True, + raster_map_params=map_params, + only_predict=[AgentType.VEHICLE, AgentType.PEDESTRIAN], + no_types=[AgentType.UNKNOWN], + num_workers=0, + standardize_data=True, + data_dirs={ + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + def _assert_allclose_with_nans(self, tensor1, tensor2): + """ + asserts that the two tensors have nans in the same locations, and the non-nan + elements all are close. + """ + # Check nans are in the same place + self.assertFalse( + torch.any( # True if there's any mismatch + torch.logical_xor( # True where either tensor1 or tensor 2 has nans, but not both (mismatch) + torch.isnan(tensor1), # True where tensor1 has nans + torch.isnan(tensor2), # True where tensor2 has nans + ) + ), + msg="Nans occur in different places.", + ) + valid_mask = torch.logical_not(torch.isnan(tensor1)) + self.assertTrue( + torch.allclose(tensor1[valid_mask], tensor2[valid_mask]), + msg="Non-nan values don't match.", + ) + + def _test_agent_idx(self, agent_dataset_idx: int, verbose=False): + for offset in range(50): + agent_batch_element = self._agent_dataset[agent_dataset_idx] + agent_scene_path, _, _ = self._agent_dataset._data_index[agent_dataset_idx] + agent_batch = self._agent_dataset.get_collate_fn(pad_format="right")( + [agent_batch_element] + ) + scene_ts = agent_batch_element.scene_ts + scene_id = agent_batch_element.scene_id + agent_name = agent_batch_element.agent_name + if verbose: + print( + f"From the agent-centric dataset at index {agent_dataset_idx}, we're looking at:\nAgent {agent_name} in {scene_id} at timestep {scene_ts}" + ) + + # find same scene and ts in scene-centric dataset + scene_dataset_idx = 0 + for scene_dataset_idx in range(len(self._scene_dataset)): + scene_path, ts = self._scene_dataset._data_index[scene_dataset_idx] + if ts == scene_ts and scene_path == agent_scene_path: + # load scene to check scene name + scene = EnvCache.load(scene_path) + if scene.name == scene_id: + break + + if verbose: + print( + f"We found a matching scene in the scene-centric dataset at index {scene_dataset_idx}" + ) + + scene_batch_element = self._scene_dataset[scene_dataset_idx] + converted_agent_batch = convert_to_agent_batch( + scene_batch_element, + self._scene_dataset.only_types, + self._scene_dataset.no_types, + self._scene_dataset.agent_interaction_distances, + self._scene_dataset.incl_raster_map, + self._scene_dataset.raster_map_params, + self._scene_dataset.max_neighbor_num, + self._scene_dataset.standardize_data, + self._scene_dataset.standardize_derivatives, + pad_format="right", + ) + + agent_idx = -1 + for j, name in enumerate(converted_agent_batch.agent_name): + if name == agent_name: + agent_idx = j + + if agent_idx < 0: + if verbose: + print("no matching scene containing agent, checking next index") + agent_dataset_idx += 1 + else: + break + + self.assertTrue( + agent_idx >= 0, "Matching scene not found in scene-centric dataset!" + ) + + if verbose: + print( + f"Agent {converted_agent_batch.agent_name[agent_idx]} appears in {scene_batch_element.scene_id} at timestep {scene_batch_element.scene_ts}, as agent number {agent_idx}" + ) + + attrs_to_ignore = ["data_idx", "extras", "history_pad_dir"] + + variable_length_keys = { + "neigh_types": "num_neigh", + "neigh_hist": "num_neigh", + "neigh_hist_extents": "num_neigh", + "neigh_hist_len": "num_neigh", + "neigh_fut": "num_neigh", + "neigh_fut_extents": "num_neigh", + "neigh_fut_len": "num_neigh", + } + + for attr, val in converted_agent_batch.__dict__.items(): + if attr in attrs_to_ignore: + continue + if verbose: + print(f"Checking {attr}") + + if val is None: + self.assertTrue(agent_batch.__dict__[attr] is None) + elif type(val[agent_idx]) is torch.Tensor: + if attr in variable_length_keys: + attr_len = converted_agent_batch.__dict__[ + variable_length_keys[attr] + ][agent_idx] + convertedTensor = val[agent_idx, :attr_len, ...] + targetTensor = agent_batch.__dict__[attr][0, :attr_len, ...] + else: + convertedTensor = val[agent_idx] + targetTensor = agent_batch.__dict__[attr][0] + self._assert_allclose_with_nans(convertedTensor, targetTensor) + else: + self.assertTrue(val[agent_idx] == agent_batch.__dict__[attr][0]) + + def test_index_1(self): + self._test_agent_idx(0, verbose=False) + + def test_index_2(self): + self._test_agent_idx(116, verbose=False) + + def test_index_3(self): + self._test_agent_idx(222, verbose=False) diff --git a/tests/test_datasizes.py b/tests/test_datasizes.py index 1283230..6c3b00a 100644 --- a/tests/test_datasizes.py +++ b/tests/test_datasizes.py @@ -141,8 +141,8 @@ def test_interpolation(self): future_sec=(4.8, 4.8), only_types=[AgentType.VEHICLE], incl_robot_future=False, - incl_map=False, - map_params={ + incl_raster_map=False, + raster_map_params={ "px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0), diff --git a/tests/test_traffic_data.py b/tests/test_traffic_data.py new file mode 100644 index 0000000..e20c7a0 --- /dev/null +++ b/tests/test_traffic_data.py @@ -0,0 +1,158 @@ +import unittest + +from trajdata import UnifiedDataset +from trajdata.caching.df_cache import DataFrameCache + + +class TestTrafficLightData(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + kwargs = { + "desired_data": ["nuplan_mini-mini_val"], + "centric": "scene", + "history_sec": (3.2, 3.2), + "future_sec": (4.8, 4.8), + "incl_robot_future": False, + "incl_raster_map": True, + "cache_location": "~/.unified_data_cache", + "raster_map_params": { + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + "num_workers": 64, + "verbose": True, + "data_dirs": { # Remember to change this to match your filesystem! + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, + } + + cls.dataset = UnifiedDataset( + **kwargs, + desired_dt=0.05, + ) + + cls.downsampled_dataset = UnifiedDataset( + **kwargs, + desired_dt=0.1, + ) + + cls.upsampled_dataset = UnifiedDataset( + **kwargs, + desired_dt=0.025, + ) + + cls.scene_num: int = 100 + + def test_traffic_light_loading(self): + # get random scene + scene = self.dataset.get_scene(self.scene_num) + scene_cache = DataFrameCache(self.dataset.cache_path, scene) + traffic_light_status = scene_cache.get_traffic_light_status_dict() + + # just check if the loading works without errors + self.assertTrue(traffic_light_status is not None) + + def test_downsampling(self): + # get random scene from both datasets + scene = self.dataset.get_scene(self.scene_num) + downsampled_scene = self.downsampled_dataset.get_scene(self.scene_num) + + self.assertEqual(scene.name, downsampled_scene.name) + + scene_cache = DataFrameCache(self.dataset.cache_path, scene) + downsampled_scene_cache = DataFrameCache( + self.downsampled_dataset.cache_path, downsampled_scene + ) + traffic_light_status = scene_cache.get_traffic_light_status_dict() + downsampled_traffic_light_status = ( + downsampled_scene_cache.get_traffic_light_status_dict() + ) + + orig_lane_ids = set(key[0] for key in traffic_light_status.keys()) + downsampled_lane_ids = set( + key[0] for key in downsampled_traffic_light_status.keys() + ) + self.assertSetEqual(orig_lane_ids, downsampled_lane_ids) + + # check that matching indices match + for ( + lane_id, + scene_ts, + ), downsampled_status in downsampled_traffic_light_status.items(): + if scene_ts % 2 == 0: + try: + prev_status = traffic_light_status[lane_id, scene_ts * 2] + except KeyError: + prev_status = None + + try: + next_status = traffic_light_status[lane_id, scene_ts * 2 + 1] + except KeyError: + next_status = None + + self.assertTrue( + prev_status is not None or next_status is not None, + f"Lane {lane_id} at t={scene_ts} has status {downsampled_status} " + f"in the downsampled dataset, but neither t={2*scene_ts} nor " + f"t={2*scene_ts + 1} were found in the original dataset.", + ) + self.assertTrue( + downsampled_status == prev_status + or downsampled_status == next_status, + f"Lane {lane_id} at t={scene_ts*2, scene_ts*2 + 1} in the original dataset " + f"had status {prev_status, next_status}, but in the downsampled dataset, " + f"{lane_id} at t={scene_ts} had status {downsampled_status}", + ) + + def test_upsampling(self): + # get random scene from both datasets + scene = self.dataset.get_scene(self.scene_num) + upsampled_scene = self.upsampled_dataset.get_scene(self.scene_num) + scene_cache = DataFrameCache(self.dataset.cache_path, scene) + upsampled_scene_cache = DataFrameCache( + self.upsampled_dataset.cache_path, upsampled_scene + ) + traffic_light_status = scene_cache.get_traffic_light_status_dict() + upsampled_traffic_light_status = ( + upsampled_scene_cache.get_traffic_light_status_dict() + ) + + # check that matching indices match + for (lane_id, scene_ts), status in upsampled_traffic_light_status.items(): + if scene_ts % 2 == 0: + orig_status = traffic_light_status[lane_id, scene_ts // 2] + self.assertEqual( + status, + orig_status, + f"Lane {lane_id} at t={scene_ts // 2} in the original dataset " + f"had status {orig_status}, but in the upsampled dataset, " + f"{lane_id} at t={scene_ts} had status {status}", + ) + else: + try: + prev_status = traffic_light_status[lane_id, scene_ts // 2] + except KeyError: + prev_status = None + try: + next_status = traffic_light_status[lane_id, scene_ts // 2 + 1] + except KeyError as k: + next_status = None + + self.assertTrue( + prev_status is not None or next_status is not None, + f"Lane {lane_id} at t={scene_ts} has status {status} " + f"in the upsampled dataset, but neither t={scene_ts // 2} nor " + f"t={scene_ts // 2 + 1} were found in the original dataset.", + ) + + self.assertTrue( + status == prev_status or status == next_status, + f"Lane {lane_id} at t={scene_ts // 2, scene_ts // 2 + 1} in the original dataset " + f"had status {prev_status, next_status}, but in the upsampled dataset, " + f"{lane_id} at t={scene_ts} had status {status}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vec_map.py b/tests/test_vec_map.py new file mode 100644 index 0000000..da80041 --- /dev/null +++ b/tests/test_vec_map.py @@ -0,0 +1,58 @@ +import unittest +from pathlib import Path +from typing import Dict, List + +from trajdata import MapAPI, VectorMap + + +class TestVectorMap(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cache_path = Path("~/.unified_data_cache").expanduser() + cls.map_api = MapAPI(cache_path) + cls.proto_loading_kwargs = { + "incl_road_lanes": True, + "incl_road_areas": True, + "incl_ped_crosswalks": True, + "incl_ped_walkways": True, + } + + cls.location_dict: Dict[str, List[str]] = { + "nuplan_mini": ["boston", "singapore", "pittsburgh", "las_vegas"], + "nusc_mini": ["boston-seaport", "singapore-onenorth"], + "lyft_sample": ["palo_alto"], + } + + def test_map_existence(self): + for env_name, map_names in self.location_dict.items(): + for map_name in map_names: + vec_map: VectorMap = self.map_api.get_map( + f"{env_name}:{map_name}", **self.proto_loading_kwargs + ) + assert vec_map is not None + + def test_proto_equivalence(self): + for env_name, map_names in self.location_dict.items(): + for map_name in map_names: + vec_map: VectorMap = self.map_api.get_map( + f"{env_name}:{map_name}", **self.proto_loading_kwargs + ) + + assert maps_equal( + VectorMap.from_proto( + vec_map.to_proto(), **self.proto_loading_kwargs + ), + vec_map, + ) + + # TODO(bivanovic): Add more! + + +def maps_equal(map1: VectorMap, map2: VectorMap) -> bool: + elements1_set = set([elem.id for elem in map1.iter_elems()]) + elements2_set = set([elem.id for elem in map1.iter_elems()]) + return elements1_set == elements2_set + + +if __name__ == "__main__": + unittest.main()