Skip to content

Commit

Permalink
Fix library detection (#1690)
Browse files Browse the repository at this point in the history
* fix task detection

* remove unnecessary workflow
  • Loading branch information
fxmarty authored Feb 13, 2024
1 parent 96c6d48 commit ec85aa9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
10 changes: 5 additions & 5 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,11 @@ def infer_library_from_model(

if "model_index.json" in all_files:
library_name = "diffusers"
elif (
any(file_path.startswith("sentence_") for file_path in all_files)
or "config_sentence_transformers.json" in all_files
):
library_name = "sentence_transformers"
elif CONFIG_NAME in all_files:
# We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type.
kwargs = {
Expand All @@ -1671,11 +1676,6 @@ def infer_library_from_model(
library_name = "diffusers"
else:
library_name = "transformers"
elif (
any(file_path.startswith("sentence_") for file_path in all_files)
or "config_sentence_transformers.json" in all_files
):
library_name = "sentence_transformers"
else:
library_name = "transformers"

Expand Down
10 changes: 10 additions & 0 deletions tests/exporters/common/test_tasks_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,13 @@ def test_custom_class(self):

model = TasksManager.get_model_from_task("question-answering", "uclanlp/visualbert-vqa")
self.assertTrue(isinstance(model, VisualBertForQuestionAnswering))

def test_library_detection(self):
self.assertEqual(
TasksManager.infer_library_from_model("intfloat/multilingual-e5-large"), "sentence_transformers"
)
self.assertEqual(
TasksManager.infer_library_from_model("stabilityai/stable-diffusion-xl-base-1.0"), "diffusers"
)
self.assertEqual(TasksManager.infer_library_from_model("gpt2"), "transformers")
self.assertEqual(TasksManager.infer_library_from_model("timm/mobilenetv3_large_100.ra_in1k"), "timm")

0 comments on commit ec85aa9

Please sign in to comment.