From 4d36e67cc1a7c92815ccba379db5d789003bfb27 Mon Sep 17 00:00:00 2001 From: Boris Ivanovic Date: Thu, 6 Oct 2022 15:35:06 -0400 Subject: [PATCH] Adding a few new features like specifying the max number of agents and if only the ego should be the focus agent, as well as a hotfix bringing back rasterized maps in their original format. --- setup.cfg | 2 +- src/trajdata/caching/df_cache.py | 9 +- src/trajdata/caching/scene_cache.py | 1 + src/trajdata/data_structures/batch_element.py | 102 +++++++++--------- src/trajdata/dataset.py | 29 ++++- .../eth_ucy_peds/eupeds_dataset.py | 7 +- .../dataset_specific/lyft/lyft_dataset.py | 68 +++++++++++- .../dataset_specific/nusc/nusc_dataset.py | 102 +++++++++++++----- src/trajdata/dataset_specific/raw_dataset.py | 7 +- src/trajdata/filtering/filters.py | 6 +- src/trajdata/simulation/sim_scene.py | 1 + 11 files changed, 244 insertions(+), 90 deletions(-) diff --git a/setup.cfg b/setup.cfg index c48aa88..8ecbf21 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = trajdata -version = 1.0.7 +version = 1.0.8 author = Boris Ivanovic author_email = bivanovic@nvidia.com description = A unified interface to many trajectory forecasting datasets. diff --git a/src/trajdata/caching/df_cache.py b/src/trajdata/caching/df_cache.py index 4b855ef..8f9c9f9 100644 --- a/src/trajdata/caching/df_cache.py +++ b/src/trajdata/caching/df_cache.py @@ -637,13 +637,14 @@ def cache_map( @staticmethod def cache_map_layers( cache_path: Path, + vec_map: VectorizedMap, map_info: RasterizedMapMetadata, layer_fn: Callable[[str], np.ndarray], env_name: str, ) -> None: ( maps_path, - _, + vector_map_path, raster_map_path, raster_metadata_path, ) = DataFrameCache.get_map_paths( @@ -653,10 +654,16 @@ def cache_map_layers( # Ensuring the maps directory exists. maps_path.mkdir(parents=True, exist_ok=True) + # Saving the vectorized map data. + with open(vector_map_path, "wb") as f: + f.write(vec_map.SerializeToString()) + + # Saving the rasterized map data. disk_data = zarr.open_array(raster_map_path, mode="w", shape=map_info.shape) for idx, layer_name in enumerate(map_info.layers): disk_data[idx] = layer_fn(layer_name) + # Saving the rasterized map metadata. with open(raster_metadata_path, "wb") as f: dill.dump(map_info, f) diff --git a/src/trajdata/caching/scene_cache.py b/src/trajdata/caching/scene_cache.py index 261d891..8f8b1fc 100644 --- a/src/trajdata/caching/scene_cache.py +++ b/src/trajdata/caching/scene_cache.py @@ -140,6 +140,7 @@ def cache_map( @staticmethod def cache_map_layers( cache_path: Path, + vec_map: VectorizedMap, map_info: RasterizedMapMetadata, layer_fn: Callable[[str], np.ndarray], env_name: str, diff --git a/src/trajdata/data_structures/batch_element.py b/src/trajdata/data_structures/batch_element.py index 4c26b49..2b177f9 100644 --- a/src/trajdata/data_structures/batch_element.py +++ b/src/trajdata/data_structures/batch_element.py @@ -1,6 +1,6 @@ from collections import defaultdict from math import sqrt -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np @@ -26,9 +26,10 @@ def __init__( ] = defaultdict(lambda: np.inf), incl_robot_future: bool = False, incl_map: bool = False, - map_params: Optional[Dict[str, int]] = None, + map_params: Optional[Dict[str, Any]] = None, standardize_data: bool = False, standardize_derivatives: bool = False, + max_neighbor_num: Optional[int] = None, ) -> None: self.cache: SceneCache = cache self.data_index: int = data_index @@ -40,6 +41,7 @@ def __init__( agent_info: AgentMetadata = scene_time_agent.agent self.agent_name: str = agent_info.name self.agent_type: AgentType = agent_info.type + self.max_neighbor_num = max_neighbor_num self.curr_agent_state_np: np.ndarray = cache.get_state( agent_info.name, self.scene_ts @@ -156,82 +158,86 @@ def get_agent_future( ) return agent_future_np, agent_extent_future_np - def get_neighbor_history( + # @profile + def get_neighbor_data( self, scene_time: SceneTimeAgent, agent_info: AgentMetadata, - history_sec: Tuple[Optional[float], Optional[float]], + length_sec: Tuple[Optional[float], Optional[float]], distance_limit: Callable[[np.ndarray, int], np.ndarray], + mode: str, ) -> Tuple[int, np.ndarray, List[np.ndarray], List[np.ndarray], np.ndarray]: - # The indices of the returned ndarray match the scene_time agents list (including the index of the central agent, - # which would have a distance of 0 to itself). + # The indices of the returned ndarray match the scene_time agents list + # (including the index of the central agent, which would have a distance + # of 0 to itself). agent_distances: np.ndarray = scene_time.get_agent_distances_to(agent_info) agent_idx: int = scene_time.agents.index(agent_info) + neighbor_types: np.ndarray = np.array([a.type.value for a in scene_time.agents]) nearby_mask: np.ndarray = agent_distances <= distance_limit( neighbor_types, agent_info.type ) nearby_mask[agent_idx] = False + nb_idx = agent_distances.argsort() nearby_agents: List[AgentMetadata] = [ - agent for (idx, agent) in enumerate(scene_time.agents) if nearby_mask[idx] + scene_time.agents[idx] for idx in nb_idx if nearby_mask[idx] ] neighbor_types_np: np.ndarray = neighbor_types[nearby_mask] + if self.max_neighbor_num is not None: + # Pruning nearby_agents and re-creating + # neighbor_types_np with the remaining agents. + nearby_agents = nearby_agents[: self.max_neighbor_num] + neighbor_types_np: np.ndarray = np.array( + [a.type.value for a in nearby_agents] + ) + num_neighbors: int = len(nearby_agents) - ( - neighbor_histories, - neighbor_history_extents, - neighbor_history_lens_np, - ) = self.cache.get_agents_history(self.scene_ts, nearby_agents, history_sec) + + if mode == "history": + ( + neighbor_data, + neighbor_extents_data, + neighbor_data_lens_np, + ) = self.cache.get_agents_history(self.scene_ts, nearby_agents, length_sec) + elif mode == "future": + ( + neighbor_data, + neighbor_extents_data, + neighbor_data_lens_np, + ) = self.cache.get_agents_future(self.scene_ts, nearby_agents, length_sec) + else: + raise ValueError(f"Unknown mode {mode} passed in!") return ( num_neighbors, neighbor_types_np, - neighbor_histories, - neighbor_history_extents, - neighbor_history_lens_np, + neighbor_data, + neighbor_extents_data, + neighbor_data_lens_np, ) - # @profile - def get_neighbor_future( + def get_neighbor_history( self, scene_time: SceneTimeAgent, agent_info: AgentMetadata, - future_sec: Tuple[Optional[float], Optional[float]], + history_sec: Tuple[Optional[float], Optional[float]], distance_limit: Callable[[np.ndarray, int], np.ndarray], ) -> Tuple[int, np.ndarray, List[np.ndarray], List[np.ndarray], np.ndarray]: - scene_ts: int = self.scene_ts - - # The indices of the returned ndarray match the scene_time agents list (including the index of the central agent, - # which would have a distance of 0 to itself). - agent_distances: np.ndarray = scene_time.get_agent_distances_to(agent_info) - agent_idx: int = scene_time.agents.index(agent_info) - - neighbor_types: np.ndarray = np.array([a.type.value for a in scene_time.agents]) - nearby_mask: np.ndarray = agent_distances <= distance_limit( - neighbor_types, agent_info.type + return self.get_neighbor_data( + scene_time, agent_info, history_sec, distance_limit, mode="history" ) - nearby_mask[agent_idx] = False - - nearby_agents: List[AgentMetadata] = [ - agent for (idx, agent) in enumerate(scene_time.agents) if nearby_mask[idx] - ] - neighbor_types_np: np.ndarray = neighbor_types[nearby_mask] - - num_neighbors: int = len(nearby_agents) - ( - neighbor_futures, - neighbor_future_extents, - neighbor_future_lens_np, - ) = self.cache.get_agents_future(scene_ts, nearby_agents, future_sec) - return ( - num_neighbors, - neighbor_types_np, - neighbor_futures, - neighbor_future_extents, - neighbor_future_lens_np, + def get_neighbor_future( + self, + scene_time: SceneTimeAgent, + agent_info: AgentMetadata, + future_sec: Tuple[Optional[float], Optional[float]], + distance_limit: Callable[[np.ndarray, int], np.ndarray], + ) -> Tuple[int, np.ndarray, List[np.ndarray], List[np.ndarray], np.ndarray]: + return self.get_neighbor_data( + scene_time, agent_info, future_sec, distance_limit, mode="future" ) def get_robot_current_and_future( @@ -310,7 +316,7 @@ def __init__( ] = defaultdict(lambda: np.inf), incl_robot_future: bool = False, incl_map: bool = False, - map_params: Optional[Dict[str, int]] = None, + map_params: Optional[Dict[str, Any]] = None, standardize_data: bool = False, standardize_derivatives: bool = False, max_agent_num: Optional[int] = None, diff --git a/src/trajdata/dataset.py b/src/trajdata/dataset.py index eff2da7..5979b0b 100644 --- a/src/trajdata/dataset.py +++ b/src/trajdata/dataset.py @@ -3,7 +3,7 @@ from functools import partial from itertools import chain from pathlib import Path -from typing import Callable, Dict, Final, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Final, List, Optional, Set, Tuple, Union import numpy as np from torch.utils.data import DataLoader, Dataset @@ -61,7 +61,7 @@ def __init__( ] = defaultdict(lambda: np.inf), incl_robot_future: bool = False, incl_map: bool = False, - map_params: Optional[Dict[str, int]] = None, + map_params: Optional[Dict[str, Any]] = None, only_types: Optional[List[AgentType]] = None, only_predict: Optional[List[AgentType]] = None, no_types: Optional[List[AgentType]] = None, @@ -69,6 +69,8 @@ def __init__( standardize_derivatives: bool = False, augmentations: Optional[List[Augmentation]] = None, max_agent_num: Optional[int] = None, + max_neighbor_num: Optional[int] = None, + ego_only: Optional[bool] = False, data_dirs: Dict[str, str] = { # "nusc_trainval": "~/datasets/nuScenes", # "nusc_test": "~/datasets/nuScenes", @@ -104,7 +106,7 @@ def __init__( agent_interaction_distances: (Dict[Tuple[AgentType, AgentType], float]): A dictionary mapping agent-agent interaction distances in meters (determines which agents are included as neighbors to the predicted agent). Defaults to infinity for all types. incl_robot_future (bool, optional): Include the ego agent's future trajectory in batches (accordingly, never predict the ego's future). Defaults to False. incl_map (bool, optional): Include a local cropping of the rasterized map (if the dataset provides a map) per agent. Defaults to False. - map_params (Optional[Dict[str, int]], optional): Local map cropping parameters, must be specified if incl_map is True. Must contain keys {"px_per_m", "map_size_px"} and can optionally contain {"offset_frac_xy"}. Defaults to None. + map_params (Optional[Dict[str, Any]], optional): Local map cropping parameters, must be specified if incl_map is True. Must contain keys {"px_per_m", "map_size_px"} and can optionally contain {"offset_frac_xy"}. Defaults to None. only_types (Optional[List[AgentType]], optional): Filter out all agents EXCEPT for those of the specified types. Defaults to None. only_predict (Optional[List[AgentType]], optional): Only predict the specified types of agents. Importantly, this keeps other agent types in the scene, e.g., as neighbors of the agent to be predicted. Defaults to None. no_types (Optional[List[AgentType]], optional): Filter out all agents with the specified types. Defaults to None. @@ -112,6 +114,8 @@ def __init__( standardize_derivatives (bool, optional): Make agent velocities and accelerations relative to the agent being predicted. Defaults to False. augmentations (Optional[List[Augmentation]], optional): Perform the specified augmentations to the batch or dataset. Defaults to None. max_agent_num (int, optional): The maximum number of agents to include in a batch for scene-centric batching. + max_neighbor_num (int, optional): The maximum number of neighbors to include in a batch for agent-centric batching. + ego_only (bool, optional): If True, only return batches where the ego-agent is the one being predicted. data_dirs (Optional[Dict[str, str]], optional): Dictionary mapping dataset names to their directories on disk. Defaults to { "eupeds_eth": "~/datasets/eth_ucy_peds", "eupeds_hotel": "~/datasets/eth_ucy_peds", "eupeds_univ": "~/datasets/eth_ucy_peds", "eupeds_zara1": "~/datasets/eth_ucy_peds", "eupeds_zara2": "~/datasets/eth_ucy_peds", "nusc_mini": "~/datasets/nuScenes", "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", }. cache_type (str, optional): What type of cache to use to store preprocessed, cached data on disk. Defaults to "dataframe". cache_location (str, optional): Where to store and load preprocessed, cached data. Defaults to "~/.unified_data_cache". @@ -157,6 +161,8 @@ def __init__( self.extras = extras self.verbose = verbose self.max_agent_num = max_agent_num + self.max_neighbor_num = max_neighbor_num + self.ego_only = ego_only # Ensuring scene description queries are all lowercase if scene_description_contains is not None: @@ -178,6 +184,7 @@ def __init__( if any(env.name in dataset_tuple for dataset_tuple in matching_datasets): all_data_cached: bool = False all_maps_cached: bool = not env.has_maps or not self.incl_map + if self.env_cache.env_is_cached(env.name) and not self.rebuild_cache: scenes_list: List[Scene] = self.get_desired_scenes_from_env( matching_datasets, scene_description_contains, env @@ -224,7 +231,7 @@ def __init__( env.cache_maps( self.cache_path, self.cache_class, - resolution=self.map_params["px_per_m"], + self.map_params, ) scenes_list: List[SceneMetadata] = self.get_desired_scenes_from_env( @@ -294,7 +301,10 @@ def get_data_index( history_sec=self.history_sec, future_sec=self.future_sec, desired_dt=self.desired_dt, + ego_only=self.ego_only, ) + else: + raise ValueError(f"{self.centric}-centric data batches are not supported.") # data_index is either: # [(scene_path, total_index_len, valid_scene_ts)] for scene-centric data, or @@ -381,6 +391,7 @@ def _get_data_index_agent( future_sec: Tuple[Optional[float], Optional[float]], desired_dt: Optional[float], ret_scene_info: bool = False, + ego_only: bool = False, ) -> Tuple[Optional[Scene], Path, int, List[Tuple[str, np.ndarray]]]: index_elems_len: int = 0 index_elems: List[Tuple[str, np.ndarray]] = list() @@ -395,10 +406,15 @@ def _get_data_index_agent( ) for agent_info in filtered_agents: - # Don't want to predict the ego if we're going to be giving the model its future! + # Don't want to predict the ego if we're going to be + # giving the model its future! if incl_robot_future and agent_info.name == "ego": continue + if ego_only and agent_info.name != "ego": + # We only want to return the ego. + continue + valid_ts: Tuple[int, int] = filtering.get_valid_ts( agent_info, scene.dt, history_sec, future_sec ) @@ -443,6 +459,8 @@ def get_collate_fn( pad_format=pad_format, batch_augments=batch_augments, ) + else: + raise ValueError(f"{self.centric}-centric data batches are not supported.") return collate_fn @@ -698,6 +716,7 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: self.map_params, self.standardize_data, self.standardize_derivatives, + self.max_neighbor_num, ) for key, extra_fn in self.extras.items(): diff --git a/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py b/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py index 17a25b0..2bca501 100644 --- a/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py +++ b/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, Final, List, Optional, Tuple, Type +from typing import Any, Dict, Final, List, Optional, Tuple, Type import numpy as np import pandas as pd @@ -325,7 +325,10 @@ def cache_map( pass def cache_maps( - self, cache_path: Path, map_cache_class: Type[SceneCache], resolution: float + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], ) -> None: """ No maps in this dataset! diff --git a/src/trajdata/dataset_specific/lyft/lyft_dataset.py b/src/trajdata/dataset_specific/lyft/lyft_dataset.py index 84b7a41..aeb8634 100644 --- a/src/trajdata/dataset_specific/lyft/lyft_dataset.py +++ b/src/trajdata/dataset_specific/lyft/lyft_dataset.py @@ -1,3 +1,4 @@ +import warnings from collections import defaultdict from functools import partial from math import ceil @@ -403,8 +404,12 @@ def extract_vectorized(self, mapAPI: MapAPI) -> VectorizedMap: return vec_map def cache_maps( - self, cache_path: Path, map_cache_class: Type[SceneCache], resolution: float + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], ) -> None: + resolution: float = map_params["px_per_m"] map_name: str = "palo_alto" print(f"Caching {map_name} Map at {resolution:.2f} px/m...", flush=True) @@ -416,8 +421,65 @@ def cache_maps( world_to_ecef = np.array(dataset_meta["world_to_ecef"], dtype=np.float64) mapAPI = MapAPI(semantic_map_filepath, world_to_ecef) - vectorized_map: VectorizedMap = self.extract_vectorized(mapAPI) - map_data, map_from_world = map_utils.rasterize_map(vectorized_map, resolution) + if map_params.get("original_format", False): + warnings.warn( + "Using a dataset's original map format is deprecated, and will be removed in the next version of trajdata!", + FutureWarning, + ) + + mins = np.stack( + [ + map_elem["bounds"][:, 0].min(axis=0) + for map_elem in mapAPI.bounds_info.values() + ] + ).min(axis=0) + maxs = np.stack( + [ + map_elem["bounds"][:, 1].max(axis=0) + for map_elem in mapAPI.bounds_info.values() + ] + ).max(axis=0) + + world_right, world_top = maxs + world_left, world_bottom = mins + + world_center: np.ndarray = np.array( + [(world_left + world_right) / 2, (world_bottom + world_top) / 2] + ) + raster_size_px: np.ndarray = np.array( + [ + ceil((world_right - world_left) * resolution), + ceil((world_top - world_bottom) * resolution), + ] + ) + + render_context = RenderContext( + raster_size_px=raster_size_px, + pixel_size_m=np.array([1 / resolution, 1 / resolution]), + center_in_raster_ratio=np.array([0.5, 0.5]), + set_origin_to_bottom=False, + ) + + map_from_world: np.ndarray = render_context.raster_from_world( + world_center, 0.0 + ) + + rasterizer = MapSemanticRasterizer( + render_context, semantic_map_filepath, world_to_ecef + ) + + print("Rendering palo_alto Map...", flush=True, end=" ") + map_data: np.ndarray = rasterizer.render_semantic_map( + world_center, map_from_world + ) + print("done!", flush=True) + + vectorized_map = VectorizedMap() + else: + vectorized_map: VectorizedMap = self.extract_vectorized(mapAPI) + map_data, map_from_world = map_utils.rasterize_map( + vectorized_map, resolution + ) rasterized_map_info: RasterizedMapMetadata = RasterizedMapMetadata( name=map_name, diff --git a/src/trajdata/dataset_specific/nusc/nusc_dataset.py b/src/trajdata/dataset_specific/nusc/nusc_dataset.py index ab76a21..401658a 100644 --- a/src/trajdata/dataset_specific/nusc/nusc_dataset.py +++ b/src/trajdata/dataset_specific/nusc/nusc_dataset.py @@ -1,7 +1,7 @@ import warnings from copy import deepcopy from pathlib import Path -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import numpy as np import pandas as pd @@ -496,37 +496,89 @@ def cache_map( map_name: str, cache_path: Path, map_cache_class: Type[SceneCache], - resolution: float, + map_params: Dict[str, Any], ) -> None: - """ - resolution is in pixels per meter. - """ + resolution: float = map_params["px_per_m"] + nusc_map: NuScenesMap = NuScenesMap( dataroot=self.metadata.data_dir, map_name=map_name ) - vectorized_map: VectorizedMap = self.extract_vectorized(nusc_map) + if map_params.get("original_format", False): + warnings.warn( + "Using a dataset's original map format is deprecated, and will be removed in the next version of trajdata!", + FutureWarning, + ) - pbar_kwargs = {"position": 2, "leave": False} - map_data, map_from_world = map_utils.rasterize_map( - vectorized_map, resolution, **pbar_kwargs - ) + width_m, height_m = nusc_map.canvas_edge + height_px, width_px = round(height_m * resolution), round( + width_m * resolution + ) - rasterized_map_info: RasterizedMapMetadata = RasterizedMapMetadata( - name=map_name, - shape=map_data.shape, - layers=["drivable_area", "lane_divider", "ped_area"], - layer_rgb_groups=([0], [1], [2]), - resolution=resolution, - map_from_world=map_from_world, - ) - rasterized_map_obj: RasterizedMap = RasterizedMap(rasterized_map_info, map_data) - map_cache_class.cache_map( - cache_path, vectorized_map, rasterized_map_obj, self.name - ) + def layer_fn(layer_name: str) -> np.ndarray: + # Getting rid of the channels dim by accessing index [0] + return nusc_map.get_map_mask( + patch_box=None, + patch_angle=0, + layer_names=[layer_name], + canvas_size=(height_px, width_px), + )[0].astype(np.bool) + + map_from_world: np.ndarray = np.array( + [[resolution, 0.0, 0.0], [0.0, resolution, 0.0], [0.0, 0.0, 1.0]] + ) + + layer_names: List[str] = [ + "lane", + "road_segment", + "drivable_area", + "road_divider", + "lane_divider", + "ped_crossing", + "walkway", + ] + map_info: RasterizedMapMetadata = RasterizedMapMetadata( + name=map_name, + shape=(len(layer_names), height_px, width_px), + layers=layer_names, + layer_rgb_groups=([0, 1, 2], [3, 4], [5, 6]), + resolution=resolution, + map_from_world=map_from_world, + ) + + map_cache_class.cache_map_layers( + cache_path, VectorizedMap(), map_info, layer_fn, self.name + ) + else: + vectorized_map: VectorizedMap = self.extract_vectorized(nusc_map) + + pbar_kwargs = {"position": 2, "leave": False} + map_data, map_from_world = map_utils.rasterize_map( + vectorized_map, resolution, **pbar_kwargs + ) + + rasterized_map_info: RasterizedMapMetadata = RasterizedMapMetadata( + name=map_name, + shape=map_data.shape, + layers=["drivable_area", "lane_divider", "ped_area"], + layer_rgb_groups=([0], [1], [2]), + resolution=resolution, + map_from_world=map_from_world, + ) + + rasterized_map_obj: RasterizedMap = RasterizedMap( + rasterized_map_info, map_data + ) + + map_cache_class.cache_map( + cache_path, vectorized_map, rasterized_map_obj, self.name + ) def cache_maps( - self, cache_path: Path, map_cache_class: Type[SceneCache], resolution: float + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], ) -> None: """ Stores rasterized maps to disk for later retrieval. @@ -551,7 +603,7 @@ def cache_maps( """ for map_name in tqdm( locations, - desc=f"Caching {self.name} Maps at {resolution:.2f} px/m", + desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", position=0, ): - self.cache_map(map_name, cache_path, map_cache_class, resolution) + self.cache_map(map_name, cache_path, map_cache_class, map_params) diff --git a/src/trajdata/dataset_specific/raw_dataset.py b/src/trajdata/dataset_specific/raw_dataset.py index c09843c..b26f392 100644 --- a/src/trajdata/dataset_specific/raw_dataset.py +++ b/src/trajdata/dataset_specific/raw_dataset.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, NamedTuple, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Type, Union from trajdata.caching import EnvCache, SceneCache from trajdata.data_structures import ( @@ -84,7 +84,10 @@ def get_agent_info( raise NotImplementedError() def cache_maps( - self, cache_path: Path, map_cache_class: Type[SceneCache], resolution: float + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], ) -> None: """ resolution is in pixels per meter. diff --git a/src/trajdata/filtering/filters.py b/src/trajdata/filtering/filters.py index 4c42be1..de198da 100644 --- a/src/trajdata/filtering/filters.py +++ b/src/trajdata/filtering/filters.py @@ -8,13 +8,13 @@ def agent_types( agents: List[AgentMetadata], no_types: Set[AgentType], only_types: Set[AgentType] ) -> List[AgentMetadata]: agents_list: List[AgentMetadata] = agents - + if no_types is not None: agents_list = [agent for agent in agents_list if agent.type not in no_types] - + if only_types is not None: agents_list = [agent for agent in agents_list if agent.type in only_types] - + return agents_list diff --git a/src/trajdata/simulation/sim_scene.py b/src/trajdata/simulation/sim_scene.py index 998495b..56da477 100644 --- a/src/trajdata/simulation/sim_scene.py +++ b/src/trajdata/simulation/sim_scene.py @@ -137,6 +137,7 @@ def get_obs( map_params=self.dataset.map_params, standardize_data=self.dataset.standardize_data, standardize_derivatives=self.dataset.standardize_derivatives, + max_neighbor_num=self.dataset.max_neighbor_num, ) )