Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add classifier explainability based on token importance #198

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

Pringled
Copy link
Member

@Pringled Pringled commented Feb 19, 2025

This PR adds explainability for classifiers by sorting the tokens of the input based on output layer logits for the predicted class. Couple of open questions/concerns:

  • I feel like this adds even more functions (on top of the evaluate functions) to inference/model.py Is this the best place to add functions that are shared between inference/model.py and train/classifier.py, or should we have a utils.py or something similar? In the inference module makes sense since that's also installed when installing the train module, but perhaps a different file for functions that are shared between the two makes sense.
  • Right now, if token_logits do not exist, they are computed when get_most_important_tokens is called. I think this is fine since it's really fast, but should we log this, or is it ok as is?
  • typing for the model is set to Any since we get some nasty circular import stuff if we want to type it correctly, but maybe the lord of the types knows how to fix this 👀

@Pringled Pringled requested a review from stephantul February 19, 2025 19:21
Copy link

codecov bot commented Feb 19, 2025

Codecov Report

Attention: Patch coverage is 94.11765% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
model2vec/inference/model.py 93.47% 3 Missing ⚠️
model2vec/train/base.py 75.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
model2vec/inference/__init__.py 100.00% <100.00%> (ø)
model2vec/train/classifier.py 97.48% <100.00%> (+0.10%) ⬆️
tests/test_inference.py 100.00% <100.00%> (ø)
tests/test_trainable.py 100.00% <100.00%> (ø)
model2vec/train/base.py 98.76% <75.00%> (-1.24%) ⬇️
model2vec/inference/model.py 92.66% <93.47%> (+0.28%) ⬆️

Copy link
Member

@stephantul stephantul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the implementation contains a bug wrt the tokenization, you retokenize tokens.

Here's what you could do: both the pipeline and the model already contain embeddings for all tokens separately, it's just the embedding matrix itself.

So getting token logits is in both cases just (paraphrased).

model.model_head.predict_logits(model.model.embeddings)

No need to tokenize anything. So this is what I would start with, and then we'll see.

@Pringled Pringled requested a review from stephantul February 20, 2025 16:16
Copy link
Member

@stephantul stephantul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some nitpicks, I'm curious why you only check the logits for the predicted class, and not all classes. It seems to me that that could or should be part of what it means for a classifier to be explainable.

for token_id in unique_ids:
# Get the token string and logit
token_str = model.tokenizer.id_to_token(token_id)
token_logit = model.token_logits_cache.get(token_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this ever be None?

if token_logit is None:
continue
# Get the logit for the predicted label
score = float(token_logit[label_idx])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks kind of odd to me. Why would the logit of the predicted label, and not the collected logits of the token over all classes, determine the classification. For example, consider a situation in which a token gets high logits for 2 out of many classes. In that case ,it might not be important at all, right?

results.append((token_str, score))

# Sort tokens by descending score
results.sort(key=lambda x: x[1], reverse=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be nicer to not do an in-place sort.

unique_ids = set(input_ids[0])

# Identify tokens that are not yet cached and compute their logits
tokens_to_compute = [token_id for token_id in unique_ids if token_id not in model.token_logits_cache]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you can just do this in the loop below? For each token, you just get it from the cache, if you miss, you compute it? Just a small idea

@@ -246,6 +267,15 @@ def evaluate(

return report

def get_most_important_tokens(self, text: str) -> list[tuple[str, float]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can leave these functions out? I don't really see why they exist, except that they strongly couple the modules, where they could be decoupled before. i.e., if this function doesn't exist, there is no need to ever change this code if the explainability module changes, but now there suddenly is.

mlp.out_activation_ = original_activation
return logits

def get_most_important_tokens(self, text: str) -> list[tuple[str, float]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idem as in the training code.

"""Predict the logits for the specified token IDs."""
# Extract embeddings for the specified token IDs.
token_embeddings = self.embeddings[token_ids]
mlp = self.head[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a fair warning that this doesn't work if the head is ever not a pipeline with an mlp as final estimator. Which we do support, in principle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants