From 7c044b37f3e9837963fd55ae39c8060353c0dcdb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 17 Dec 2024 16:32:56 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- sota-implementations/dqn/dqn_atari.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index ae012cadec8..786e5d2ebb0 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -20,7 +20,7 @@ from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector -from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.envs import ExplorationType, set_exploration_type from torchrl.modules import EGreedyModule from torchrl.objectives import DQNLoss, HardUpdate @@ -65,15 +65,14 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create the replay buffer if cfg.buffer.scratch_dir in ("", None): - storage_cls = functools.partial(LazyTensorStorage, device=device) - transform = None + storage_cls = LazyMemmapStorage else: storage_cls = functools.partial( - LazyTensorStorage, scratch_dir=cfg.buffer.scratch_dir + LazyMemmapStorage, scratch_dir=cfg.buffer.scratch_dir ) - def transform(td): - return td.to(device) + def transform(td): + return td.to(device) replay_buffer = TensorDictReplayBuffer( pin_memory=False,