Skip to content

Commit

Permalink
New vector map API! Lots of changes related to that, starting to supp…
Browse files Browse the repository at this point in the history
…ort nuPlan, code reorganization, and general bugfixes.
  • Loading branch information
BorisIvanovic committed Nov 18, 2022
1 parent fe85a94 commit 4351479
Show file tree
Hide file tree
Showing 66 changed files with 4,870 additions and 1,234 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Currently, the dataloader supports interfacing with the following datasets:
| nuScenes Train/TrainVal/Val | `nusc_trainval` | `train`, `train_val`, `val` | `boston`, `singapore` | nuScenes prediction challenge training/validation/test splits (500/200/150 scenes) | 0.5s (2Hz) | :white_check_mark: |
| nuScenes Test | `nusc_test` | `test` | `boston`, `singapore` | nuScenes' test split, no annotations (150 scenes) | 0.5s (2Hz) | :white_check_mark: |
| nuScenes Mini | `nusc_mini` | `mini_train`, `mini_val` | `boston`, `singapore` | nuScenes mini training/validation splits (8/2 scenes) | 0.5s (2Hz) | :white_check_mark: |
| nuPlan Mini | `nuplan_mini` | `mini_train`, `mini_val`, `mini_test` | `boston`, `singapore`, `pittsburgh`, `las_vegas` | nuPlan mini training/validation/test splits (942/197/224 scenes) | 0.05s (20Hz) | :white_check_mark: |
| Lyft Level 5 Train | `lyft_train` | `train` | `palo_alto` | Lyft Level 5 training data - part 1/2 (8.4 GB) | 0.1s (10Hz) | :white_check_mark: |
| Lyft Level 5 Train Full | `lyft_train_full` | `train` | `palo_alto` | Lyft Level 5 training data - part 2/2 (70 GB) | 0.1s (10Hz) | :white_check_mark: |
| Lyft Level 5 Validation | `lyft_val` | `val` | `palo_alto` | Lyft Level 5 validation data (8.2 GB) | 0.1s (10Hz) | :white_check_mark: |
Expand Down Expand Up @@ -127,6 +128,19 @@ dataset = UnifiedDataset(

**Note**: Be careful about loading multiple datasets without an associated `desired_dt` argument; many datasets do not share the same underlying data annotation frequency. To address this, we've implemented timestep interpolation to a common frequency which will ensure that all batched data shares the same dt. Interpolation can only be performed to integer multiples of the original data annotation frequency. For example, nuScenes' `dt=0.5` and the ETH BIWI dataset's `dt=0.4` can be interpolated to a common `desired_dt=0.1`.

## Map API
`trajdata` also provides an API to access the raw vector map information from datasets that provide it.

```py
from pathlib import Path
from trajdata import MapAPI, VectorMap

cache_path = Path("~/.unified_data_cache").expanduser()
map_api = MapAPI(cache_path)

vector_map: VectorMap = map_api.get_map("nusc_mini:boston-seaport")
```

## Simulation Interface
One additional feature of trajdata is that it can be used to initialize simulations from real data and track resulting agent motion, metrics, etc.

Expand Down Expand Up @@ -159,7 +173,7 @@ sim_scene = SimulationScene(
)

obs: AgentBatch = sim_scene.reset()
for t in range(1, sim_scene.scene_info.length_timesteps):
for t in range(1, sim_scene.scene.length_timesteps):
new_xyh_dict: Dict[str, np.ndarray] = dict()

# Everything inside the forloop just sets
Expand All @@ -181,4 +195,3 @@ for t in range(1, sim_scene.scene_info.length_timesteps):
## TODO
- Create a method like finalize() which writes all the batch information to a TFRecord/WebDataset/some other format which is (very) fast to read from for higher epoch training.
- Add more examples to the README.

8 changes: 6 additions & 2 deletions examples/batch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ def main():
only_predict=[AgentType.VEHICLE],
agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=False,
incl_map=True,
map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)},
incl_raster_map=True,
raster_map_params={
"px_per_m": 2,
"map_size_px": 224,
"offset_frac_xy": (-0.5, 0.0),
},
augmentations=[noise_hists],
num_workers=0,
verbose=True,
Expand Down
97 changes: 97 additions & 0 deletions examples/cache_and_filter_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
from collections import defaultdict

from torch.utils.data import DataLoader
from tqdm import tqdm

from trajdata import AgentBatch, AgentType, UnifiedDataset
from trajdata.augmentation import NoiseHistories
from trajdata.data_structures.batch_element import AgentBatchElement
from trajdata.visualization.vis import plot_agent_batch


def main():
noise_hists = NoiseHistories()

create_dataset = lambda: UnifiedDataset(
desired_data=["nusc_mini-mini_val"],
centric="agent",
desired_dt=0.5,
history_sec=(2.0, 2.0),
future_sec=(4.0, 4.0),
only_predict=[AgentType.VEHICLE],
agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=False,
incl_raster_map=False,
# map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)},
augmentations=[noise_hists],
num_workers=0,
verbose=True,
data_dirs={ # Remember to change this to match your filesystem!
"nusc_mini": "~/datasets/nuScenes",
},
)

dataset = create_dataset()

print(f"# Data Samples: {len(dataset):,}")

print(
"To demonstrate how to use caching we will first save the "
"entire dataset (all BatchElements) to a cache file and then load from "
"the cache file. Note that for large datasets and/or high time resolution "
"this will create a large file and will use a lot of RAM."
)
cache_path = "./temp_cache_file.dill"

