From e80c032516c0d314b966b86ecfc5c089eb8e447c Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Wed, 10 Apr 2024 11:53:27 +0200 Subject: [PATCH] Fix the `done` flag setting in one-step episodes (#22) --- cogment_lab/envs/gymnasium.py | 2 +- cogment_lab/utils/trial_utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cogment_lab/envs/gymnasium.py b/cogment_lab/envs/gymnasium.py index ba06473..a1a4938 100644 --- a/cogment_lab/envs/gymnasium.py +++ b/cogment_lab/envs/gymnasium.py @@ -198,7 +198,7 @@ async def reset(self, state: State): logging.info("Resetting environment") - obs, _info = state.env.reset(seed=state.session_cfg.seed, options=state.session_cfg.reset_args) # THIS + obs, _info = state.env.reset(seed=state.session_cfg.seed, options=state.session_cfg.reset_args) state.observation_space = state.session_helper.get_observation_space(self.actor_name) frame = state.env.render() if state.session_cfg.render else None diff --git a/cogment_lab/utils/trial_utils.py b/cogment_lab/utils/trial_utils.py index 69d0c4c..576a63e 100644 --- a/cogment_lab/utils/trial_utils.py +++ b/cogment_lab/utils/trial_utils.py @@ -179,7 +179,8 @@ def extract_data_from_samples( data.rewards = initialize_buffer(None, sample_count - 1) # type: ignore if "done" in fields: data.done = initialize_buffer(None, sample_count - 1) # type: ignore - data.done[-1] = True # type: ignore + if sample_count > 1: + data.done[-1] = True # type: ignore if "next_observations" in fields: data.next_observations = initialize_buffer(observation_space, sample_count - 1) if "last_observation" in fields: