-
Notifications
You must be signed in to change notification settings - Fork 41
/
cli_args.py
74 lines (61 loc) · 2.91 KB
/
cli_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) 2022-2024, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import argparse
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from omni.isaac.orbit_tasks.utils.wrappers.rsl_rl import RslRlOnPolicyRunnerCfg
def add_rsl_rl_args(parser: argparse.ArgumentParser):
"""Add RSL-RL arguments to the parser.
Args:
parser: The parser to add the arguments to.
"""
# create a new argument group
arg_group = parser.add_argument_group("rsl_rl", description="Arguments for RSL-RL agent.")
# -- experiment arguments
arg_group.add_argument(
"--experiment_name", type=str, default=None, help="Name of the experiment folder where logs will be stored."
)
arg_group.add_argument("--run_name", type=str, default=None, help="Run name suffix to the log directory.")
# -- load arguments
arg_group.add_argument("--resume", type=bool, default=None, help="Whether to resume from a checkpoint.")
arg_group.add_argument("--load_run", type=str, default=None, help="Name of the run folder to resume from.")
arg_group.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file to resume from.")
# -- logger arguments
arg_group.add_argument(
"--logger", type=str, default=None, choices={"wandb", "tensorboard", "neptune"}, help="Logger module to use."
)
arg_group.add_argument(
"--log_project_name", type=str, default=None, help="Name of the logging project when using wandb or neptune."
)
def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPolicyRunnerCfg:
"""Parse configuration for RSL-RL agent based on inputs.
Args:
task_name: The name of the environment.
args_cli: The command line arguments.
Returns:
The parsed configuration for RSL-RL agent based on inputs.
"""
from omni.isaac.orbit_tasks.utils.parse_cfg import load_cfg_from_registry
# load the default configuration
rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
# override the default configuration with CLI arguments
if args_cli.seed is not None:
rslrl_cfg.seed = args_cli.seed
if args_cli.resume is not None:
rslrl_cfg.resume = args_cli.resume
if args_cli.load_run is not None:
rslrl_cfg.load_run = args_cli.load_run
if args_cli.checkpoint is not None:
rslrl_cfg.load_checkpoint = args_cli.checkpoint
if args_cli.run_name is not None:
rslrl_cfg.run_name = args_cli.run_name
if args_cli.logger is not None:
rslrl_cfg.logger = args_cli.logger
# set the project name for wandb and neptune
if rslrl_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name:
rslrl_cfg.wandb_project = args_cli.log_project_name
rslrl_cfg.neptune_project = args_cli.log_project_name
return rslrl_cfg