Skip to content

Commit b3712ea

Browse files
kurtamohlerVincent Moens
authored andcommitted
[Performance] Improve performance of compiled ReplayBuffer (#2529)
(cherry picked from commit 2a07f4c)
1 parent c989891 commit b3712ea

File tree

6 files changed

+376
-66
lines changed

6 files changed

+376
-66
lines changed

benchmarks/test_replaybuffer_benchmark.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,82 @@ def test_rb_populate(benchmark, rb, storage, sampler, size):
172172
)
173173

174174

175+
class create_compiled_tensor_rb:
176+
def __init__(
177+
self, rb, storage, sampler, storage_size, data_size, iters, compilable=False
178+
):
179+
self.storage = storage
180+
self.rb = rb
181+
self.sampler = sampler
182+
self.storage_size = storage_size
183+
self.data_size = data_size
184+
self.iters = iters
185+
self.compilable = compilable
186+
187+
def __call__(self):
188+
kwargs = {}
189+
if self.sampler is not None:
190+
kwargs["sampler"] = self.sampler()
191+
if self.storage is not None:
192+
kwargs["storage"] = self.storage(
193+
self.storage_size, compilable=self.compilable
194+
)
195+
196+
rb = self.rb(batch_size=3, compilable=self.compilable, **kwargs)
197+
data = torch.randn(self.data_size, 1)
198+
return ((rb, data, self.iters), {})
199+
200+
201+
def extend_and_sample(rb, td, iters):
202+
for _ in range(iters):
203+
rb.extend(td)
204+
rb.sample()
205+
206+
207+
def extend_and_sample_compiled(rb, td, iters):
208+
@torch.compile
209+
def fn(td):
210+
rb.extend(td)
211+
rb.sample()
212+
213+
for _ in range(iters):
214+
fn(td)
215+
216+
217+
@pytest.mark.parametrize(
218+
"rb,storage,sampler,storage_size,data_size,iters,compiled",
219+
[
220+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, True],
221+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, False],
222+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, True],
223+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, False],
224+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, True],
225+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, False],
226+
],
227+
)
228+
def test_rb_extend_sample(
229+
benchmark, rb, storage, sampler, storage_size, data_size, iters, compiled
230+
):
231+
if compiled:
232+
torch._dynamo.reset_code_caches()
233+
234+
benchmark.pedantic(
235+
extend_and_sample_compiled if compiled else extend_and_sample,
236+
setup=create_compiled_tensor_rb(
237+
rb=rb,
238+
storage=storage,
239+
sampler=sampler,
240+
storage_size=storage_size,
241+
data_size=data_size,
242+
iters=iters,
243+
compilable=compiled,
244+
),
245+
iterations=1,
246+
warmup_rounds=10,
247+
rounds=50,
248+
)
249+
250+
175251
if __name__ == "__main__":
176252
args, unknown = argparse.ArgumentParser().parse_known_args()
177253
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_rb.py

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,18 +171,24 @@
171171
)
172172
@pytest.mark.parametrize("size", [3, 5, 100])
173173
class TestComposableBuffers:
174-
def _get_rb(self, rb_type, size, sampler, writer, storage):
174+
def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False):
175175

176176
if storage is not None:
177-
storage = storage(size)
177+
storage = storage(size, compilable=compilable)
178178

179179
sampler_args = {}
180180
if sampler is samplers.PrioritizedSampler:
181181
sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9}
182182

183183
sampler = sampler(**sampler_args)
184-
writer = writer()
185-
rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3)
184+
writer = writer(compilable=compilable)
185+
rb = rb_type(
186+
storage=storage,
187+
sampler=sampler,
188+
writer=writer,
189+
batch_size=3,
190+
compilable=compilable,
191+
)
186192
return rb
187193

188194
def _get_datum(self, datatype):
@@ -407,6 +413,84 @@ def data_iter():
407413
) if cond else contextlib.nullcontext():
408414
rb.extend(data2)
409415

