Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 17, 2024
1 parent 1a8360b commit e78ab4a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
4 changes: 2 additions & 2 deletions sota-implementations/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ env:
# collector
collector:
total_frames: 40_000_100
frames_per_batch: 16
frames_per_batch: 1600
eps_start: 1.0
eps_end: 0.01
annealing_frames: 4_000_000
Expand Down Expand Up @@ -38,7 +38,7 @@ optim:
loss:
gamma: 0.99
hard_update_freq: 10_000
num_updates: 1
num_updates: 100

compile:
compile: False
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/dqn/config_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ env:
# collector
collector:
total_frames: 500_100
frames_per_batch: 10
frames_per_batch: 1000
eps_start: 1.0
eps_end: 0.05
annealing_frames: 250_000
Expand Down Expand Up @@ -37,7 +37,7 @@ optim:
loss:
gamma: 0.99
hard_update_freq: 50
num_updates: 1
num_updates: 100

compile:
compile: False
Expand Down
25 changes: 15 additions & 10 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""
from __future__ import annotations

import tempfile
import functools
import warnings

import hydra
Expand All @@ -20,7 +20,7 @@
from torchrl._utils import timeit

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.modules import EGreedyModule
from torchrl.objectives import DQNLoss, HardUpdate
Expand Down Expand Up @@ -64,20 +64,26 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create the replay buffer
if cfg.buffer.scratch_dir is None:
tempdir = tempfile.TemporaryDirectory()
scratch_dir = tempdir.name
if cfg.buffer.scratch_dir in ("", None):
storage_cls = functools.partial(LazyTensorStorage, device=device)
transform = None
else:
scratch_dir = cfg.buffer.scratch_dir
storage_cls = functools.partial(
LazyTensorStorage, scratch_dir=cfg.buffer.scratch_dir
)

def transform(td):
return td.to(device)

replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=3,
storage=LazyMemmapStorage(
storage=storage_cls(
max_size=cfg.buffer.buffer_size,
scratch_dir=scratch_dir,
),
batch_size=cfg.buffer.batch_size,
)
if transform is not None:
replay_buffer.append_transform(transform)

# Create the loss module
loss_module = DQNLoss(
Expand Down Expand Up @@ -210,7 +216,6 @@ def update(sampled_tensordict):
for j in range(num_updates):
with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)
with timeit("update"):
q_loss = update(sampled_tensordict)
q_losses[j].copy_(q_loss)
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create the replay buffer
replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=10,
storage=LazyTensorStorage(
max_size=cfg.buffer.buffer_size,
device="cpu",
),
storage=LazyTensorStorage(max_size=cfg.buffer.buffer_size, device=device),
batch_size=cfg.buffer.batch_size,
)

Expand Down

0 comments on commit e78ab4a

Please sign in to comment.