Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add reduction parameter to On-Policy losses. #1890

Merged
merged 31 commits into from
Feb 15, 2024
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
670108a
ppo reduction
albertbou92 Feb 9, 2024
668d212
ppo reduction
albertbou92 Feb 9, 2024
e624083
ppo reduction
albertbou92 Feb 9, 2024
26cd568
ppo reduction
albertbou92 Feb 9, 2024
3b253fc
ppo reduction
albertbou92 Feb 9, 2024
81e6a3a
ppo reduction
albertbou92 Feb 9, 2024
8560e8b
a2c / reinforce reduction
albertbou92 Feb 9, 2024
e747cec
a2c / reinforce reduction
albertbou92 Feb 9, 2024
b7c249c
a2c / reinforce tests
albertbou92 Feb 9, 2024
ea93914
format
albertbou92 Feb 9, 2024
666e7b1
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 10, 2024
b5ef409
fix recursion issue
vmoens Feb 10, 2024
6275f02
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 10, 2024
107e875
init
vmoens Feb 11, 2024
3163d1d
Merge branch 'fix-loss-exploration' into loss_reduction
vmoens Feb 11, 2024
331bd38
Update torchrl/objectives/reinforce.py
albertbou92 Feb 12, 2024
61fc41b
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
6dbb622
Update torchrl/objectives/a2c.py
albertbou92 Feb 12, 2024
2d6674e
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
95efebd
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
e64ee3d
suggestions added
albertbou92 Feb 12, 2024
5368bdc
format
albertbou92 Feb 12, 2024
efaa893
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 12, 2024
ac115a3
Merge branch 'loss_reduction' of https://github.com/PyTorchRL/rl into…
vmoens Feb 12, 2024
c218352
default reduction none
albertbou92 Feb 13, 2024
eebcbb4
Merge branch 'main' into loss_reduction
albertbou92 Feb 15, 2024
7e516f8
remove bs from loss
albertbou92 Feb 15, 2024
2701bb8
fix test
albertbou92 Feb 15, 2024
566b2b9
format
albertbou92 Feb 15, 2024
7e6b4b2
better tests
albertbou92 Feb 15, 2024
8052e33
better tests
albertbou92 Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
format
  • Loading branch information
albertbou92 committed Feb 15, 2024
commit 566b2b9b0741a9e2caf4bf6df3fe48bea29f5ab6
4 changes: 1 addition & 3 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
@@ -547,9 +547,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

log_weight, dist = self._log_weight(tensordict)
neg_loss = log_weight.exp() * advantage
td_out = TensorDict(
{"loss_objective": -neg_loss}, batch_size=[]
)
td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[])
if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.detach()) # for logging
4 changes: 1 addition & 3 deletions torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
@@ -399,9 +399,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if log_prob.shape == advantage.shape[:-1]:
log_prob = log_prob.unsqueeze(-1)
loss_actor = -log_prob * advantage.detach()
td_out = TensorDict(
{"loss_actor": loss_actor}, batch_size=[]
)
td_out = TensorDict({"loss_actor": loss_actor}, batch_size=[])

td_out.set("loss_value", self.loss_critic(tensordict))
td_out = td_out.apply(