416+
@pytest.mark.skipif(
417+
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
418+
)
419+
# Compiling on Windows requires "cl" compiler to be installed.
420+
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
421+
# Our Windows CI jobs do not have "cl", so skip this test.
422+
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
423+
@pytest.mark.parametrize("avoid_max_size", [False, True])
424+
def test_extend_sample_recompile(
425+
self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size
426+
):
427+
if rb_type is not ReplayBuffer:
428+
pytest.skip(
429+
"Only replay buffer of type 'ReplayBuffer' is currently supported."
430+
)
431+
if sampler is not RandomSampler:
432+
pytest.skip("Only sampler of type 'RandomSampler' is currently supported.")
433+
if storage is not LazyTensorStorage:
434+
pytest.skip(
435+
"Only storage of type 'LazyTensorStorage' is currently supported."
436+
)
437+
if writer is not RoundRobinWriter:
438+
pytest.skip(
439+
"Only writer of type 'RoundRobinWriter' is currently supported."
440+
)
441+
if datatype == "tensordict":
442+
pytest.skip("'tensordict' datatype is not currently supported.")
443+
444+
torch._dynamo.reset_code_caches()
445+
446+
# Number of times to extend the replay buffer
447+
num_extend = 10
448+
data_size = size
449+
450+
# These two cases are separated because when the max storage size is
451+
# reached, the code execution path changes, causing necessary
452+
# recompiles.
453+
if avoid_max_size:
454+
storage_size = (num_extend + 1) * data_size
455+
else:
456+
storage_size = 2 * data_size
457+
458+
rb = self._get_rb(
459+
rb_type=rb_type,
460+
sampler=sampler,
461+
writer=writer,
462+
storage=storage,
463+
size=storage_size,
464+
compilable=True,
465+
)
466+
data = self._get_data(datatype, size=data_size)
467+
468+
@torch.compile
469+
def extend_and_sample(data):
470+
rb.extend(data)
471+
return rb.sample()
472+
473+
# NOTE: The first three calls to 'extend' and 'sample' can currently
474+
# cause recompilations, so avoid capturing those.
475+
num_extend_before_capture = 3
476+
477+
for _ in range(num_extend_before_capture):
478+
extend_and_sample(data)
479+
480+
try:
481+
torch._logging.set_logs(recompiles=True)
482+
records = []
483+
capture_log_records(records, "torch._dynamo", "recompiles")
484+
485+
for _ in range(num_extend - num_extend_before_capture):
486+
extend_and_sample(data)
487+
488+
finally:
489+
torch._logging.set_logs()
490+
491+
assert len(rb) == min((num_extend * data_size), storage_size)
492+
assert len(records) == 0
493+
410494
def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
411495
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
412496
pytest.skip(
@@ -730,6 +814,52 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend):
730814
s = new_replay_buffer.sample()
731815
assert (s.exclude("index") == 1).all()
732816

817+
@pytest.mark.skipif(
818+
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
819+
)
820+
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
821+
# This test checks if the `torch._dynamo.disable` wrapper around
822+
# `TensorStorage._rand_given_ndim` is still necessary.
823+
def test__rand_given_ndim_recompile(self):
824+
torch._dynamo.reset_code_caches()
825+
826+
# Number of times to extend the replay buffer
827+
num_extend = 10
828+
data_size = 100
829+
storage_size = (num_extend + 1) * data_size
830+
sample_size = 3
831+
832+
storage = LazyTensorStorage(storage_size, compilable=True)
833+
sampler = RandomSampler()
834+
835+
# Override to avoid the `torch._dynamo.disable` wrapper
836+
storage._rand_given_ndim = storage._rand_given_ndim_impl
837+
838+
@torch.compile
839+
def extend_and_sample(data):
840+
storage.set(torch.arange(data_size) + len(storage), data)
841+
return sampler.sample(storage, sample_size)
842+
843+
data = torch.randint(100, (data_size, 1))
844+
845+
try:
846+
torch._logging.set_logs(recompiles=True)
847+
records = []
848+
capture_log_records(records, "torch._dynamo", "recompiles")
849+
850+
for _ in range(num_extend):
851+
extend_and_sample(data)
852+
853+
finally:
854+
torch._logging.set_logs()
855+
856+
assert len(storage) == num_extend * data_size
857+
assert len(records) == 8, (
858+
"If this ever decreases, that's probably good news and the "
859+
"`torch._dynamo.disable` wrapper around "
860+
"`TensorStorage._rand_given_ndim` can be removed."
861+
)
862+
733863
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
734864
def test_extend_lazystack(self, storage_type):
735865

