Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents f872d5c + aaec905 commit 9f08541
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def update(batch, num_network_updates):

# Update the networks
optim.step()
return loss.detach().set("alpha", alpha), num_network_updates.clone()
return loss.detach().set("alpha", alpha), num_network_updates
if cfg.compile.compile:
update = compile_with_warmup(update, mode=compile_mode, warmup=1)
adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1)
Expand Down Expand Up @@ -246,6 +246,8 @@ def update(batch, num_network_updates):
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
loss = loss.clone()
num_network_updates = num_network_updates.clone()
losses[j, k] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
)
Expand Down
4 changes: 3 additions & 1 deletion sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def update(batch, num_network_updates):

# Update the networks
optim.step()
return loss.detach().set("alpha", alpha), num_network_updates.clone()
return loss.detach().set("alpha", alpha), num_network_updates
if cfg.compile.compile:
update = compile_with_warmup(update, mode=compile_mode, warmup=1)
adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1)
Expand Down Expand Up @@ -235,6 +235,8 @@ def update(batch, num_network_updates):
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
loss = loss.clone()
num_network_updates = num_network_updates.clone()
losses[j, k] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
)
Expand Down

0 comments on commit 9f08541

Please sign in to comment.