Skip to content

Commit

Permalink
Interactive visualizations with Bokeh (+ Animations), the SceneTimeBa…
Browse files Browse the repository at this point in the history
…tcher, and bugfixes.
  • Loading branch information
BorisIvanovic committed Apr 21, 2023
1 parent 34048dc commit 5a0567b
Show file tree
Hide file tree
Showing 16 changed files with 1,775 additions and 15 deletions.
38 changes: 38 additions & 0 deletions DATASETS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,44 @@ It should look like this after downloading:

**Note**: At a minimum, only the annotations need to be downloaded (not the raw radar/camera/lidar/etc data).

## nuPlan
Nothing special needs to be done for the nuPlan dataset, simply download v1.1 as per [the instructions in the devkit documentation](https://nuplan-devkit.readthedocs.io/en/latest/dataset_setup.html).

It should look like this after downloading:
```
/path/to/nuPlan/
└── dataset
├── maps
│ ├── nuplan-maps-v1.0.json
│ ├── sg-one-north
│ │ └── 9.17.1964
│ │ └── map.gpkg
│ ├── us-ma-boston
│ │ └── 9.12.1817
│ │ └── map.gpkg
│ ├── us-nv-las-vegas-strip
│ │ └── 9.15.1915
│ │ ├── drivable_area.npy.npz
│ │ ├── Intensity.npy.npz
│ │ └── map.gpkg
│ └── us-pa-pittsburgh-hazelwood
│ └── 9.17.1937
│ └── map.gpkg
└── nuplan-v1.1
├── mini
│ ├── 2021.05.12.22.00.38_veh-35_01008_01518.db
│ ├── 2021.06.09.17.23.18_veh-38_00773_01140.db
│ ├── ...
│ └── 2021.10.11.08.31.07_veh-50_01750_01948.db
└── trainval
├── 2021.05.12.22.00.38_veh-35_01008_01518.db
├── 2021.06.09.17.23.18_veh-38_00773_01140.db
├── ...
└── 2021.10.11.08.31.07_veh-50_01750_01948.db
```

**Note**: Not all dataset splits need to be downloaded. For example, you can download only the nuPlan Mini Split in case you only need a small sample dataset.

## Lyft Level 5
Nothing special needs to be done for the Lyft Level 5 dataset, simply install it as per [the instructions on the dataset website](https://woven-planet.github.io/l5kit/dataset.html).

Expand Down
53 changes: 53 additions & 0 deletions examples/scenetimebatcher_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from collections import defaultdict

from torch.utils.data import DataLoader
from tqdm import tqdm

from trajdata import AgentBatch, AgentType, UnifiedDataset
from trajdata.utils.batch_utils import SceneTimeBatcher
from trajdata.visualization.vis import plot_agent_batch_all


def main():
"""
Here, we use SceneTimeBatcher to loop through an
Agent-centric dataset with batches grouped by scene and timestep
"""
dataset = UnifiedDataset(
desired_data=["nusc_mini-mini_train"],
centric="agent",
desired_dt=0.1,
history_sec=(3.2, 3.2),
future_sec=(4.8, 4.8),
only_predict=[AgentType.VEHICLE],
agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=False,
incl_raster_map=True,
raster_map_params={
"px_per_m": 2,
"map_size_px": 224,
"offset_frac_xy": (-0.5, 0.0),
},
num_workers=0,
verbose=True,
data_dirs={ # Remember to change this to match your filesystem!
"nusc_mini": "~/datasets/nuScenes",
},
)

print(f"# Data Samples: {len(dataset):,}")

dataloader = DataLoader(
dataset,
batch_sampler=SceneTimeBatcher(dataset),
collate_fn=dataset.get_collate_fn(),
num_workers=4,
)

batch: AgentBatch
for batch in tqdm(dataloader):
plot_agent_batch_all(batch)


if __name__ == "__main__":
main()
68 changes: 68 additions & 0 deletions examples/visualization_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from collections import defaultdict

from torch.utils.data import DataLoader
from tqdm import tqdm

from trajdata import AgentBatch, AgentType, UnifiedDataset
from trajdata.visualization.interactive_animation import (
InteractiveAnimation,
animate_agent_batch_interactive,
)
from trajdata.visualization.interactive_vis import plot_agent_batch_interactive
from trajdata.visualization.vis import plot_agent_batch


def main():
dataset = UnifiedDataset(
desired_data=["nusc_mini"],
centric="agent",
desired_dt=0.1,
# history_sec=(3.2, 3.2),
# future_sec=(4.8, 4.8),
only_predict=[AgentType.VEHICLE],
state_format="x,y,z,xd,yd,h",
obs_format="x,y,z,xd,yd,s,c",
# agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=False,
incl_raster_map=True,
raster_map_params={
"px_per_m": 2,
"map_size_px": 224,
"offset_frac_xy": (-0.5, 0.0),
},
num_workers=4,
verbose=True,
data_dirs={ # Remember to change this to match your filesystem!
"nusc_mini": "~/datasets/nuScenes",
"lyft_sample": "~/datasets/lyft/scenes/sample.zarr",
"nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1",
},
)

print(f"# Data Samples: {len(dataset):,}")

dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
collate_fn=dataset.get_collate_fn(),
num_workers=0,
)

batch: AgentBatch
for batch in tqdm(dataloader):
plot_agent_batch_interactive(batch, batch_idx=0, cache_path=dataset.cache_path)
plot_agent_batch(batch, batch_idx=0)

animation = InteractiveAnimation(
animate_agent_batch_interactive,
batch=batch,
batch_idx=0,
cache_path=dataset.cache_path,
)
animation.show()
# break


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pyarrow
torch
zarr
kornia
bokeh

# nuScenes devkit
nuscenes-devkit==1.1.9
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = trajdata
version = 1.2.1
version = 1.3.0
author = Boris Ivanovic
author_email = [email protected]
description = A unified interface to many trajectory forecasting datasets.
Expand Down Expand Up @@ -30,6 +30,7 @@ install_requires =
zarr>=2.11.0
kornia>=0.6.4
seaborn>=0.12
bokeh>=3.0.3

[options.packages.find]
where = src
Expand Down
2 changes: 1 addition & 1 deletion src/trajdata/data_structures/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ class StateTensor(State, Tensor):
_FUNCS = {
"cos": torch.cos,
"sin": torch.sin,
"arctan": torch.arctan2,
"arctan": torch.atan2,
"lon_component": lon_component,
"lat_component": lat_component,
"x_component": x_component,
Expand Down
2 changes: 1 addition & 1 deletion src/trajdata/maps/map_kdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
from trajdata.maps.vec_map import VectorMap

from typing import Optional
from typing import Optional, Tuple

import numpy as np
from scipy.spatial import KDTree
Expand Down
10 changes: 5 additions & 5 deletions src/trajdata/utils/arr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,23 @@ def transform_coords_2d_np(
"""
Args:
coords (np.ndarray): [..., 2] coordinates
offset (Optional[np.ndarray], optional): [..., 2] offset to transalte. Defaults to None.
offset (Optional[np.ndarray], optional): [..., 2] offset to translate. Defaults to None.
angle (Optional[np.ndarray], optional): [...] angle to rotate by. Defaults to None.
rot_mat (Optional[np.ndarray], optional): [..., 2,2] rotation matrix to apply. Defaults to None.
If rot_mat is given, angle is ignored.
Returns:
np.ndarray: transformed coords
"""
if offset is not None:
coords = coords + offset

if rot_mat is None and angle is not None:
rot_mat = rotation_matrix(angle)

if rot_mat is not None:
coords = np.einsum("...ij,...j->...i", rot_mat, coords)

if offset is not None:
coords += offset

return coords


Expand Down Expand Up @@ -240,7 +240,7 @@ def transform_xyh_np(xyh: np.ndarray, tf_mat: np.ndarray) -> np.ndarray:
tf_mat (np.ndarray): shape [...,3,3]
"""
transformed_xy = transform_coords_np(xyh[..., :2], tf_mat)
transformed_angles = transform_angles_np(xyh[..., 3], tf_mat)
transformed_angles = transform_angles_np(xyh[..., 2], tf_mat)
return np.concatenate([transformed_xy, transformed_angles[..., None]], axis=-1)


Expand Down
88 changes: 87 additions & 1 deletion src/trajdata/utils/batch_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,104 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

import numpy as np
from torch.utils.data import Sampler

from trajdata import UnifiedDataset
from trajdata.data_structures import (
AgentBatch,
AgentBatchElement,
AgentDataIndex,
AgentType,
SceneBatchElement,
SceneTimeAgent,
)
from trajdata.data_structures.collation import agent_collate_fn


class SceneTimeBatcher(Sampler):
_agent_data_index: AgentDataIndex
_agent_idx: int

def __init__(
self, agent_centric_dataset: UnifiedDataset, agent_idx_to_follow: int = 0
) -> None:
"""
Returns a sampler (to be used in a torch.utils.data.DataLoader)
which works with an agent-centric UnifiedDataset, yielding
batches consisting of whole scenes (AgentBatchElements for all agents
in a particular scene at a particular time)
Args:
agent_centric_dataset (UnifiedDataset)
agent_idx_to_follow (int): index of agent to return batches for. Defaults to 0,
meaning we include all scene frames where the ego agent appears, which
usually covers the entire dataset.
"""
super().__init__(agent_centric_dataset)
self._agent_data_index = agent_centric_dataset._data_index
self._agent_idx = agent_idx_to_follow
self._cumulative_lengths = np.concatenate(
[
[0],
np.cumsum(
[
cumulative_scene_length[self._agent_idx + 1]
- cumulative_scene_length[self._agent_idx]
for cumulative_scene_length in self._agent_data_index._cumulative_scene_lengths
]
),
]
)

def __len__(self):
return self._cumulative_lengths[-1]

def __iter__(self) -> Iterator[int]:
for idx in range(len(self)):
# TODO(apoorvas) May not need to do this search, since we only support an iterable style access?
scene_idx: int = (
np.searchsorted(self._cumulative_lengths, idx, side="right").item() - 1
)

# offset into dataset index to reach current scene
scene_offset = self._agent_data_index._cumulative_lengths[scene_idx].item()

# how far along we are in the current scene
scene_elem_index = idx - self._cumulative_lengths[scene_idx].item()

# convert to scene-timestep for the tracked agent
scene_ts = (
scene_elem_index
+ self._agent_data_index._agent_times[scene_idx][self._agent_idx, 0]
)

# build a set of indices into the agent-centric dataset for all agents that exist at this scene and timestep
indices = []
for agent_idx, agent_times in enumerate(
self._agent_data_index._agent_times[scene_idx]
):
if scene_ts > agent_times[1]:
# we are past the last timestep for this agent (times are inclusive)
continue
agent_offset = scene_ts - agent_times[0]
if agent_offset < 0:
# this agent hasn't entered the scene yet
continue

# compute index into original dataset, first into scene, then into this agent's part in scene, and then the offset
index_to_add = (
scene_offset
+ self._agent_data_index._cumulative_scene_lengths[scene_idx][
agent_idx
]
+ agent_offset
)
indices.append(index_to_add)

yield indices


def convert_to_agent_batch(
scene_batch_element: SceneBatchElement,
only_types: Optional[List[AgentType]] = None,
Expand Down
3 changes: 0 additions & 3 deletions src/trajdata/utils/map_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import TYPE_CHECKING

from tqdm import tqdm

if TYPE_CHECKING:
from trajdata.maps import map_kdtree, vec_map

Expand All @@ -14,7 +12,6 @@
import numpy as np
from scipy.stats import circmean

import trajdata.maps.vec_map_elements as vec_map_elems
import trajdata.proto.vectorized_map_pb2 as map_proto
from trajdata.utils import arr_utils

Expand Down
Loading

0 comments on commit 5a0567b

Please sign in to comment.