From e440b42b34119e56c04da5a1d91d9a16d3b0e9bb Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Fri, 18 Oct 2024 14:15:32 -0700 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- test/_utils_internal.py | 27 +++++++++++++++++ test/test_rb.py | 64 ++++++++++++++++++++++++++++++++++++++++- test/test_utils.py | 27 ++++++++++++++++- 3 files changed, 116 insertions(+), 2 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 51535afa606..48492459315 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -5,10 +5,12 @@ from __future__ import annotations import contextlib +import logging import os import os.path import time +import unittest from functools import wraps # Get relative file path @@ -204,6 +206,31 @@ def f_retry(*args, **kwargs): return deco_retry +# After calling this function, any log record whose name contains 'record_name' +# and is emitted from the logger that has qualified name 'logger_qname' is +# appended to the 'records' list. +# NOTE: This function is based on testing utilities for 'torch._logging' +def capture_log_records(records, logger_qname, record_name): + assert isinstance(records, list) + logger = logging.getLogger(logger_qname) + + class EmitWrapper: + def __init__(self, old_emit): + self.old_emit = old_emit + + def __call__(self, record): + nonlocal records + self.old_emit(record) + if record_name in record.name: + records.append(record) + + for handler in logger.handlers: + new_emit = EmitWrapper(handler.emit) + contextlib.ExitStack().enter_context( + unittest.mock.patch.object(handler, "emit", new_emit) + ) + + @pytest.fixture def dtype_fixture(): dtype = torch.get_default_dtype() diff --git a/test/test_rb.py b/test/test_rb.py index 24b33f89795..d2041930364 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -17,7 +17,12 @@ import pytest import torch -from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc +from _utils_internal import ( + capture_log_records, + CARTPOLE_VERSIONED, + get_default_devices, + make_tc, +) from mocking_classes import CountingEnv from packaging import version @@ -399,6 +404,63 @@ def data_iter(): ) if cond else contextlib.nullcontext(): rb.extend(data2) + def test_extend_recompile(self, rb_type, sampler, writer, storage, size, datatype): + if rb_type is not ReplayBuffer: + pytest.skip( + "Only replay buffer of type 'ReplayBuffer' is currently supported." + ) + if sampler in (PrioritizedSampler,): + pytest.skip(f"Sampler of type '{sampler.__name__}' is not yet supported.") + if storage is not LazyTensorStorage: + pytest.skip( + "Only storage of type 'LazyTensorStorage' is currently supported." + ) + if writer is not RoundRobinWriter: + pytest.skip( + "Only writer of type 'RoundRobinWriter' is currently supported." + ) + + torch.compiler.reset() + + storage_size = 10 * size + rb = self._get_rb( + rb_type=rb_type, + sampler=sampler, + writer=writer, + storage=storage, + size=storage_size, + ) + data_size = size + data = self._get_data(datatype, size=data_size) + + @torch.compile + def extend(data): + rb.extend(data) + + # Number of times to extend the replay buffer + num_extend = 30 + + # NOTE: The first two calls to 'extend' currently cause recompilations, + # so avoid capturing those for now. + num_extend_before_capture = 2 + + for _ in range(num_extend_before_capture): + extend(data) + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + + for _ in range(num_extend - num_extend_before_capture): + extend(data) + + assert len(records) == 0 + assert len(rb) == storage_size + + finally: + torch._logging.set_logs() + def test_sample(self, rb_type, sampler, writer, storage, size, datatype): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: pytest.skip( diff --git a/test/test_utils.py b/test/test_utils.py index 4224a36b54f..af5dc09985c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,7 +14,7 @@ import torch -from _utils_internal import get_default_devices +from _utils_internal import capture_log_records, get_default_devices from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend @@ -380,6 +380,31 @@ def test_rng_decorator(device): torch.testing.assert_close(s0b, s1b) +# Check that 'capture_log_records' captures records emitted when torch +# recompiles a function. +def test_capture_log_records_recompile(): + torch.compiler.reset() + + # This function recompiles each time it is called with a different string + # input. + @torch.compile + def str_to_tensor(s): + return bytes(s, "utf8") + + str_to_tensor("a") + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + str_to_tensor("b") + + finally: + torch._logging.set_logs() + + assert len(records) == 1 + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 7a264c10466ac0d71f5282a4a0b1fa008f02e158 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 23 Oct 2024 12:51:20 -0700 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/storages.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 217229b5d9b..c71ad56f554 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -146,10 +146,14 @@ def _empty(self): def _rand_given_ndim(self, batch_size): # a method to return random indices given the storage ndim + if isinstance(self, TensorStorage): + storage_len = self._len + else: + storage_len = len(self) if self.ndim == 1: return torch.randint( 0, - self._len, + storage_len, (batch_size,), generator=self._rng, device=getattr(self, "device", None), From 9c78d0a6a8ee9baba8406e77acaa688c247f8c6d Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 23 Oct 2024 14:11:39 -0700 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- test/test_rb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_rb.py b/test/test_rb.py index 9db96e5d8c0..d451bcef9c4 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -429,7 +429,7 @@ def test_extend_sample_recompile( if datatype == "tensordict": pytest.skip("'tensordict' datatype is not currently supported.") - torch.compiler.reset() + torch._dynamo.reset_code_caches() storage_size = 10 * size rb = self._get_rb( From 289022da4fa21d167e1feff1a92edf93eb6f8a3a Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 23 Oct 2024 14:14:13 -0700 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- test/test_rb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_rb.py b/test/test_rb.py index d451bcef9c4..c08fb8e2d18 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -429,7 +429,7 @@ def test_extend_sample_recompile( if datatype == "tensordict": pytest.skip("'tensordict' datatype is not currently supported.") - torch._dynamo.reset_code_caches() + torch._dynamo.reset() storage_size = 10 * size rb = self._get_rb( From 6750343642c18c602433b9b04ea45a8c6b8573f7 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 23 Oct 2024 14:57:27 -0700 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/storages.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index c71ad56f554..736776006ab 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -144,16 +144,19 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _empty(self): ... + # NOTE: This property is used to enable compiled Storages. A `len(self)` + # call can cause recompiles, but for some reason, wrapping the call in a + # `property` decorated function avoids the recompiles. + @property + def len(self): + return len(self) + def _rand_given_ndim(self, batch_size): # a method to return random indices given the storage ndim - if isinstance(self, TensorStorage): - storage_len = self._len - else: - storage_len = len(self) if self.ndim == 1: return torch.randint( 0, - storage_len, + self.len, (batch_size,), generator=self._rng, device=getattr(self, "device", None), From 08cb766d620583aa111c7367c912720b1a94b523 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 23 Oct 2024 17:58:30 -0700 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- test/test_rb.py | 15 +++++++++------ torchrl/data/replay_buffers/storages.py | 16 +++++++++++++--- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index c08fb8e2d18..490dd1f56df 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -116,6 +116,7 @@ _has_gym = importlib.util.find_spec("gym") is not None _has_snapshot = importlib.util.find_spec("torchsnapshot") is not None _os_is_windows = sys.platform == "win32" +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) torch_2_3 = version.parse( ".".join([str(s) for s in version.parse(str(torch.__version__)).release]) @@ -404,14 +405,16 @@ def data_iter(): ) if cond else contextlib.nullcontext(): rb.extend(data2) + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" + ) + # Compiling on Windows requires "cl" compiler to be installed. + # + # Our Windows CI jobs do not have "cl", so skip this test. + @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") def test_extend_sample_recompile( self, rb_type, sampler, writer, storage, size, datatype ): - if _os_is_windows: - # Compiling on Windows requires "cl" compiler to be installed. - # - # Our Windows CI jobs do not have "cl", so skip this test. - pytest.skip("This test does not support Windows.") if rb_type is not ReplayBuffer: pytest.skip( "Only replay buffer of type 'ReplayBuffer' is currently supported." @@ -429,7 +432,7 @@ def test_extend_sample_recompile( if datatype == "tensordict": pytest.skip("'tensordict' datatype is not currently supported.") - torch._dynamo.reset() + torch._dynamo.reset_code_caches() storage_size = 10 * size rb = self._get_rb( diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 736776006ab..beab68971b5 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -144,9 +144,19 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _empty(self): ... - # NOTE: This property is used to enable compiled Storages. A `len(self)` - # call can cause recompiles, but for some reason, wrapping the call in a - # `property` decorated function avoids the recompiles. + # NOTE: This property is used to enable compiled Storages. Calling + # `len(self)` on a TensorStorage should normally cause a graph break since + # it uses a `mp.Value`, and it does cause a break when the `len(self)` call + # happens within a method of TensorStorage itself. However, when the + # `len(self)` call happens in the Storage base class, for an unknown reason + # the compiler doesn't seem to recognize that there should be a graph break, + # and the lack of a break causes a recompile each time `len(self)` is called + # in this context. Also for an unknown reason, we can force the graph break + # to happen if we wrap the `len(self)` call with a `property`-decorated + # function. For another unknown reason, if we change + # `TensorStorage._len_value` from `mp.Value` to int, it seems like there + # should no longer be any need to recompile, but recompiles happen anyway. + # Ideally, this should all be investigated and understood in the future. @property def len(self): return len(self) From fff3e79e052e1f08598448fbac7e1eaf6fb26f2d Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 24 Oct 2024 10:34:43 -0700 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- test/test_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index af5dc09985c..6537c19ff54 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -15,10 +15,13 @@ import torch from _utils_internal import capture_log_records, get_default_devices +from packaging import version from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) + @pytest.mark.parametrize("value", ["True", "1", "true"]) def test_get_binary_env_var_positive(value): @@ -382,6 +385,9 @@ def test_rng_decorator(device): # Check that 'capture_log_records' captures records emitted when torch # recompiles a function. +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" +) def test_capture_log_records_recompile(): torch.compiler.reset()