Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformerSenderReinforce max_len parameter #247

Open
mitjanikolaus opened this issue Sep 8, 2022 · 10 comments
Open

TransformerSenderReinforce max_len parameter #247

mitjanikolaus opened this issue Sep 8, 2022 · 10 comments

Comments

@mitjanikolaus
Copy link

The docstring of TransformerSenderReinforce mentions that the max_len parameter includes the EOS token: https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L687

However, the transformer can create a message of max_len (without EOS) and the EOS token will be appended afterwards: https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L835

So, to my understanding, the max_len parameter does actually not include the EOS token?

@robertodessi
Copy link
Contributor

True, docstrings have not been updated after we went with the "force eos" decision. Do you want to open a PR about this as well? :)

@robertodessi
Copy link
Contributor

Btw, there's some known inconsistencies in egg regarding max_len, see #137 and #138, if you feel like you would want to fix it contributions are always welcome :)

@mitjanikolaus
Copy link
Author

I think these two issues are related, but require a bit more refactoring as it's not straightforward to change the behavior of find_lengths to return lengths that are 0 (these can't be handled by the pytorch RNN implementations)

@robertodessi
Copy link
Contributor

Can we assume length is always > 0 unless the input is something like an empty tensor? We do check that length is greater than one in EGG so that should not break when call find_lenths

@mitjanikolaus
Copy link
Author

I was referring to the solution 2 you proposed here: #138
In this case, a message starting with an EOS token would be treated as having a length of 0, which then causes issues when read by an RNN.

@robertodessi
Copy link
Contributor

Technically a message starting with EOS should not be accepted. How would you handle it?
We are giving the input through the receiver RNN regardless of the symbol (whether it's EOS or not) and just ignoring everything after EOS when computing the loss. Therefore, I guess the RNN implementations shouldn't fail

@mitjanikolaus
Copy link
Author

To my mind the problem is this issue in pytorch.

@robertodessi
Copy link
Contributor

But we'll never have a sequence of len 0. This is because
1/ we enforce max_len to be greater or equal than 1 https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L266

2/ if the user sets max_len equals to 1 and, though unlikely, the sender generates only the EOS, we would have a message of size 2: the sender-generated EOS and the one automatically appended to each message. This tensor is then given to the receiver RNN in case of a RNN receiver and all the EOS handling is done in the Game instance, see e.g. https://github.com/facebookresearch/EGG/blob/main/egg/core/reinforce_wrappers.py#L579-L588

For the such reasons I don't think we'll encounter the above error

@mitjanikolaus
Copy link
Author

If we consider the solution 2 you proposed here (#138), find_lengths would need to return 0 in case a message consists only of the EOS token. Or maybe I understood the proposal wrong?

@robertodessi
Copy link
Contributor

Yes, but I think this was before we appended an EOS to every message, so that proposal is outdated. Do you have anything in mind? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants