From ef50d1b8383a7dd4b0b8cf4e0ac86f18cc5f287e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Sun, 14 Apr 2024 10:27:36 -0400 Subject: [PATCH] Log environment info - PyTorch implementation --- skrl/trainers/torch/base.py | 25 +++++++++++++++++++++++++ skrl/trainers/torch/parallel.py | 1 + skrl/trainers/torch/sequential.py | 15 +++++++++++++++ skrl/trainers/torch/step.py | 31 +++++++++++++++++++++++++++++-- 4 files changed, 70 insertions(+), 2 deletions(-) diff --git a/skrl/trainers/torch/base.py b/skrl/trainers/torch/base.py index a13d70a4..232e01ee 100644 --- a/skrl/trainers/torch/base.py +++ b/skrl/trainers/torch/base.py @@ -59,6 +59,7 @@ def __init__(self, self.headless = self.cfg.get("headless", False) self.disable_progressbar = self.cfg.get("disable_progressbar", False) self.close_environment_at_exit = self.cfg.get("close_environment_at_exit", True) + self.environment_info = self.cfg.get("environment_info", "episode") self.initial_timestep = 0 @@ -190,6 +191,12 @@ def single_agent_train(self) -> None: timestep=timestep, timesteps=self.timesteps) + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + self.agents.track_data(f"Info / {k}", v.item()) + # post-interaction self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps) @@ -244,6 +251,12 @@ def single_agent_eval(self) -> None: timesteps=self.timesteps) super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + self.agents.track_data(f"Info / {k}", v.item()) + # reset environments if self.env.num_envs > 1: states = next_states @@ -304,6 +317,12 @@ def multi_agent_train(self) -> None: timestep=timestep, timesteps=self.timesteps) + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + self.agents.track_data(f"Info / {k}", v.item()) + # post-interaction self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps) @@ -361,6 +380,12 @@ def multi_agent_eval(self) -> None: timesteps=self.timesteps) super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + self.agents.track_data(f"Info / {k}", v.item()) + # reset environments if not self.env.agents: states, infos = self.env.reset() diff --git a/skrl/trainers/torch/parallel.py b/skrl/trainers/torch/parallel.py index 68b9b9d8..1a086e0e 100644 --- a/skrl/trainers/torch/parallel.py +++ b/skrl/trainers/torch/parallel.py @@ -18,6 +18,7 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination + "environment_info": "episode", # key used to get and log environment info } # [end-config-dict-torch] diff --git a/skrl/trainers/torch/sequential.py b/skrl/trainers/torch/sequential.py index 49952351..67b43ca7 100644 --- a/skrl/trainers/torch/sequential.py +++ b/skrl/trainers/torch/sequential.py @@ -17,6 +17,7 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination + "environment_info": "episode", # key used to get and log environment info } # [end-config-dict-torch] @@ -116,6 +117,13 @@ def train(self) -> None: timestep=timestep, timesteps=self.timesteps) + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + for agent in self.agents: + agent.track_data(f"Info / {k}", v.item()) + # post-interaction for agent in self.agents: agent.post_interaction(timestep=timestep, timesteps=self.timesteps) @@ -184,6 +192,13 @@ def eval(self) -> None: timesteps=self.timesteps) super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps) + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + for agent in self.agents: + agent.track_data(f"Info / {k}", v.item()) + # reset environments if terminated.any() or truncated.any(): states, infos = self.env.reset() diff --git a/skrl/trainers/torch/step.py b/skrl/trainers/torch/step.py index c60476f1..3bfd9acc 100644 --- a/skrl/trainers/torch/step.py +++ b/skrl/trainers/torch/step.py @@ -17,6 +17,7 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination + "environment_info": "episode", # key used to get and log environment info } # [end-config-dict-torch] @@ -129,8 +130,8 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) self.env.render() if self.num_simultaneous_agents == 1: - # record the environments' transitions with torch.no_grad(): + # record the environments' transitions self.agents.record_transition(states=self.states, actions=actions, rewards=rewards, @@ -141,12 +142,18 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) timestep=timestep, timesteps=timesteps) + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + self.agents.track_data(f"Info / {k}", v.item()) + # post-interaction self.agents.post_interaction(timestep=timestep, timesteps=timesteps) else: - # record the environments' transitions with torch.no_grad(): + # record the environments' transitions for agent, scope in zip(self.agents, self.agents_scope): agent.record_transition(states=self.states[scope[0]:scope[1]], actions=actions[scope[0]:scope[1]], @@ -158,6 +165,13 @@ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) timestep=timestep, timesteps=timesteps) + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + for agent in self.agents: + agent.track_data(f"Info / {k}", v.item()) + # post-interaction for agent in self.agents: agent.post_interaction(timestep=timestep, timesteps=timesteps) @@ -242,6 +256,12 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) timesteps=timesteps) super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=timesteps) + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + self.agents.track_data(f"Info / {k}", v.item()) + else: # write data to TensorBoard for agent, scope in zip(self.agents, self.agents_scope): @@ -256,6 +276,13 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) timesteps=timesteps) super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps) + # log environment info + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + for agent in self.agents: + agent.track_data(f"Info / {k}", v.item()) + # reset environments if terminated.any() or truncated.any(): self.states, infos = self.env.reset()