You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If we introduce dropout layers into the module argument to the ValueOperator, a RuntimeError is raised.
To Reproduce
Steps to reproduce the behavior (minimally adapted PPO example from the docs)
from __future__ import annotations
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (
Compose,
ObservationNorm,
StepCounter,
TransformedEnv
)
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs.utils import check_env_specs
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from gymnasium.wrappers import FlattenObservation
import gymnasium as gym
from tensordict.nn import (
TensorDictModule,
)
device = "cpu" if not torch.cuda.is_available() else "cuda:0"
def make_env():
env = gym.make("MountainCarContinuous-v0")
env = FlattenObservation(env)
env = GymWrapper(env)
env = TransformedEnv(
env,
Compose(
# normalize observations
ObservationNorm(
in_keys=["observation"],
out_keys=["observation"],
),
StepCounter(),
),
)
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
return env
def main():
train_env = make_env()
check_env_specs(train_env)
num_cells = 128
frames_per_batch = 1_000
total_frames = 1_000_000
actor_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(2 * train_env.action_spec.shape[-1], device=device),
NormalParamExtractor(),
)
policy_module = TensorDictModule( # if the actor_net is a custom torch.nn.Module subclass, it can accept safety information!
actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)
policy_module = ProbabilisticActor(
module=policy_module,
spec=train_env.action_spec,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"min": train_env.action_spec.space.minimum,
"max": train_env.action_spec.space.maximum,
},
return_log_prob=True,
# we'll need the log-prob for the numerator of the importance weights
)
value_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Dropout(0.3),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Dropout(0.3),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Dropout(0.3),
nn.Tanh(),
nn.LazyLinear(1, device=device),
)
value_module = ValueOperator(
module=value_net,
in_keys=["observation"],
)
print("Running policy:", policy_module(train_env.reset().to(device)))
print("Running value:", value_module(train_env.reset().to(device)))
collector = SyncDataCollector(
train_env,
policy_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=False,
device=device,
)
replay_buffer = ReplayBuffer(
storage=LazyTensorStorage(frames_per_batch),
sampler=SamplerWithoutReplacement(),
)
advantage_module = GAE(
gamma=0.99, lmbda=0.95, value_network=value_module, average_gae=True
)
loss_module = ClipPPOLoss(
actor=policy_module,
critic=value_module,
clip_epsilon=0.2,
entropy_bonus=True,
entropy_coef=1e-4,
# these keys match by default but we set this for completeness
value_target_key=advantage_module.value_target_key,
critic_coef=1.0,
gamma=0.99,
loss_critic_type="smooth_l1",
)
optim = torch.optim.Adam(loss_module.parameters(), 1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim, total_frames // frames_per_batch, 0.0
)
# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
# we now have a batch of data to work with. Let's learn something from it.
for _ in range(1000):
# We'll need an "advantage" signal to make PPO work.
# We re-compute it at each epoch as its value depends on the value
# network which is updated in the inner loop.
with torch.no_grad():
advantage_module(tensordict_data)
data_view = tensordict_data.reshape(-1)
replay_buffer.extend(data_view.cpu())
for _ in range(frames_per_batch // 128):
subdata = replay_buffer.sample(128)
loss_vals = loss_module(subdata.to(device))
loss_value = (
loss_vals["loss_objective"]
+ loss_vals["loss_critic"]
+ loss_vals["loss_entropy"]
)
# Optimization: backward, grad clipping and optim step
loss_value.backward()
# this is not strictly mandatory but it's good practice to keep
# your gradient norm bounded
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 2.0)
optim.step()
optim.zero_grad()
# We're also using a learning rate scheduler. Like the gradient clipping,
# this is a nice-to-have but nothing necessary for PPO to work.
scheduler.step()
if __name__ == "__main__":
main()
Expected behavior
Running the example above should start a normal PPO training, just that the value network now has dropout layers. However, the following exception is raised:
Traceback (most recent call last):
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 1178, in forward
raise err
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 1164, in forward
tensors = self._call_module(tensors, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 1121, in _call_module
out = self.module(*tensors, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/functional_modules.py", line 588, in new_fun
return getattr(type(self), fun_name)(self, *args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/container.py", line 215, in forward
input = module(input)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/functional_modules.py", line 588, in new_fun
return getattr(type(self), fun_name)(self, *args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/dropout.py", line 58, in forward
return F.dropout(input, self.p, self.training, self.inplace)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/functional.py", line 1266, in dropout
return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "mre.py", line 200, in <module>
main()
File "mre.py", line 175, in main
advantage_module(tensordict_data)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torchrl/objectives/value/advantages.py", line 63, in new_func
return fun(self, *args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torchrl/objectives/value/advantages.py", line 52, in new_fun
return fun(self, *args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 282, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torchrl/objectives/value/advantages.py", line 1224, in forward
value, next_value = _call_value_nets(
File "<project_dir>/.venv/lib/python3.8/site-packages/torchrl/objectives/value/advantages.py", line 140, in _call_value_nets
data_out = vmap(value_net, (0,))(data_in)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/_functorch/apis.py", line 188, in wrapped
return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 266, in vmap_impl
return _flat_vmap(
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 38, in fn
return f(*args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 379, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/functional_modules.py", line 588, in new_fun
return getattr(type(self), fun_name)(self, *args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 282, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/_contextlib.py", line 126, in decorate_context
return func(*args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/utils.py", line 254, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 1204, in forward
raise RuntimeError(
RuntimeError: TensorDictModule failed with operation
Sequential(
(0): Linear(in_features=2, out_features=128, bias=True)
(1): Dropout(p=0.3, inplace=False)
(2): Tanh()
(3): Linear(in_features=128, out_features=128, bias=True)
(4): Dropout(p=0.3, inplace=False)
(5): Tanh()
(6): Linear(in_features=128, out_features=128, bias=True)
(7): Dropout(p=0.3, inplace=False)
(8): Tanh()
(9): Linear(in_features=128, out_features=1, bias=True)
)
in_keys=['observation']
out_keys=['state_value'].
System info
Describe the characteristic of your environment:
Describe how the library was installed: pip install torchrl
Python version: 3.8.10
torch version: 2.1.2+cu121
torchrl version: 0.2.1
Checklist
I have checked that there is no similar issue in the repo (required)
Describe the bug
If we introduce dropout layers into the module argument to the ValueOperator, a RuntimeError is raised.
To Reproduce
Steps to reproduce the behavior (minimally adapted PPO example from the docs)
Expected behavior
Running the example above should start a normal PPO training, just that the value network now has dropout layers. However, the following exception is raised:
System info
Describe the characteristic of your environment:
pip install torchrl
Checklist
The text was updated successfully, but these errors were encountered: