diff --git a/audiocraft/utils/audio_effects.py b/audiocraft/utils/audio_effects.py index 70fe4dbe..f8af66c9 100644 --- a/audiocraft/utils/audio_effects.py +++ b/audiocraft/utils/audio_effects.py @@ -9,6 +9,7 @@ import random import typing as tp from functools import partial +import torchaudio import julius import omegaconf @@ -250,9 +251,9 @@ def echo( # Define a few reflections with decreasing amplitude impulse_response[0] = 1.0 # Direct sound - impulse_response[ - int(sample_rate * duration) - 1 - ] = volume # First reflection after 100ms + impulse_response[int(sample_rate * duration) - 1] = ( + volume # First reflection after 100ms + ) # Add batch and channel dimensions to the impulse response impulse_response = impulse_response.unsqueeze(0).unsqueeze(0) @@ -455,3 +456,213 @@ def aac_compression( tensor, get_aac, sr=sample_rate, bitrate=bitrate, lowpass_freq=lowpass_freq ) return audio_effect_return(tensor=out, mask=mask) + + @staticmethod + def pitcch_shift( + tensor: torch.Tensor, + n_steps: float = 2.0, + sample_rate: int = 16000, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Change the pitch of the audio signal by a given number of steps. + + Parameters: + - tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time). + - n_steps (float): Number of pitch steps to shift (positive for higher pitch, negative for lower pitch). + - sample_rate (int): Sample rate of the audio signal. + - mask (torch.Tensor): Optional mask tensor. + + Returns: + - torch.Tensor: Pitch-shifted audio tensor. + """ + shifted_tensor = torchaudio.transforms.PitchShift(sample_rate, n_steps=n_steps)( + tensor + ) + return audio_effect_return(tensor=shifted_tensor, mask=mask) + + @staticmethod + def reverse( + tensor: torch.Tensor, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Reverse the audio signal. + + Parameters: + - tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time). + - mask (torch.Tensor): Optional mask tensor. + + Returns: + - torch.Tensor: Reversed audio tensor. + """ + reversed_tensor = torch.flip(tensor, dims=[-1]) + return audio_effect_return(tensor=reversed_tensor, mask=mask) + + @staticmethod + def clipping( + tensor: torch.Tensor, + clip_value: float = 0.5, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Clip the audio signal to a specific threshold value, distorting the signal. + + Parameters: + - tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time). + - clip_value (float): Threshold for clipping the audio signal. + - mask (torch.Tensor): Optional mask tensor. + + Returns: + - torch.Tensor: Clipped audio tensor. + """ + clipped_tensor = torch.clamp(tensor, min=-clip_value, max=clip_value) + return audio_effect_return(tensor=clipped_tensor, mask=mask) + + @staticmethod + def time_stretch( + tensor: torch.Tensor, + stretch_factor: float = 1.2, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Stretch the audio signal in time without changing its pitch. + + Parameters: + - tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time). + - stretch_factor (float): Factor by which to stretch the audio. + - mask (torch.Tensor): Optional mask tensor. + + Returns: + - torch.Tensor: Time-stretched audio tensor. + """ + stretched_tensor = julius.time_stretch(tensor, stretch_factor) + return audio_effect_return(tensor=stretched_tensor, mask=mask) + + @staticmethod + def tremolo( + tensor: torch.Tensor, + frequency: float = 5.0, + depth: float = 0.5, + sample_rate: int = 16000, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Apply a tremolo effect to the audio signal by modulating its amplitude. + + Parameters: + - tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time). + - frequency (float): Frequency of the tremolo effect in Hz. + - depth (float): Depth of modulation (between 0 and 1). + - sample_rate (int): Sample rate of the audio signal. + - mask (torch.Tensor): Optional mask tensor. + + Returns: + - torch.Tensor: Audio tensor with tremolo effect applied. + """ + time = torch.arange(tensor.shape[-1], device=tensor.device) / sample_rate + modulation = (1.0 + depth * torch.sin(2 * torch.pi * frequency * time)) / 2.0 + tremolo_tensor = tensor * modulation.unsqueeze(0).unsqueeze(0) + return audio_effect_return(tensor=tremolo_tensor, mask=mask) + + @staticmethod + def flanger( + tensor: torch.Tensor, + delay: float = 0.002, + depth: float = 0.002, + rate: float = 0.25, + sample_rate: int = 16000, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Apply a flanger effect to the audio signal by mixing a delayed version of the signal with itself. + + Parameters: + - tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time). + - delay (float): Base delay time in seconds. + - depth (float): Depth of the delay modulation. + - rate (float): Rate of modulation in Hz. + - sample_rate (int): Sample rate of the audio signal. + - mask (torch.Tensor): Optional mask tensor. + + Returns: + - torch.Tensor: Audio tensor with flanger effect applied. + """ + time = torch.arange(tensor.shape[-1], device=tensor.device) / sample_rate + lfo = torch.sin(2 * torch.pi * rate * time) * depth + delay + lfo_samples = (lfo * sample_rate).long().clamp(0, tensor.shape[-1] - 1) + delayed_signal = tensor[..., lfo_samples] + flanger_tensor = tensor + delayed_signal + return audio_effect_return(tensor=flanger_tensor, mask=mask) + + @staticmethod + def bit_crusher( + tensor: torch.Tensor, + bit_depth: int = 8, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Apply a bit crusher effect by reducing the bit depth of the audio signal. + + Parameters: + - tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time). + - bit_depth (int): Bit depth to reduce to (e.g., 8 bits). + - mask (torch.Tensor): Optional mask tensor. + + Returns: + - torch.Tensor: Audio tensor with reduced bit depth. + """ + scale = 2**bit_depth + crushed_tensor = torch.round(tensor * scale) / scale + return audio_effect_return(tensor=crushed_tensor, mask=mask) + + @staticmethod + def ring_modulation( + tensor: torch.Tensor, + modulation_frequency: float = 30.0, + sample_rate: int = 16000, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Apply a ring modulation effect to the audio signal, creating a metallic sound. + + Parameters: + - tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time). + - modulation_frequency (float): Frequency of the modulation in Hz. + - sample_rate (int): Sample rate of the audio signal. + - mask (torch.Tensor): Optional mask tensor. + + Returns: + - torch.Tensor: Ring-modulated audio tensor. + """ + time = torch.arange(tensor.shape[-1], device=tensor.device) / sample_rate + modulation = torch.sin(2 * torch.pi * modulation_frequency * time) + ring_modulated_tensor = tensor * modulation.unsqueeze(0).unsqueeze(0) + return audio_effect_return(tensor=ring_modulated_tensor, mask=mask) + + @staticmethod + def granulate( + tensor: torch.Tensor, + grain_size: int = 512, + overlap: float = 0.5, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Apply a granulation effect by breaking the audio into small overlapping grains. + + Parameters: + - tensor (torch.Tensor): Input audio tensor, assuming shape (batch_size, channels, time). + - grain_size (int): Size of each grain in samples. + - overlap (float): Overlap ratio between grains (0 to 1). + - mask (torch.Tensor): Optional mask tensor. + + Returns: + - torch.Tensor: Granulated audio tensor. + """ + step_size = int(grain_size * (1 - overlap)) + grains = [ + tensor[..., i:i+grain_size] + for i in range(0, tensor.shape[-1] - grain_size, step_size) + ] + granulated_tensor = torch.cat(grains, dim=-1) + return audio_effect_return(tensor=granulated_tensor, mask=mask) diff --git a/config/augmentations/default.yaml b/config/augmentations/default.yaml index 120887b0..ae54474d 100644 --- a/config/augmentations/default.yaml +++ b/config/augmentations/default.yaml @@ -41,6 +41,31 @@ audio_effects: encodec: ckpt: "//pretrained/facebook/encodec_24khz" n_qs: [4, 8, 16] + pitch_shift: + sample_rate: ${sample_rate} + n_steps: 2.0 + clipping: + clip_value: 0.5 + time_stretch: + sample_rate: ${sample_rate} + stretch_factor: 1.2 + tremolo: + frequency: 5.0 + depth: 0.5 + sample_rate: ${sample_rate} + flanger: + delay: 0.002 + depth: 0.002 + rate: 0.25 + sample_rate: ${sample_rate} + bit_crusher: + bit_depth: 8 + ring_modulation: + modulation_frequency: 30.0 + sample_rate: ${sample_rate} + granulate: + grain_size: 512 + overlap: 0.5 select_aug_mode: "use_eval" # other are 'all' and 'use_eval_acc', used to sample augmentations, `fixed` uses the prob from aug_weights, `all` uses all agmentations every step @@ -61,5 +86,14 @@ aug_weights: aac_compression: 0.1 # eval only never use in training even if eval_acc low encodec: 0.1 identity: 1 # no augmentation + pitch_shift: 0.1 + reverse: 0.1 + clipping: 0.1 + time_stretch: 0.1 + tremolo: 0.1 + flanger: 0.1 + bit_crusher: 0.1 + ring_modulation: 0.1 + granulate: 0.1 n_max_aug: null \ No newline at end of file