Skip to content

Commit

Permalink
[Benchmark] Add benchmark for compiled ReplayBuffer.extend/sample
Browse files Browse the repository at this point in the history
ghstack-source-id: d4562697e2c1a8392cf5bdcadb50f8b7b6939e41
Pull Request resolved: #2514
  • Loading branch information
kurtamohler authored and vmoens committed Oct 25, 2024
1 parent 0f29c7e commit 5e03a55
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions benchmarks/test_replaybuffer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LazyMemmapStorage,
LazyTensorStorage,
ListStorage,
ReplayBuffer,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
Expand Down Expand Up @@ -172,6 +173,65 @@ def test_rb_populate(benchmark, rb, storage, sampler, size):
)


class create_tensor_rb:
def __init__(self, rb, storage, sampler, size=1_000_000, iters=100):
self.storage = storage
self.rb = rb
self.sampler = sampler
self.size = size
self.iters = iters

def __call__(self):
kwargs = {}
if self.sampler is not None:
kwargs["sampler"] = self.sampler()
if self.storage is not None:
kwargs["storage"] = self.storage(10 * self.size)

rb = self.rb(batch_size=3, **kwargs)
data = torch.randn(self.size, 1)
return ((rb, data, self.iters), {})


def extend_and_sample(rb, td, iters):
for _ in range(iters):
rb.extend(td)
rb.sample()


def extend_and_sample_compiled(rb, td, iters):
@torch.compile
def fn(td):
rb.extend(td)
rb.sample()

for _ in range(iters):
fn(td)


@pytest.mark.parametrize(
"rb,storage,sampler,size,iters,compiled",
[
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False],
],
)
def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled):
benchmark.pedantic(
extend_and_sample_compiled if compiled else extend_and_sample,
setup=create_tensor_rb(
rb=rb,
storage=storage,
sampler=sampler,
size=size,
iters=iters,
),
iterations=1,
warmup_rounds=10,
rounds=50,
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

1 comment on commit 5e03a55

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 5e03a55 Previous: 0f29c7e Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 32.505583334513354 iter/sec (stddev: 0.18450450281117275) 249.41595015187497 iter/sec (stddev: 0.0005465164421569596) 7.67

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.