Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jul 29, 2024
1 parent b80ccb1 commit 44a1e8a
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 32 deletions.
1 change: 1 addition & 0 deletions benchmarl/conf/model/layers/gru.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
name: gru

hidden_size: 128
compile: True

mlp_num_cells: [256, 256]
mlp_layer_class: torch.nn.Linear
Expand Down
8 changes: 8 additions & 0 deletions benchmarl/conf/task/vmas/repeat_last.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults:
- vmas_repeat_last_config
- _self_

max_steps: 100
horizon: 100
n_agents: 4
k: 1
1 change: 1 addition & 0 deletions benchmarl/environments/vmas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class VmasTask(Task):
SIMPLE_SPREAD = None
SIMPLE_TAG = None
SIMPLE_WORLD_COMM = None
REPEAT_LAST = None

def get_env_fun(
self,
Expand Down
15 changes: 15 additions & 0 deletions benchmarl/environments/vmas/repeat_last.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from dataclasses import dataclass, MISSING


@dataclass
class TaskConfig:
max_steps: int = MISSING
n_agents: int = MISSING
horizon: int = MISSING
k: int = MISSING
107 changes: 82 additions & 25 deletions benchmarl/models/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,74 @@
from torchrl.modules import GRUCell, MLP, MultiAgentMLP

from benchmarl.models.common import Model, ModelConfig
from benchmarl.utils import DEVICE_TYPING


class GRU(torch.nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
device: DEVICE_TYPING,
time_dim: int = -2,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.device = device
self.time_dim = time_dim

self.gru = GRUCell(input_size, hidden_size, device=self.device)

def forward(
self,
input,
is_init,
h,
):
hs = []
for in_t, init_t in zip(
input.unbind(self.time_dim), is_init.unbind(self.time_dim)
):
h = torch.where(init_t, 0, h)
h = self.gru(in_t, h)
hs.append(h)
h_n = h
output = torch.stack(hs, self.time_dim)

return output, h_n


class MultiAgentGRU(torch.nn.Module):
def __init__(self, input_size, hidden_size, n_agents, device):
def __init__(
self,
input_size: int,
hidden_size: int,
n_agents: int,
device: DEVICE_TYPING,
compile: bool,
centralised: bool,
):
super().__init__()
self.input_size = input_size
self.n_agents = n_agents
self.hidden_size = hidden_size
self.device = device
self.compile = compile
self.centralised = centralised

self.gru = GRUCell(input_size, hidden_size, device=self.device)
if self.centralised:
self.input_size = self.input_size * self.n_agents

self.vmap_rnn = self.get_for_loop(self.gru)
# self.vmap_rnn_compiled = torch.compile(
# self.vmap_rnn, mode="reduce-overhead", fullgraph=True
# )
self.gru = GRU(
input_size,
hidden_size,
device=self.device,
)
if self.compile:
self.gru = torch.compile(self.gru, mode="reduce-overhead", fullgraph=True)
if not self.centralised:
self.gru = torch.vmap(self.gru, in_dims=-2, out_dims=-2)

def forward(
self,
Expand Down Expand Up @@ -81,33 +133,29 @@ def forward(
device=self.device,
dtype=torch.float,
)
output = self.vmap_rnn(input, is_init, h_0)
h_n = output[..., -1, :, :]
if self.centralised:
input = input.view(batch, seq, self.n_agents * self.input_size)
h_0 = h_0[..., 0, :]
is_init = is_init.view(batch, seq, self.n_agents)

output, h_n = self.vmap_gru(input, is_init, h_0)

if self.centralised:
output = output.unsqueeze(-2).expand(
batch, seq, self.n_agents, self.hidden_size
)
h_n = h_n.unsqueeze(-2).expand(batch, self.n_agents, self.hidden_size)

if not training:
output = output.squeeze(1)
return output, h_n

# @torch.compile(mode="reduce-overhead", fullgraph=True)

@staticmethod
def get_for_loop(rnn):
def for_loop(input, is_init, h, time_dim=-3):
hs = []
for in_t, init_t in zip(input.unbind(time_dim), is_init.unbind(time_dim)):
h = torch.where(init_t, 0, h)
h = rnn(in_t, h)
hs.append(h)
output = torch.stack(hs, time_dim)
return output

return torch.vmap(for_loop)


class Gru(Model):
def __init__(
self,
hidden_size: int,
compile: bool,
**kwargs,
):

Expand All @@ -124,6 +172,7 @@ def __init__(
)

self.hidden_size = hidden_size
self.compile = compile

self.input_features = sum(
[spec.shape[-1] for spec in self.input_spec.values(True, True)]
Expand All @@ -132,7 +181,12 @@ def __init__(

if self.input_has_agent_dim:
self.gru = MultiAgentGRU(
self.input_features, self.hidden_size, self.n_agents, self.device
self.input_features,
self.hidden_size,
self.n_agents,
self.device,
centralised=self.centralised,
compile=self.compile,
)

mlp_net_kwargs = {
Expand Down Expand Up @@ -204,8 +258,10 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
is_init = tensordict.get("is_init")

# Has multi-agent input dimension
if self.input_has_agent_dim and self.share_params and not self.centralised:
if self.input_has_agent_dim:
output, h_n = self.gru(input, is_init, h_0)
if not self.output_has_agent_dim:
output = output[..., 0, :]
else:
pass

Expand All @@ -232,6 +288,7 @@ class GruConfig(ModelConfig):
"""Dataclass config for a :class:`~benchmarl.models.Gru`."""

hidden_size: int = MISSING
compile: bool = MISSING

mlp_num_cells: Sequence[int] = MISSING
mlp_layer_class: Type[nn.Module] = MISSING
Expand Down
14 changes: 7 additions & 7 deletions fine_tuned/vmas/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
defaults:
- experiment: base_experiment
- algorithm: ???
- task: ???
- algorithm: mappo
- task: vmas/repeat_last
- model: layers/mlp
- model@critic_model: layers/mlp
- model@critic_model: layers/gru
- _self_

hydra:
Expand All @@ -15,9 +15,9 @@ seed: 0

experiment:

sampling_device: "cuda"
train_device: "cuda"
buffer_device: "cuda"
sampling_device: "cpu"
train_device: "cpu"
buffer_device: "cpu"

share_policy_params: True
prefer_continuous_actions: True
Expand Down Expand Up @@ -54,7 +54,7 @@ experiment:
evaluation_interval: 120_000
evaluation_episodes: 200

loggers: [wandb]
loggers: []
create_json: True

save_folder: null
Expand Down

0 comments on commit 44a1e8a

Please sign in to comment.