From e73927c3b23eb59f38980137ad9d4eb6e815b826 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 25 Dec 2024 17:35:19 -0800 Subject: [PATCH 1/3] fix vits dtype Signed-off-by: jiqing-feng --- src/transformers/models/vits/modeling_vits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 66834167d15e06..2cda95b3169da2 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1407,9 +1407,9 @@ def forward( raise NotImplementedError("Training of VITS is not supported yet.") if attention_mask is not None: - input_padding_mask = attention_mask.unsqueeze(-1).float() + input_padding_mask = attention_mask.unsqueeze(-1).to(self.dtype) else: - input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float() + input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(self.dtype) if self.config.num_speakers > 1 and speaker_id is not None: if not 0 <= speaker_id < self.config.num_speakers: From 671fee8f71a220e880db6c92a6d1c294f3b8b503 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 3 Jan 2025 16:19:40 +0000 Subject: [PATCH 2/3] add tests Signed-off-by: jiqing-feng --- tests/models/vits/test_modeling_vits.py | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/models/vits/test_modeling_vits.py b/tests/models/vits/test_modeling_vits.py index 366194090953f5..9733fb4bce1e65 100644 --- a/tests/models/vits/test_modeling_vits.py +++ b/tests/models/vits/test_modeling_vits.py @@ -27,6 +27,7 @@ is_flaky, is_torch_available, require_torch, + require_torch_fp16, require_torch_multi_gpu, slow, torch_device, @@ -434,3 +435,34 @@ def test_forward(self): ) # fmt: on self.assertTrue(torch.allclose(outputs.waveform[0, 10000:10030].cpu(), EXPECTED_LOGITS, atol=1e-4)) + + @require_torch_fp16 + def test_forward_fp16(self): + # GPU gives different results than CPU + torch_device = "cpu" + + model = VitsModel.from_pretrained("facebook/mms-tts-eng", torch_dtype=torch.float16) + model.to(torch_device) + + tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng") + + set_seed(555) # make deterministic + + input_text = "Mister quilter is the apostle of the middle classes and we are glad to welcome his gospel!" + input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(torch_device) + + with torch.no_grad(): + outputs = model(input_ids) + + self.assertEqual(outputs.waveform.shape, (1, 87040)) + # fmt: off + EXPECTED_LOGITS = torch.tensor( + [ + 0.0101, 0.0318, 0.0489, 0.0627, 0.0728, 0.0865, 0.1053, 0.1279, + 0.1514, 0.1703, 0.1827, 0.1829, 0.1694, 0.1509, 0.1332, 0.1188, + 0.1066, 0.0978, 0.0936, 0.0867, 0.0724, 0.0493, 0.0197, -0.0141, + -0.0501, -0.0817, -0.1065, -0.1223, -0.1311, -0.1339 + ] + ).to(torch.float16) + # fmt: on + self.assertTrue(torch.allclose(outputs.waveform[0, 10000:10030].cpu(), EXPECTED_LOGITS, atol=1e-4)) From 057804c624ad195dd56e629ea790d7a4c780c0af Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 17 Jan 2025 09:39:38 +0000 Subject: [PATCH 3/3] use weight dtype Signed-off-by: jiqing-feng --- src/transformers/models/vits/modeling_vits.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 3a3e0c4502e46f..7a506d497f9a26 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1406,10 +1406,11 @@ def forward( if labels is not None: raise NotImplementedError("Training of VITS is not supported yet.") + mask_dtype = self.text_encoder.embed_tokens.weight.dtype if attention_mask is not None: - input_padding_mask = attention_mask.unsqueeze(-1).to(self.dtype) + input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype) else: - input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(self.dtype) + input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype) if self.config.num_speakers > 1 and speaker_id is not None: if not 0 <= speaker_id < self.config.num_speakers: