From c84828523877b198e35656acbdb1ff1d33ad00f3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 15:06:40 +0000 Subject: [PATCH] [BugFix] Fix collector length with non-empty batch size ghstack-source-id: 0c6a7a49f0570fad083340a64dd89c0f4c220c06 Pull Request resolved: https://github.com/pytorch/rl/pull/2575 --- test/test_collector.py | 22 ++++++++++++++++++++++ torchrl/collectors/collectors.py | 5 +++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 1309254ce2d..7c185830a92 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -3172,6 +3172,28 @@ def make_and_test_policy( ) +@pytest.mark.parametrize( + "ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] +) +def test_no_stopiteration(ctype): + # Tests that there is no StopIteration raised and that the length of the collector is properly set + if ctype is SyncDataCollector: + envs = SerialEnv(16, CountingEnv) + else: + envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)] + + collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300) + try: + c_iter = iter(collector) + for i in range(len(collector)): # noqa: B007 + c = next(c_iter) + assert c is not None + assert i == 1 + finally: + collector.shutdown() + del collector + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 20128e4f6a2..abf817861ff 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -138,6 +138,7 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): _iterator = None total_frames: int frames_per_batch: int + requested_frames_per_batch: int trust_policy: bool compiled_policy: bool cudagraphed_policy: bool @@ -296,7 +297,7 @@ def __class_getitem__(self, index): def __len__(self) -> int: if self.total_frames > 0: - return -(self.total_frames // -self.frames_per_batch) + return -(self.total_frames // -self.requested_frames_per_batch) raise RuntimeError("Non-terminating collectors do not have a length") @@ -691,7 +692,7 @@ def __init__( remainder = total_frames % frames_per_batch if remainder != 0 and RL_WARNINGS: warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})." + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " f"This means {frames_per_batch - remainder} additional frames will be collected." "To silence this message, set the environment variable RL_WARNINGS to False." )