From ec85aa9b95baa2e2d9d01d2f50e855f206ba3afb Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 13 Feb 2024 14:39:51 +0100 Subject: [PATCH] Fix library detection (#1690) * fix task detection * remove unnecessary workflow --- optimum/exporters/tasks.py | 10 +++++----- tests/exporters/common/test_tasks_manager.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index cd2d0f3f1bd..e8e8af2bce9 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1655,6 +1655,11 @@ def infer_library_from_model( if "model_index.json" in all_files: library_name = "diffusers" + elif ( + any(file_path.startswith("sentence_") for file_path in all_files) + or "config_sentence_transformers.json" in all_files + ): + library_name = "sentence_transformers" elif CONFIG_NAME in all_files: # We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type. kwargs = { @@ -1671,11 +1676,6 @@ def infer_library_from_model( library_name = "diffusers" else: library_name = "transformers" - elif ( - any(file_path.startswith("sentence_") for file_path in all_files) - or "config_sentence_transformers.json" in all_files - ): - library_name = "sentence_transformers" else: library_name = "transformers" diff --git a/tests/exporters/common/test_tasks_manager.py b/tests/exporters/common/test_tasks_manager.py index fc0d3eb8dbc..32cb1afc14b 100644 --- a/tests/exporters/common/test_tasks_manager.py +++ b/tests/exporters/common/test_tasks_manager.py @@ -177,3 +177,13 @@ def test_custom_class(self): model = TasksManager.get_model_from_task("question-answering", "uclanlp/visualbert-vqa") self.assertTrue(isinstance(model, VisualBertForQuestionAnswering)) + + def test_library_detection(self): + self.assertEqual( + TasksManager.infer_library_from_model("intfloat/multilingual-e5-large"), "sentence_transformers" + ) + self.assertEqual( + TasksManager.infer_library_from_model("stabilityai/stable-diffusion-xl-base-1.0"), "diffusers" + ) + self.assertEqual(TasksManager.infer_library_from_model("gpt2"), "transformers") + self.assertEqual(TasksManager.infer_library_from_model("timm/mobilenetv3_large_100.ra_in1k"), "timm")