Skip to content

Commit

Permalink
Merge branch 'main' into alex
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Mar 12, 2024
2 parents c8f0fd4 + d6fea7f commit 621156e
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 146 deletions.
51 changes: 37 additions & 14 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ def __init__(
min_dist_gain: int = 0,
noise_ratio: float = 0.95,
reverb_ratio: float = 0.95,
applause_ratio: float = 0.01, # CHANGE
applause_ratio: float = 0.01,
bandpass_ratio: float = 0.1,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
spec_aug_ratio: float = 0.25,
codecs_ratio: float = 0.01,
spec_aug_ratio: float = 0.5,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand All @@ -215,6 +216,7 @@ def __init__(
self.noise_ratio = noise_ratio
self.reverb_ratio = reverb_ratio
self.applause_ratio = applause_ratio
self.bandpass_ratio = bandpass_ratio
self.distort_ratio = distort_ratio
self.reduce_ratio = reduce_ratio
self.spec_aug_ratio = spec_aug_ratio
Expand Down Expand Up @@ -259,14 +261,16 @@ def __init__(
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.FrequencyMasking(
freq_mask_param=10, iid_masks=True
freq_mask_param=15, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=1000, iid_masks=True
),
)

def _get_paths(self, dir_path):
os.makedirs(dir_path, exist_ok=True)

return [
os.path.join(dir_path, f)
for f in os.listdir(dir_path)
Expand Down Expand Up @@ -350,6 +354,14 @@ def apply_applause(self, wav: torch.tensor):

return AF.add_noise(waveform=wav, noise=applause, snr=snr_dbs)

def apply_bandpass(self, wav: torch.tensor):
central_freq = random.randint(1000, 3500)
Q = random.uniform(0.707, 1.41)

return torchaudio.functional.bandpass_biquad(
wav, self.sample_rate, central_freq, Q
)

def apply_reduction(self, wav: torch.tensor):
"""
Limit the high-band pass filter, the low-band pass filter and the sample rate
Expand All @@ -376,6 +388,17 @@ def apply_distortion(self, wav: torch.tensor):

return AF.overdrive(wav, gain=gain, colour=colour)

def distortion_aug_cpu(self, wav: torch.Tensor):
# This function should run on the cpu (i.e. in the dataloader collate
# function) in order to not be a bottlekneck

if random.random() < self.reduce_ratio:
wav = self.apply_reduction(wav)
if random.random() < self.distort_ratio:
wav = self.apply_distortion(wav)

return wav

def apply_codec(self, wav: torch.tensor):
"""
Apply different audio codecs to the audio.
Expand All @@ -389,7 +412,6 @@ def apply_codec(self, wav: torch.tensor):
encoder = torchaudio.io.AudioEffector(format=format, encoder=encoder)
if random.random() < self.codecs_ratio:
wav = encoder.apply(wav, self.sample_rate)
return wav

def shift_spec(self, specs: torch.Tensor, shift: int):
if shift == 0:
Expand Down Expand Up @@ -417,24 +439,25 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
return shifted_specs

def aug_wav(self, wav: torch.Tensor):
# This function doesn't apply distortion. If distortion is desired it
# should be run before hand on the cpu with distortion_aug_cpu.

# Noise
if random.random() < self.noise_ratio:
wav = self.apply_noise(wav)

if random.random() < self.applause_ratio:
wav = self.apply_applause(wav)

# Distortion
if random.random() < self.reduce_ratio:
wav = self.apply_reduction(wav)
elif random.random() < self.distort_ratio:
wav = self.apply_distortion(wav)

# Reverb
if random.random() < self.reverb_ratio:
return self.apply_reverb(wav)
else:
return wav
wav = self.apply_reverb(wav)

# EQ
if random.random() < self.bandpass_ratio:
wav = self.apply_bandpass(wav)

return wav

def norm_mel(self, mel_spec: torch.Tensor):
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
Expand All @@ -457,7 +480,7 @@ def log_mel(self, wav: torch.Tensor, shift: int | None = None):
return log_spec

def forward(self, wav: torch.Tensor, shift: int = 0):
# Noise, distortion, and reverb
# Noise, and reverb
wav = self.aug_wav(wav)

# Spec & pitch shift
Expand Down
55 changes: 54 additions & 1 deletion amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,17 @@ def __init__(self, load_path: str):
self.file_mmap = mmap.mmap(
self.file_buff.fileno(), 0, access=mmap.ACCESS_READ
)
self.index = self._build_index()

index_path = AmtDataset._get_index_path(load_path=load_path)
if os.path.isfile(index_path) is True:
self.index = self._load_index(load_path=index_path)
else:
print("Calculating index...")
self.index = self._build_index()
print(
f"Index of length {len(self.index)} calculated, saving to {index_path}"
)
self._save_index(index=self.index, save_path=index_path)

def close(self):
if self.file_buff:
Expand Down Expand Up @@ -242,6 +252,21 @@ def _build_index(self):

return index

def _save_index(self, index: list[int], save_path: str):
with open(save_path, "w") as file:
for idx in index:
file.write(f"{idx}\n")

def _load_index(self, load_path: str):
with open(load_path, "r") as file:
return [int(line.strip()) for line in file]

@staticmethod
def _get_index_path(load_path: str):
return (
f"{load_path.rsplit('.', 1)[0]}_index.{load_path.rsplit('.', 1)[1]}"
)

@classmethod
def build(
cls,
Expand All @@ -250,6 +275,12 @@ def build(
num_processes: int = 1,
):
assert os.path.isfile(save_path) is False, f"{save_path} already exists"

index_path = AmtDataset._get_index_path(load_path=save_path)
if os.path.isfile(index_path):
print(f"Removing existing index file at {index_path}")
os.remove(AmtDataset._get_index_path(load_path=save_path))

num_paths = len(matched_load_paths)
with Pool(processes=num_processes) as pool:
sharded_save_paths = []
Expand Down Expand Up @@ -277,3 +308,25 @@ def build(
os.system(shell_cmd)
for _path in sharded_save_paths:
os.remove(_path)

# Create index by loading object
AmtDataset(load_path=save_path)

def _build_index(self):
self.file_mmap.seek(0)
index = []
pos = 0
while True:
pos_buff = pos

pos = self.file_mmap.find(b"\n", pos)
if pos == -1:
break
pos = self.file_mmap.find(b"\n", pos + 1)
if pos == -1:
break

index.append(pos_buff)
pos += 1

return index
Loading

0 comments on commit 621156e

Please sign in to comment.