Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow separate config for critic model #17

Merged
merged 3 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
self,
experiment_config: "DictConfig", # noqa: F821
model_config: ModelConfig,
critic_model_config: ModelConfig,
observation_spec: CompositeSpec,
action_spec: CompositeSpec,
state_spec: Optional[CompositeSpec],
Expand All @@ -34,6 +35,7 @@ def __init__(

self.experiment_config = experiment_config
self.model_config = model_config
self.critic_model_config = critic_model_config
self.on_policy = on_policy
self.group_map = group_map
self.observation_spec = observation_spec
Expand Down Expand Up @@ -225,6 +227,7 @@ def get_algorithm(
self,
experiment_config,
model_config: ModelConfig,
critic_model_config: ModelConfig,
observation_spec: CompositeSpec,
action_spec: CompositeSpec,
state_spec: CompositeSpec,
Expand All @@ -235,6 +238,7 @@ def get_algorithm(
**self.__dict__,
experiment_config=experiment_config,
model_config=model_config,
critic_model_config=critic_model_config,
observation_spec=observation_spec,
action_spec=action_spec,
state_spec=state_spec,
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def get_critic(self, group: str) -> TensorDictModule:
)
}
)
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
)
}
)
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down Expand Up @@ -336,7 +336,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down Expand Up @@ -257,7 +257,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def get_critic(self, group: str) -> TensorDictModule:
)

if self.state_spec is not None:
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=self.state_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand All @@ -281,7 +281,7 @@ def get_critic(self, group: str) -> TensorDictModule:
)
}
)
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
8 changes: 4 additions & 4 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
)

if self.state_spec is not None:
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=self.state_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand All @@ -299,7 +299,7 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
)
}
)
value_module = self.model_config.get_model(
value_module = self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down Expand Up @@ -365,7 +365,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down Expand Up @@ -406,7 +406,7 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
)

modules.append(
self.model_config.get_model(
self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
Expand Down
7 changes: 6 additions & 1 deletion benchmarl/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterator, Sequence, Set
from typing import Iterator, Optional, Sequence, Set

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task
Expand All @@ -14,12 +14,16 @@ def __init__(
tasks: Sequence[Task],
seeds: Set[int],
experiment_config: ExperimentConfig,
critic_model_config: Optional[ModelConfig] = None,
):
self.algorithm_configs = algorithm_configs
self.tasks = tasks
self.seeds = seeds

self.model_config = model_config
self.critic_model_config = (
critic_model_config if critic_model_config is not None else model_config
)
self.experiment_config = experiment_config

print(f"Created benchmark with {self.n_experiments} experiments.")
Expand All @@ -37,6 +41,7 @@ def get_experiments(self) -> Iterator[Experiment]:
algorithm_config=algorithm_config,
seed=seed,
model_config=self.model_config,
critic_model_config=self.critic_model_config,
config=self.experiment_config,
)

Expand Down
1 change: 1 addition & 0 deletions benchmarl/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ defaults:
- algorithm: ???
- task: ???
- model: layers/mlp
- model@critic_model: layers/mlp
- _self_

seed: 0
5 changes: 5 additions & 0 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,15 @@ def __init__(
model_config: ModelConfig,
seed: int,
config: ExperimentConfig,
critic_model_config: Optional[ModelConfig] = None,
):
self.config = config

self.task = task
self.model_config = model_config
self.critic_model_config = (
critic_model_config if critic_model_config is not None else model_config
)
self.algorithm_config = algorithm_config
self.seed = seed

Expand Down Expand Up @@ -233,6 +237,7 @@ def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(
experiment_config=self.config,
model_config=self.model_config,
critic_model_config=self.critic_model_config,
observation_spec=self.observation_spec,
action_spec=self.action_spec,
state_spec=self.state_spec,
Expand Down
2 changes: 2 additions & 0 deletions benchmarl/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment:
experiment_config = load_experiment_config_from_hydra(cfg.experiment)
task_config = load_task_config_from_hydra(cfg.task, task_name)
model_config = load_model_config_from_hydra(cfg.model)
critic_model_config = load_model_config_from_hydra(cfg.critic_model)

return Experiment(
task=task_config,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=cfg.seed,
config=experiment_config,
)
Expand Down
7 changes: 4 additions & 3 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ def test_loading_sequence_models(model_name, intermidiate_size=10):
"model=sequence",
f"model/[email protected]={model_name}",
f"model/[email protected]={model_name}",
f"model.intermediate_sizes={[intermidiate_size]}",
f"+model/[email protected]={model_name}",
f"model.intermediate_sizes={[intermidiate_size,intermidiate_size]}",
],
)
hydra_model_config = load_model_config_from_hydra(cfg.model)
layer_config = model_config_registry[model_name].get_from_yaml()
yaml_config = SequenceModelConfig(
model_configs=[layer_config, layer_config],
intermediate_sizes=[intermidiate_size],
model_configs=[layer_config, layer_config, layer_config],
intermediate_sizes=[intermidiate_size, intermidiate_size],
)
assert hydra_model_config == yaml_config
6 changes: 6 additions & 0 deletions test/test_vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, VmasTask
from benchmarl.experiment import Experiment
from benchmarl.models import MlpConfig
from torch import nn
from utils_experiment import ExperimentUtils

_has_vmas = importlib.util.find_spec("vmas") is not None
Expand Down Expand Up @@ -78,10 +80,14 @@ def test_share_policy_params(
mlp_sequence_config,
):
experiment_config.share_policy_params = share_params
critic_model_config = MlpConfig(
num_cells=[6], activation_class=nn.Tanh, layer_class=nn.Linear
)
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=mlp_sequence_config,
critic_model_config=critic_model_config,
seed=0,
config=experiment_config,
task=task,
Expand Down
Loading