diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 0f72128318d..969c7fc083e 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -23,7 +23,7 @@ from tensordict import TensorDict from tensordict.nn import CudaGraphModule from torchrl.collectors import SyncDataCollector -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import set_gym_backend @@ -84,7 +84,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create data buffer data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch), + storage=LazyTensorStorage(cfg.ppo.collector.frames_per_batch, device=device), sampler=SamplerWithoutReplacement(), batch_size=cfg.ppo.loss.mini_batch_size, ) @@ -134,7 +134,6 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch=cfg.ppo.collector.frames_per_batch, total_frames=cfg.ppo.collector.total_frames, device=device, - storing_device=device, max_frames_per_traj=-1, compile_policy={"mode": compile_mode} if compile_mode is not None else False, cudagraph_policy=cfg.compile.cudagraphs,