Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into dreamer_v1_refactor
Browse files Browse the repository at this point in the history
# Conflicts:
#	torchrl/data/replay_buffers/storages.py
  • Loading branch information
vmoens committed Apr 18, 2024
2 parents 7d0a158 + ee8cafb commit 4451f63
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 53 deletions.
113 changes: 96 additions & 17 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
LazyTensorStorage,
ReplayBuffer,
TensorDictReplayBuffer,
TensorSpec,
TensorStorage,
UnboundedContinuousTensorSpec,
)
Expand Down Expand Up @@ -120,6 +121,7 @@
_has_tv,
BatchSizeTransform,
FORWARD_NOT_IMPLEMENTED,
Transform,
)
from torchrl.envs.transforms.vc1 import _has_vc
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
Expand Down Expand Up @@ -6423,17 +6425,11 @@ def test_trans_parallel_env_check(self):
finally:
env.close()

def test_trans_serial_env_check(self):
with pytest.raises(RuntimeError, match="The leading shape of the primer specs"):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])),
)
_ = env.observation_spec

@pytest.mark.parametrize("spec_shape", [[4], [2, 4]])
def test_trans_serial_env_check(self, spec_shape):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
Expand Down Expand Up @@ -6533,6 +6529,72 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done):
r1 = env.rollout(100, break_when_any_done=break_when_any_done)
tensordict.tensordict.assert_allclose_td(r0, r1)

def test_callable_default_value(self):
def create_tensor():
return torch.ones(3)

env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey=UnboundedContinuousTensorSpec([3]), default_value=create_tensor
),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
assert ("next", "mykey") in env.rollout(3).keys(True)

def test_dict_default_value(self):

# Test with a dict of float default values
key1_spec = UnboundedContinuousTensorSpec([3])
key2_spec = UnboundedContinuousTensorSpec([3])
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey1=key1_spec,
mykey2=key2_spec,
default_value={
"mykey1": 1.0,
"mykey2": 2.0,
},
),
)
check_env_specs(env)
reset_td = env.reset()
assert "mykey1" in reset_td.keys()
assert "mykey2" in reset_td.keys()
rollout_td = env.rollout(3)
assert ("next", "mykey1") in rollout_td.keys(True)
assert ("next", "mykey2") in rollout_td.keys(True)
assert (rollout_td.get(("next", "mykey1")) == 1.0).all()
assert (rollout_td.get(("next", "mykey2")) == 2.0).all()

# Test with a dict of callable default values
key1_spec = UnboundedContinuousTensorSpec([3])
key2_spec = DiscreteTensorSpec(3, dtype=torch.int64)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey1=key1_spec,
mykey2=key2_spec,
default_value={
"mykey1": lambda: torch.ones(3),
"mykey2": lambda: torch.tensor(1, dtype=torch.int64),
},
),
)
check_env_specs(env)
reset_td = env.reset()
assert "mykey1" in reset_td.keys()
assert "mykey2" in reset_td.keys()
rollout_td = env.rollout(3)
assert ("next", "mykey1") in rollout_td.keys(True)
assert ("next", "mykey2") in rollout_td.keys(True)
assert (rollout_td.get(("next", "mykey1")) == torch.ones(3)).all
assert (
rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64)
).all


class TestTimeMaxPool(TransformBase):
@pytest.mark.parametrize("T", [2, 4])
Expand Down Expand Up @@ -6813,18 +6875,13 @@ def make_env():
finally:
env.close()

def test_trans_serial_env_check(self):
@pytest.mark.parametrize("shape", [(), (2,)])
def test_trans_serial_env_check(self, shape):
state_dim = 7
action_dim = 7
with pytest.raises(RuntimeError, match="The leading shape of the primer"):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=()),
)
check_env_specs(env)
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)),
gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=shape),
)
try:
check_env_specs(env)
Expand Down Expand Up @@ -7897,6 +7954,28 @@ def test_added_transforms_are_in_eval_mode():


class TestTransformedEnv:
def test_attr_error(self):
class BuggyTransform(Transform):
def transform_observation_spec(
self, observation_spec: TensorSpec
) -> TensorSpec:
raise AttributeError

def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
raise RuntimeError("reward!")

env = TransformedEnv(CountingEnv(), BuggyTransform())
with pytest.raises(
AttributeError, match="because an internal error was raised"
):
env.observation_spec
with pytest.raises(
AttributeError, match="'CountingEnv' object has no attribute 'tralala'"
):
env.tralala
with pytest.raises(RuntimeError, match="reward!"):
env.transform.transform_reward_spec(env.base_env.full_reward_spec)

def test_independent_obs_specs_from_shared_env(self):
obs_spec = CompositeSpec(
observation=BoundedTensorSpec(low=0, high=10, shape=torch.Size((1,)))
Expand Down
3 changes: 3 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
VERBOSE = strtobool(os.environ.get("VERBOSE", "0"))
_os_is_windows = sys.platform == "win32"
RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1"))
if RL_WARNINGS:
warnings.simplefilter("once", DeprecationWarning)

BATCHED_PIPE_TIMEOUT = float(os.environ.get("BATCHED_PIPE_TIMEOUT", "10000.0"))


Expand Down
8 changes: 6 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
def mark_update(self, index: Union[int, torch.Tensor]) -> None:
self._sampler.mark_update(index)

def append_transform(self, transform: "Transform") -> None: # noqa-F821
def append_transform(self, transform: "Transform") -> ReplayBuffer: # noqa-F821
"""Appends transform at the end.
Transforms are applied in order when `sample` is called.
Expand All @@ -626,8 +626,11 @@ def append_transform(self, transform: "Transform") -> None: # noqa-F821
transform = _CallableTransform(transform)
transform.eval()
self._transform.append(transform)
return self

def insert_transform(self, index: int, transform: "Transform") -> None: # noqa-F821
def insert_transform(
self, index: int, transform: "Transform" # noqa-F821
) -> ReplayBuffer:
"""Inserts transform.
Transforms are executed in order when `sample` is called.
Expand All @@ -638,6 +641,7 @@ def insert_transform(self, index: int, transform: "Transform") -> None: # noqa-
"""
transform.eval()
self._transform.insert(index, transform)
return self

def __iter__(self):
if self._sampler.ran_out:
Expand Down
22 changes: 14 additions & 8 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,10 +1109,13 @@ def max_size_along_dim0(data_shape):
for key, tensor in sorted(
out.items(include_nested=True, leaves_only=True), key=str
):
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
torchrl_logger.debug(
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
)
try:
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
torchrl_logger.debug(
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
)
except RuntimeError:
pass
else:
out = _init_pytree(self.scratch_dir, max_size_along_dim0, data)
self._storage = out
Expand Down Expand Up @@ -1468,10 +1471,13 @@ def _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor):
filename=total_tensor_path,
dtype=tensor.dtype,
)
filesize = os.path.getsize(out.filename) / 1024 / 1024
torchrl_logger.debug(
f"The storage was created in {out.filename} and occupies {filesize} Mb of storage."
)
try:
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
torchrl_logger.debug(
f"The storage was created in {out.filename} and occupies {filesize} Mb of storage."
)
except RuntimeError:
pass
return out


Expand Down
Loading

0 comments on commit 4451f63

Please sign in to comment.