torchrl/_utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@ class implement_for:
252252
Keyword Args:
253253
class_method (bool, optional): if ``True``, the function will be written as a class method.
254254
Defaults to ``False``.
255+
compilable (bool, optional): If ``False``, the module import happens
256+
only on the first call to the wrapped function. If ``True``, the
257+
module import happens when the wrapped function is initialized. This
258+
allows the wrapped function to work well with ``torch.compile``.
259+
Defaults to ``False``.
255260
256261
Examples:
257262
>>> @implement_for("gym", "0.13", "0.14")
@@ -290,11 +295,13 @@ def __init__(
290295
to_version: str = None,
291296
*,
292297
class_method: bool = False,
298+
compilable: bool = False,
293299
):
294300
self.module_name = module_name
295301
self.from_version = from_version
296302
self.to_version = to_version
297303
self.class_method = class_method
304+
self._compilable = compilable
298305
implement_for._setters.append(self)
299306

300307
@staticmethod
@@ -386,18 +393,27 @@ def __call__(self, fn):
386393
self.fn = fn
387394
implement_for._lazy_impl[self.func_name].append(self._call)
388395

389-
@wraps(fn)
390-
def _lazy_call_fn(*args, **kwargs):
391-
# first time we call the function, we also do the replacement.
392-
# This will cause the imports to occur only during the first call to fn
396+
if self._compilable:
397+
_call_fn = self._delazify(self.func_name)
393398

394-
result = self._delazify(self.func_name)(*args, **kwargs)
395-
return result
399+
if self.class_method:
400+
return classmethod(_call_fn)
396401

397-
if self.class_method:
398-
return classmethod(_lazy_call_fn)
402+
return _call_fn
403+
else:
404+
405+
@wraps(fn)
406+
def _lazy_call_fn(*args, **kwargs):
407+
# first time we call the function, we also do the replacement.
408+
# This will cause the imports to occur only during the first call to fn
409+
410+
result = self._delazify(self.func_name)(*args, **kwargs)
411+
return result
412+
413+
if self.class_method:
414+
return classmethod(_lazy_call_fn)
399415

400-
return _lazy_call_fn
416+
return _lazy_call_fn
401417

402418
def _call(self):
403419

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919

2020
import torch
2121

22+
try:
23+
from torch.compiler import is_dynamo_compiling
24+
except ImportError:
25+
from torch._dynamo import is_compiling as is_dynamo_compiling
26+
2227
from tensordict import (
2328
is_tensor_collection,
2429
is_tensorclass,
@@ -132,6 +137,9 @@ class ReplayBuffer:
132137
.. warning:: As of now, the generator has no effect on the transforms.
133138
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
134139
Defaults to ``False``.
140+
compilable (bool, optional): whether the writer is compilable.
141+
If ``True``, the writer cannot be shared between multiple processes.
142+
Defaults to ``False``.
135143
136144
Examples:
137145
>>> import torch
@@ -217,11 +225,20 @@ def __init__(
217225
checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821
218226
generator: torch.Generator | None = None,
219227
shared: bool = False,
228+
compilable: bool = None,
220229
) -> None:
221-
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
230+
self._storage = (
231+
storage
232+
if storage is not None
233+
else ListStorage(max_size=1_000, compilable=compilable)
234+
)
222235
self._storage.attach(self)
223236
self._sampler = sampler if sampler is not None else RandomSampler()
224-
self._writer = writer if writer is not None else RoundRobinWriter()
237+
self._writer = (
238+
writer
239+
if writer is not None
240+
else RoundRobinWriter(compilable=bool(compilable))
241+
)
225242
self._writer.register_storage(self._storage)
226243

227244
self._get_collate_fn(collate_fn)
@@ -600,7 +617,9 @@ def _add(self, data):
600617
return index
601618

602619
def _extend(self, data: Sequence) -> torch.Tensor:
603-
with self._replay_lock, self._write_lock:
620+
is_compiling = is_dynamo_compiling()
621+
nc = contextlib.nullcontext()
622+
with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc:
604623
if self.dim_extend > 0:
605624
data = self._transpose(data)
606625
index = self._writer.extend(data)
@@ -653,7 +672,7 @@ def update_priority(
653672

654673
@pin_memory_output
655674
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
656-
with self._replay_lock:
675+
with self._replay_lock if not is_dynamo_compiling() else contextlib.nullcontext():
657676
index, info = self._sampler.sample(self._storage, batch_size)
658677
info["index"] = index
659678
data = self._storage.get(index)

0 commit comments

Comments
 (0)