Skip to content

Commit

Permalink
[BugFix] Support for tensor collection in the PPOLoss (#2543)
Browse files Browse the repository at this point in the history
Co-authored-by: Pau Riba <[email protected]>
  • Loading branch information
priba and Pau Riba authored Nov 6, 2024
1 parent 997d90e commit 0eabb78
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
93 changes: 92 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7650,6 +7650,7 @@ def _create_mock_actor(
observation_key="observation",
sample_log_prob_key="sample_log_prob",
composite_action_dist=False,
aggregate_probabilities=True,
):
# Actor
action_spec = Bounded(
Expand All @@ -7668,7 +7669,7 @@ def _create_mock_actor(
"action1": (action_key, "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
aggregate_probabilities=aggregate_probabilities,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -8038,6 +8039,96 @@ def test_ppo(
assert counter == 2
actor.zero_grad()

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
def test_ppo_composite_no_aggregate(
self, loss_class, device, gradient_mode, advantage, td_est, functional
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device, composite_action_dist=True)

actor = self._create_mock_actor(
device=device,
composite_action_dist=True,
aggregate_probabilities=False,
)
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
elif advantage == "vtrace":
advantage = VTrace(
gamma=0.9,
value_network=value,
actor_network=actor,
differentiable=gradient_mode,
)
elif advantage == "td":
advantage = TD1Estimator(
gamma=0.9, value_network=value, differentiable=gradient_mode
)
elif advantage == "td_lambda":
advantage = TDLambdaEstimator(
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
elif advantage is None:
pass
else:
raise NotImplementedError

loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
functional=functional,
)
if advantage is not None:
advantage(td)
else:
if td_est is not None:
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if isinstance(loss_fn, KLPENPPOLoss):
kl = loss.pop("kl_approx")
assert (kl != 0).any()

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
# check that grads are independent and non null
named_parameters = loss_fn.named_parameters()
counter = 0
for name, p in named_parameters:
if p.grad is not None and p.grad.norm() > 0.0:
counter += 1
assert "actor" not in name
assert "critic" in name
if p.grad is None:
assert ("actor" in name) or ("target_" in name)
assert ("critic" not in name) or ("target_" in name)
assert counter == 2

value.zero_grad()
loss_objective.backward()
counter = 0
named_parameters = loss_fn.named_parameters()
for name, p in named_parameters:
if p.grad is not None and p.grad.norm() > 0.0:
counter += 1
assert "actor" in name
assert "critic" not in name
if p.grad is None:
assert ("actor" not in name) or ("target_" in name)
assert ("critic" in name) or ("target_" in name)
assert counter == 2
actor.zero_grad()

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("gradient_mode", (True,))
@pytest.mark.parametrize("device", get_default_devices())
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ def reset(self) -> None:
def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
try:
entropy = dist.entropy()
if is_tensor_collection(entropy):
entropy = entropy.get(dist.entropy_key)
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
log_prob = dist.log_prob(x)
Expand Down

1 comment on commit 0eabb78

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 0eabb78 Previous: 997d90e Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 37.61636596480376 iter/sec (stddev: 0.1724681286738312) 440.6111781120217 iter/sec (stddev: 0.0007071230066077521) 11.71

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.