Skip to content

Commit

Permalink
Merge pull request #151 from Toni-SM/log_episode_info
Browse files Browse the repository at this point in the history
Log environment info - PyTorch implementation
  • Loading branch information
Toni-SM authored May 2, 2024
2 parents a364aef + ef50d1b commit c96d722
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
25 changes: 25 additions & 0 deletions skrl/trainers/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self,
self.headless = self.cfg.get("headless", False)
self.disable_progressbar = self.cfg.get("disable_progressbar", False)
self.close_environment_at_exit = self.cfg.get("close_environment_at_exit", True)
self.environment_info = self.cfg.get("environment_info", "episode")

self.initial_timestep = 0

Expand Down Expand Up @@ -190,6 +191,12 @@ def single_agent_train(self) -> None:
timestep=timestep,
timesteps=self.timesteps)

# log environment info
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
self.agents.track_data(f"Info / {k}", v.item())

# post-interaction
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)

Expand Down Expand Up @@ -244,6 +251,12 @@ def single_agent_eval(self) -> None:
timesteps=self.timesteps)
super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps)

# log environment info
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
self.agents.track_data(f"Info / {k}", v.item())

# reset environments
if self.env.num_envs > 1:
states = next_states
Expand Down Expand Up @@ -304,6 +317,12 @@ def multi_agent_train(self) -> None:
timestep=timestep,
timesteps=self.timesteps)

# log environment info
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
self.agents.track_data(f"Info / {k}", v.item())

# post-interaction
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)

Expand Down Expand Up @@ -361,6 +380,12 @@ def multi_agent_eval(self) -> None:
timesteps=self.timesteps)
super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps)

# log environment info
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
self.agents.track_data(f"Info / {k}", v.item())

# reset environments
if not self.env.agents:
states, infos = self.env.reset()
Expand Down
1 change: 1 addition & 0 deletions skrl/trainers/torch/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
"environment_info": "episode", # key used to get and log environment info
}
# [end-config-dict-torch]

Expand Down
15 changes: 15 additions & 0 deletions skrl/trainers/torch/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
"environment_info": "episode", # key used to get and log environment info
}
# [end-config-dict-torch]

Expand Down Expand Up @@ -116,6 +117,13 @@ def train(self) -> None:
timestep=timestep,
timesteps=self.timesteps)

# log environment info
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
for agent in self.agents:
agent.track_data(f"Info / {k}", v.item())

# post-interaction
for agent in self.agents:
agent.post_interaction(timestep=timestep, timesteps=self.timesteps)
Expand Down Expand Up @@ -184,6 +192,13 @@ def eval(self) -> None:
timesteps=self.timesteps)
super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps)

# log environment info
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
for agent in self.agents:
agent.track_data(f"Info / {k}", v.item())

# reset environments
if terminated.any() or truncated.any():
states, infos = self.env.reset()
Expand Down
31 changes: 29 additions & 2 deletions skrl/trainers/torch/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
"environment_info": "episode", # key used to get and log environment info
}
# [end-config-dict-torch]

Expand Down Expand Up @@ -129,8 +130,8 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
self.env.render()

if self.num_simultaneous_agents == 1:
# record the environments' transitions
with torch.no_grad():
# record the environments' transitions
self.agents.record_transition(states=self.states,
actions=actions,
rewards=rewards,
Expand All @@ -141,12 +142,18 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timestep=timestep,
timesteps=timesteps)

# log environment info
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
self.agents.track_data(f"Info / {k}", v.item())

# post-interaction
self.agents.post_interaction(timestep=timestep, timesteps=timesteps)

else:
# record the environments' transitions
with torch.no_grad():
# record the environments' transitions
for agent, scope in zip(self.agents, self.agents_scope):
agent.record_transition(states=self.states[scope[0]:scope[1]],
actions=actions[scope[0]:scope[1]],
Expand All @@ -158,6 +165,13 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timestep=timestep,
timesteps=timesteps)

# log environment info
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
for agent in self.agents:
agent.track_data(f"Info / {k}", v.item())

# post-interaction
for agent in self.agents:
agent.post_interaction(timestep=timestep, timesteps=timesteps)
Expand Down Expand Up @@ -242,6 +256,12 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timesteps=timesteps)
super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=timesteps)

# log environment info
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
self.agents.track_data(f"Info / {k}", v.item())

else:
# write data to TensorBoard
for agent, scope in zip(self.agents, self.agents_scope):
Expand All @@ -256,6 +276,13 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None)
timesteps=timesteps)
super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps)

# log environment info
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
for agent in self.agents:
agent.track_data(f"Info / {k}", v.item())

# reset environments
if terminated.any() or truncated.any():
self.states, infos = self.env.reset()
Expand Down

0 comments on commit c96d722

Please sign in to comment.