Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 17, 2024
1 parent e78ab4a commit 7c044b3
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torchrl._utils import timeit

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.modules import EGreedyModule
from torchrl.objectives import DQNLoss, HardUpdate
Expand Down Expand Up @@ -65,15 +65,14 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create the replay buffer
if cfg.buffer.scratch_dir in ("", None):
storage_cls = functools.partial(LazyTensorStorage, device=device)
transform = None
storage_cls = LazyMemmapStorage
else:
storage_cls = functools.partial(
LazyTensorStorage, scratch_dir=cfg.buffer.scratch_dir
LazyMemmapStorage, scratch_dir=cfg.buffer.scratch_dir
)

def transform(td):
return td.to(device)
def transform(td):
return td.to(device)

replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
Expand Down

0 comments on commit 7c044b3

Please sign in to comment.