Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 23cab41 commit ad1f0a4
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

from torchrl.envs import set_gym_backend
Expand Down Expand Up @@ -84,7 +84,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create data buffer
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch),
storage=LazyTensorStorage(cfg.ppo.collector.frames_per_batch, device=device),
sampler=SamplerWithoutReplacement(),
batch_size=cfg.ppo.loss.mini_batch_size,
)
Expand Down Expand Up @@ -134,7 +134,6 @@ def main(cfg: "DictConfig"): # noqa: F821
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
cudagraph_policy=cfg.compile.cudagraphs,
Expand Down

0 comments on commit ad1f0a4

Please sign in to comment.