-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Allow for composite action distributions in PPO/A2C losses #2391
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2391
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 6 Unrelated FailuresAs of commit 69922fa with merge base a6310ae (): NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work!
There's an exception with test_ppo_notensordict
in test_cost.py
Have a look at the couple of comments I left
Thanks a lot for the feedback! I have made a few changes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's safe to assume that people won't use this feature with non-tensordict inputs, because the action will be a tensordict anyway.
I would just document it properly in the loss docstrings where we explain how to use the loss without tensordict.
torchrl/objectives/a2c.py
Outdated
@@ -383,26 +383,39 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: | |||
entropy = dist.entropy() | |||
except NotImplementedError: | |||
x = dist.rsample((self.samples_mc_entropy,)) | |||
entropy = -dist.log_prob(x).mean(0) | |||
log_prob = dist.log_prob(x) | |||
if isinstance(log_prob, TensorDict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lazy stack is not a TensorDict but a TensorDict base.
Also ideally we would want this to work with tensorclasses.
The way to go should be to use is_tensor_collection
from tensordict lib.
torchrl/objectives/ppo.py
Outdated
@@ -449,28 +449,38 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: | |||
entropy = dist.entropy() | |||
except NotImplementedError: | |||
x = dist.rsample((self.samples_mc_entropy,)) | |||
entropy = -dist.log_prob(x).mean(0) | |||
log_prob = dist.log_prob(x) | |||
if isinstance(log_prob, TensorDict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
torchrl/objectives/ppo.py
Outdated
kl = (previous_dist.log_prob(x) - current_dist.log_prob(x)).mean(0) | ||
previous_log_prob = previous_dist.log_prob(x) | ||
current_log_prob = current_dist.log_prob(x) | ||
if isinstance(x, TensorDict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's safe to assume that people won't use this feature with non-tensordict inputs, because the action will be a tensordict anyway.
I would just document it properly in the loss docstrings where we explain how to use the loss without tensordict.
Done! I will do the off-policy and offline losses in separate PRs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Do we need a simple test for this?
Like a dedicated function in PPOTest that runs it with composite dists?
@@ -383,26 +395,39 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: | |||
entropy = dist.entropy() | |||
except NotImplementedError: | |||
x = dist.rsample((self.samples_mc_entropy,)) | |||
entropy = -dist.log_prob(x).mean(0) | |||
log_prob = dist.log_prob(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, was this a bug or did we sum the log-probs automatically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is simply because the log_prob()
method for a composite dist will return a TD instead of a Tensor, so we compute the entropy in 2 steps.
This is the old version:
def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
try:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
entropy = -dist.log_prob(x).mean(0)
return entropy.unsqueeze(-1)
This is the new version. It simply retrieves the log tensor before computing the entropy.
def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
try:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
log_prob = dist.log_prob(x)
if is_tensor_collection(log_prob):
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
entropy = -log_prob.mean(0)
return entropy.unsqueeze(-1)
Regarding a dedicated test, how do you usually approach this decision? When I thought about testing different dists I saw it a bit like testing for different types of ValueEstimators. It should work both with single dists and with composite dists in all the tested situations. So I added it to all tests (except the We could probably switch to a single dedicated test function though. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for this!
Computing the entropy for composite distributions is not fully resolved, particularly when dealing with a composite distribution that includes some distributions with an implemented entropy() method and others without. We could add an entropy method to wdyt? |
From discord:
I think it's a good idea, I prefer a |
To add on my previous comment, here is how I would address this:
|
makes sense |
Merging this to clear space in the PR list but we should take care of #2391 (comment) sooner than later! Wanna give a shot at it or should I? |
I will give it a shot, give me a few days |
Description
At the moment objective classes do not allow to use an actor with a composite distribution.
This PR aims to fix this. I have started with PPO, it turned out to required more changes than I anticipated. In particular, I am struggling with the test
test_ppo_notensordict
.Once these modification are correct, I will move on the the tests of the other on-policy objectives and then to all other objectives.
This PR requires the TensorDict PR pytorch/tensordict#961 to be merged.