Skip to content

Commit

Permalink
Update annotation in wrapper test files
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Aug 27, 2024
1 parent 48459c4 commit 9fdb103
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tests/jax/test_jax_wrapper_isaacgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def close(self) -> None:

@pytest.mark.parametrize("backend", ["jax", "numpy"])
@pytest.mark.parametrize("num_states", [0, 5])
def test_env(capsys: pytest.CaptureFixture, backend: str, num_states):
def test_env(capsys: pytest.CaptureFixture, backend: str, num_states: str):
config.jax.backend = backend
Array = jax.Array if backend == "jax" else np.ndarray

Expand Down
2 changes: 1 addition & 1 deletion tests/jax/test_jax_wrapper_isaaclab.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def close(self) -> None:

@pytest.mark.parametrize("backend", ["jax", "numpy"])
@pytest.mark.parametrize("num_states", [0, 5])
def test_env(capsys: pytest.CaptureFixture, backend: str, num_states):
def test_env(capsys: pytest.CaptureFixture, backend: str, num_states: int):
config.jax.backend = backend
Array = jax.Array if backend == "jax" else np.ndarray

Expand Down
2 changes: 1 addition & 1 deletion tests/jax/test_jax_wrapper_omniisaacgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def close(self) -> None:

@pytest.mark.parametrize("backend", ["jax", "numpy"])
@pytest.mark.parametrize("num_states", [0, 5])
def test_env(capsys: pytest.CaptureFixture, backend: str, num_states):
def test_env(capsys: pytest.CaptureFixture, backend: str, num_states: int):
config.jax.backend = backend
Array = jax.Array if backend == "jax" else np.ndarray

Expand Down
5 changes: 3 additions & 2 deletions tests/torch/test_torch_wrapper_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def test_env(capsys: pytest.CaptureFixture):

env.close()

def test_vectorized_env(capsys: pytest.CaptureFixture):
@pytest.mark.parametrize("vectorization_mode", ["async", "sync"])
def test_vectorized_env(capsys: pytest.CaptureFixture, vectorization_mode: str):
num_envs = 10
action = torch.ones((num_envs, 1))

# load wrap the environment
original_env = gym.vector.make("Pendulum-v1", num_envs=num_envs, asynchronous=False)
original_env = gym.vector.make("Pendulum-v1", num_envs=num_envs, asynchronous=(vectorization_mode == "async"))
env = wrap_env(original_env, "auto")
assert isinstance(env, GymWrapper)
env = wrap_env(original_env, "gym")
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/test_torch_wrapper_isaacgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def close(self) -> None:


@pytest.mark.parametrize("num_states", [0, 5])
def test_env(capsys: pytest.CaptureFixture, num_states):
def test_env(capsys: pytest.CaptureFixture, num_states: int):
num_envs = 10
action = torch.ones((num_envs, 1))

Expand Down
2 changes: 1 addition & 1 deletion tests/torch/test_torch_wrapper_isaaclab.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def close(self) -> None:


@pytest.mark.parametrize("num_states", [0, 5])
def test_env(capsys: pytest.CaptureFixture, num_states):
def test_env(capsys: pytest.CaptureFixture, num_states: int):
num_envs = 10
action = torch.ones((num_envs, 1))

Expand Down
2 changes: 1 addition & 1 deletion tests/torch/test_torch_wrapper_omniisaacgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def close(self) -> None:


@pytest.mark.parametrize("num_states", [0, 5])
def test_env(capsys: pytest.CaptureFixture, num_states):
def test_env(capsys: pytest.CaptureFixture, num_states: int):
num_envs = 10
action = torch.ones((num_envs, 1))

Expand Down

0 comments on commit 9fdb103

Please sign in to comment.