Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jul 29, 2024
1 parent 6edcdc1 commit b80ccb1
Show file tree
Hide file tree
Showing 7 changed files with 349 additions and 8 deletions.
11 changes: 11 additions & 0 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,17 @@ def get_replay_buffer(
"""
memory_size = self.experiment_config.replay_buffer_memory_size(self.on_policy)
sampling_size = self.experiment_config.train_minibatch_size(self.on_policy)
if (
self.experiment.model_config.is_rnn
or self.experiment.critic_model_config.is_rnn
):
sequence_length = -(
-self.experiment_config.collected_frames_per_batch(self.on_policy)
// self.experiment_config.n_envs_per_worker(self.on_policy)
)
memory_size = -(-memory_size // sequence_length)
sampling_size = -(-sampling_size // sequence_length)

sampler = SamplerWithoutReplacement() if self.on_policy else RandomSampler()
return TensorDictReplayBuffer(
storage=LazyTensorStorage(
Expand Down
11 changes: 11 additions & 0 deletions benchmarl/conf/model/layers/gru.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

name: gru

hidden_size: 128

mlp_num_cells: [256, 256]
mlp_layer_class: torch.nn.Linear
mlp_activation_class: torch.nn.Tanh
mlp_activation_kwargs: null
mlp_norm_class: null
mlp_norm_kwargs: null
9 changes: 6 additions & 3 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,15 +383,17 @@ def _setup_task(self):
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)
),
self.task,
)()
env_func = self.model_config.process_env_fun(
self.task.get_env_fun(
num_envs=self.config.n_envs_per_worker(self.on_policy),
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)
),
self.task,
)

transforms_env = self.task.get_env_transforms(test_env)
Expand Down Expand Up @@ -610,7 +612,8 @@ def _collection_loop(self):
for group in self.train_group_map.keys():
group_batch = batch.exclude(*self._get_excluded_keys(group))
group_batch = self.algorithm.process_batch(group, group_batch)
group_batch = group_batch.reshape(-1)
if not (self.model_config.is_rnn or self.critic_model_config.is_rnn):
group_batch = group_batch.reshape(-1)
self.replay_buffers[group].extend(group_batch)

training_tds = []
Expand Down
4 changes: 4 additions & 0 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig
from .deepsets import Deepsets, DeepsetsConfig
from .gnn import Gnn, GnnConfig
from .gru import Gru, GruConfig
from .mlp import Mlp, MlpConfig

classes = [
Expand All @@ -19,11 +20,14 @@
"CnnConfig",
"Deepsets",
"DeepsetsConfig",
"Gru",
"GruConfig",
]

model_config_registry = {
"mlp": MlpConfig,
"gnn": GnnConfig,
"cnn": CnnConfig,
"deepsets": DeepsetsConfig,
"gru": GruConfig,
}
30 changes: 26 additions & 4 deletions benchmarl/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,17 +298,27 @@ def associated_class():
"""
raise NotImplementedError

def process_env_fun(self, env_fun: Callable[[], EnvBase]) -> Callable[[], EnvBase]:
def process_env_fun(
self,
env_fun: Callable[[], EnvBase],
task,
model_index: int = 0,
) -> Callable[[], EnvBase]:
"""
This function can be used to wrap env_fun
Args:
env_fun (callable): a function that takes no args and creates an enviornment
task (Task): the task
Returns: a function that takes no args and creates an enviornment
"""
return env_fun

@property
def is_rnn(self) -> bool:
return False

@staticmethod
def _load_from_yaml(name: str) -> Dict[str, Any]:
yaml_path = (
Expand Down Expand Up @@ -451,11 +461,23 @@ def get_model(
def associated_class():
return SequenceModel

def process_env_fun(self, env_fun: Callable[[], EnvBase]) -> Callable[[], EnvBase]:
for model_config in self.model_configs:
env_fun = model_config.process_env_fun(env_fun)
def process_env_fun(
self,
env_fun: Callable[[], EnvBase],
task,
model_index: int = 0,
) -> Callable[[], EnvBase]:
for i, model_config in enumerate(self.model_configs):
env_fun = model_config.process_env_fun(env_fun, task, i)
return env_fun

@property
def is_rnn(self) -> bool:
is_rnn = False
for model_config in self.model_configs:
is_rnn += model_config.is_rnn
return is_rnn

@classmethod
def get_from_yaml(cls, path: Optional[str] = None):
raise NotImplementedError
Loading

0 comments on commit b80ccb1

Please sign in to comment.