From 58184bdc809b6f2ff4c6659a2d9de392abbbcf6f Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Mon, 21 Oct 2024 12:20:50 -0700 Subject: [PATCH] [Feature] Add test for recompiles of `ReplayBuffer.extend` ghstack-source-id: 9f3ab17a572ffd28a30ad4dd46305b2face65bef Pull Request resolved: https://github.com/pytorch/rl/pull/2504 --- test/_utils_internal.py | 27 ++++++++++++++++ test/test_rb.py | 69 ++++++++++++++++++++++++++++++++++++++++- test/test_utils.py | 27 +++++++++++++++- 3 files changed, 121 insertions(+), 2 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 51535afa606..48492459315 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -5,10 +5,12 @@ from __future__ import annotations import contextlib +import logging import os import os.path import time +import unittest from functools import wraps # Get relative file path @@ -204,6 +206,31 @@ def f_retry(*args, **kwargs): return deco_retry +# After calling this function, any log record whose name contains 'record_name' +# and is emitted from the logger that has qualified name 'logger_qname' is +# appended to the 'records' list. +# NOTE: This function is based on testing utilities for 'torch._logging' +def capture_log_records(records, logger_qname, record_name): + assert isinstance(records, list) + logger = logging.getLogger(logger_qname) + + class EmitWrapper: + def __init__(self, old_emit): + self.old_emit = old_emit + + def __call__(self, record): + nonlocal records + self.old_emit(record) + if record_name in record.name: + records.append(record) + + for handler in logger.handlers: + new_emit = EmitWrapper(handler.emit) + contextlib.ExitStack().enter_context( + unittest.mock.patch.object(handler, "emit", new_emit) + ) + + @pytest.fixture def dtype_fixture(): dtype = torch.get_default_dtype() diff --git a/test/test_rb.py b/test/test_rb.py index 0e10f534728..b21614376a6 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -17,7 +17,12 @@ import pytest import torch -from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc +from _utils_internal import ( + capture_log_records, + CARTPOLE_VERSIONED, + get_default_devices, + make_tc, +) from mocking_classes import CountingEnv from packaging import version @@ -399,6 +404,68 @@ def data_iter(): ) if cond else contextlib.nullcontext(): rb.extend(data2) + def test_extend_recompile(self, rb_type, sampler, writer, storage, size, datatype): + if _os_is_windows: + # Compiling on Windows requires "cl" compiler to be installed. + # + # Our Windows CI jobs do not have "cl", so skip this test. + pytest.skip("This test does not support Windows.") + if rb_type is not ReplayBuffer: + pytest.skip( + "Only replay buffer of type 'ReplayBuffer' is currently supported." + ) + if sampler in (PrioritizedSampler,): + pytest.skip(f"Sampler of type '{sampler.__name__}' is not yet supported.") + if storage is not LazyTensorStorage: + pytest.skip( + "Only storage of type 'LazyTensorStorage' is currently supported." + ) + if writer is not RoundRobinWriter: + pytest.skip( + "Only writer of type 'RoundRobinWriter' is currently supported." + ) + + torch.compiler.reset() + + storage_size = 10 * size + rb = self._get_rb( + rb_type=rb_type, + sampler=sampler, + writer=writer, + storage=storage, + size=storage_size, + ) + data_size = size + data = self._get_data(datatype, size=data_size) + + @torch.compile + def extend(data): + rb.extend(data) + + # Number of times to extend the replay buffer + num_extend = 30 + + # NOTE: The first two calls to 'extend' currently cause recompilations, + # so avoid capturing those for now. + num_extend_before_capture = 2 + + for _ in range(num_extend_before_capture): + extend(data) + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + + for _ in range(num_extend - num_extend_before_capture): + extend(data) + + assert len(records) == 0 + assert len(rb) == storage_size + + finally: + torch._logging.set_logs() + def test_sample(self, rb_type, sampler, writer, storage, size, datatype): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: pytest.skip( diff --git a/test/test_utils.py b/test/test_utils.py index 4224a36b54f..af5dc09985c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,7 +14,7 @@ import torch -from _utils_internal import get_default_devices +from _utils_internal import capture_log_records, get_default_devices from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend @@ -380,6 +380,31 @@ def test_rng_decorator(device): torch.testing.assert_close(s0b, s1b) +# Check that 'capture_log_records' captures records emitted when torch +# recompiles a function. +def test_capture_log_records_recompile(): + torch.compiler.reset() + + # This function recompiles each time it is called with a different string + # input. + @torch.compile + def str_to_tensor(s): + return bytes(s, "utf8") + + str_to_tensor("a") + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + str_to_tensor("b") + + finally: + torch._logging.set_logs() + + assert len(records) == 1 + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)