From 8923cdd574186fa0eb16afcea656b432883d58c2 Mon Sep 17 00:00:00 2001 From: Kaustubh Mani Date: Wed, 6 Sep 2023 20:51:23 -0500 Subject: [PATCH] made action scaling a parameter --- cleanrl/dqn.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/cleanrl/dqn.py b/cleanrl/dqn.py index 1e9d7ac4f..47cf77cb1 100644 --- a/cleanrl/dqn.py +++ b/cleanrl/dqn.py @@ -39,8 +39,6 @@ def parse_args(): parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to save model into the `runs/{run_name}` folder") parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, - help="whether to upload the saved model to huggingface") - parser.add_argument("--hf-entity", type=str, default="", help="the user or org name of the model repository from the Hugging Face Hub") # Algorithm specific arguments @@ -52,6 +50,8 @@ def parse_args(): help="how many violations before you terminate") parser.add_argument("--failure-penalty", type=float, default=0.0, help="Reward Penalty when you fail") + parser.add_argument("--action-scale", type=float, default=0.2, + help="Reward Penalty when you fail") parser.add_argument("--collect-data", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="store data while trianing") parser.add_argument("--storage-path", type=str, default="./data/ppo/term_1", @@ -126,13 +126,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): -action_map = {0: np.array([1., 1.]), 1: np.array([0., 1.]), 2: np.array([1., 0.])} -def action_map_fn(action): - return action_map[action] - -def get_random_action(): - return random.choice(range(len(action_map))) - if __name__ == "__main__": @@ -147,6 +140,14 @@ def get_random_action(): ) args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + action_map = {0: np.array([args.action_scale, args.action_scale]), 1: np.array([0., args.action_scale]), 2: np.array([args.action_scale, 0.])} + def action_map_fn(action): + return action_map[action] + + def get_random_action(): + return random.choice(range(len(action_map))) + + if args.track: import wandb