Skip to content

Commit

Permalink
Adding tensor/array classes with state information, visualization imp…
Browse files Browse the repository at this point in the history
…rovements, and more for v1.2.0!
  • Loading branch information
BorisIvanovic committed Feb 1, 2023
1 parent 5e171d1 commit b74ff32
Show file tree
Hide file tree
Showing 30 changed files with 1,663 additions and 575 deletions.
2 changes: 1 addition & 1 deletion examples/custom_batch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def custom_goal_location(
batch_elem: Union[AgentBatchElement, SceneBatchElement]
) -> np.ndarray:
# simply access existing element attributes
return batch_elem.agent_future_np[:, :2]
return batch_elem.agent_future_np.position


def custom_min_distance_from_others(
Expand Down
59 changes: 36 additions & 23 deletions examples/lane_query_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
This is an example of how to extend a batch with lane information
"""

import random
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from trajdata import AgentBatch, AgentType, UnifiedDataset
from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement
from trajdata.data_structures.batch_element import AgentBatchElement
from trajdata.maps import VectorMap
from trajdata.maps.vec_map_elements import RoadLane
from trajdata.utils.arr_utils import batch_nd_transform_points_np
from trajdata.utils.arr_utils import transform_angles_np, transform_coords_np
from trajdata.utils.state_utils import transform_state_np_2d
from trajdata.visualization.vis import plot_agent_batch


Expand All @@ -23,38 +25,43 @@ def get_closest_lane_point(element: AgentBatchElement) -> np.ndarray:
# Transform from agent coordinate frame to world coordinate frame.
vector_map: VectorMap = element.vec_map
world_from_agent_tf = np.linalg.inv(element.agent_from_world_tf)
agent_future_xy_world = batch_nd_transform_points_np(
element.agent_future_np[:, :2], world_from_agent_tf
)
agent_future_xyzh_world = transform_state_np_2d(
element.agent_future_np, world_from_agent_tf
).as_format("x,y,z,h")

# Use cached kdtree to find closest lane point
lane_points_world = []
for xy_world in agent_future_xy_world:
point_xyz = np.array([[xy_world[0], xy_world[1], 0.0]])
closest_lane: RoadLane = vector_map.get_closest_lane(point_xyz.squeeze(axis=0))
lane_points_world.append(closest_lane.center.project_onto(point_xyz))

lane_points_world = np.concatenate(lane_points_world, axis=0)

# Transform lane points to agent coordinate frame
lane_points = batch_nd_transform_points_np(
lane_points_world[:, :2], element.agent_from_world_tf
)

lane_points = []
for point_xyzh in agent_future_xyzh_world:
possible_lanes = vector_map.get_current_lane(point_xyzh)
xyzh_on_lane = np.full((1, 4), np.nan)
if len(possible_lanes) > 0:
xyzh_on_lane = possible_lanes[0].center.project_onto(point_xyzh[None, :3])
xyzh_on_lane[:, :2] = transform_coords_np(
xyzh_on_lane[:, :2], element.agent_from_world_tf
)
xyzh_on_lane[:, -1] = transform_angles_np(
xyzh_on_lane[:, -1], element.agent_from_world_tf
)

lane_points.append(xyzh_on_lane)

lane_points = np.concatenate(lane_points, axis=0)
return lane_points


def main():
dataset = UnifiedDataset(
desired_data=[
"nusc_mini-mini_train",
# "nusc_mini-mini_train",
"lyft_sample-mini_val",
],
centric="agent",
desired_dt=0.1,
history_sec=(3.2, 3.2),
future_sec=(4.8, 4.8),
only_types=[AgentType.VEHICLE],
state_format="x,y,z,xd,yd,xdd,ydd,h",
obs_format="x,y,z,xd,yd,xdd,ydd,s,c",
agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=False,
incl_raster_map=True,
Expand All @@ -67,7 +74,7 @@ def main():
num_workers=0,
verbose=True,
data_dirs={ # Remember to change this to match your filesystem!
"nusc_mini": "~/datasets/nuScenes",
# "nusc_mini": "~/datasets/nuScenes",
"lyft_sample": "~/datasets/lyft/scenes/sample.zarr",
},
# A dictionary that contains functions that generate our custom data.
Expand All @@ -89,8 +96,8 @@ def main():

# Visualize selected examples
num_plots = 3
batch_idxs = [10876, 10227, 1284]
# batch_idxs = random.sample(range(len(dataset)), num_plots)
# batch_idxs = [10876, 10227, 1284]
batch_idxs = random.sample(range(len(dataset)), num_plots)
batch: AgentBatch = dataset.get_collate_fn(pad_format="right")(
[dataset[i] for i in batch_idxs]
)
Expand All @@ -101,14 +108,20 @@ def main():
batch, batch_idx=batch_i, legend=False, show=False, close=False
)
lane_points = batch.extras["closest_lane_point"][batch_i]
lane_points = lane_points[
torch.logical_not(torch.any(torch.isnan(lane_points), dim=1)), :
].numpy()

ax.plot(
lane_points[:, 0],
lane_points[:, 1],
"o-",
markersize=3,
label="Lane points",
)

ax.legend(loc="best", frameon=True)

plt.show()
plt.close("all")

Expand Down
96 changes: 96 additions & 0 deletions examples/state_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from collections import defaultdict

import numpy as np
from torch.utils.data import DataLoader

from trajdata import AgentBatch, AgentType, UnifiedDataset
from trajdata.data_structures.state import StateArray, StateTensor


def main():
dataset = UnifiedDataset(
desired_data=["lyft_sample-mini_val"],
centric="agent",
desired_dt=0.1,
history_sec=(3.2, 3.2),
future_sec=(4.8, 4.8),
only_predict=[AgentType.VEHICLE],
state_format="x,y,z,xd,yd,xdd,ydd,h",
agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=False,
incl_raster_map=True,
raster_map_params={
"px_per_m": 2,
"map_size_px": 224,
"offset_frac_xy": (-0.5, 0.0),
},
num_workers=0,
verbose=True,
data_dirs={ # Remember to change this to match your filesystem!
"lyft_sample": "~/datasets/lyft_sample/scenes/sample.zarr",
},
)

print(f"# Data Samples: {len(dataset):,}")

dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
collate_fn=dataset.get_collate_fn(),
num_workers=4,
)

# batchElement has properties that correspond to agent states
ego_state = dataset[0].curr_agent_state_np.copy()
print(ego_state)

# StateArray types offer easy conversion to whatever format you want your state
# e.g. we want x,y position and cos/sin heading:
print(ego_state.as_format("x,y,c,s"))

# We can also access elements via properties
print(ego_state.position3d)
print(ego_state.velocity)

# We can set elements of states via properties. E.g., let's reset the heading to 0
ego_state.heading = 0
print(ego_state)

# We can request elements that aren't directly stored in the state, e.g. cos/sin heading
print(ego_state.heading_vector)

# However, we can't set properties that aren't directly stored in the state tensor
try:
ego_state.heading_vector = 0.0
except AttributeError as e:
print(e)

# Finally, StateArrays are just np.ndarrays under the hood, and any normal np operation
# should convert them to a normal array
print(ego_state**2)

# To convert an np.array into a StateArray, we just need to specify what format it is
# Note that StateArrays can have an arbitrary number of batch elems
print(StateArray.from_array(np.random.randn(1, 2, 3), "x,y,z"))

# Analagous to StateArray wrapping np.arrays, the StateTensor class gives the same
# functionality to torch.Tensors
batch: AgentBatch = next(iter(dataloader))
ego_state_t: StateTensor = batch.curr_agent_state

print(ego_state_t.as_format("x,y,c,s"))
print(ego_state_t.position3d)
print(ego_state_t.velocity)
ego_state_t.heading = 0
print(ego_state_t)
print(ego_state_t.heading_vector)

# Furthermore, we can use the from_numpy() and numpy() methods to convert to and from
# StateTensors with the same format
print(ego_state_t.numpy())
print(StateTensor.from_numpy(ego_state))


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = trajdata
version = 1.1.1
version = 1.2.0
author = Boris Ivanovic
author_email = [email protected]
description = A unified interface to many trajectory forecasting datasets.
Expand Down
Loading

0 comments on commit b74ff32

Please sign in to comment.