diff --git a/src/transformers/pipelines/audio_classification.py b/src/transformers/pipelines/audio_classification.py index f6acbb3096e07d..1d7f5172a57bb6 100644 --- a/src/transformers/pipelines/audio_classification.py +++ b/src/transformers/pipelines/audio_classification.py @@ -211,6 +211,8 @@ def preprocess(self, inputs): processed = self.feature_extractor( inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" ) + if self.torch_dtype is not None: + processed = processed.to(dtype=self.torch_dtype) return processed def _forward(self, model_inputs): diff --git a/tests/pipelines/test_pipelines_audio_classification.py b/tests/pipelines/test_pipelines_audio_classification.py index 10b8a859ff7fb3..73534598d7d007 100644 --- a/tests/pipelines/test_pipelines_audio_classification.py +++ b/tests/pipelines/test_pipelines_audio_classification.py @@ -17,7 +17,11 @@ import numpy as np from huggingface_hub import AudioClassificationOutputElement -from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING +from transformers import ( + MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, + TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, + is_torch_available, +) from transformers.pipelines import AudioClassificationPipeline, pipeline from transformers.testing_utils import ( compare_pipeline_output_to_hub_spec, @@ -32,6 +36,10 @@ from .test_pipelines_common import ANY +if is_torch_available(): + import torch + + @is_pipeline_test class AudioClassificationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING @@ -127,6 +135,33 @@ def test_small_model_pt(self): output = audio_classifier(audio_dict, top_k=4) self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2]) + @require_torch + def test_small_model_pt_fp16(self): + model = "anton-l/wav2vec2-random-tiny-classifier" + + audio_classifier = pipeline("audio-classification", model=model, torch_dtype=torch.float16) + + audio = np.ones((8000,)) + output = audio_classifier(audio, top_k=4) + + EXPECTED_OUTPUT = [ + {"score": 0.0839, "label": "no"}, + {"score": 0.0837, "label": "go"}, + {"score": 0.0836, "label": "yes"}, + {"score": 0.0835, "label": "right"}, + ] + EXPECTED_OUTPUT_PT_2 = [ + {"score": 0.0845, "label": "stop"}, + {"score": 0.0844, "label": "on"}, + {"score": 0.0841, "label": "right"}, + {"score": 0.0834, "label": "left"}, + ] + self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2]) + + audio_dict = {"array": np.ones((8000,)), "sampling_rate": audio_classifier.feature_extractor.sampling_rate} + output = audio_classifier(audio_dict, top_k=4) + self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2]) + @require_torch @slow def test_large_model_pt(self):