Skip to content

Commit

Permalink
Changed FeaturesObservation to a Dictionary of named_tuples. Modified…
Browse files Browse the repository at this point in the history
… example to use this.
  • Loading branch information
nandantumu committed Oct 20, 2023
1 parent 15d8f03 commit 6e9c13b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
6 changes: 3 additions & 3 deletions examples/waypoint_follow.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,9 @@ def render_callback(env_renderer):
while not done:
agent_id = env.agent_ids[0]
speed, steer = planner.plan(
obs[agent_id]["pose_x"],
obs[agent_id]["pose_y"],
obs[agent_id]["pose_theta"],
obs[agent_id].pose_x,
obs[agent_id].pose_y,
obs[agent_id].pose_theta,
work["tlad"],
work["vgain"],
)
Expand Down
31 changes: 17 additions & 14 deletions gym/f110_gym/envs/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from abc import abstractmethod
from typing import List
from collections import namedtuple

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -100,7 +101,7 @@ def space(self):

def observe(self):
# state indices
xi, yi, deltai, vxi, yawi, yaw_ratei, slipi = range(
Xi, Yi, DELTAi, VXi, YAWi, YAW_RATEi, SLIPi = range(
7
) # 7 largest state size (ST Model)

Expand All @@ -124,10 +125,10 @@ def observe(self):
lap_count = self.env.lap_counts[i]
collision = self.env.sim.collisions[i]

x, y, theta = agent.state[xi], agent.state[yi], agent.state[yawi]
vx, vy = agent.state[vxi], 0.0
x, y, theta = agent.state[Xi], agent.state[Yi], agent.state[YAWi]
vx, vy = agent.state[VXi], 0.0
angvel = (
0.0 if len(agent.state) < 7 else agent.state[yaw_ratei]
0.0 if len(agent.state) < 7 else agent.state[YAW_RATEi]
) # set 0.0 when KST Model

observations["scans"].append(agent_scan)
Expand Down Expand Up @@ -155,6 +156,7 @@ class FeaturesObservation(Observation):
def __init__(self, env, features: List[str]):
super().__init__(env)
self.features = features
self.type = namedtuple("Observation", self.features)

def space(self):
scan_size = self.env.sim.agents[0].scan_simulator.num_beams
Expand Down Expand Up @@ -201,8 +203,8 @@ def space(self):
low=0.0, high=large_num, shape=(), dtype=np.float32
),
}
complete_space[agent_id] = gym.spaces.Dict(
{k: agent_dict[k] for k in self.features}
complete_space[agent_id] = gym.spaces.Tuple(
[agent_dict[k] for k in self.features]
)

obs_space = gym.spaces.Dict(complete_space)
Expand Down Expand Up @@ -249,16 +251,17 @@ def observe(self):
}

# add agent's observation to multi-agent observation
obs[agent_id] = {k: agent_obs[k] for k in self.features}
obs[agent_id] = [agent_obs[k] for k in self.features]

# cast to match observation space
for key in obs[agent_id].keys():
if isinstance(obs[agent_id][key], np.ndarray) or isinstance(
obs[agent_id][key], list
):
obs[agent_id][key] = np.array(obs[agent_id][key], dtype=np.float32)
if isinstance(obs[agent_id][key], float):
obs[agent_id][key] = np.float32(obs[agent_id][key])
for i, item in enumerate(obs[agent_id]):
if isinstance(item, np.ndarray) or \
isinstance(item, list):
obs[agent_id][i] = np.array(item, dtype=np.float32)
if isinstance(item, float):
obs[agent_id][i] = np.float32(item)

obs[agent_id] = self.type(*obs[agent_id]) # namedtuple

return obs

Expand Down

0 comments on commit 6e9c13b

Please sign in to comment.