-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] SACLoss
module does not allow stochastic modules (i.e. Dropout
, etc.) due to vmap
#2313
Comments
|
Hmm, that's surprising. Were it truly recursive, I'd expect I've got something running atm, so I can't provide proof of this just yet, but 1. was what I observed when stepping. |
I fixed a couple more things, but I can't try your example because i'm (as always) having problems with petting zoo dependencies |
FWIW, I was able to replicate this issue on For a little extra information, I patched @property
def vmap_randomness(self):
modules = []
if self._vmap_randomness is None:
do_break = False
for val in self.__dict__.values():
if isinstance(val, torch.nn.Module):
for module in val.modules():
modules.append(str(type(module)))
if isinstance(module, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
do_break = True
break
if do_break:
# double break
break
else:
self._vmap_randomness = "error"
print(','.join(modules))
return self._vmap_randomness This is the script proper: # BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)
from copy import deepcopy
import tqdm
import numpy as np
import math
import torch
from torch import nn
import torch.distributions as D
from torchrl.envs import check_env_specs, VmasEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import SyncDataCollector
from torchrl.objectives import SACLoss, ValueEstimators
from torchrl.data.replay_buffers import TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
EPS = 1e-7
class SMACCNet(MultiAgentNetBase):
def __init__(self,
n_agent_inputs: int | None,
n_agent_outputs: int,
n_agents: int,
centralised: bool,
share_params: bool,
device = 'cpu',
activation_class = nn.Tanh,
**kwargs):
self.n_agents = n_agents
self.n_agent_inputs = n_agent_inputs
self.n_agent_outputs = n_agent_outputs
self.share_params = share_params
self.centralised = centralised
self.activation_class = activation_class
self.device = device
super().__init__(
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
agent_dim=-2,
device = device,
**kwargs,
)
def _pre_forward_check(self, inputs):
if inputs.shape[-2] != self.n_agents:
raise ValueError(
f"Multi-agent network expected input with shape[-2]={self.n_agents},"
f" but got {inputs.shape}"
)
if self.centralised:
inputs = inputs.flatten(-2, -1)
return inputs
def init_net_params(self, net):
def init_layer_params(layer):
if isinstance(layer, nn.Linear):
weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
if 'bias' in layer.state_dict():
torch.nn.init.zeros_(layer.bias)
net.apply(init_layer_params)
return net
def _build_single_net(self, *, device, **kwargs):
n_agent_inputs = self.n_agent_inputs
if self.centralised and n_agent_inputs is not None:
n_agent_inputs = self.n_agent_inputs * self.n_agents
model = nn.Sequential(
nn.Linear(n_agent_inputs, 400),
self.activation_class(),
nn.Linear(400, 300),
self.activation_class(),
nn.Dropout(0.5), # <- The dropout is here!
nn.Linear(300, self.n_agent_outputs)
).to(self.device)
model = self.init_net_params(model)
return model
class CustomTanhTransform(D.transforms.TanhTransform):
def _inverse(self, y):
# Yoinked from SB3!!!
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
y = y.clamp(-1. + EPS, 1. - EPS)
return 0.5 * (y.log1p() - (-y).log1p())
def log_abs_det_jacobian(self, x, y):
# Yoinked from PyTorch TanhTransform!
'''
tl;dr log(1-tanh^2(x)) = log(sech^2(x))
= 2log(2/(e^x + e^(-x)))
= 2(log2 - log(e^x/(1 + e^(-2x)))
= 2(log2 - x - log(1 + e^(-2x)))
= 2(log2 - x - softplus(-2x))
'''
return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))
class TanhNormalStable(D.TransformedDistribution):
'''Numerically stable variant of TanhNormal. Employs clipping trick.'''
def __init__(self, loc, scale, event_dims = 1):
self._event_dims = event_dims
self._t = [
CustomTanhTransform()
]
self.update(loc, scale)
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - D.utils._sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
log_prob = log_prob + D.utils._sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
log_prob = torch.clamp(log_prob, min = math.log10(EPS))
return log_prob
def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
self.loc = loc
self.scale = scale
if (
hasattr(self, "base_dist")
and (self.base_dist.base_dist.loc.shape == self.loc.shape)
and (self.base_dist.base_dist.scale.shape == self.scale.shape)
):
self.base_dist.base_dist.loc = self.loc
self.base_dist.base_dist.scale = self.scale
else:
base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
super().__init__(base, self._t)
@property
def mode(self):
m = self.base_dist.base_dist.mean
for t in self.transforms:
m = t(m)
return m
# Main Function
if __name__ == "__main__":
NUM_AGENTS = 3
NUM_CRITICS = 2
NUM_EXPLORE_WORKERS = 1
EXPLORATION_STEPS = 256
MAX_EPISODE_STEPS = 1000
DEVICE = "cuda:0"
REPLAY_BUFFER_SIZE = int(1e6)
VALUE_GAMMA = 0.99
MAX_GRAD_NORM = 1.0
BATCH_SIZE = 256
LR = 3e-4
UPDATE_STEPS_PER_EXPLORATION = 1
WARMUP_STEPS = 0
TRAIN_TIMESTEPS = int(1e7)
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
def env_fn():
def base_env_fn():
return VmasEnv(
scenario="navigation",
num_envs=NUM_EXPLORE_WORKERS,
continuous_actions=True,
max_steps=200,
device="cpu",
seed=None,
# Scenario kwargs
n_agents=NUM_AGENTS,
)
env = base_env_fn # noqa: E731
def env_with_transforms():
init_env = env()
init_env = TransformedEnv(init_env, Compose(
StepCounter(max_steps = MAX_EPISODE_STEPS),
RewardSum(
in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)],
out_keys = [("agents", "episode_reward")] * NUM_AGENTS,
reset_keys = ["_reset"] * NUM_AGENTS
),
)
)
return init_env
return env_with_transforms
train_env = env_fn()()
if train_env.is_closed:
train_env.start()
check_env_specs(train_env)
# print(train_env.full_observation_spec)
# print(train_env.full_action_spec)
print(train_env.done_spec)
# breakpoint()
obs_dim = train_env.full_observation_spec["agents", "observation"].shape[-1]
action_dim = train_env.full_action_spec["agents", "action"].shape[-1]
policy_net = nn.Sequential(
SMACCNet(n_agent_inputs = obs_dim,
n_agent_outputs = 2 * action_dim,
n_agents = NUM_AGENTS,
centralised = False,
share_params = True,
device = DEVICE,
activation_class = nn.LeakyReLU,
),
NormalParamExtractor(),
).to(DEVICE)
critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
n_agent_outputs = 1,
n_agents = NUM_AGENTS,
centralised = True,
share_params = True,
device = DEVICE,
activation_class = nn.LeakyReLU,
).to(DEVICE)
policy_net_td_module = TensorDictModule(module = policy_net,
in_keys = [("agents", "observation")],
out_keys = [("agents", "loc"), ("agents", "scale")]
)
obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
in_keys = [("agents", "observation"), ("agents", "action")],
out_keys = [("agents", "obs_act")]
)
critic_net_td_module = TensorDictModule(module = critic_net,
in_keys = [("agents", "obs_act")],
out_keys = [("agents", "state_action_value")]
)
# Attach our raw policy network to a probabilistic actor
policy_actor = ProbabilisticActor(
module = policy_net_td_module,
spec = train_env.full_action_spec["agents", "action"],
in_keys = [("agents", "loc"), ("agents", "scale")],
out_keys = [("agents", "action")],
distribution_class = TanhNormalStable,
return_log_prob = True,
)
# with torch.no_grad():
# fake_td = train_env.fake_tensordict()
# policy_actor(fake_td)
critic_actor = TensorDictSequential(
obs_act_module, critic_net_td_module
)
# with torch.no_grad():
# reset_obs = train_env.reset()
# reset_obs_clean = deepcopy(reset_obs)
# action = policy_actor(reset_obs)
# state_action_value = critic_actor(action)
# reset_obs = train_env.reset()
# reset_obs["agents", "action"] = torch.zeros((*reset_obs["agents", "observation"].shape[:-1], action_dim))
# train_env.rand_action(reset_obs)
# action = train_env.step(reset_obs)
collector = SyncDataCollector(
ParallelEnv(NUM_EXPLORE_WORKERS,
[
env_fn()
for _ in range(NUM_EXPLORE_WORKERS)
],
device = None,
mp_start_method = "spawn"
),
policy = policy_actor,
frames_per_batch = BATCH_SIZE,
max_frames_per_traj = -1,
total_frames = TRAIN_TIMESTEPS,
device = 'cpu',
policy_device = 'cpu',
reset_at_each_iter = False
)
# Dummy loss module
replay_buffer = TensorDictPrioritizedReplayBuffer(
alpha = 0.7,
beta = 0.9,
storage = LazyMemmapStorage(
1e5,
device = 'cpu',
scratch_dir = "googoogaagaa/"
),
priority_key = "td_error",
batch_size = BATCH_SIZE,
)
sac_loss = SACLoss(actor_network = policy_actor,
qvalue_network = critic_actor,
num_qvalue_nets = 2,
loss_function = "l2",
delay_actor = False,
delay_qvalue = True,
alpha_init = 0.1,
)
sac_loss.set_keys(
action = ("agents", "action"),
state_action_value = ("agents", "state_action_value"),
reward = ("agents", "reward"),
done = ("agents", "done"),
terminated = ("agents", "terminated"),
)
sac_loss.make_value_estimator(
value_type = ValueEstimators.TD0,
gamma = 0.99,
)
# Compiling replay_buffer.sample works :D
@torch.compile(mode = "reduce-overhead")
def rb_sample():
td_sample = replay_buffer.sample()
if td_sample.device != torch.device(DEVICE):
td_sample = td_sample.to(
DEVICE,
non_blocking = False
)
else:
td_sample = td_sample.clone()
return td_sample
def test_compile():
td_sample = rb_sample()
return sac_loss(td_sample)
samples = 0
for i, tensordict in (pbar := tqdm.tqdm(enumerate(collector), total = TRAIN_TIMESTEPS)):
tensordict.set(
("next", "agents", "done"),
tensordict.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict.get_item_shape(("next", "agents", "reward"))),
)
tensordict.set(
("next", "agents", "terminated"),
tensordict.get(("next", "terminated"))
.unsqueeze(-1)
.expand(tensordict.get_item_shape(("next", "agents", "reward"))),
)
tensordict = tensordict.reshape(-1)
samples += tensordict.numel()
replay_buffer.extend(tensordict.to('cpu', non_blocking = True))
pbar.write("Hey Hey!!! :D")
a = test_compile()
print(a)
collector.shutdown()
train_env.close() Running the script now yields <class 'torchrl.modules.tensordict_module.actors.ProbabilisticActor'>,<class 'torch.nn.modules.container.ModuleList'>,<class 'tensordict.nn.common.TensorDictModule'>,<class 'torch.nn.modules.container.Sequential'>,<class '__main__.SMACCNet'>,<class 'tensordict.nn.params.TensorDictParams'>,<class 'tensordict.nn.distributions.continuous.NormalParamExtractor'>,<class 'torchrl.modules.tensordict_module.probabilistic.SafeProbabilisticModule'>,<class 'tensordict.nn.sequence.TensorDictSequential'>,<class 'torch.nn.modules.container.ModuleList'>,<class 'tensordict.nn.common.TensorDictModule'>,<class 'tensordict.nn.common.TensorDictModule'>,<class '__main__.SMACCNet'>,<class 'tensordict.nn.params.TensorDictParams'>
<...omitted for brevity...>
...
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/dropout.py", line 59, in forward
return F.dropout(input, self.p, self.training, self.inplace)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/functional.py", line 1295, in dropout
return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap Seems that the list of modules being checked does not go deep enough because |
Got it, here your problem is that the dropout is hidden by the MARL model which does not register the inner module in a usual way. Should be somewhat easy to fix |
Describe the bug
SACLoss
has flawed checks for determining the nature ofvmap_randomness
. Therefore, stochastic modules cannot be used in constituent networks.To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
Expected behavior
SACLoss
module performs forward passes successfully.System info
Describe the characteristic of your environment:
Reason and Possible fixes
There are essentially two reasons for this error:
RANDOM_MODULE_LIST
LossModule.set_vmap_randomness
asself.vmap_randomness
is accessed immediately during initialization timeChecklist
The text was updated successfully, but these errors were encountered: