Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add reduction parameter to On-Policy losses. #1890

Merged
merged 31 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
670108a
ppo reduction
albertbou92 Feb 9, 2024
668d212
ppo reduction
albertbou92 Feb 9, 2024
e624083
ppo reduction
albertbou92 Feb 9, 2024
26cd568
ppo reduction
albertbou92 Feb 9, 2024
3b253fc
ppo reduction
albertbou92 Feb 9, 2024
81e6a3a
ppo reduction
albertbou92 Feb 9, 2024
8560e8b
a2c / reinforce reduction
albertbou92 Feb 9, 2024
e747cec
a2c / reinforce reduction
albertbou92 Feb 9, 2024
b7c249c
a2c / reinforce tests
albertbou92 Feb 9, 2024
ea93914
format
albertbou92 Feb 9, 2024
666e7b1
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 10, 2024
b5ef409
fix recursion issue
vmoens Feb 10, 2024
6275f02
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 10, 2024
107e875
init
vmoens Feb 11, 2024
3163d1d
Merge branch 'fix-loss-exploration' into loss_reduction
vmoens Feb 11, 2024
331bd38
Update torchrl/objectives/reinforce.py
albertbou92 Feb 12, 2024
61fc41b
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
6dbb622
Update torchrl/objectives/a2c.py
albertbou92 Feb 12, 2024
2d6674e
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
95efebd
Update torchrl/objectives/ppo.py
albertbou92 Feb 12, 2024
e64ee3d
suggestions added
albertbou92 Feb 12, 2024
5368bdc
format
albertbou92 Feb 12, 2024
efaa893
Merge remote-tracking branch 'origin/main' into loss_reduction
vmoens Feb 12, 2024
ac115a3
Merge branch 'loss_reduction' of https://github.com/PyTorchRL/rl into…
vmoens Feb 12, 2024
c218352
default reduction none
albertbou92 Feb 13, 2024
eebcbb4
Merge branch 'main' into loss_reduction
albertbou92 Feb 15, 2024
7e516f8
remove bs from loss
albertbou92 Feb 15, 2024
2701bb8
fix test
albertbou92 Feb 15, 2024
566b2b9
format
albertbou92 Feb 15, 2024
7e6b4b2
better tests
albertbou92 Feb 15, 2024
8052e33
better tests
albertbou92 Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 95 additions & 18 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import contextlib
import functools
Expand Down Expand Up @@ -5817,8 +5816,16 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_ppo(
self, loss_class, device, gradient_mode, advantage, td_est, functional
self,
loss_class,
device,
gradient_mode,
advantage,
td_est,
functional,
reduction,
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
Expand Down Expand Up @@ -5849,14 +5856,24 @@ def test_ppo(
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2", functional=functional)
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)
if advantage is not None:
advantage(td)
else:
if td_est is not None:
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -5904,7 +5921,8 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode):
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_ppo_shared(self, loss_class, device, advantage):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_shared(self, loss_class, device, advantage, reduction):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)

Expand Down Expand Up @@ -5941,11 +5959,16 @@ def test_ppo_shared(self, loss_class, device, advantage):
value,
loss_critic_type="l2",
separate_losses=True,
reduction=reduction,
)

if advantage is not None:
advantage(td)
loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -5989,7 +6012,10 @@ def test_ppo_shared(self, loss_class, device, advantage):
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("separate_losses", [True, False])
def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_shared_seq(
self, loss_class, device, advantage, separate_losses, reduction
):
"""Tests PPO with shared module with and without passing twice across the common module."""
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
Expand Down Expand Up @@ -6027,6 +6053,7 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
loss_critic_type="l2",
separate_losses=separate_losses,
entropy_coef=0.0,
reduction=reduction,
)

loss_fn2 = loss_class(
Expand All @@ -6035,16 +6062,25 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
loss_critic_type="l2",
separate_losses=separate_losses,
entropy_coef=0.0,
reduction=reduction,
)

if advantage is not None:
advantage(td)
loss = loss_fn(td).exclude("entropy")
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
grad = TensorDict(dict(model.named_parameters()), []).apply(
lambda x: x.grad.clone()
)
loss2 = loss_fn2(td).exclude("entropy")
if reduction is None:
assert loss2.batch_size == td.batch_size
loss2 = loss2.apply(lambda x: x.float().mean(), batch_size=[])

model.zero_grad()
sum(val for key, val in loss2.items() if key.startswith("loss_")).backward()
grad2 = TensorDict(dict(model.named_parameters()), []).apply(
Expand All @@ -6061,7 +6097,8 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_diff(self, loss_class, device, gradient_mode, advantage, reduction):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)

Expand Down Expand Up @@ -6091,7 +6128,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2")
loss_fn = loss_class(actor, value, loss_critic_type="l2", reduction=reduction)

params = TensorDict.from_module(loss_fn, as_module=True)

Expand All @@ -6107,6 +6144,9 @@ def zero_param(p):
if advantage is not None:
advantage(td)
loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
Expand Down Expand Up @@ -6194,7 +6234,8 @@ def test_ppo_tensordict_keys(self, loss_class, td_est):
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est, reduction):
"""Test PPO loss module with non-default tensordict keys."""
torch.manual_seed(self.seed)
gradient_mode = True
Expand Down Expand Up @@ -6247,7 +6288,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2")
loss_fn = loss_class(actor, value, loss_critic_type="l2", reduction=reduction)
loss_fn.set_keys(**tensor_keys)
if advantage is not None:
# collect tensordict key names for the advantage module
Expand All @@ -6263,6 +6304,9 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
Expand Down Expand Up @@ -6537,7 +6581,8 @@ def _create_seq_mock_data_a2c(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", (True, False))
def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reduction):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)

Expand Down Expand Up @@ -6567,7 +6612,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
else:
raise NotImplementedError

loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional)
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)

# Check error is raised when actions require grads
td["action"].requires_grad = True
Expand All @@ -6584,6 +6635,9 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6626,11 +6680,15 @@ def test_a2c_state_dict(self, device, gradient_mode):
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("separate_losses", [False, True])
def test_a2c_separate_losses(self, separate_losses):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_a2c_separate_losses(self, separate_losses, reduction):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = A2CLoss(
actor_network=actor, critic_network=critic, separate_losses=separate_losses
actor_network=actor,
critic_network=critic,
separate_losses=separate_losses,
reduction=reduction,
)

# Check error is raised when actions require grads
Expand All @@ -6644,6 +6702,9 @@ def test_a2c_separate_losses(self, separate_losses):

td = td.exclude(loss_fn.tensor_keys.value_target)
loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6684,7 +6745,8 @@ def test_a2c_separate_losses(self, separate_losses):
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_a2c_diff(self, device, gradient_mode, advantage):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def test_a2c_diff(self, device, gradient_mode, advantage, reduction):
if pack_version.parse(torch.__version__) > pack_version.parse("1.14"):
raise pytest.skip("make_functional_with_buffers needs to be changed")
torch.manual_seed(self.seed)
Expand Down Expand Up @@ -6716,13 +6778,16 @@ def test_a2c_diff(self, device, gradient_mode, advantage):
else:
raise NotImplementedError

loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
loss_fn = A2CLoss(actor, value, loss_critic_type="l2", reduction=reduction)

floss_fn, params, buffers = make_functional_with_buffers(loss_fn)

if advantage is not None:
advantage(td)
loss = floss_fn(params, buffers, td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -6983,8 +7048,9 @@ class TestReinforce(LossModuleTestBase):
@pytest.mark.parametrize(
"delay_value,functional", [[False, True], [False, False], [True, True]]
)
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_reinforce_value_net(
self, advantage, gradient_mode, delay_value, td_est, functional
self, advantage, gradient_mode, delay_value, td_est, functional, reduction
):
n_obs = 3
n_act = 5
Expand Down Expand Up @@ -7032,6 +7098,7 @@ def test_reinforce_value_net(
critic_network=value_net,
delay_value=delay_value,
functional=functional,
reduction=reduction,
)

td = TensorDict(
Expand Down Expand Up @@ -7063,6 +7130,9 @@ def test_reinforce_value_net(
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
loss_td = loss_fn(td)
if reduction is None:
assert loss_td.batch_size == td.batch_size
loss_td = loss_td.apply(lambda x: x.float().mean(), batch_size=[])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
autograd.grad(
loss_td.get("loss_actor"),
actor_net.parameters(),
Expand Down Expand Up @@ -7210,14 +7280,21 @@ def _create_mock_common_layer_setup(
return actor, critic, common, td

@pytest.mark.parametrize("separate_losses", [False, True])
def test_reinforce_tensordict_separate_losses(self, separate_losses):
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
loss_fn = ReinforceLoss(
actor_network=actor, critic_network=critic, separate_losses=separate_losses
actor_network=actor,
critic_network=critic,
separate_losses=separate_losses,
reduction=reduction,
)

loss = loss_fn(td)
if reduction is None:
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

assert all(
(p.grad is None) or (p.grad == 0).all()
Expand Down
22 changes: 16 additions & 6 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
VTrace,
)

from .utils import _reduce
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved


class A2CLoss(LossModule):
"""TorchRL implementation of the A2C loss.
Expand Down Expand Up @@ -68,6 +70,10 @@ class A2CLoss(LossModule):
Functionalizing permits features like meta-RL, but makes it
impossible to use distributed models (DDP, FSDP, ...) and comes
with a little cost. Defaults to ``True``.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

.. note:
The advantage (typically GAE) can be computed by the loss function or
Expand Down Expand Up @@ -234,6 +240,7 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
reduction: str = "mean",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use None as a default, it can come in handy if some day we change the default
smth like if reduction is None: reduction="mean" later in the constructor.

Copy link
Contributor Author

@albertbou92 albertbou92 Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I follow. The way I thought about it is that mean as default ensures BC. Also Torch losses have mean as default, which I think is consistent. So it seemed to me a reasonable choice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will still be "mean". But not having "mean" is good practice because it allows us to detect whenever the user is passing or not the reduction.
Might come in handy at some point: for instance, if we want to change the default for one loss someday, we will just have to check whether the reduction is None or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and how do you define a no-reduction option then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"none" (not None)

):
if actor is not None:
actor_network = actor
Expand Down Expand Up @@ -276,6 +283,7 @@ def __init__(

self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus and entropy_coef
self.reduction = reduction

try:
device = next(self.parameters()).device
Expand Down Expand Up @@ -388,7 +396,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
entropy = -dist.log_prob(x)
entropy = -dist.log_prob(x).mean(0)
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
return entropy.unsqueeze(-1)

def _log_probs(
Expand Down Expand Up @@ -457,14 +465,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
assert not advantage.requires_grad
log_probs, dist = self._log_probs(tensordict)
loss = -(log_probs * advantage)
td_out = TensorDict({"loss_objective": loss.mean()}, [])
td_out = TensorDict({"loss_objective": loss}, batch_size=tensordict.batch_size)
if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.mean().detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())
td_out.set("entropy", entropy.detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef:
loss_critic = self.loss_critic(tensordict).mean()
td_out.set("loss_critic", loss_critic.mean())
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
if self.reduction is not None:
td_out = td_out.apply(lambda x: _reduce(x, self.reduction), batch_size=[])
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down
Loading
Loading