Skip to content

Commit

Permalink
made action scaling a parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Sep 7, 2023
1 parent 6c0f4aa commit 8923cdd
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def parse_args():
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
Expand All @@ -52,6 +50,8 @@ def parse_args():
help="how many violations before you terminate")
parser.add_argument("--failure-penalty", type=float, default=0.0,
help="Reward Penalty when you fail")
parser.add_argument("--action-scale", type=float, default=0.2,
help="Reward Penalty when you fail")
parser.add_argument("--collect-data", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="store data while trianing")
parser.add_argument("--storage-path", type=str, default="./data/ppo/term_1",
Expand Down Expand Up @@ -126,13 +126,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):



action_map = {0: np.array([1., 1.]), 1: np.array([0., 1.]), 2: np.array([1., 0.])}
def action_map_fn(action):
return action_map[action]

def get_random_action():
return random.choice(range(len(action_map)))



if __name__ == "__main__":
Expand All @@ -147,6 +140,14 @@ def get_random_action():
)
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
action_map = {0: np.array([args.action_scale, args.action_scale]), 1: np.array([0., args.action_scale]), 2: np.array([args.action_scale, 0.])}
def action_map_fn(action):
return action_map[action]

def get_random_action():
return random.choice(range(len(action_map)))


if args.track:
import wandb

Expand Down

0 comments on commit 8923cdd

Please sign in to comment.