Skip to content

Commit

Permalink
fix roformer models run on cuda (#84)
Browse files Browse the repository at this point in the history
* fix roformer models run on cuda

* fix roformer models run on cpu/cuda/mps

---------

Co-authored-by: 周志洋 <[email protected]>
  • Loading branch information
young01ai and 周志洋 authored Jul 3, 2024
1 parent 588a82f commit f2b78ef
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 45 deletions.
5 changes: 4 additions & 1 deletion audio_separator/separator/uvr_lib_v5/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ def flash_attn(self, q, k, v):

config = self.cuda_config if is_cuda else self.cpu_config

# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
# sdpa_flash kernel only supports float16 on sm80+ architecture gpu
if is_cuda and q.dtype != torch.float16:
config = FlashAttentionConfig(False, True, True)

# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0)

Expand Down
32 changes: 16 additions & 16 deletions audio_separator/separator/uvr_lib_v5/bs_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,10 @@ def forward(
"""

original_device = raw_audio.device

x_is_mps = True if original_device.type == 'mps' else False

if x_is_mps:
raw_audio = raw_audio.cpu()
# if x_is_mps:
# raw_audio = raw_audio.cpu()

device = raw_audio.device

Expand Down Expand Up @@ -517,13 +516,11 @@ def forward(

x = self.final_norm(x)

num_stems = len(self.mask_estimators)

mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)

if x_is_mps:
mask = mask.to('cpu')
# if x_is_mps:
# mask = mask.to('cpu')

# modulate frequency representation

Expand All @@ -540,11 +537,14 @@ def forward(

stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)

recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr,
**self.stft_kwargs,
window=stft_window.cpu() if x_is_mps else stft_window,
return_complex=False).to(device)

recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=self.num_stems)

if num_stems == 1:
if self.num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')

# if a target is passed in, calculate loss for learning
Expand Down Expand Up @@ -585,15 +585,15 @@ def forward(

if not return_loss_breakdown:
# Move the result back to the original device if it was moved to CPU for MPS compatibility
if x_is_mps:
total_loss = total_loss.to(original_device)
# if x_is_mps:
# total_loss = total_loss.to(original_device)
return total_loss

# For detailed loss breakdown, ensure all components are moved back to the original device for MPS
if x_is_mps:
loss = loss.to(original_device)
multi_stft_resolution_loss = multi_stft_resolution_loss.to(original_device)
weighted_multi_resolution_loss = weighted_multi_resolution_loss.to(original_device)
# if x_is_mps:
# loss = loss.to(original_device)
# multi_stft_resolution_loss = multi_stft_resolution_loss.to(original_device)
# weighted_multi_resolution_loss = weighted_multi_resolution_loss.to(original_device)

return total_loss, (loss, multi_stft_resolution_loss)

Expand Down
59 changes: 32 additions & 27 deletions audio_separator/separator/uvr_lib_v5/mel_band_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ def forward(
original_device = raw_audio.device
x_is_mps = True if original_device.type == 'mps' else False

if x_is_mps:
raw_audio = raw_audio.cpu()
# if x_is_mps:
# raw_audio = raw_audio.cpu()

device = raw_audio.device

Expand All @@ -418,7 +418,8 @@ def forward(

batch_arange = torch.arange(batch, device=device)[..., None]

x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices]
# x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices]
x = stft_repr[batch_arange, self.freq_indices]

x = rearrange(x, 'b f t c -> b t (f c)')

Expand All @@ -438,12 +439,10 @@ def forward(

x, = unpack(x, ps, '* f d')

num_stems = len(self.mask_estimators)

masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
if x_is_mps:
masks = masks.cpu()
# if x_is_mps:
# masks = masks.cpu()

stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')

Expand All @@ -452,29 +451,35 @@ def forward(

masks = masks.type(stft_repr.dtype)

if x_is_mps:
scatter_indices = repeat(self.freq_indices.cpu(), 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
else:
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
# if x_is_mps:
# scatter_indices = repeat(self.freq_indices.cpu(), 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
# else:
# scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=self.num_stems, t=stft_repr.shape[-1])
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=self.num_stems)
masks_summed = torch.zeros_like(stft_repr_expanded_stems.cpu() if x_is_mps else stft_repr_expanded_stems
).scatter_add_(2, scatter_indices.cpu() if x_is_mps else scatter_indices,
masks.cpu() if x_is_mps else masks).to(device)

denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
if x_is_mps:
denom = denom.cpu()
# if x_is_mps:
# denom = denom.cpu()

masks_averaged = masks_summed / denom.clamp(min=1e-8)

stft_repr = stft_repr * masks_averaged

stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)

recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr,
**self.stft_kwargs,
window=stft_window.cpu() if x_is_mps else stft_window,
return_complex=False,
length=istft_length)

recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=self.num_stems)

if num_stems == 1:
if self.num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')

if not exists(target):
Expand Down Expand Up @@ -512,17 +517,17 @@ def forward(


# Move the total loss back to the original device if necessary
if x_is_mps:
total_loss = total_loss.to(original_device)
# if x_is_mps:
# total_loss = total_loss.to(original_device)

if not return_loss_breakdown:
return total_loss
# if not return_loss_breakdown:
# return total_loss

# If detailed loss breakdown is requested, ensure all components are on the original device
return total_loss, (loss.to(original_device) if x_is_mps else loss,
multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss)
# return total_loss, (loss.to(original_device) if x_is_mps else loss,
# multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss)

# if not return_loss_breakdown:
# return total_loss
if not return_loss_breakdown:
return total_loss

# return total_loss, (loss, multi_stft_resolution_loss)
return total_loss, (loss, multi_stft_resolution_loss)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "audio-separator"
version = "0.17.4"
version = "0.17.5"
description = "Easy to use audio stem separation, using various models from UVR trained primarily by @Anjok07"
authors = ["Andrew Beveridge <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit f2b78ef

Please sign in to comment.