Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Using Dropout in a ValueOperator results in RuntimeError #1935

Closed
3 tasks done
BeFranke opened this issue Feb 20, 2024 · 2 comments · Fixed by #1942
Closed
3 tasks done

[BUG] Using Dropout in a ValueOperator results in RuntimeError #1935

BeFranke opened this issue Feb 20, 2024 · 2 comments · Fixed by #1942
Assignees
Labels
bug Something isn't working

Comments

@BeFranke
Copy link

Describe the bug

If we introduce dropout layers into the module argument to the ValueOperator, a RuntimeError is raised.

To Reproduce

Steps to reproduce the behavior (minimally adapted PPO example from the docs)

from __future__ import annotations
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (
    Compose,
    ObservationNorm,
    StepCounter,
    TransformedEnv
)
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs.utils import check_env_specs
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from gymnasium.wrappers import FlattenObservation
import gymnasium as gym
from tensordict.nn import (
    TensorDictModule,
)


device = "cpu" if not torch.cuda.is_available() else "cuda:0"


def make_env():
    env = gym.make("MountainCarContinuous-v0")
    env = FlattenObservation(env)
    env = GymWrapper(env)
    env = TransformedEnv(
        env,
        Compose(
            # normalize observations
            ObservationNorm(
                in_keys=["observation"], 
                out_keys=["observation"],
            ),
            StepCounter(),
        ),
    )
    env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
    
    return env


def main():
    train_env = make_env()
    check_env_specs(train_env)

    num_cells = 128
    frames_per_batch = 1_000
    total_frames = 1_000_000

    actor_net = nn.Sequential(
        nn.LazyLinear(num_cells, device=device),
        nn.Tanh(),
        nn.LazyLinear(num_cells, device=device),
        nn.Tanh(),
        nn.LazyLinear(num_cells, device=device),
        nn.Tanh(),
        nn.LazyLinear(2 * train_env.action_spec.shape[-1], device=device),
        NormalParamExtractor(),
    )

    policy_module = TensorDictModule(           # if the actor_net is a custom torch.nn.Module subclass, it can accept safety information!
        actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
    )

    policy_module = ProbabilisticActor(
        module=policy_module,
        spec=train_env.action_spec,
        in_keys=["loc", "scale"],
        distribution_class=TanhNormal,
        distribution_kwargs={
            "min": train_env.action_spec.space.minimum,
            "max": train_env.action_spec.space.maximum,
        },
        return_log_prob=True,
        # we'll need the log-prob for the numerator of the importance weights
    )

    value_net = nn.Sequential(
        nn.LazyLinear(num_cells, device=device),
        nn.Dropout(0.3),
        nn.Tanh(),
        nn.LazyLinear(num_cells, device=device),
        nn.Dropout(0.3),
        nn.Tanh(),
        nn.LazyLinear(num_cells, device=device),
        nn.Dropout(0.3),
        nn.Tanh(),
        nn.LazyLinear(1, device=device),
    )
    value_module = ValueOperator(
        module=value_net,
        in_keys=["observation"],
    )

    print("Running policy:", policy_module(train_env.reset().to(device)))
    print("Running value:", value_module(train_env.reset().to(device)))
    
    collector = SyncDataCollector(
        train_env,
        policy_module,
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
        split_trajs=False,
        device=device,
    )

    replay_buffer = ReplayBuffer(
        storage=LazyTensorStorage(frames_per_batch),
        sampler=SamplerWithoutReplacement(),
    )

    advantage_module = GAE(
        gamma=0.99, lmbda=0.95, value_network=value_module, average_gae=True
    )

    loss_module = ClipPPOLoss(
        actor=policy_module,
        critic=value_module,
        clip_epsilon=0.2,
        entropy_bonus=True,
        entropy_coef=1e-4,
        # these keys match by default but we set this for completeness
        value_target_key=advantage_module.value_target_key,
        critic_coef=1.0,
        gamma=0.99,
        loss_critic_type="smooth_l1",
    )

    optim = torch.optim.Adam(loss_module.parameters(), 1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optim, total_frames // frames_per_batch, 0.0
    )

    # We iterate over the collector until it reaches the total number of frames it was
    # designed to collect:
    for i, tensordict_data in enumerate(collector):
        # we now have a batch of data to work with. Let's learn something from it.
        for _ in range(1000):
            # We'll need an "advantage" signal to make PPO work.
            # We re-compute it at each epoch as its value depends on the value
            # network which is updated in the inner loop.
            with torch.no_grad():
                advantage_module(tensordict_data)
            data_view = tensordict_data.reshape(-1)
            replay_buffer.extend(data_view.cpu())
            for _ in range(frames_per_batch // 128):
                subdata = replay_buffer.sample(128)
                loss_vals = loss_module(subdata.to(device))
                loss_value = (
                    loss_vals["loss_objective"]
                    + loss_vals["loss_critic"]
                    + loss_vals["loss_entropy"]
                )

                # Optimization: backward, grad clipping and optim step
                loss_value.backward()
                # this is not strictly mandatory but it's good practice to keep
                # your gradient norm bounded
                torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 2.0)
                optim.step()
                optim.zero_grad()

        # We're also using a learning rate scheduler. Like the gradient clipping,
        # this is a nice-to-have but nothing necessary for PPO to work.
        scheduler.step()
        
