Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Vmap randomness for value estimator #1942

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torchrl._utils import RL_WARNINGS
from torchrl.envs.utils import step_mdp

from torchrl.objectives.utils import _vmap_func, hold_out_net
from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST
from torchrl.objectives.value.functional import (
generalized_advantage_estimate,
td0_return_estimate,
Expand Down Expand Up @@ -78,6 +78,7 @@ def _call_value_nets(
single_call: bool,
value_key: NestedKey,
detach_next: bool,
vmap_randomness: str = "error",
):
in_keys = value_net.in_keys
if single_call:
Expand Down Expand Up @@ -141,9 +142,11 @@ def _call_value_nets(
)
elif params is not None:
params_stack = torch.stack([params, next_params], 0).contiguous()
data_out = _vmap_func(value_net, (0, 0))(data_in, params_stack)
data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)(
data_in, params_stack
)
else:
data_out = vmap(value_net, (0,))(data_in)
data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in)
value_est = data_out.get(value_key)
value, value_ = value_est[0], value_est[1]
data.set(value_key, value)
Expand Down Expand Up @@ -214,6 +217,7 @@ class _AcceptedKeys:

default_keys = _AcceptedKeys()
value_network: Union[TensorDictModule, Callable]
_vmap_randomness = None

@property
def advantage_key(self):
Expand Down Expand Up @@ -428,6 +432,28 @@ def _next_value(self, tensordict, target_params, kwargs):
next_value = step_td.get(self.tensor_keys.value)
return next_value

@property
def vmap_randomness(self):
if self._vmap_randomness is None:
do_break = False
for val in self.__dict__.values():
if isinstance(val, torch.nn.Module):
for module in val.modules():
if isinstance(module, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
do_break = True
break
if do_break:
# double break
break
else:
self._vmap_randomness = "error"

return self._vmap_randomness

def set_vmap_randomness(self, value):
self._vmap_randomness = value


class TD0Estimator(ValueEstimatorBase):
"""Temporal Difference (TD(0)) estimate of advantage function.
Expand Down Expand Up @@ -589,6 +615,7 @@ def forward(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down Expand Up @@ -790,6 +817,7 @@ def forward(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down Expand Up @@ -1001,6 +1029,7 @@ def forward(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down Expand Up @@ -1247,6 +1276,7 @@ def forward(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down Expand Up @@ -1329,6 +1359,7 @@ def value_estimate(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down Expand Up @@ -1575,6 +1606,7 @@ def forward(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down
Loading