Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into sampler-doc
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 7, 2024
2 parents 712068c + 1fe745a commit e4a945a
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 12 deletions.
23 changes: 23 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,29 @@ def test_parallel_env_with_policy(
# env_serial.close()
env0.close()

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
@pytest.mark.parametrize("heterogeneous", [False, True])
def test_transform_env_transform_no_device(self, heterogeneous):
# Tests non-regression on 1865
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(), StepCounter(max_steps=3)
)

if heterogeneous:
make_envs = [EnvCreator(make_env), EnvCreator(make_env)]
else:
make_envs = make_env
penv = ParallelEnv(2, make_envs)
r = penv.rollout(6, break_when_any_done=False)
assert r.shape == (2, 6)
try:
env = TransformedEnv(penv)
r = env.rollout(6, break_when_any_done=False)
assert r.shape == (2, 6)
finally:
penv.close()

@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.parametrize(
"env_name",
Expand Down
10 changes: 7 additions & 3 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,14 @@ def _map_to_device_params(param, device):

# Create a stateless policy, then populate this copy with params on device
with param_and_buf.apply(
functools.partial(_map_to_device_params, device="meta")
functools.partial(_map_to_device_params, device="meta"),
filter_empty=False,
).to_module(policy):
policy = deepcopy(policy)

param_and_buf.apply(
functools.partial(_map_to_device_params, device=self.policy_device)
functools.partial(_map_to_device_params, device=self.policy_device),
filter_empty=False,
).to_module(policy)

return policy, get_weights_fn
Expand Down Expand Up @@ -1495,7 +1497,9 @@ def map_weight(
weight = nn.Parameter(weight, requires_grad=False)
return weight

local_policy_weights = TensorDictParams(policy_weights.apply(map_weight))
local_policy_weights = TensorDictParams(
policy_weights.apply(map_weight, filter_empty=False)
)

def _get_weight_fn(weights=policy_weights):
# This function will give the local_policy_weight the original weights.
Expand Down
25 changes: 23 additions & 2 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,22 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
tensordict_, keys_to_update=list(self._selected_reset_keys)
)
continue
out = ("reset", tensordict_)
if tensordict_ is not None:
tdkeys = list(tensordict_.keys(True, True))

# This way we can avoid calling select over all the keys in the shared tensordict
def tentative_update(val, other):
if other is not None:
val.copy_(other)
return val

self.shared_tensordicts[i].apply_(
tentative_update, tensordict_, default=None
)
out = ("reset", tdkeys)
else:
out = ("reset", False)

channel.send(out)
workers.append(i)

Expand Down Expand Up @@ -1509,7 +1524,13 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
torchrl_logger.info(f"resetting worker {pid}")
if not initialized:
raise RuntimeError("call 'init' before resetting")
cur_td = env.reset(tensordict=data)
# we use 'data' to pass the keys that we need to pass to reset,
# because passing the entire buffer may have unwanted consequences
cur_td = env.reset(
tensordict=root_shared_tensordict.select(*data, strict=False)
if data
else None
)
shared_tensordict.update_(
cur_td,
keys_to_update=list(_selected_reset_keys),
Expand Down
6 changes: 4 additions & 2 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def __init__(

# check that the model has parameters
params = TensorDict.from_module(actor)
with params.apply(_stateless_param, device="meta").to_module(actor):
with params.apply(
_stateless_param, device="meta", filter_empty=False
).to_module(actor):
# copy a stateless actor
self.__dict__["functional_actor"] = deepcopy(actor)
# we need to register these params as buffer to have `to` and similar
Expand All @@ -129,7 +131,7 @@ def _make_detached_param(x):
)
return x.clone()

self.frozen_params = params.apply(_make_detached_param)
self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
if requires_grad:
# includes the frozen params/buffers in the module parameters/buffers
self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
Expand Down
10 changes: 7 additions & 3 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ def _compare_and_expand(param):

params = TensorDictParams(
params.apply(
_compare_and_expand, batch_size=[expand_dim, *params.shape]
_compare_and_expand,
batch_size=[expand_dim, *params.shape],
filter_empty=False,
),
no_convert=True,
)
Expand All @@ -283,7 +285,7 @@ def _compare_and_expand(param):
# set the functional module: we need to convert the params to non-differentiable params
# otherwise they will appear twice in parameters
with params.apply(
self._make_meta_params, device=torch.device("meta")
self._make_meta_params, device=torch.device("meta"), filter_empty=False
).to_module(module):
# avoid buffers and params being exposed
self.__dict__[module_name] = deepcopy(module)
Expand All @@ -293,7 +295,9 @@ def _compare_and_expand(param):
# if create_target_params:
# we create a TensorDictParams to keep the target params as Buffer instances
target_params = TensorDictParams(
params.apply(_make_target_param(clone=create_target_params)),
params.apply(
_make_target_param(clone=create_target_params), filter_empty=False
),
no_convert=True,
)
setattr(self, name_params_target + "_params", target_params)
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def __init__(

actor_critic = ActorCriticWrapper(actor_network, value_network)
params = TensorDict.from_module(actor_critic)
params_meta = params.apply(self._make_meta_params, device=torch.device("meta"))
params_meta = params.apply(
self._make_meta_params, device=torch.device("meta"), filter_empty=False
)
with params_meta.to_module(actor_critic):
self.__dict__["actor_critic"] = deepcopy(actor_critic)

Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(
global_value_network = SafeSequential(local_value_network, mixer_network)
params = TensorDict.from_module(global_value_network)
with params.apply(
self._make_meta_params, device=torch.device("meta")
self._make_meta_params, device=torch.device("meta"), filter_empty=False
).to_module(global_value_network):
self.__dict__["global_value_network"] = deepcopy(global_value_network)

Expand Down

0 comments on commit e4a945a

Please sign in to comment.