Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 18, 2024
2 parents b09e334 + b5b0b09 commit 3cf986c
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 25 deletions.
22 changes: 0 additions & 22 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,28 +3172,6 @@ def make_and_test_policy(
)


@pytest.mark.parametrize(
"ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector]
)
def test_no_stopiteration(ctype):
# Tests that there is no StopIteration raised and that the length of the collector is properly set
if ctype is SyncDataCollector:
envs = SerialEnv(16, CountingEnv)
else:
envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)]

collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300)
try:
c_iter = iter(collector)
for i in range(len(collector)): # noqa: B007
c = next(c_iter)
assert c is not None
assert i == 1
finally:
collector.shutdown()
del collector


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
5 changes: 2 additions & 3 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
_iterator = None
total_frames: int
frames_per_batch: int
requested_frames_per_batch: int
trust_policy: bool
compiled_policy: bool
cudagraphed_policy: bool
Expand Down Expand Up @@ -306,7 +305,7 @@ def __class_getitem__(self, index):

def __len__(self) -> int:
if self.total_frames > 0:
return -(self.total_frames // -self.requested_frames_per_batch)
return -(self.total_frames // -self.frames_per_batch)
raise RuntimeError("Non-terminating collectors do not have a length")


Expand Down Expand Up @@ -701,7 +700,7 @@ def __init__(
remainder = total_frames % frames_per_batch
if remainder != 0 and RL_WARNINGS:
warnings.warn(
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})."
f"This means {frames_per_batch - remainder} additional frames will be collected."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
Expand Down
1 change: 1 addition & 0 deletions torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def __init__(
self.policy_weights = policy_weights
self.num_workers = len(create_env_fn)
self.frames_per_batch = frames_per_batch
self.requested_frames_per_batch = frames_per_batch

self.device = device
self.storing_device = storing_device
Expand Down
1 change: 1 addition & 0 deletions torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def __init__(
self.policy_weights = policy_weights
self.num_workers = len(create_env_fn)
self.frames_per_batch = frames_per_batch
self.requested_frames_per_batch = frames_per_batch

self.device = device
self.storing_device = storing_device
Expand Down
1 change: 1 addition & 0 deletions torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def __init__(
self.policy_weights = policy_weights
self.num_workers = len(create_env_fn)
self.frames_per_batch = frames_per_batch
self.requested_frames_per_batch = frames_per_batch

self.device = device
self.storing_device = storing_device
Expand Down

0 comments on commit 3cf986c

Please sign in to comment.