-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StaticModelForClassification with distilled model #199
Comments
This code also does not work:
|
Hello @Mahhos, I'll tackle them in reverse order: this snippet distilled_model = distill("stella_m2v_model") is incorrect, you shouldn't try to distill a model that is already distilled again. Instead, you should load it directly: classifier = StaticModelForClassification.from_pretrained(model_name="stella_m2v_model") The problem is that you are passing the model name as the first positional argument to the classifier. But the first positional argument is the number of out items. Please use keyword arguments. The code above should work. Regarding it taking 3 hours: that's a long time, could you give more information on your data? Did the model work in the end? Stéphan |
Thank you for your response! My training set consists of 1 million search queries, where each query could belong to any of my 10,000 categories (multi-label task). Originally, my labels were formatted as ["label_1", "label_2", "label_3"], which I have mapped to indices, resulting in [1080, 243, 21]. Yes, the model work in the end but it took 4h. Do you have any idea as why this is happening? |
Hello, This is likely due to the dimensionality of the output. If you have a many classes, the output layer could become a bottleneck. |
The results are not meaningful too! do you have any suggestion on how to incorporate your
|
Hi! I have distilled a model2vec model from a relatively small sentence transformer and now I'm trying to train a classifier using my new m2v static model. Here's the error I'm encountering. Do you have any advise?
P.S. I can fit a classifier on my data using
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")
However, this also took more than 3 hours.
The text was updated successfully, but these errors were encountered: