Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add reduction parameter to On-Policy losses. #1890

Merged
merged 31 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
670108a
ppo reduction
albertbou92 Feb 9, 2024
668d212
ppo reduction
albertbou92 Feb 9, 2024
e624083
ppo reduction
albertbou92 Feb 9, 2024
26cd568
ppo reduction
albertbou92 Feb 9, 2024
3b253fc
ppo reduction
albertbou92 Feb 9, 2024
81e6a3a
ppo reduction
albertbou92 Feb 9, 2024
8560e8b
a2c / reinforce reduction
albertbou92 Feb 9, 2024
e747cec
a2c / reinforce reduction
albertbou92 Feb 9, 2024
b7c249c
a2c / reinforce tests
albertbou92 Feb 9, 2024
ea93914
format
albertbou92 Feb 9, 2024
666e7b1
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 10, 2024
b5ef409
fix recursion issue
vmoens Feb 10, 2024
6275f02
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 10, 2024
107e875
init
vmoens Feb 11, 2024
3163d1d
Merge branch 'fix-loss-exploration' into loss_reduction
vmoens Feb 11, 2024
331bd38
Update torchrl/objectives/reinforce.py
albertbou92 Feb 12, 2024
61fc41b
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
6dbb622
Update torchrl/objectives/a2c.py
albertbou92 Feb 12, 2024
2d6674e
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
95efebd
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
e64ee3d
suggestions added
albertbou92 Feb 12, 2024
5368bdc
format
albertbou92 Feb 12, 2024
efaa893
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 12, 2024
ac115a3
Merge branch 'loss_reduction' of https://github.com/PyTorchRL/rl into…
vmoens Feb 12, 2024
c218352
default reduction none
albertbou92 Feb 13, 2024
eebcbb4
Merge branch 'main' into loss_reduction
albertbou92 Feb 15, 2024
7e516f8
remove bs from loss
albertbou92 Feb 15, 2024
2701bb8
fix test
albertbou92 Feb 15, 2024
566b2b9
format
albertbou92 Feb 15, 2024
7e6b4b2
better tests
albertbou92 Feb 15, 2024
8052e33
better tests
albertbou92 Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 100 additions & 18 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import contextlib
import functools
import itertools
import operator
import sys
import warnings
from copy import deepcopy
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -155,6 +155,10 @@
)


# Increase recursion limit, some objectives have more than 1000 tests
sys.setrecursionlimit(2000)
vmoens marked this conversation as resolved.
Show resolved Hide resolved


class _check_td_steady:
def __init__(self, td):
self.td_clone = td.clone()
Expand Down Expand Up @@ -5817,8 +5821,16 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_ppo(
self, loss_class, device, gradient_mode, advantage, td_est, functional
self,
loss_class,
device,
gradient_mode,
advantage,
td_est,
functional,
reduction,
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
Expand Down Expand Up @@ -5849,14 +5861,24 @@ def test_ppo(
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2", functional=functional)
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)
if advantage is not None:
advantage(td)
else:
if td_est is not None:
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -5904,7 +5926,8 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode):
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_ppo_shared(self, loss_class, device, advantage):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_shared(self, loss_class, device, advantage, reduction):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)

Expand Down Expand Up @@ -5941,11 +5964,16 @@ def test_ppo_shared(self, loss_class, device, advantage):
value,
loss_critic_type="l2",
separate_losses=True,
reduction=reduction,
)

if advantage is not None:
advantage(td)
loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -5989,7 +6017,10 @@ def test_ppo_shared(self, loss_class, device, advantage):
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("separate_losses", [True, False])
def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_shared_seq(
self, loss_class, device, advantage, separate_losses, reduction
):
"""Tests PPO with shared module with and without passing twice across the common module."""
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
Expand Down Expand Up @@ -6027,6 +6058,7 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
loss_critic_type="l2",
separate_losses=separate_losses,
entropy_coef=0.0,
reduction=reduction,
)

loss_fn2 = loss_class(
Expand All @@ -6035,16 +6067,25 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
loss_critic_type="l2",
separate_losses=separate_losses,
entropy_coef=0.0,
reduction=reduction,
)

if advantage is not None:
advantage(td)
loss = loss_fn(td).exclude("entropy")
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
grad = TensorDict(dict(model.named_parameters()), []).apply(
lambda x: x.grad.clone()
)
loss2 = loss_fn2(td).exclude("entropy")
if reduction is None:
assert loss2.batch_size == td.batch_size
loss2 = loss2.apply(lambda x: x.float().mean(), batch_size=[])

