diff --git a/sota-implementations/dqn/config_atari.yaml b/sota-implementations/dqn/config_atari.yaml index 021e7fd6132..bcbada5dc36 100644 --- a/sota-implementations/dqn/config_atari.yaml +++ b/sota-implementations/dqn/config_atari.yaml @@ -7,7 +7,7 @@ env: # collector collector: total_frames: 40_000_100 - frames_per_batch: 16 + frames_per_batch: 1600 eps_start: 1.0 eps_end: 0.01 annealing_frames: 4_000_000 @@ -38,7 +38,7 @@ optim: loss: gamma: 0.99 hard_update_freq: 10_000 - num_updates: 1 + num_updates: 100 compile: compile: False diff --git a/sota-implementations/dqn/config_cartpole.yaml b/sota-implementations/dqn/config_cartpole.yaml index 58be7fb3bb5..199533ba9be 100644 --- a/sota-implementations/dqn/config_cartpole.yaml +++ b/sota-implementations/dqn/config_cartpole.yaml @@ -7,7 +7,7 @@ env: # collector collector: total_frames: 500_100 - frames_per_batch: 10 + frames_per_batch: 1000 eps_start: 1.0 eps_end: 0.05 annealing_frames: 250_000 @@ -37,7 +37,7 @@ optim: loss: gamma: 0.99 hard_update_freq: 50 - num_updates: 1 + num_updates: 100 compile: compile: False diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index f6bcf3044cb..ae012cadec8 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -9,7 +9,7 @@ """ from __future__ import annotations -import tempfile +import functools import warnings import hydra @@ -20,7 +20,7 @@ from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.envs import ExplorationType, set_exploration_type from torchrl.modules import EGreedyModule from torchrl.objectives import DQNLoss, HardUpdate @@ -64,20 +64,26 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create the replay buffer - if cfg.buffer.scratch_dir is None: - tempdir = tempfile.TemporaryDirectory() - scratch_dir = tempdir.name + if cfg.buffer.scratch_dir in ("", None): + storage_cls = functools.partial(LazyTensorStorage, device=device) + transform = None else: - scratch_dir = cfg.buffer.scratch_dir + storage_cls = functools.partial( + LazyTensorStorage, scratch_dir=cfg.buffer.scratch_dir + ) + + def transform(td): + return td.to(device) + replay_buffer = TensorDictReplayBuffer( pin_memory=False, - prefetch=3, - storage=LazyMemmapStorage( + storage=storage_cls( max_size=cfg.buffer.buffer_size, - scratch_dir=scratch_dir, ), batch_size=cfg.buffer.batch_size, ) + if transform is not None: + replay_buffer.append_transform(transform) # Create the loss module loss_module = DQNLoss( @@ -210,7 +216,6 @@ def update(sampled_tensordict): for j in range(num_updates): with timeit("rb - sample"): sampled_tensordict = replay_buffer.sample() - sampled_tensordict = sampled_tensordict.to(device) with timeit("update"): q_loss = update(sampled_tensordict) q_losses[j].copy_(q_loss) diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 69689dd4c92..4fde452fba9 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -55,11 +55,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create the replay buffer replay_buffer = TensorDictReplayBuffer( pin_memory=False, - prefetch=10, - storage=LazyTensorStorage( - max_size=cfg.buffer.buffer_size, - device="cpu", - ), + storage=LazyTensorStorage(max_size=cfg.buffer.buffer_size, device=device), batch_size=cfg.buffer.batch_size, )