if __name__ == "__main__":
   main()

Expected behavior

Running the example above should start a normal PPO training, just that the value network now has dropout layers. However, the following exception is raised:

Traceback (most recent call last):
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 1178, in forward
    raise err
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 1164, in forward
    tensors = self._call_module(tensors, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 1121, in _call_module
    out = self.module(*tensors, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/functional_modules.py", line 588, in new_fun
    return getattr(type(self), fun_name)(self, *args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/functional_modules.py", line 588, in new_fun
    return getattr(type(self), fun_name)(self, *args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/dropout.py", line 58, in forward
    return F.dropout(input, self.p, self.training, self.inplace)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/functional.py", line 1266, in dropout
    return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "mre.py", line 200, in <module>
    main()
  File "mre.py", line 175, in main
    advantage_module(tensordict_data)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torchrl/objectives/value/advantages.py", line 63, in new_func
    return fun(self, *args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torchrl/objectives/value/advantages.py", line 52, in new_fun
    return fun(self, *args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 282, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torchrl/objectives/value/advantages.py", line 1224, in forward
    value, next_value = _call_value_nets(
  File "<project_dir>/.venv/lib/python3.8/site-packages/torchrl/objectives/value/advantages.py", line 140, in _call_value_nets
    data_out = vmap(value_net, (0,))(data_in)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 266, in vmap_impl
    return _flat_vmap(
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 38, in fn
    return f(*args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 379, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/functional_modules.py", line 588, in new_fun
    return getattr(type(self), fun_name)(self, *args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 282, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/utils.py", line 254, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "<project_dir>/.venv/lib/python3.8/site-packages/tensordict/nn/common.py", line 1204, in forward
    raise RuntimeError(
RuntimeError: TensorDictModule failed with operation
    Sequential(
      (0): Linear(in_features=2, out_features=128, bias=True)
      (1): Dropout(p=0.3, inplace=False)
      (2): Tanh()
      (3): Linear(in_features=128, out_features=128, bias=True)
      (4): Dropout(p=0.3, inplace=False)
      (5): Tanh()
      (6): Linear(in_features=128, out_features=128, bias=True)
      (7): Dropout(p=0.3, inplace=False)
      (8): Tanh()
      (9): Linear(in_features=128, out_features=1, bias=True)
    )
    in_keys=['observation']
    out_keys=['state_value'].

System info

Describe the characteristic of your environment:

  • Describe how the library was installed: pip install torchrl
  • Python version: 3.8.10
  • torch version: 2.1.2+cu121
  • torchrl version: 0.2.1

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@BeFranke BeFranke added the bug Something isn't working label Feb 20, 2024
@vmoens
Copy link
Contributor

vmoens commented Feb 20, 2024

Right I guess we covered this in #1740 but not for value functions.

Happy to make a PR if @BY571 is too busy :)

@BY571
Copy link
Contributor

BY571 commented Feb 21, 2024

Yes, should be very similar to what we did to the losses, I can have a look at it @vmoens

@vmoens vmoens linked a pull request Feb 21, 2024 that will close this issue
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants