Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 23, 2024
1 parent 4d1790a commit 6427212
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
18 changes: 9 additions & 9 deletions sota-implementations/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main(cfg: "DictConfig"): # noqa: F821
},
)
else:
logger = ""
logger = None

key, init_env_steps, stats = None, None, None
if not cfg.env.vecnorm and cfg.env.norm_stats:
Expand Down Expand Up @@ -174,14 +174,14 @@ def main(cfg: "DictConfig"): # noqa: F821
t.loc.fill_(0.0)

trainer = make_trainer(
collector,
loss_module,
recorder,
target_net_updater,
actor_model_explore,
replay_buffer,
logger,
cfg,
collector=collector,
loss_module=loss_module,
recorder=recorder,
target_net_updater=target_net_updater,
policy_exploration=actor_model_explore,
replay_buffer=replay_buffer,
logger=logger,
cfg=cfg,
)

trainer.train()
Expand Down
5 changes: 4 additions & 1 deletion torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from torchrl.data.replay_buffers.storages import (
_get_default_collate,
_stack_anything,
ListStorage,
Storage,
StorageEnsemble,
Expand Down Expand Up @@ -1541,8 +1542,10 @@ def __init__(
num_buffer_sampled: int | None = None,
**kwargs,
):

if collate_fn is None:
collate_fn = LazyStackedTensorDict.maybe_dense_stack
collate_fn = _stack_anything

if rbs:
if storages is not None or samplers is not None or writers is not None:
raise RuntimeError
Expand Down
8 changes: 7 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,10 +1323,16 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor:


def _collate_list_tensordict(x):
out = LazyStackedTensorDict.maybe_dense_stack(x, 0)
out = torch.stack(x, 0)
return out


def _stack_anything(x):
if is_tensor_collection(x[0]):
return LazyStackedTensorDict.maybe_dense_stack(x)
return torch.stack(x)


def _collate_id(x):
return x

Expand Down

0 comments on commit 6427212

Please sign in to comment.