From f2b78ef389ff8d50cde384b52a5a0bddf702ef62 Mon Sep 17 00:00:00 2001 From: young01ai Date: Wed, 3 Jul 2024 22:37:51 +0800 Subject: [PATCH] fix roformer models run on cuda (#84) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix roformer models run on cuda * fix roformer models run on cpu/cuda/mps --------- Co-authored-by: 周志洋 --- .../separator/uvr_lib_v5/attend.py | 5 +- .../separator/uvr_lib_v5/bs_roformer.py | 32 +++++----- .../separator/uvr_lib_v5/mel_band_roformer.py | 59 ++++++++++--------- pyproject.toml | 2 +- 4 files changed, 53 insertions(+), 45 deletions(-) diff --git a/audio_separator/separator/uvr_lib_v5/attend.py b/audio_separator/separator/uvr_lib_v5/attend.py index 96858c0..cc01093 100644 --- a/audio_separator/separator/uvr_lib_v5/attend.py +++ b/audio_separator/separator/uvr_lib_v5/attend.py @@ -70,8 +70,11 @@ def flash_attn(self, q, k, v): config = self.cuda_config if is_cuda else self.cpu_config - # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + # sdpa_flash kernel only supports float16 on sm80+ architecture gpu + if is_cuda and q.dtype != torch.float16: + config = FlashAttentionConfig(False, True, True) + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale with torch.backends.cuda.sdp_kernel(**config._asdict()): out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0) diff --git a/audio_separator/separator/uvr_lib_v5/bs_roformer.py b/audio_separator/separator/uvr_lib_v5/bs_roformer.py index ab7e031..6ce73bb 100644 --- a/audio_separator/separator/uvr_lib_v5/bs_roformer.py +++ b/audio_separator/separator/uvr_lib_v5/bs_roformer.py @@ -457,11 +457,10 @@ def forward( """ original_device = raw_audio.device - x_is_mps = True if original_device.type == 'mps' else False - if x_is_mps: - raw_audio = raw_audio.cpu() + # if x_is_mps: + # raw_audio = raw_audio.cpu() device = raw_audio.device @@ -517,13 +516,11 @@ def forward( x = self.final_norm(x) - num_stems = len(self.mask_estimators) - mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2) - if x_is_mps: - mask = mask.to('cpu') + # if x_is_mps: + # mask = mask.to('cpu') # modulate frequency representation @@ -540,11 +537,14 @@ def forward( stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) - recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False) + recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=False).to(device) - recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems) + recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=self.num_stems) - if num_stems == 1: + if self.num_stems == 1: recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') # if a target is passed in, calculate loss for learning @@ -585,15 +585,15 @@ def forward( if not return_loss_breakdown: # Move the result back to the original device if it was moved to CPU for MPS compatibility - if x_is_mps: - total_loss = total_loss.to(original_device) + # if x_is_mps: + # total_loss = total_loss.to(original_device) return total_loss # For detailed loss breakdown, ensure all components are moved back to the original device for MPS - if x_is_mps: - loss = loss.to(original_device) - multi_stft_resolution_loss = multi_stft_resolution_loss.to(original_device) - weighted_multi_resolution_loss = weighted_multi_resolution_loss.to(original_device) + # if x_is_mps: + # loss = loss.to(original_device) + # multi_stft_resolution_loss = multi_stft_resolution_loss.to(original_device) + # weighted_multi_resolution_loss = weighted_multi_resolution_loss.to(original_device) return total_loss, (loss, multi_stft_resolution_loss) diff --git a/audio_separator/separator/uvr_lib_v5/mel_band_roformer.py b/audio_separator/separator/uvr_lib_v5/mel_band_roformer.py index 5cfecd9..c86b1f2 100644 --- a/audio_separator/separator/uvr_lib_v5/mel_band_roformer.py +++ b/audio_separator/separator/uvr_lib_v5/mel_band_roformer.py @@ -390,8 +390,8 @@ def forward( original_device = raw_audio.device x_is_mps = True if original_device.type == 'mps' else False - if x_is_mps: - raw_audio = raw_audio.cpu() + # if x_is_mps: + # raw_audio = raw_audio.cpu() device = raw_audio.device @@ -418,7 +418,8 @@ def forward( batch_arange = torch.arange(batch, device=device)[..., None] - x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices] + # x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices] + x = stft_repr[batch_arange, self.freq_indices] x = rearrange(x, 'b f t c -> b t (f c)') @@ -438,12 +439,10 @@ def forward( x, = unpack(x, ps, '* f d') - num_stems = len(self.mask_estimators) - masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2) - if x_is_mps: - masks = masks.cpu() + # if x_is_mps: + # masks = masks.cpu() stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') @@ -452,16 +451,19 @@ def forward( masks = masks.type(stft_repr.dtype) - if x_is_mps: - scatter_indices = repeat(self.freq_indices.cpu(), 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) - else: - scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) - stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems) - masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks) + # if x_is_mps: + # scatter_indices = repeat(self.freq_indices.cpu(), 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) + # else: + # scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) + scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=self.num_stems, t=stft_repr.shape[-1]) + stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=self.num_stems) + masks_summed = torch.zeros_like(stft_repr_expanded_stems.cpu() if x_is_mps else stft_repr_expanded_stems + ).scatter_add_(2, scatter_indices.cpu() if x_is_mps else scatter_indices, + masks.cpu() if x_is_mps else masks).to(device) denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels) - if x_is_mps: - denom = denom.cpu() + # if x_is_mps: + # denom = denom.cpu() masks_averaged = masks_summed / denom.clamp(min=1e-8) @@ -469,12 +471,15 @@ def forward( stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) - recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, + recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=False, length=istft_length) - recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems) + recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=self.num_stems) - if num_stems == 1: + if self.num_stems == 1: recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') if not exists(target): @@ -512,17 +517,17 @@ def forward( # Move the total loss back to the original device if necessary - if x_is_mps: - total_loss = total_loss.to(original_device) + # if x_is_mps: + # total_loss = total_loss.to(original_device) - if not return_loss_breakdown: - return total_loss + # if not return_loss_breakdown: + # return total_loss # If detailed loss breakdown is requested, ensure all components are on the original device - return total_loss, (loss.to(original_device) if x_is_mps else loss, - multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss) + # return total_loss, (loss.to(original_device) if x_is_mps else loss, + # multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss) - # if not return_loss_breakdown: - # return total_loss + if not return_loss_breakdown: + return total_loss - # return total_loss, (loss, multi_stft_resolution_loss) \ No newline at end of file + return total_loss, (loss, multi_stft_resolution_loss) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 25d1905..c1f835f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "audio-separator" -version = "0.17.4" +version = "0.17.5" description = "Easy to use audio stem separation, using various models from UVR trained primarily by @Anjok07" authors = ["Andrew Beveridge "] license = "MIT"