Skip to content

Commit

Permalink
Tests are passing
Browse files Browse the repository at this point in the history
  • Loading branch information
ffelten committed May 23, 2024
1 parent 062849a commit 7fbaf38
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
10 changes: 5 additions & 5 deletions mo_gymnasium/wrappers/vector/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Vector wrappers."""
import time
from copy import deepcopy
from typing import Any, Iterator
from typing import Any, Iterator, Tuple, Dict

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
dtype=np.float32,
)

def step(self, actions: ActType) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
def step(self, actions: ActType) -> Tuple[ObsType, ArrayType, ArrayType, ArrayType, Dict[str, Any]]:
"""Steps through each of the environments returning the batched results.
Returns:
Expand Down Expand Up @@ -187,7 +187,7 @@ def reset(self, **kwargs):

return obs, info

def step(self, actions: ActType) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
def step(self, actions: ActType) -> Tuple[ObsType, ArrayType, ArrayType, ArrayType, Dict[str, Any]]:
"""Steps through the environment, recording the episode statistics."""
(
observations,
Expand Down Expand Up @@ -229,8 +229,8 @@ def step(self, actions: ActType) -> tuple[ObsType, ArrayType, ArrayType, ArrayTy

episode_time_length = np.round(time.perf_counter() - self.episode_start_times, 6)
infos[self._stats_key] = {
"r": np.where(dones, self.episode_returns, np.zeros(self.rewards_shape, dtype=np.float32)),
"dr": np.where(dones, self.disc_episode_returns, np.zeros(self.rewards_shape, dtype=np.float32)),
"r": episode_return,
"dr": disc_episode_return,
"l": np.where(dones, self.episode_lengths, 0),
"t": np.where(dones, episode_time_length, 0.0),
}
Expand Down
6 changes: 1 addition & 5 deletions tests/test_vector_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import numpy as np

import mo_gymnasium as mo_gym
from mo_gymnasium.wrappers.vector import (
MORecordEpisodeStatistics,
MOSyncVectorEnv,
)
from mo_gymnasium.wrappers.vector import MORecordEpisodeStatistics, MOSyncVectorEnv


def test_mo_sync_wrapper():
def make_env(env_id):
def thunk():
env = mo_gym.make(env_id)
env = MORecordEpisodeStatistics(env, gamma=0.97)
return env

return thunk
Expand Down

0 comments on commit 7fbaf38

Please sign in to comment.