From 5e03a5518b00bdfb05da4d6b9506fbe30a6a7809 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 24 Oct 2024 18:50:18 -0700 Subject: [PATCH] [Benchmark] Add benchmark for compiled `ReplayBuffer.extend/sample` ghstack-source-id: d4562697e2c1a8392cf5bdcadb50f8b7b6939e41 Pull Request resolved: https://github.com/pytorch/rl/pull/2514 --- benchmarks/test_replaybuffer_benchmark.py | 60 +++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/benchmarks/test_replaybuffer_benchmark.py b/benchmarks/test_replaybuffer_benchmark.py index c10e7758361..34116ff9703 100644 --- a/benchmarks/test_replaybuffer_benchmark.py +++ b/benchmarks/test_replaybuffer_benchmark.py @@ -13,6 +13,7 @@ LazyMemmapStorage, LazyTensorStorage, ListStorage, + ReplayBuffer, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, ) @@ -172,6 +173,65 @@ def test_rb_populate(benchmark, rb, storage, sampler, size): ) +class create_tensor_rb: + def __init__(self, rb, storage, sampler, size=1_000_000, iters=100): + self.storage = storage + self.rb = rb + self.sampler = sampler + self.size = size + self.iters = iters + + def __call__(self): + kwargs = {} + if self.sampler is not None: + kwargs["sampler"] = self.sampler() + if self.storage is not None: + kwargs["storage"] = self.storage(10 * self.size) + + rb = self.rb(batch_size=3, **kwargs) + data = torch.randn(self.size, 1) + return ((rb, data, self.iters), {}) + + +def extend_and_sample(rb, td, iters): + for _ in range(iters): + rb.extend(td) + rb.sample() + + +def extend_and_sample_compiled(rb, td, iters): + @torch.compile + def fn(td): + rb.extend(td) + rb.sample() + + for _ in range(iters): + fn(td) + + +@pytest.mark.parametrize( + "rb,storage,sampler,size,iters,compiled", + [ + [ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False], + ], +) +def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled): + benchmark.pedantic( + extend_and_sample_compiled if compiled else extend_and_sample, + setup=create_tensor_rb( + rb=rb, + storage=storage, + sampler=sampler, + size=size, + iters=iters, + ), + iterations=1, + warmup_rounds=10, + rounds=50, + ) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)