diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index ae012cadec8..786e5d2ebb0 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -20,7 +20,7 @@ from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector -from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.envs import ExplorationType, set_exploration_type from torchrl.modules import EGreedyModule from torchrl.objectives import DQNLoss, HardUpdate @@ -65,15 +65,14 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create the replay buffer if cfg.buffer.scratch_dir in ("", None): - storage_cls = functools.partial(LazyTensorStorage, device=device) - transform = None + storage_cls = LazyMemmapStorage else: storage_cls = functools.partial( - LazyTensorStorage, scratch_dir=cfg.buffer.scratch_dir + LazyMemmapStorage, scratch_dir=cfg.buffer.scratch_dir ) - def transform(td): - return td.to(device) + def transform(td): + return td.to(device) replay_buffer = TensorDictReplayBuffer( pin_memory=False,