Skip to content

Commit

Permalink
Add applause augmentation (#15)
Browse files Browse the repository at this point in the history
* add data aug and clean

* fix reverb
  • Loading branch information
loubbrad authored Mar 8, 2024
1 parent e00b374 commit b82a9da
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 50 deletions.
4 changes: 0 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
*.xml
*.html
*.htm
*.mid
*.midi
*.wav
*.mp3

.idea/

Expand Down
84 changes: 62 additions & 22 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def log_mel_spectrogram(
return log_spec


# Refactor default params are stored in config.json
class AudioTransform(torch.nn.Module):
def __init__(
self,
Expand All @@ -190,10 +191,12 @@ def __init__(
max_snr: int = 50,
max_dist_gain: int = 25,
min_dist_gain: int = 0,
# ratios for the reduction of the audio quality
distort_ratio: float = 0.2,
reduce_ratio: float = 0.2,
spec_aug_ratio: float = 0.2,
noise_ratio: float = 0.95,
reverb_ratio: float = 0.95,
applause_ratio: float = 0.01, # CHANGE
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
spec_aug_ratio: float = 0.25,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand All @@ -208,9 +211,13 @@ def __init__(
self.chunk_len = self.config["chunk_len"]
self.num_samples = self.sample_rate * self.chunk_len

self.dist_ratio = distort_ratio
self.noise_ratio = noise_ratio
self.reverb_ratio = reverb_ratio
self.applause_ratio = applause_ratio
self.distort_ratio = distort_ratio
self.reduce_ratio = reduce_ratio
self.spec_aug_ratio = spec_aug_ratio
self.reduction_resample_rate = 6000 # Hardcoded?

# Audio aug
impulse_paths = self._get_paths(
Expand All @@ -219,6 +226,9 @@ def __init__(
noise_paths = self._get_paths(
os.path.join(os.path.dirname(__file__), "assets", "noise")
)
applause_paths = self._get_paths(
os.path.join(os.path.dirname(__file__), "assets", "applause")
)

# Register impulses and noises as buffers
self.num_impulse = 0
Expand All @@ -231,6 +241,11 @@ def __init__(
self.register_buffer(f"noise_{i}", noise)
self.num_noise += 1

self.num_applause = 0
for i, applause in enumerate(self._get_noise(applause_paths)):
self.register_buffer(f"applause_{i}", applause)
self.num_applause += 1

self.spec_transform = torchaudio.transforms.Spectrogram(
n_fft=self.config["n_fft"],
hop_length=self.config["hop_len"],
Expand Down Expand Up @@ -321,15 +336,37 @@ def apply_noise(self, wav: torch.tensor):

return AF.add_noise(waveform=wav, noise=noise, snr=snr_dbs)

def apply_applause(self, wav: torch.tensor):
batch_size, _ = wav.shape

snr_dbs = torch.tensor(
[random.randint(1, self.min_snr) for _ in range(batch_size)]
).to(wav.device)
applause_type = random.randint(5, self.num_applause - 1)

applause = getattr(self, f"applause_{applause_type}")

return AF.add_noise(waveform=wav, noise=applause, snr=snr_dbs)

def apply_reduction(self, wav: torch.tensor):
"""
Limit the high-band pass filter, the low-band pass filter and the sample rate
Designed to mimic the effect of recording on a low-quality microphone or phone.
"""
wav = AF.highpass_biquad(wav, self.sample_rate, cutoff_freq=1200)
wav = AF.lowpass_biquad(wav, self.sample_rate, cutoff_freq=1400)
resample_rate = 6000
return AF.resample(wav, orig_freq=self.sample_rate, new_freq=resample_rate, lowpass_filter_width=3)
wav = AF.highpass_biquad(wav, self.sample_rate, cutoff_freq=300)
wav = AF.lowpass_biquad(wav, self.sample_rate, cutoff_freq=3400)
wav_downsampled = AF.resample(
wav,
orig_freq=self.sample_rate,
new_freq=self.reduction_resample_rate,
lowpass_filter_width=3,
)

return AF.resample(
wav_downsampled,
self.reduction_resample_rate,
self.sample_rate,
)

def apply_distortion(self, wav: torch.tensor):
gain = random.randint(self.min_dist_gain, self.max_dist_gain)
Expand Down Expand Up @@ -363,20 +400,23 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
return shifted_specs

def aug_wav(self, wav: torch.Tensor):
"""
pipeline for audio augmentation:
1. apply noise
2. apply distortion (x% of the time)
3. apply reduction (x% of the time)
4. apply reverb
"""
# Noise
if random.random() < self.noise_ratio:
wav = self.apply_noise(wav)
if random.random() < self.applause_ratio:
wav = self.apply_applause(wav)

wav = self.apply_noise(wav)
if random.random() < self.dist_ratio:
wav = self.apply_distortion(wav)
# Distortion
if random.random() < self.reduce_ratio:
wav = self.apply_reduction(wav)
return self.apply_reverb(wav)
elif random.random() < self.distort_ratio:
wav = self.apply_distortion(wav)

# Reverb
if random.random() < self.reverb_ratio:
return self.apply_reverb(wav)
else:
return wav

def norm_mel(self, mel_spec: torch.Tensor):
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
Expand All @@ -399,13 +439,13 @@ def log_mel(self, wav: torch.Tensor, shift: int | None = None):
return log_spec

def forward(self, wav: torch.Tensor, shift: int = 0):
# noise, distortion, reduction and reverb
# Noise, distortion, and reverb
wav = self.aug_wav(wav)

# Spec & pitch shift
log_mel = self.log_mel(wav, shift)

# Spec aug in 20% of the cases
# Spec aug
if random.random() < self.spec_aug_ratio:
log_mel = self.spec_aug(log_mel)

Expand Down
48 changes: 36 additions & 12 deletions amt/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import os


def midi_to_intervals_and_pitches(midi_file_path):
"""
This function reads a MIDI file and extracts note intervals and pitches
Expand Down Expand Up @@ -55,18 +56,26 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
if ref_fpath in ref_midi_files:
est_ref_pairs.append((est_fpath, ref_fpath))
if ref_fpath.replace(".mid", ".midi") in ref_midi_files:
est_ref_pairs.append((est_fpath, ref_fpath.replace(".mid", ".midi")))
est_ref_pairs.append(
(est_fpath, ref_fpath.replace(".mid", ".midi"))
)
else:
print(f"Reference file not found for {est_fpath} (ref file: {ref_fpath})")
print(
f"Reference file not found for {est_fpath} (ref file: {ref_fpath})"
)

output_fhandle = open(output_stats_file, "w") if output_stats_file is not None else None
output_fhandle = (
open(output_stats_file, "w") if output_stats_file is not None else None
)

for est_file, ref_file in tqdm(est_ref_pairs):
ref_intervals, ref_pitches = midi_to_intervals_and_pitches(ref_file)
est_intervals, est_pitches = midi_to_intervals_and_pitches(est_file)
ref_pitches_hz = midi_to_hz(ref_pitches)
est_pitches_hz = midi_to_hz(est_pitches, est_shift)
scores = mir_eval.transcription.evaluate(ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz)
scores = mir_eval.transcription.evaluate(
ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz
)
if output_fhandle is not None:
output_fhandle.write(json.dumps(scores))
output_fhandle.write("\n")
Expand All @@ -76,30 +85,43 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(usage="evaluate <command> [<args>]")
parser.add_argument(
"--est-dir",
type=str,
help="Path to the directory containing either the transcribed MIDI files or WAV files to be transcribed."
help="Path to the directory containing either the transcribed MIDI files or WAV files to be transcribed.",
)
parser.add_argument(
"--ref-dir",
type=str,
help="Path to the directory containing the reference files (we'll use gold MIDI for mir_eval, WAV for dtw)."
help="Path to the directory containing the reference files (we'll use gold MIDI for mir_eval, WAV for dtw).",
)
parser.add_argument(
'--output-stats-file',
"--output-stats-file",
default=None,
type=str, help="Path to the file to save the evaluation stats"
type=str,
help="Path to the file to save the evaluation stats",
)

# add mir_eval and dtw subparsers
subparsers = parser.add_subparsers(help="sub-command help")
mir_eval_parse = subparsers.add_parser("run_mir_eval", help="Run standard mir_eval evaluation on MAESTRO test set.")
mir_eval_parse.add_argument('--shift', type=int, default=0, help="Shift to apply to the estimated pitches.")
mir_eval_parse = subparsers.add_parser(
"run_mir_eval",
help="Run standard mir_eval evaluation on MAESTRO test set.",
)
mir_eval_parse.add_argument(
"--shift",
type=int,
default=0,
help="Shift to apply to the estimated pitches.",
)

# to come
dtw_eval_parse = subparsers.add_parser("run_dtw", help="Run dynamic time warping evaluation on a specified dataset.")
dtw_eval_parse = subparsers.add_parser(
"run_dtw",
help="Run dynamic time warping evaluation on a specified dataset.",
)

args = parser.parse_args()
if not hasattr(args, "command"):
Expand All @@ -112,6 +134,8 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
# -> We expect that baseline methods will fall flat on these, while aria-amt will be OK.

if args.command == "run_mir_eval":
evaluate_mir_eval(args.est_dir, args.ref_dir, args.output_stats_file, args.shift)
evaluate_mir_eval(
args.est_dir, args.ref_dir, args.output_stats_file, args.shift
)
elif args.command == "run_dtw":
pass
4 changes: 4 additions & 0 deletions amt/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

# TODO: Profile and fix gpu util


def calculate_vel(
logits: torch.Tensor,
init_vel: int,
Expand Down Expand Up @@ -89,6 +90,8 @@ def calculate_onset(

from functools import wraps
from torch.cuda import is_bf16_supported


def optional_bf16_autocast(func):
@wraps(func)
def wrapper(*args, **kwargs):
Expand All @@ -100,6 +103,7 @@ def wrapper(*args, **kwargs):
# Call the function with float16 if bfloat16 is not supported
with torch.autocast("cuda", dtype=torch.float16):
return func(*args, **kwargs)

return wrapper


Expand Down
37 changes: 25 additions & 12 deletions amt/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,25 @@ def _add_maestro_args(subparser):

def _add_transcribe_args(subparser):
subparser.add_argument("model_name", help="name of model config file")
subparser.add_argument('checkpoint_path', help="checkpoint path")
subparser.add_argument("-load_path", help="path to mp3/wav file", required=False)
subparser.add_argument("checkpoint_path", help="checkpoint path")
subparser.add_argument(
"-load_path", help="path to mp3/wav file", required=False
)
subparser.add_argument(
"-load_dir", help="dir containing mp3/wav files", required=False
)
subparser.add_argument("-save_dir", help="dir to save midi files", required=True)
subparser.add_argument(
"-save_dir", help="dir to save midi files", required=True
)
subparser.add_argument(
"-multi_gpu", help="use all GPUs", action="store_true", default=False
)
subparser.add_argument("-bs", help="batch size", type=int, default=16)


def build_maestro(maestro_dir, maestro_csv_file, train_file, val_file, test_file, num_procs):
def build_maestro(
maestro_dir, maestro_csv_file, train_file, val_file, test_file, num_procs
):
from amt.data import AmtDataset

assert os.path.isdir(maestro_dir), "MAESTRO directory not found"
Expand Down Expand Up @@ -101,9 +107,14 @@ def build_maestro(maestro_dir, maestro_csv_file, train_file, val_file, test_file


def transcribe(
model_name, checkpoint_path, save_dir, load_path=None, load_dir=None,
batch_size=16, multi_gpu=False,
augment=None,
model_name,
checkpoint_path,
save_dir,
load_path=None,
load_dir=None,
batch_size=16,
multi_gpu=False,
augment=None,
):
"""
Transcribe audio files to midi using the given model and checkpoint.
Expand Down Expand Up @@ -139,9 +150,7 @@ def transcribe(
assert os.path.isfile(checkpoint_path), "model checkpoint file not found"
assert load_path or load_dir, "must give either load path or dir"
if load_path:
assert os.path.isfile(
load_path
), f"audio file not found: {load_path}"
assert os.path.isfile(load_path), f"audio file not found: {load_path}"
trans_mode = "single"
if load_dir:
assert os.path.isdir(load_dir), "load directory doesn't exist"
Expand Down Expand Up @@ -232,8 +241,12 @@ def main():
parser = argparse.ArgumentParser(usage="amt <command> [<args>]")
subparsers = parser.add_subparsers(help="sub-command help")
# add maestro and transcribe subparsers
subparser_maestro = subparsers.add_parser("maestro", help="Commands to build the maestro dataset.")
subparser_transcribe = subparsers.add_parser("transcribe", help="Commands to run transcription.")
subparser_maestro = subparsers.add_parser(
"maestro", help="Commands to build the maestro dataset."
)
subparser_transcribe = subparsers.add_parser(
"transcribe", help="Commands to run transcription."
)
_add_maestro_args(subparser_maestro)
_add_transcribe_args(subparser_transcribe)

Expand Down
Loading

0 comments on commit b82a9da

Please sign in to comment.