Skip to content

Commit

Permalink
half baked risk model fine tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaustubh Mani committed Aug 6, 2023
1 parent 859ff36 commit 1b868f4
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,16 @@ def parse_args():
help="Use risk model in the critic or not ")
parser.add_argument("--model-type", type=str, default="mlp",
help="specify the NN to use for the risk model")

parser.add_argument("--fear-radius", type=int, default=5,
help="fear radius for training the risk model")
parser.add_argument("--num-risk-datapoints", type=int, default=10000,
help="fear radius for training the risk model")
parser.add_argument("--update-risk-model", type=int, default=100,
help="number of epochs to update the risk model")
parser.add_argument("--risk-sgd-steps", type=int, default=100,
help="number of epochs to update the risk model")


args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
Expand Down Expand Up @@ -230,12 +239,12 @@ def train(cfg):
# api_key="FlhfmY238jUlHpcRzzuIw3j2t",
# project_name="risk-aware-exploration",
# workspace="hbutsuak95",
#)
import wandb
wandb.init(config=vars(cfg), entity="kaustubh95",
project="risk_aware_exploration",
name=run_name, monitor_gym=True,
sync_tensorboard=True, save_code=True)
#)
# import wandb
# wandb.init(config=vars(cfg), entity="kaustubh95",
# project="risk_aware_exploration",
# name=run_name, monitor_gym=True,
# sync_tensorboard=True, save_code=True)

writer = SummaryWriter(f"runs/{run_name}")
#writer.add_text(
Expand Down Expand Up @@ -309,6 +318,14 @@ def train(cfg):
last_step = 0
episode = 0
step_log = 0

## Finetuning data collection
f_obs = next_obs
f_risks = torch.Tensor([[0.]]).to(device)

print(f_obs.size(), f_risks.size())


if cfg.collect_data:
storage_path = os.path.join(cfg.storage_path, run_name)
make_dirs(storage_path, episode)
Expand All @@ -321,6 +338,7 @@ def train(cfg):
optimizer.param_groups[0]["lr"] = lrnow

for step in range(0, cfg.num_steps):
risk = torch.Tensor([[0.]]).to(device)
global_step += 1 * cfg.num_envs
obs[step] = next_obs
dones[step] = next_done
Expand Down Expand Up @@ -359,7 +377,13 @@ def train(cfg):
store_data(next_obs, info_dict, storage_path, episode, step_log)
step_log+=1
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
f_obs = torch.concat([f_obs, next_obs], axis=0)
f_risks = torch.concat([f_risks, risk], axis=0)
print(f_risks.size(), f_obs.size())

if cost > 0:
f_risks[global_step-cfg.fear_radius:, 0] = 1.

if not done:
cost = torch.Tensor(infos["cost"]).to(device).view(-1)
ep_cost += infos["cost"]; cum_cost += infos["cost"]
Expand Down

0 comments on commit 1b868f4

Please sign in to comment.