Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 30, 2024
1 parent 82cc9c4 commit a3212db
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 14 deletions.
57 changes: 56 additions & 1 deletion benchmarks/test_collectors_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import time

import pytest
import torch.cuda
import tqdm

from torchrl.collectors import SyncDataCollector
from torchrl.collectors.collectors import (
MultiaSyncDataCollector,
MultiSyncDataCollector,
)
from torchrl.envs import EnvCreator, GymEnv, StepCounter, TransformedEnv
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.utils import RandomPolicy

Expand Down Expand Up @@ -180,6 +184,57 @@ def test_async_pixels(benchmark):
benchmark(execute_collector, c)


class TestRBGCollector:
@pytest.mark.parametrize(
"n_col,n_wokrers_per_col",
[
[2, 2],
[4, 2],
[8, 2],
[16, 2],
[2, 1],
[4, 1],
[8, 1],
[16, 1],
],
)
def test_multiasync_rb(self, n_col, n_wokrers_per_col):
make_env = EnvCreator(lambda: GymEnv("ALE/Pong-v5"))
if n_wokrers_per_col > 1:
make_env = ParallelEnv(n_wokrers_per_col, make_env)
env = make_env
policy = RandomPolicy(env.action_spec)
else:
env = make_env()
policy = RandomPolicy(env.action_spec)

storage = LazyTensorStorage(10_000)
rb = ReplayBuffer(storage=storage)
rb.extend(env.rollout(2, policy).reshape(-1))
rb.append_transform(CloudpickleWrapper(lambda x: x.reshape(-1)), invert=True)

fpb = n_wokrers_per_col * 100
total_frames = n_wokrers_per_col * 100_000
c = MultiaSyncDataCollector(
[make_env] * n_col,
policy,
frames_per_batch=fpb,
total_frames=total_frames,
replay_buffer=rb,
)
frames = 0
pbar = tqdm.tqdm(total=total_frames - (n_col * fpb))
for i, _ in enumerate(c):
if i == n_col:
t0 = time.time()
if i >= n_col:
frames += fpb
if i > n_col:
fps = frames / (time.time() - t0)
pbar.update(fpb)
pbar.set_description(f"fps: {fps: 4.4f}")


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 0 additions & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def map_weight(
is_buffer = isinstance(weight, nn.Buffer)
weight = weight.data
if weight.device != policy_device:
has_different_device = True
weight = weight.to(policy_device)
elif weight.device.type in ("cpu", "mps"):
weight = weight.share_memory_()
Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def _run_collector(

if isinstance(policy, nn.Module):
policy_weights = TensorDict.from_module(policy)
policy_weights = policy_weights.data
policy_weights = policy_weights.data.lock_()
else:
warnings.warn(_NON_NN_POLICY_WEIGHTS)
policy_weights = TensorDict()
policy_weights = TensorDict(lock=True)

collector = collector_class(
env_make,
Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/distributed/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,10 @@ def check_list_length_consistency(*lists):
self._local_policy = policy
if isinstance(self._local_policy, nn.Module):
policy_weights = TensorDict.from_module(policy)
policy_weights = policy_weights.data
policy_weights = policy_weights.data.lock_()
else:
warnings.warn(_NON_NN_POLICY_WEIGHTS)
policy_weights = TensorDict()
policy_weights = TensorDict(lock=True)
self.policy_weights = policy_weights
self.collector_class = collector_class
self.collected_frames = 0
Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,10 @@ def __init__(
self.policy = policy
if isinstance(policy, nn.Module):
policy_weights = TensorDict.from_module(policy)
policy_weights = policy_weights.data
policy_weights = policy_weights.data.lock_()
else:
warnings.warn(_NON_NN_POLICY_WEIGHTS)
policy_weights = TensorDict()
policy_weights = TensorDict(lock=True)
self.policy_weights = policy_weights
self.num_workers = len(create_env_fn)
self.frames_per_batch = frames_per_batch
Expand Down
13 changes: 8 additions & 5 deletions torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def _distributed_init_collection_node(

if isinstance(policy, nn.Module):
policy_weights = TensorDict.from_module(policy)
policy_weights = policy_weights.data
policy_weights = policy_weights.data.lock_()
else:
warnings.warn(_NON_NN_POLICY_WEIGHTS)
policy_weights = TensorDict()
policy_weights = TensorDict(lock=True)

collector = collector_class(
env_make,
Expand Down Expand Up @@ -309,11 +309,14 @@ def __init__(
self.collector_class = collector_class
self.env_constructors = create_env_fn
self.policy = policy

if isinstance(policy, nn.Module):
policy_weights = TensorDict(dict(policy.named_parameters()), [])
policy_weights = policy_weights.apply(lambda x: x.data)
policy_weights = TensorDict.from_module(policy)
policy_weights = policy_weights.data.lock_()
else:
policy_weights = TensorDict({}, [])
warnings.warn(_NON_NN_POLICY_WEIGHTS)
policy_weights = TensorDict(lock=True)

self.policy_weights = policy_weights
self.num_workers = len(create_env_fn)
self.frames_per_batch = frames_per_batch
Expand Down
1 change: 0 additions & 1 deletion torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,7 +1561,6 @@ def _make_compatible_policy(


def _policy_is_tensordict_compatible(policy: nn.Module):

def is_compatible(policy):
return isinstance(policy, (RandomPolicy, TensorDictModuleBase))

Expand Down

0 comments on commit a3212db

Please sign in to comment.