Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention can be None in ModernBertForSequenceClassification #35917

Open
ashmikuz opened this issue Jan 27, 2025 · 5 comments · May be fixed by #35991
Open

Attention can be None in ModernBertForSequenceClassification #35917

ashmikuz opened this issue Jan 27, 2025 · 5 comments · May be fixed by #35991

Comments

@ashmikuz
Copy link

In the ModernBertForSequenceClassification class, the attention is never computed outside of the self.model (which is a ModernBertModel). Therefore when the attention is not used as input for the model the .unsqueeze() here fails.
I solved this issue by assagning torch.ones(batch_size,seq_len) to the attention_mask, but I am not sure whether this is correct.

@Rocketknight1
Copy link
Member

Rocketknight1 commented Jan 28, 2025

Hi @ashmikuz, when no attention mask is passed then we can't really work out which positions are masked! Although we could add code to estimate this (like attention_mask = input_ids != self.config.pad_token_id), this is error-prone and I think a better solution is just to raise a clear error in this case, telling the user they have to pass an attention mask if they want to use self.classifier_pooling == "mean".

Would you be interested in making a PR for that?

@tom13878
Copy link

Hi @Rocketknight1 , @ashmikuz,

I had the same issue. Is anyone working on this? Otherwise I will raise a PR and add this RuntimeError:

if self.config.classifier_pooling == "mean" and attention_mask is None: raise RuntimeError("Mean pooling requires an attention mask to properly compute the pooled output. Please provide an attention mask to indicate which tokens should be considered in the mean pooling calculation.")

@ashmikuz
Copy link
Author

Sorry I was quite busy in the last few days. Shouldn't it match how other models behave? As far as I understand, other models just print a warning and then create an attention mask from torch.ones, right?

@tom13878
Copy link

Yes, you are right, I see the torch.ones for e.g. deberta here

I don't see the warning but I may have missed it ...

@ashmikuz
Copy link
Author

I'm working on a quick PR, just a moment and i'll send it. Hopefully it fixes the issue and is in line with other models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants