Skip to content

Commit

Permalink
working data collection with DQN
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Sep 23, 2023
1 parent abc11a4 commit d519292
Showing 1 changed file with 61 additions and 12 deletions.
73 changes: 61 additions & 12 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.buffers import ReplayBuffer
import stable_baselines3.common.buffers as buffers # import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

import gymnasium.spaces as spaces

from src.utils import *
def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -82,6 +82,51 @@ def parse_args():
help="timestep to start learning")
parser.add_argument("--train-frequency", type=int, default=10,
help="the frequency of training")

## Arguments related to risk model
parser.add_argument("--use-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model or not ")
parser.add_argument("--risk-actor", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Use risk model in the actor or not ")
parser.add_argument("--risk-critic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model in the critic or not ")
parser.add_argument("--risk-model-path", type=str, default="None",
help="the id of the environment")
parser.add_argument("--binary-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model in the critic or not ")
parser.add_argument("--model-type", type=str, default="mlp",
help="specify the NN to use for the risk model")
parser.add_argument("--risk-bnorm", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--risk-type", type=str, default="binary",
help="whether the risk is binary or continuous")
parser.add_argument("--fear-radius", type=int, default=5,
help="fear radius for training the risk model")
parser.add_argument("--num-risk-datapoints", type=int, default=1000,
help="fear radius for training the risk model")
parser.add_argument("--risk-update-period", type=int, default=1000,
help="how frequently to update the risk model")
parser.add_argument("--num-update-risk", type=int, default=10,
help="number of sgd steps to update the risk model")
parser.add_argument("--risk-lr", type=float, default=1e-7,
help="the learning rate of the optimizer")
parser.add_argument("--risk-batch-size", type=int, default=1000,
help="number of epochs to update the risk model")
parser.add_argument("--fine-tune-risk", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--finetune-risk-online", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--start-risk-update", type=int, default=10000,
help="number of epochs to update the risk model")
parser.add_argument("--rb-type", type=str, default="balanced",
help="which type of replay buffer to use for ")
parser.add_argument("--freeze-risk-layers", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--weight", type=float, default=1.0,
help="weight for the 1 class in BCE loss")
parser.add_argument("--quantile-size", type=int, default=4, help="size of the risk quantile ")
parser.add_argument("--quantile-num", type=int, default=5, help="number of quantiles to make")

args = parser.parse_args()
# fmt: on
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
Expand Down Expand Up @@ -151,7 +196,7 @@ def get_random_action():
if args.track:
import wandb

wandb.init(
run = wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
Expand Down Expand Up @@ -186,7 +231,7 @@ def get_random_action():
target_network = QNetwork(envs, action_size=len(action_map)).to(device)
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
rb = buffers.ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
spaces.MultiDiscrete(np.array([len(action_map)])),
Expand All @@ -206,15 +251,16 @@ def get_random_action():
f_dones = None
f_costs = None

step_log = 0
scores = []
if args.collect_data:
#os.system("rm -rf %s"%args.storage_path)
storage_path = os.path.join(args.storage_path, args.env_id, run.name)
make_dirs(storage_path, episode)
make_dirs(storage_path, 0) #episode)
# TRY NOT TO MODIFY: start the game
obs, _ = envs.reset(seed=args.seed)
cum_cost, ep_cost, ep_risk_cost_int, cum_risk_cost_int, ep_risk, cum_risk = 0, 0, 0, 0, 0, 0

last_step = 0
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -229,7 +275,7 @@ def get_random_action():
# info_dict = {'reward': rewards, 'done': done, 'cost': cost, 'obs': obs}
# if args.collect_data:
# store_data(next_obs, info_dict, storage_path, episode, step_log)

done = np.logical_or(terminated, truncated)

step_log += 1
if not done:
Expand All @@ -244,8 +290,8 @@ def get_random_action():
if args.fine_tune_risk or args.collect_data:
f_obs = torch.Tensor(obs) if f_obs is None else torch.concat([f_obs, torch.Tensor(obs)], axis=0)
f_next_obs = torch.Tensor(next_obs) if f_next_obs is None else torch.concat([f_next_obs, torch.Tensor(next_obs)], axis=0)
f_actions = torch.Tensor(action) if f_actions is None else torch.concat([f_actions, torch.Tensor(action)], axis=0)
f_rewards = torch.Tensor(reward) if f_rewards is None else torch.concat([f_rewards, torch.Tensor(reward)], axis=0)
f_actions = torch.Tensor(actions) if f_actions is None else torch.concat([f_actions, torch.Tensor(actions)], axis=0)
f_rewards = torch.Tensor(rewards) if f_rewards is None else torch.concat([f_rewards, torch.Tensor(rewards)], axis=0)
# f_risks = risk_ if f_risks is None else torch.concat([f_risks, risk_], axis=0)
f_costs = torch.Tensor(cost) if f_costs is None else torch.concat([f_costs, torch.Tensor(cost)], axis=0)
# f_dones = torch.Tensor(next_done) if f_dones is None else torch.concat([f_dones, torch.Tensor(next_done)], axis=0)
Expand All @@ -258,11 +304,13 @@ def get_random_action():
if "episode" not in info:
continue

ep_cost = torch.sum(all_costs[last_step:global_step]).item()
cum_cost += ep_cost

ep_cost = torch.sum(f_costs[last_step:global_step]).item()
cum_cost += torch.sum(f_costs).item()
last_step = global_step
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_cost", ep_cost, global_step)
writer.add_scalar("charts/cummulative_cost", cum_cost, global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)

Expand Down Expand Up @@ -348,3 +396,4 @@ def get_random_action():

envs.close()
writer.close()

0 comments on commit d519292

Please sign in to comment.