Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StaticModelForClassification with distilled model #199

Open
Mahhos opened this issue Feb 25, 2025 · 5 comments
Open

StaticModelForClassification with distilled model #199

Mahhos opened this issue Feb 25, 2025 · 5 comments

Comments

@Mahhos
Copy link

Mahhos commented Feb 25, 2025

Hi! I have distilled a model2vec model from a relatively small sentence transformer and now I'm trying to train a classifier using my new m2v static model. Here's the error I'm encountering. Do you have any advise?


import pandas as pd
from model2vec.train import StaticModelForClassification

classifier = StaticModelForClassification.from_pretrained("stella_m2v_model")

classifier.fit(df['search_query'], df['label_indices'].tolist())
TypeError: 'str' object cannot be interpreted as an integer
File <command-2318105830669928>, line 4
      1 import pandas as pd
      2 from model2vec.train import StaticModelForClassification
----> 4 classifier = StaticModelForClassification.from_pretrained("stella_m2v_model")
      6 classifier.fit(df['search_query'], df['label_indices'].tolist())
File /lib/python3.12/site-packages/model2vec/train/base.py:52, in FinetunableStaticModel.from_pretrained(cls, out_dim, model_name, **kwargs)
     50 """Load the model from a pretrained model2vec model."""
     51 model = StaticModel.from_pretrained(model_name)
---> 52 return cls.from_static_model(model, out_dim, **kwargs)
File /lib/python3.12/site-packages/model2vec/train/base.py:59, in FinetunableStaticModel.from_static_model(cls, model, out_dim, **kwargs)
     57 model.embedding = np.nan_to_num(model.embedding)
     58 embeddings_converted = torch.from_numpy(model.embedding)
---> 59 return cls(
     60     vectors=embeddings_converted,
     61     pad_id=model.tokenizer.token_to_id("[PAD]"),
     62     out_dim=out_dim,
     63     tokenizer=model.tokenizer,
     64     **kwargs,
     65 )
File /lib/python3.12/site-packages/model2vec/train/classifier.py:46, in StaticModelForClassification.__init__(self, vectors, tokenizer, n_layers, hidden_dim, out_dim, pad_id)
     44 self.hidden_dim = hidden_dim
     45 # Alias: Follows scikit-learn. Set to dummy classes
---> 46 self.classes_: list[str] = [str(x) for x in range(out_dim)]
     47 # multilabel flag will be set based on the type of `y` passed to fit.
     48 self.multilabel: bool = False

P.S. I can fit a classifier on my data using
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")
However, this also took more than 3 hours.

#how i used to load my distilled model:

from model2vec import StaticModel
model = StaticModel.from_pretrained("stella_m2v_model")
@Mahhos
Copy link
Author

Mahhos commented Feb 25, 2025

This code also does not work:

import pandas as pd
from model2vec.train import StaticModelForClassification
from model2vec.distill import distill

# From a distilled model
distilled_model = distill("stella_m2v_model")
classifier = StaticModelForClassification.from_static_model("stella_m2v_model")

classifier.fit(df['search_query'], df['label_indices'].tolist())
ValueError: The checkpoint you are trying to load has model type `model2vec` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.
File /databricks/python/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py:1034, in AutoConfig.from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
   1033 try:
-> 1034     config_class = CONFIG_MAPPING[config_dict["model_type"]]
   1035 except KeyError:
File /databricks/python/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py:736, in _LazyConfigMapping.__getitem__(self, key)
    735 if key not in self._mapping:
--> 736     raise KeyError(key)
    737 value = self._mapping[key]
KeyError: 'model2vec'

@stephantul
Copy link
Member

Hello @Mahhos,

I'll tackle them in reverse order: this snippet

distilled_model = distill("stella_m2v_model")

is incorrect, you shouldn't try to distill a model that is already distilled again. Instead, you should load it directly:

 classifier = StaticModelForClassification.from_pretrained(model_name="stella_m2v_model") 

The problem is that you are passing the model name as the first positional argument to the classifier. But the first positional argument is the number of out items. Please use keyword arguments. The code above should work.

Regarding it taking 3 hours: that's a long time, could you give more information on your data? Did the model work in the end?

Stéphan

@Mahhos
Copy link
Author

Mahhos commented Feb 26, 2025

Thank you for your response! My training set consists of 1 million search queries, where each query could belong to any of my 10,000 categories (multi-label task). Originally, my labels were formatted as ["label_1", "label_2", "label_3"], which I have mapped to indices, resulting in [1080, 243, 21].

Yes, the model work in the end but it took 4h. Do you have any idea as why this is happening?

@stephantul
Copy link
Member

Hello,

This is likely due to the dimensionality of the output. If you have a many classes, the output layer could become a bottleneck.

@Mahhos
Copy link
Author

Mahhos commented Feb 27, 2025

The results are not meaningful too! do you have any suggestion on how to incorporate your StaticModelForClassification for such extreme multi-label tasks?

test set #1: 
             precision    recall  f1-score  
   micro avg       0.02      0.03      0.02     
   macro avg       0.02      0.03      0.02     
weighted avg       0.02      0.03      0.02     
 samples avg       0.02      0.03      0.02     

test set #2:
   micro avg       0.00      0.00      0.00    
   macro avg       0.00      0.00      0.00    
weighted avg       0.00      0.00      0.00    
 samples avg       0.00      0.00      0.00    

test set #3:
   micro avg       0.00      0.00      0.00     
   macro avg       0.00      0.00      0.00     
weighted avg       0.00      0.00      0.00     
 samples avg       0.00      0.00      0.00   

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants