Skip to content

Commit

Permalink
Quick fix for neighbor agent type order
Browse files Browse the repository at this point in the history
  • Loading branch information
BorisIvanovic authored Jan 9, 2023
1 parent 51fa572 commit 68cc897
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/trajdata/data_structures/batch_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,14 @@ def get_nearby_agents(
nearby_agents: List[AgentMetadata] = [
scene_time.agents[idx] for idx in nb_idx if nearby_mask[idx]
]
neighbor_types_np: np.ndarray = neighbor_types[nearby_mask]

if self.max_neighbor_num is not None:
# Pruning nearby_agents and re-creating
# neighbor_types_np with the remaining agents.
nearby_agents = nearby_agents[: self.max_neighbor_num]
neighbor_types_np: np.ndarray = np.array(
[a.type.value for a in nearby_agents]
)

# Doing this here because the argsort above changes the order of agents.
neighbor_types_np: np.ndarray = np.array([a.type.value for a in nearby_agents])

return nearby_agents, neighbor_types_np

Expand Down

0 comments on commit 68cc897

Please sign in to comment.