From 9166449a429d55daeb0717b74db0d90c7e06ce0c Mon Sep 17 00:00:00 2001 From: Kaustubh Mani Date: Sun, 21 Jan 2024 00:04:18 -0500 Subject: [PATCH] old thing that was working --- cleanrl/ppo_continuous_action_wandb.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index 24bdbf47a..20a886de9 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -411,7 +411,7 @@ def train_risk(cfg, model, data, criterion, opt, device): model.train() dataset = RiskyDataset(data["next_obs"].to('cpu'), None, data["risks"].to('cpu'), False, risk_type=cfg.risk_type, fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num) - dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device='cpu')) + dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=4, generator=torch.Generator(device='cpu')) net_loss = 0 for batch in dataloader: pred = model(get_risk_obs(cfg, batch[0]).to(device)) @@ -505,10 +505,20 @@ def train(cfg): cfg.use_risk = False if cfg.risk_model_path == "None" else True import wandb - run = wandb.init(config=vars(cfg), entity="kaustubh95", + run = wandb.init(config=vars(cfg), entity="manila95", project="risk_aware_exploration", monitor_gym=True, sync_tensorboard=True, save_code=True) + experiment = Experiment( + api_key="FlhfmY238jUlHpcRzzuIw3j2t", + project_name="risk-aware-exploration", + workspace="hbutsuak95", + ) + + experiment.add_tag(run.sweep_id) + experiment.log_parameters(cfg) + #experiment.log_parameters(cfg.risk) + #experiment.log_parameters(cfg.env) #run_name = "something" run_name = run.name @@ -737,7 +747,7 @@ def train(cfg): # writer.add_scalar("risk/risk_loss", risk_loss, global_step) if cfg.fine_tune_risk == "sync" and cfg.use_risk: - if cfg.use_risk and buffer_num > cfg.risk_batch_size and cfg.fine_tune_risk: + if cfg.use_risk and len(rb) > cfg.start_risk_update and cfg.fine_tune_risk: if cfg.finetune_risk_online: print("I am online") data = rb.slice_data(-cfg.risk_batch_size, 0)