Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into spec-with-neg-shape
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed May 31, 2024
2 parents 3e579c1 + 765952a commit ed7267a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
28 changes: 27 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
RenameTransform,
)
from torchrl.envs.batched_envs import SerialEnv
from torchrl.envs.libs.brax import _has_brax, BraxEnv
from torchrl.envs.libs.brax import _has_brax, BraxEnv, BraxWrapper
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper
from torchrl.envs.libs.envpool import _has_envpool, MultiThreadedEnvWrapper
from torchrl.envs.libs.gym import (
Expand Down Expand Up @@ -1943,6 +1943,32 @@ def test_env_device(self, env_name, frame_skip, transformed_out, device):
@pytest.mark.skipif(not _has_brax, reason="brax not installed")
@pytest.mark.parametrize("envname", ["fast"])
class TestBrax:
@pytest.mark.parametrize("requires_grad", [False, True])
def test_brax_constructor(self, envname, requires_grad):
env0 = BraxEnv(envname, requires_grad=requires_grad)
env1 = BraxWrapper(env0._env, requires_grad=requires_grad)

env0.set_seed(0)
torch.manual_seed(0)
init = env0.reset()
if requires_grad:
init = init.apply(
lambda x: x.requires_grad_(True) if x.is_floating_point() else x
)
r0 = env0.rollout(10, tensordict=init, auto_reset=False)
assert r0.requires_grad == requires_grad

env1.set_seed(0)
torch.manual_seed(0)
init = env1.reset()
if requires_grad:
init = init.apply(
lambda x: x.requires_grad_(True) if x.is_floating_point() else x
)
r1 = env1.rollout(10, tensordict=init, auto_reset=False)
assert r1.requires_grad == requires_grad
assert_allclose_td(r0.data, r1.data)

def test_brax_seeding(self, envname):
final_seed = []
tdreset = []
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,10 +867,10 @@ def max_size_along_dim0(data_shape):
return (self.max_size, *data_shape)

if is_tensor_collection(data):
out = data.expand(max_size_along_dim0(data.shape))
out = data.to(self.device)
out = out.expand(max_size_along_dim0(data.shape))
out = out.clone()
out = out.zero_()
out = out.to(self.device)
else:
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = tree_map(
Expand Down
13 changes: 8 additions & 5 deletions torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, Optional, Union

import torch
from packaging import version
from tensordict import TensorDict, TensorDictBase

from torchrl.data.tensor_specs import (
Expand All @@ -15,9 +16,6 @@
UnboundedContinuousTensorSpec,
)
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.utils import _classproperty

_has_brax = importlib.util.find_spec("brax") is not None
from torchrl.envs.libs.jax_utils import (
_extract_spec,
_ndarray_to_tensor,
Expand All @@ -27,6 +25,9 @@
_tree_flatten,
_tree_reshape,
)
from torchrl.envs.utils import _classproperty

_has_brax = importlib.util.find_spec("brax") is not None


def _get_envs():
Expand Down Expand Up @@ -204,12 +205,14 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs):

def _check_kwargs(self, kwargs: Dict):
brax = self.lib
if version.parse(brax.__version__) < version.parse("0.10.4"):
raise ImportError("Brax v0.10.4 or greater is required.")

if "env" not in kwargs:
raise TypeError("Could not find environment key 'env' in kwargs.")
env = kwargs["env"]
if not isinstance(env, brax.envs.env.Env):
raise TypeError("env is not of type 'brax.envs.env.Env'.")
if not isinstance(env, brax.envs.Env):
raise TypeError("env is not of type 'brax.envs.Env'.")

def _build_env(
self,
Expand Down

0 comments on commit ed7267a

Please sign in to comment.