|
171 | 171 | )
|
172 | 172 | @pytest.mark.parametrize("size", [3, 5, 100])
|
173 | 173 | class TestComposableBuffers:
|
174 |
| - def _get_rb(self, rb_type, size, sampler, writer, storage): |
| 174 | + def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False): |
175 | 175 |
|
176 | 176 | if storage is not None:
|
177 |
| - storage = storage(size) |
| 177 | + storage = storage(size, compilable=compilable) |
178 | 178 |
|
179 | 179 | sampler_args = {}
|
180 | 180 | if sampler is samplers.PrioritizedSampler:
|
181 | 181 | sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9}
|
182 | 182 |
|
183 | 183 | sampler = sampler(**sampler_args)
|
184 |
| - writer = writer() |
185 |
| - rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3) |
| 184 | + writer = writer(compilable=compilable) |
| 185 | + rb = rb_type( |
| 186 | + storage=storage, |
| 187 | + sampler=sampler, |
| 188 | + writer=writer, |
| 189 | + batch_size=3, |
| 190 | + compilable=compilable, |
| 191 | + ) |
186 | 192 | return rb
|
187 | 193 |
|
188 | 194 | def _get_datum(self, datatype):
|
@@ -407,6 +413,84 @@ def data_iter():
|
407 | 413 | ) if cond else contextlib.nullcontext():
|
408 | 414 | rb.extend(data2)
|
409 | 415 |
|
| 416 | + @pytest.mark.skipif( |
| 417 | + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" |
| 418 | + ) |
| 419 | + # Compiling on Windows requires "cl" compiler to be installed. |
| 420 | + # <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143> |
| 421 | + # Our Windows CI jobs do not have "cl", so skip this test. |
| 422 | + @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") |
| 423 | + @pytest.mark.parametrize("avoid_max_size", [False, True]) |
| 424 | + def test_extend_sample_recompile( |
| 425 | + self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size |
| 426 | + ): |
| 427 | + if rb_type is not ReplayBuffer: |
| 428 | + pytest.skip( |
| 429 | + "Only replay buffer of type 'ReplayBuffer' is currently supported." |
| 430 | + ) |
| 431 | + if sampler is not RandomSampler: |
| 432 | + pytest.skip("Only sampler of type 'RandomSampler' is currently supported.") |
| 433 | + if storage is not LazyTensorStorage: |
| 434 | + pytest.skip( |
| 435 | + "Only storage of type 'LazyTensorStorage' is currently supported." |
| 436 | + ) |
| 437 | + if writer is not RoundRobinWriter: |
| 438 | + pytest.skip( |
| 439 | + "Only writer of type 'RoundRobinWriter' is currently supported." |
| 440 | + ) |
| 441 | + if datatype == "tensordict": |
| 442 | + pytest.skip("'tensordict' datatype is not currently supported.") |
| 443 | + |
| 444 | + torch._dynamo.reset_code_caches() |
| 445 | + |
| 446 | + # Number of times to extend the replay buffer |
| 447 | + num_extend = 10 |
| 448 | + data_size = size |
| 449 | + |
| 450 | + # These two cases are separated because when the max storage size is |
| 451 | + # reached, the code execution path changes, causing necessary |
| 452 | + # recompiles. |
| 453 | + if avoid_max_size: |
| 454 | + storage_size = (num_extend + 1) * data_size |
| 455 | + else: |
| 456 | + storage_size = 2 * data_size |
| 457 | + |
| 458 | + rb = self._get_rb( |
| 459 | + rb_type=rb_type, |
| 460 | + sampler=sampler, |
| 461 | + writer=writer, |
| 462 | + storage=storage, |
| 463 | + size=storage_size, |
| 464 | + compilable=True, |
| 465 | + ) |
| 466 | + data = self._get_data(datatype, size=data_size) |
| 467 | + |
| 468 | + @torch.compile |
| 469 | + def extend_and_sample(data): |
| 470 | + rb.extend(data) |
| 471 | + return rb.sample() |
| 472 | + |
| 473 | + # NOTE: The first three calls to 'extend' and 'sample' can currently |
| 474 | + # cause recompilations, so avoid capturing those. |
| 475 | + num_extend_before_capture = 3 |
| 476 | + |
| 477 | + for _ in range(num_extend_before_capture): |
| 478 | + extend_and_sample(data) |
| 479 | + |
| 480 | + try: |
| 481 | + torch._logging.set_logs(recompiles=True) |
| 482 | + records = [] |
| 483 | + capture_log_records(records, "torch._dynamo", "recompiles") |
| 484 | + |
| 485 | + for _ in range(num_extend - num_extend_before_capture): |
| 486 | + extend_and_sample(data) |
| 487 | + |
| 488 | + finally: |
| 489 | + torch._logging.set_logs() |
| 490 | + |
| 491 | + assert len(rb) == min((num_extend * data_size), storage_size) |
| 492 | + assert len(records) == 0 |
| 493 | + |
410 | 494 | def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
|
411 | 495 | if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
|
412 | 496 | pytest.skip(
|
@@ -730,6 +814,52 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend):
|
730 | 814 | s = new_replay_buffer.sample()
|
731 | 815 | assert (s.exclude("index") == 1).all()
|
732 | 816 |
|
| 817 | + @pytest.mark.skipif( |
| 818 | + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" |
| 819 | + ) |
| 820 | + @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") |
| 821 | + # This test checks if the `torch._dynamo.disable` wrapper around |
| 822 | + # `TensorStorage._rand_given_ndim` is still necessary. |
| 823 | + def test__rand_given_ndim_recompile(self): |
| 824 | + torch._dynamo.reset_code_caches() |
| 825 | + |
| 826 | + # Number of times to extend the replay buffer |
| 827 | + num_extend = 10 |
| 828 | + data_size = 100 |
| 829 | + storage_size = (num_extend + 1) * data_size |
| 830 | + sample_size = 3 |
| 831 | + |
| 832 | + storage = LazyTensorStorage(storage_size, compilable=True) |
| 833 | + sampler = RandomSampler() |
| 834 | + |
| 835 | + # Override to avoid the `torch._dynamo.disable` wrapper |
| 836 | + storage._rand_given_ndim = storage._rand_given_ndim_impl |
| 837 | + |
| 838 | + @torch.compile |
| 839 | + def extend_and_sample(data): |
| 840 | + storage.set(torch.arange(data_size) + len(storage), data) |
| 841 | + return sampler.sample(storage, sample_size) |
| 842 | + |
| 843 | + data = torch.randint(100, (data_size, 1)) |
| 844 | + |
| 845 | + try: |
| 846 | + torch._logging.set_logs(recompiles=True) |
| 847 | + records = [] |
| 848 | + capture_log_records(records, "torch._dynamo", "recompiles") |
| 849 | + |
| 850 | + for _ in range(num_extend): |
| 851 | + extend_and_sample(data) |
| 852 | + |
| 853 | + finally: |
| 854 | + torch._logging.set_logs() |
| 855 | + |
| 856 | + assert len(storage) == num_extend * data_size |
| 857 | + assert len(records) == 8, ( |
| 858 | + "If this ever decreases, that's probably good news and the " |
| 859 | + "`torch._dynamo.disable` wrapper around " |
| 860 | + "`TensorStorage._rand_given_ndim` can be removed." |
| 861 | + ) |
| 862 | + |
733 | 863 | @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
|
734 | 864 | def test_extend_lazystack(self, storage_type):
|
735 | 865 |
|
|
0 commit comments