Skip to content

Commit

Permalink
old thing that was working
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaustubh Mani committed Jan 21, 2024
1 parent 3ebb991 commit 9166449
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def train_risk(cfg, model, data, criterion, opt, device):
model.train()
dataset = RiskyDataset(data["next_obs"].to('cpu'), None, data["risks"].to('cpu'), False, risk_type=cfg.risk_type,
fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num)
dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device='cpu'))
dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=4, generator=torch.Generator(device='cpu'))
net_loss = 0
for batch in dataloader:
pred = model(get_risk_obs(cfg, batch[0]).to(device))
Expand Down Expand Up @@ -505,10 +505,20 @@ def train(cfg):
cfg.use_risk = False if cfg.risk_model_path == "None" else True

import wandb
run = wandb.init(config=vars(cfg), entity="kaustubh95",
run = wandb.init(config=vars(cfg), entity="manila95",
project="risk_aware_exploration",
monitor_gym=True,
sync_tensorboard=True, save_code=True)
experiment = Experiment(
api_key="FlhfmY238jUlHpcRzzuIw3j2t",
project_name="risk-aware-exploration",
workspace="hbutsuak95",
)

experiment.add_tag(run.sweep_id)
experiment.log_parameters(cfg)
#experiment.log_parameters(cfg.risk)
#experiment.log_parameters(cfg.env)

#run_name = "something"
run_name = run.name
Expand Down Expand Up @@ -737,7 +747,7 @@ def train(cfg):
# writer.add_scalar("risk/risk_loss", risk_loss, global_step)

if cfg.fine_tune_risk == "sync" and cfg.use_risk:
if cfg.use_risk and buffer_num > cfg.risk_batch_size and cfg.fine_tune_risk:
if cfg.use_risk and len(rb) > cfg.start_risk_update and cfg.fine_tune_risk:
if cfg.finetune_risk_online:
print("I am online")
data = rb.slice_data(-cfg.risk_batch_size, 0)
Expand Down

0 comments on commit 9166449

Please sign in to comment.