model.zero_grad()
sum(val for key, val in loss2.items() if key.startswith("loss_")).backward()
grad2 = TensorDict(dict(model.named_parameters()), []).apply(
Expand All @@ -6061,7 +6102,8 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_diff(self, loss_class, device, gradient_mode, advantage, reduction):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)

Expand Down Expand Up @@ -6091,7 +6133,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2")
loss_fn = loss_class(actor, value, loss_critic_type="l2", reduction=reduction)

params = TensorDict.from_module(loss_fn, as_module=True)

Expand All @@ -6107,6 +6149,9 @@ def zero_param(p):
if advantage is not None:
advantage(td)
loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
Expand Down Expand Up @@ -6194,7 +6239,8 @@ def test_ppo_tensordict_keys(self, loss_class, td_est):
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est, reduction):
"""Test PPO loss module with non-default tensordict keys."""
torch.manual_seed(self.seed)
gradient_mode = True
Expand Down Expand Up @@ -6247,7 +6293,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2")
loss_fn = loss_class(actor, value, loss_critic_type="l2", reduction=reduction)
loss_fn.set_keys(**tensor_keys)
if advantage is not None:
# collect tensordict key names for the advantage module
Expand All @@ -6263,6 +6309,9 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
Expand Down Expand Up @@ -6537,7 +6586,8 @@ def _create_seq_mock_data_a2c(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", (True, False))
def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reduction):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)

Expand Down Expand Up @@ -6567,7 +6617,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
else:
raise NotImplementedError

loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional)
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)

# Check error is raised when actions require grads
td["action"].requires_grad = True
Expand All @@ -6584,6 +6640,9 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6626,11 +6685,15 @@ def test_a2c_state_dict(self, device, gradient_mode):
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("separate_losses", [False, True])
def test_a2c_separate_losses(self, separate_losses):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_a2c_separate_losses(self, separate_losses, reduction):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = A2CLoss(
actor_network=actor, critic_network=critic, separate_losses=separate_losses
actor_network=actor,
critic_network=critic,
separate_losses=separate_losses,
reduction=reduction,
)

# Check error is raised when actions require grads
Expand All @@ -6644,6 +6707,9 @@ def test_a2c_separate_losses(self, separate_losses):

td = td.exclude(loss_fn.tensor_keys.value_target)
loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6684,7 +6750,8 @@ def test_a2c_separate_losses(self, separate_losses):
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_a2c_diff(self, device, gradient_mode, advantage):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_a2c_diff(self, device, gradient_mode, advantage, reduction):
if pack_version.parse(torch.__version__) > pack_version.parse("1.14"):
raise pytest.skip("make_functional_with_buffers needs to be changed")
torch.manual_seed(self.seed)
Expand Down Expand Up @@ -6716,13 +6783,16 @@ def test_a2c_diff(self, device, gradient_mode, advantage):
else:
raise NotImplementedError

loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
loss_fn = A2CLoss(actor, value, loss_critic_type="l2", reduction=reduction)

floss_fn, params, buffers = make_functional_with_buffers(loss_fn)

if advantage is not None:
advantage(td)
loss = floss_fn(params, buffers, td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6983,8 +7053,9 @@ class TestReinforce(LossModuleTestBase):
@pytest.mark.parametrize(
"delay_value,functional", [[False, True], [False, False], [True, True]]
)
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_reinforce_value_net(
self, advantage, gradient_mode, delay_value, td_est, functional
self, advantage, gradient_mode, delay_value, td_est, functional, reduction
):
n_obs = 3
n_act = 5
Expand Down Expand Up @@ -7032,6 +7103,7 @@ def test_reinforce_value_net(
critic_network=value_net,
delay_value=delay_value,
functional=functional,
reduction=reduction,
)

td = TensorDict(
Expand Down Expand Up @@ -7063,6 +7135,9 @@ def test_reinforce_value_net(
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
loss_td = loss_fn(td)
if reduction is None:
assert loss_td.batch_size == td.batch_size
loss_td = loss_td.apply(lambda x: x.float().mean(), batch_size=[])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
autograd.grad(
loss_td.get("loss_actor"),
actor_net.parameters(),
Expand Down Expand Up @@ -7210,14 +7285,21 @@ def _create_mock_common_layer_setup(
return actor, critic, common, td

@pytest.mark.parametrize("separate_losses", [False, True])
def test_reinforce_tensordict_separate_losses(self, separate_losses):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = ReinforceLoss(
actor_network=actor, critic_network=critic, separate_losses=separate_losses
actor_network=actor,
critic_network=critic,
separate_losses=separate_losses,
reduction=reduction,
)

loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

assert all(
(p.grad is None) or (p.grad == 0).all()
Expand Down
Loading
Loading