From 88fa82caacb949a85fd7a3e8b9ed00b251aa70fa Mon Sep 17 00:00:00 2001 From: Katze2664 <40237250+Katze2664@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:57:30 +1000 Subject: [PATCH 1/2] (Issue #99) Fix disc_episode_returns off-by-one error self.disc_episode_returns should be calculated before incrementing self.episode_lengths --- mo_gymnasium/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mo_gymnasium/utils.py b/mo_gymnasium/utils.py index def90471..f4c07357 100644 --- a/mo_gymnasium/utils.py +++ b/mo_gymnasium/utils.py @@ -253,13 +253,14 @@ def step(self, action): infos, dict ), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order." self.episode_returns += rewards - self.episode_lengths += 1 # CHANGE: The discounted returns are also computed here self.disc_episode_returns += rewards * np.repeat(self.gamma**self.episode_lengths, self.reward_dim).reshape( self.episode_returns.shape ) + self.episode_lengths += 1 + dones = np.logical_or(terminations, truncations) num_dones = np.sum(dones) if num_dones: From 27ca5d154548c90eab99929ba1446b52e59b81ea Mon Sep 17 00:00:00 2001 From: Katze2664 <40237250+Katze2664@users.noreply.github.com> Date: Tue, 13 Aug 2024 18:41:53 +1000 Subject: [PATCH 2/2] Updated test_mo_record_ep_statistic Corrected off-by-one error in calculation of discounted returns --- tests/test_wrappers.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 9cf42354..61c61c91 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -14,9 +14,9 @@ def go_to_8_3(env): Goes to (8.2, -3) treasure, returns the rewards """ env.reset() - env.step(3) # right - env.step(1) # down - _, rewards, _, _, infos = env.step(1) + env.step(3) # action: right, rewards: [0, -1] + env.step(1) # action: down, rewards: [0, -1] + _, rewards, _, _, infos = env.step(1) # action: down, rewards: [8.2, -1] return rewards, infos @@ -98,10 +98,9 @@ def test_mo_record_ep_statistic(): assert info["episode"]["r"].shape == (2,) assert info["episode"]["dr"].shape == (2,) assert tuple(info["episode"]["r"]) == (np.float32(8.2), np.float32(-3.0)) - assert tuple(np.round(info["episode"]["dr"], 2)) == ( - np.float32(7.48), - np.float32(-2.82), - ) + np.testing.assert_allclose(info["episode"]["dr"], [7.71538, -2.9109], rtol=0, atol=1e-2) + # 0 * 0.97**0 + 0 * 0.97**1 + 8.2 * 0.97**2 == 7.71538 + # -1 * 0.97**0 + -1 * 0.97**1 + -1 * 0.97**2 == -2.9109 assert isinstance(info["episode"]["l"], np.int32) assert info["episode"]["l"] == 3 assert isinstance(info["episode"]["t"], np.float32)