-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add classifier explainability based on token importance #198
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the implementation contains a bug wrt the tokenization, you retokenize tokens.
Here's what you could do: both the pipeline and the model already contain embeddings for all tokens separately, it's just the embedding matrix itself.
So getting token logits is in both cases just (paraphrased).
model.model_head.predict_logits(model.model.embeddings)
No need to tokenize anything. So this is what I would start with, and then we'll see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some nitpicks, I'm curious why you only check the logits for the predicted class, and not all classes. It seems to me that that could or should be part of what it means for a classifier to be explainable.
for token_id in unique_ids: | ||
# Get the token string and logit | ||
token_str = model.tokenizer.id_to_token(token_id) | ||
token_logit = model.token_logits_cache.get(token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this ever be None
?
if token_logit is None: | ||
continue | ||
# Get the logit for the predicted label | ||
score = float(token_logit[label_idx]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks kind of odd to me. Why would the logit of the predicted label, and not the collected logits of the token over all classes, determine the classification. For example, consider a situation in which a token gets high logits for 2 out of many classes. In that case ,it might not be important at all, right?
results.append((token_str, score)) | ||
|
||
# Sort tokens by descending score | ||
results.sort(key=lambda x: x[1], reverse=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be nicer to not do an in-place sort.
unique_ids = set(input_ids[0]) | ||
|
||
# Identify tokens that are not yet cached and compute their logits | ||
tokens_to_compute = [token_id for token_id in unique_ids if token_id not in model.token_logits_cache] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe you can just do this in the loop below? For each token, you just get it from the cache, if you miss, you compute it? Just a small idea
@@ -246,6 +267,15 @@ def evaluate( | |||
|
|||
return report | |||
|
|||
def get_most_important_tokens(self, text: str) -> list[tuple[str, float]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can leave these functions out? I don't really see why they exist, except that they strongly couple the modules, where they could be decoupled before. i.e., if this function doesn't exist, there is no need to ever change this code if the explainability module changes, but now there suddenly is.
mlp.out_activation_ = original_activation | ||
return logits | ||
|
||
def get_most_important_tokens(self, text: str) -> list[tuple[str, float]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idem as in the training code.
"""Predict the logits for the specified token IDs.""" | ||
# Extract embeddings for the specified token IDs. | ||
token_embeddings = self.embeddings[token_ids] | ||
mlp = self.head[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a fair warning that this doesn't work if the head is ever not a pipeline with an mlp as final estimator. Which we do support, in principle.
This PR adds explainability for classifiers by sorting the tokens of the input based on output layer logits for the predicted class. Couple of open questions/concerns:
get_most_important_tokens
is called. I think this is fine since it's really fast, but should we log this, or is it ok as is?