Skip to content

Commit

Permalink
Update agent and trainer configuration to avoid duplicated data in di…
Browse files Browse the repository at this point in the history
…stributed runs
  • Loading branch information
Toni-SM committed Jun 20, 2024
1 parent e06dcfd commit 5abed23
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
19 changes: 13 additions & 6 deletions skrl/agents/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torch.utils.tensorboard import SummaryWriter

from skrl import logger
from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model

Expand Down Expand Up @@ -129,26 +129,33 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
This method should be called before the agent is used.
It will initialize the TensoBoard writer (and optionally Weights & Biases) and create the checkpoints directory
It will initialize the TensorBoard writer (and optionally Weights & Biases) and create the checkpoints directory
:param trainer_cfg: Trainer configuration
:type trainer_cfg: dict, optional
"""
trainer_cfg = trainer_cfg if trainer_cfg is not None else {}

# update agent configuration to avoid duplicated logging/checking in distributed runs
if config.torch.is_distributed and config.torch.rank:
self.write_interval = 0
self.checkpoint_interval = 0
# TODO: disable wandb

# setup Weights & Biases
if self.cfg.get("experiment", {}).get("wandb", False):
# save experiment config
# save experiment configuration
try:
models_cfg = {k: v.net._modules for (k, v) in self.models.items()}
except AttributeError:
models_cfg = {k: v._modules for (k, v) in self.models.items()}
config={**self.cfg, **trainer_cfg, **models_cfg}
wandb_config={**self.cfg, **trainer_cfg, **models_cfg}
# set default values
wandb_kwargs = copy.deepcopy(self.cfg.get("experiment", {}).get("wandb_kwargs", {}))
wandb_kwargs.setdefault("name", os.path.split(self.experiment_dir)[-1])
wandb_kwargs.setdefault("sync_tensorboard", True)
wandb_kwargs.setdefault("config", {})
wandb_kwargs["config"].update(config)
wandb_kwargs["config"].update(wandb_config)
# init Weights & Biases
import wandb
wandb.init(**wandb_kwargs)
Expand Down Expand Up @@ -386,7 +393,7 @@ def migrate(self,
name_map: Mapping[str, Mapping[str, str]] = {},
auto_mapping: bool = True,
verbose: bool = False) -> bool:
"""Migrate the specified extrernal checkpoint to the current agent
"""Migrate the specified external checkpoint to the current agent
The final storage device is determined by the constructor of the agent.
Only files generated by the *rl_games* library are supported at the moment
Expand Down
7 changes: 3 additions & 4 deletions skrl/trainers/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,10 @@ def close_env():
self.env.close()
logger.info("Environment closed")

# multi-gpu
# update trainer configuration to avoid duplicated info/data in distributed runs
if config.torch.is_distributed:
logger.info(f"rank: {config.torch.rank}, local rank: {config.torch.local_rank}, world size: {config.torch.world_size}")
torch.distributed.init_process_group("nccl", rank=config.torch.rank, world_size=config.torch.world_size)
torch.cuda.set_device(config.torch.local_rank)
if config.torch.rank:
self.disable_progressbar = True

def __str__(self) -> str:
"""Generate a string representation of the trainer
Expand Down

0 comments on commit 5abed23

Please sign in to comment.