print(
"We also use a custom filter function that only keeps elements with more "
"than 5 neighbors"
)

def my_filter(el: AgentBatchElement) -> bool:
return el.num_neighbors > 5

print(
f"In the first run we will iterate through the entire dataset and save all "
f"BatchElements to the cache file {cache_path}"
)
print("This may take several minutes.")
dataset.load_or_create_cache(
cache_path=cache_path, num_workers=0, filter_fn=my_filter
)
assert os.path.isfile(cache_path)

print(
"To demonstrate a consecuitve run we create a new dataset and load elements "
"from the cache file."
)
del dataset
dataset = create_dataset()

dataset.load_or_create_cache(
cache_path=cache_path, num_workers=0, filter_fn=my_filter
)

# Remove the temp cache file, we dont need it anymore.
os.remove(cache_path)

print(
"We can iterate through the dataset the same way as normally, but this "
"time it will be much faster because all BatchElements are in memory."
)
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
collate_fn=dataset.get_collate_fn(),
num_workers=0,
)

batch: AgentBatch
for batch in tqdm(dataloader):
plot_agent_batch(batch, batch_idx=0)


if __name__ == "__main__":
main()
10 changes: 6 additions & 4 deletions examples/custom_batch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from tqdm import tqdm

from trajdata import AgentBatch, AgentType, UnifiedDataset
from trajdata.augmentation import NoiseHistories
from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement
from trajdata.visualization.vis import plot_agent_batch


def custom_random_data(
Expand Down Expand Up @@ -74,8 +72,12 @@ def main():
only_types=[AgentType.VEHICLE],
agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=False,
incl_map=True,
map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)},
incl_raster_map=True,
raster_map_params={
"px_per_m": 2,
"map_size_px": 224,
"offset_frac_xy": (-0.5, 0.0),
},
num_workers=0,
verbose=True,
data_dirs={ # Remember to change this to match your filesystem!
Expand Down
124 changes: 124 additions & 0 deletions examples/lane_query_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
This is an example of how to extend a batch with lane information
"""

from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm

from trajdata import AgentBatch, AgentType, UnifiedDataset
from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement
from trajdata.maps import VectorMap
from trajdata.maps.vec_map_elements import RoadLane
from trajdata.utils.arr_utils import batch_nd_transform_points_np
from trajdata.visualization.vis import plot_agent_batch


def get_closest_lane_point(element: AgentBatchElement) -> np.ndarray:
"""Closest lane for predicted agent."""

# Transform from agent coordinate frame to world coordinate frame.
vector_map: VectorMap = element.vec_map
world_from_agent_tf = np.linalg.inv(element.agent_from_world_tf)
agent_future_xy_world = batch_nd_transform_points_np(
element.agent_future_np[:, :2], world_from_agent_tf
)

# Use cached kdtree to find closest lane point
lane_points_world = []
for xy_world in agent_future_xy_world:
point_4d = np.array([[xy_world[0], xy_world[1], 0.0, 0.0]])
closest_lane: RoadLane = vector_map.get_closest_lane(point_4d.squeeze(axis=0))
lane_points_world.append(closest_lane.center.project_onto(point_4d))

lane_points_world = np.concatenate(lane_points_world, axis=0)

# Transform lane points to agent coordinate frame
lane_points = batch_nd_transform_points_np(
lane_points_world[:, :2], element.agent_from_world_tf
)

return lane_points


def main():
dataset = UnifiedDataset(
desired_data=[
"nusc_mini-mini_train",
"lyft_sample-mini_val",
],
centric="agent",
desired_dt=0.1,
history_sec=(3.2, 3.2),
future_sec=(4.8, 4.8),
only_types=[AgentType.VEHICLE],
agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=False,
incl_raster_map=True,
raster_map_params={
"px_per_m": 2,
"map_size_px": 224,
"offset_frac_xy": (-0.5, 0.0),
},
incl_vector_map=True,
num_workers=0,
verbose=True,
data_dirs={ # Remember to change this to match your filesystem!
"nusc_mini": "~/datasets/nuScenes",
"lyft_sample": "~/datasets/lyft/scenes/sample.zarr",
},
# A dictionary that contains functions that generate our custom data.
# Can be any function and has access to the batch element.
extras={
"closest_lane_point": get_closest_lane_point,
},
)

print(f"# Data Samples: {len(dataset):,}")

dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=False,
collate_fn=dataset.get_collate_fn(),
num_workers=0,
)

# Visualize selected examples
num_plots = 3
batch_idxs = [10876, 10227, 1284]
# batch_idxs = random.sample(range(len(dataset)), num_plots)
batch: AgentBatch = dataset.get_collate_fn(pad_format="right")(
[dataset[i] for i in batch_idxs]
)
assert "closest_lane_point" in batch.extras

for batch_i in range(num_plots):
ax = plot_agent_batch(
batch, batch_idx=batch_i, legend=False, show=False, close=False
)
lane_points = batch.extras["closest_lane_point"][batch_i]
ax.plot(
lane_points[:, 0],
lane_points[:, 1],
"o-",
markersize=3,
label="Lane points",
)
ax.legend(loc="best", frameon=True)
plt.show()
plt.close("all")

# Scan through dataset
batch: AgentBatch
for idx, batch in enumerate(tqdm(dataloader)):
assert "closest_lane_point" in batch.extras
if idx > 50:
break


if __name__ == "__main__":
main()
Loading

0 comments on commit 4351479

Please sign in to comment.