From 6759002fadedd5e3f630a9c45b1d0a1190eafe46 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 27 Dec 2024 00:00:48 -0800 Subject: [PATCH 1/6] fix low-precision audio classification pipeline Signed-off-by: jiqing-feng --- src/transformers/pipelines/audio_classification.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/pipelines/audio_classification.py b/src/transformers/pipelines/audio_classification.py index 089c32d502d121..138abf3ce5ed91 100644 --- a/src/transformers/pipelines/audio_classification.py +++ b/src/transformers/pipelines/audio_classification.py @@ -211,6 +211,8 @@ def preprocess(self, inputs): processed = self.feature_extractor( inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" ) + if self.torch_dtype is not None: + processed = processed.to(dtype=self.torch_dtype) return processed def _forward(self, model_inputs): From a47f5bd1abd5daea99d95a70d7584dba56393e8a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 3 Jan 2025 16:34:14 +0000 Subject: [PATCH 2/6] add test Signed-off-by: jiqing-feng --- .../test_pipelines_audio_classification.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/pipelines/test_pipelines_audio_classification.py b/tests/pipelines/test_pipelines_audio_classification.py index 10b8a859ff7fb3..63ae6a42763ec2 100644 --- a/tests/pipelines/test_pipelines_audio_classification.py +++ b/tests/pipelines/test_pipelines_audio_classification.py @@ -14,6 +14,7 @@ import unittest +import torch import numpy as np from huggingface_hub import AudioClassificationOutputElement @@ -127,6 +128,33 @@ def test_small_model_pt(self): output = audio_classifier(audio_dict, top_k=4) self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2]) + @require_torch + def test_small_model_pt_fp16(self): + model = "anton-l/wav2vec2-random-tiny-classifier" + + audio_classifier = pipeline("audio-classification", model=model, torch_dtype=torch.float16) + + audio = np.ones((8000,)) + output = audio_classifier(audio, top_k=4) + + EXPECTED_OUTPUT = [ + {'score': 0.0839, 'label': 'no'}, + {'score': 0.0837, 'label': 'go'}, + {'score': 0.0836, 'label': 'yes'}, + {'score': 0.0835, 'label': 'right'} + ] + EXPECTED_OUTPUT_PT_2 = [ + {"score": 0.0845, "label": "stop"}, + {"score": 0.0844, "label": "on"}, + {"score": 0.0841, "label": "right"}, + {"score": 0.0834, "label": "left"}, + ] + self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2]) + + audio_dict = {"array": np.ones((8000,)), "sampling_rate": audio_classifier.feature_extractor.sampling_rate} + output = audio_classifier(audio_dict, top_k=4) + self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2]) + @require_torch @slow def test_large_model_pt(self): From 0d9d2e19d4ca3b8b72e61e6f6b789249be167726 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 3 Jan 2025 16:38:46 +0000 Subject: [PATCH 3/6] fix format Signed-off-by: jiqing-feng --- tests/pipelines/test_pipelines_audio_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_audio_classification.py b/tests/pipelines/test_pipelines_audio_classification.py index 63ae6a42763ec2..2cfc25076e45aa 100644 --- a/tests/pipelines/test_pipelines_audio_classification.py +++ b/tests/pipelines/test_pipelines_audio_classification.py @@ -14,8 +14,8 @@ import unittest -import torch import numpy as np +import torch from huggingface_hub import AudioClassificationOutputElement from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING From 19f432804efb68939a4b1ba4182b34d74f1f175e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 3 Jan 2025 16:45:59 +0000 Subject: [PATCH 4/6] fix torch import Signed-off-by: jiqing-feng --- tests/pipelines/test_pipelines_audio_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_audio_classification.py b/tests/pipelines/test_pipelines_audio_classification.py index 2cfc25076e45aa..6e0354028106f1 100644 --- a/tests/pipelines/test_pipelines_audio_classification.py +++ b/tests/pipelines/test_pipelines_audio_classification.py @@ -15,7 +15,6 @@ import unittest import numpy as np -import torch from huggingface_hub import AudioClassificationOutputElement from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING @@ -130,6 +129,7 @@ def test_small_model_pt(self): @require_torch def test_small_model_pt_fp16(self): + import torch model = "anton-l/wav2vec2-random-tiny-classifier" audio_classifier = pipeline("audio-classification", model=model, torch_dtype=torch.float16) From 35d4d6d71567a44b7430743afd64878f4c29844a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 3 Jan 2025 16:47:25 +0000 Subject: [PATCH 5/6] fix torch import Signed-off-by: jiqing-feng --- .../pipelines/test_pipelines_audio_classification.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/test_pipelines_audio_classification.py b/tests/pipelines/test_pipelines_audio_classification.py index 6e0354028106f1..b4a6c1fcec0ebc 100644 --- a/tests/pipelines/test_pipelines_audio_classification.py +++ b/tests/pipelines/test_pipelines_audio_classification.py @@ -17,7 +17,11 @@ import numpy as np from huggingface_hub import AudioClassificationOutputElement -from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING +from transformers import ( + MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, + TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, + is_torch_available, +) from transformers.pipelines import AudioClassificationPipeline, pipeline from transformers.testing_utils import ( compare_pipeline_output_to_hub_spec, @@ -32,6 +36,10 @@ from .test_pipelines_common import ANY +if is_torch_available(): + import torch + + @is_pipeline_test class AudioClassificationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING @@ -129,7 +137,6 @@ def test_small_model_pt(self): @require_torch def test_small_model_pt_fp16(self): - import torch model = "anton-l/wav2vec2-random-tiny-classifier" audio_classifier = pipeline("audio-classification", model=model, torch_dtype=torch.float16) From b7954b03ecb8734cb076c8d50ce6111bd9664ba3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 3 Jan 2025 16:49:35 +0000 Subject: [PATCH 6/6] fix format Signed-off-by: jiqing-feng --- tests/pipelines/test_pipelines_audio_classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/test_pipelines_audio_classification.py b/tests/pipelines/test_pipelines_audio_classification.py index b4a6c1fcec0ebc..73534598d7d007 100644 --- a/tests/pipelines/test_pipelines_audio_classification.py +++ b/tests/pipelines/test_pipelines_audio_classification.py @@ -145,10 +145,10 @@ def test_small_model_pt_fp16(self): output = audio_classifier(audio, top_k=4) EXPECTED_OUTPUT = [ - {'score': 0.0839, 'label': 'no'}, - {'score': 0.0837, 'label': 'go'}, - {'score': 0.0836, 'label': 'yes'}, - {'score': 0.0835, 'label': 'right'} + {"score": 0.0839, "label": "no"}, + {"score": 0.0837, "label": "go"}, + {"score": 0.0836, "label": "yes"}, + {"score": 0.0835, "label": "right"}, ] EXPECTED_OUTPUT_PT_2 = [ {"score": 0.0845, "label": "stop"},