Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Discrete CQL #1666

Merged
merged 28 commits into from
Nov 10, 2023
Merged

[Algorithm] Discrete CQL #1666

merged 28 commits into from
Nov 10, 2023

Conversation

BY571
Copy link
Contributor

@BY571 BY571 commented Oct 30, 2023

Description

Adds discrete (DQN) CQL objective and example

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213 if this solves the issue #15213

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 30, 2023
@BY571
Copy link
Contributor Author

BY571 commented Nov 2, 2023

image

Converges on Cartpole as expected. Just needs some cleanup + tests

@BY571 BY571 marked this pull request as ready for review November 3, 2023 15:37
@BY571 BY571 changed the title [WIP] Discrete CQL [Algorithm] Discrete CQL Nov 3, 2023
@vmoens vmoens added the new algo New algorithm request or PR label Nov 3, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! I left some high level comments, can you have a look?
Thanks for this

examples/cql/discrete_cql_online.py Outdated Show resolved Hide resolved
test/test_cost.py Outdated Show resolved Hide resolved
test/test_cost.py Outdated Show resolved Hide resolved
torchrl/objectives/cql.py Outdated Show resolved Hide resolved
torchrl/objectives/cql.py Outdated Show resolved Hide resolved
logsumexp = torch.logsumexp(q_values, dim=-1, keepdim=True)
q_a = (q_values * current_action).sum(dim=-1, keepdim=True)

return (logsumexp - q_a).mean()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we return metadata too, like we're hoping to do for all losses in the future?

self._in_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should just be a couple of lines with dqn_loss and cql_loss IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried to inherit from the DQN class and then do something like super.forward(tensordict) and only have the cql_loss calculation added but I got circular importing issues. Do you have any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I wasn't suggesting to inherit from DQN, it's ok if they're separated. But the forward should just be a composition of loss_actor and loss_critic like we did in other losses (eg, TD3), where each sub-loss returns a tensor and a dict of metadata.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, got it! Should be adapted accordingly now.

Copy link

pytorch-bot bot commented Nov 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1666

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 25 Unrelated Failures

As of commit 9941055 with merge base 4ab5b10 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand it.
CQL loss is called in value loss, not in forward, why is that?
Why do we call item() on CQL loss value? Conventionally all losses in the output tensordict of a loss module should be differentials.
Can you give me some context?

@BY571
Copy link
Contributor Author

BY571 commented Nov 6, 2023

I don't understand it. CQL loss is called in value loss, not in forward, why is that? Why do we call item() on CQL loss value? Conventionally all losses in the output tensordict of a loss module should be differentials. Can you give me some context?

The CQL loss is more like an auxiliary term for the value loss not for a separate model like the actor. It just augments the value loss. We could separate it but then we would need to forward pass through the model again to obtain the current q values, which would slow down the process and I think there is no need to obtain only the cql loss as in itself it's incomplete.

Comment on lines 1083 to 1110
cql_loss = self.cql_loss(pred_val, action)

# calculate target value
with torch.no_grad():
target_value = self.value_estimator.value_estimate(
td_copy,
target_params=self._cached_detached_target_value_params,
).squeeze(-1)

with torch.no_grad():
td_error = (pred_val_index - target_value).pow(2)
td_error = td_error.unsqueeze(-1)
if tensordict.device is not None:
td_error = td_error.to(tensordict.device)

tensordict.set(
self.tensor_keys.priority,
td_error,
inplace=True,
)
loss = distance_loss(pred_val_index, target_value, self.loss_function).mean()

metadata = {
"td_error": td_error.mean(0).detach(),
"loss_cql": cql_loss.item(),
"pred_value": pred_val.mean().detach(),
"target_value": target_value.mean().detach(),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the cql_loss used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, I must have deleted it. Sorry for the confusion, I just updated and fixed it :)

@vmoens
Copy link
Contributor

vmoens commented Nov 6, 2023

What do you think of BY571#1? I think being able to run ablation studies has some value.

We need to fix the categorical case.

@BY571
Copy link
Contributor Author

BY571 commented Nov 7, 2023

What do you think of BY571#1? I think being able to run ablation studies has some value.

We need to fix the categorical case.

I think yes, if someone wants to check how the cql loss term influences the agent performance and want to have simple "on/off" capability it makes sense. The changes you did look good, I also pushed some adaption for the categorical case to calculate the cql loss.

@vmoens
Copy link
Contributor

vmoens commented Nov 7, 2023

Cool LMK when you've merge the PR

torchrl/objectives/cql.py Outdated Show resolved Hide resolved
torchrl/objectives/cql.py Outdated Show resolved Hide resolved
@BY571
Copy link
Contributor Author

BY571 commented Nov 8, 2023

Just merged and fixed the open issues. Let me know what you think.
Also, thank you for insisting on making the losses separate, I took advantage of it already and compared base DQN vs DQN+CQL loss :)
image

@vmoens
Copy link
Contributor

vmoens commented Nov 8, 2023

That looks great!
There are still 19 broken tests in the new test class and the example isn't running either.

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool let's merge this!

@vmoens vmoens merged commit 44dd79f into pytorch:main Nov 10, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. new algo New algorithm request or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants