Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added evals and changed Audio a little bit #23

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

.idea/
notebooks/scratch
baselines/hft_transformer/model_files/
experiments/baselines/hft_transformer/model_files/
experiments/baselines/google_t5/model_files/
experiments/aria-amt-intermediate-transcribed-data

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
104 changes: 82 additions & 22 deletions amt/audio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Contains code taken from https://github.com/openai/whisper"""

import functools
import os
import random
import torch
Expand Down Expand Up @@ -197,9 +197,11 @@ def __init__(
bandpass_ratio: float = 0.15,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
max_num_transforms: int = None, # currently we're doing 8 different transformations
detune_ratio: float = 0.0,
detune_max_shift: float = 0.0,
spec_aug_ratio: float = 0.9,

):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand All @@ -223,7 +225,14 @@ def __init__(
self.detune_ratio = detune_ratio
self.detune_max_shift = detune_max_shift
self.spec_aug_ratio = spec_aug_ratio

# the following two variables, `self.t_count` and `self.max_num_transforms`
# are state variables that track the # of transformations applied.
# `self.t_count` is set in `forward` method to 0
# `t_count` can also be passed into the following methods: `distortion_aug_cpu`, `log_mel`, `aug_wav`,
# the methods that we're stochastically applying transformations.
# a little messy/stateful, but helps the code be backwards compatible.
self.t_count = None
self.max_num_transforms = max_num_transforms
self.time_mask_param = 2500
self.freq_mask_param = 15
self.reduction_resample_rate = 6000
Expand Down Expand Up @@ -273,6 +282,34 @@ def __init__(
),
)

# inverse mel transform
self.inverse_mel = torchaudio.transforms.InverseMelScale(
n_mels=self.config["n_mels"],
sample_rate=self.config["sample_rate"],
n_stft=self.config["n_fft"] // 2 + 1,
)
self.inverse_spec_transform = torchaudio.transforms.GriffinLim(
n_fft=self.config["n_fft"],
hop_length=self.config["hop_len"],
)

def check_apply_transform(self, ratio: float):
"""
Check if a transformation should be applied based on the ratio and the
number of transformations already applied.
"""

if (
(self.max_num_transforms is not None) and
(self.t_count is not None) and
(self.t_count >= self.max_num_transforms)
):
return False
apply_transform = random.random() < ratio
if apply_transform:
self.t_count += 1
return apply_transform

def get_params(self):
return {
"noise_ratio": self.noise_ratio,
Expand Down Expand Up @@ -408,13 +445,16 @@ def apply_distortion(self, wav: torch.tensor):

return AF.overdrive(wav, gain=gain, colour=colour)

def distortion_aug_cpu(self, wav: torch.Tensor):
def distortion_aug_cpu(self, wav: torch.Tensor, t_count: int = None):
# This function should run on the cpu (i.e. in the dataloader collate
# function) in order to not be a bottlekneck
if t_count is not None:
self.t_count = t_count

if random.random() < self.reduce_ratio:
if self.check_apply_transform(self.reduce_ratio):
wav = self.apply_reduction(wav)
if random.random() < self.distort_ratio:

if self.check_apply_transform(self.distort_ratio):
wav = self.apply_distortion(wav)

return wav
Expand Down Expand Up @@ -445,34 +485,34 @@ def shift_spec(self, specs: torch.Tensor, shift: int | float):
return shifted_specs

def detune_spec(self, specs: torch.Tensor):
if random.random() < self.detune_ratio:
detune_shift = random.uniform(
-self.detune_max_shift, self.detune_max_shift
)
detuned_specs = self.shift_spec(specs, shift=detune_shift)
detune_shift = random.uniform(
-self.detune_max_shift, self.detune_max_shift
)
detuned_specs = self.shift_spec(specs, shift=detune_shift)

return (specs + detuned_specs) / 2
else:
return specs
specs = (specs + detuned_specs) / 2
return specs

def aug_wav(self, wav: torch.Tensor):
def aug_wav(self, wav: torch.Tensor, t_count: int = None):
# This function doesn't apply distortion. If distortion is desired it
# should be run beforehand on the cpu with distortion_aug_cpu. Note
# also that detuning is done to the spectrogram in log_mel, not the wav.
if t_count is not None:
self.t_count = t_count

# Noise
if random.random() < self.noise_ratio:
if self.check_apply_transform(self.noise_ratio):
wav = self.apply_noise(wav)

if random.random() < self.applause_ratio:
if self.check_apply_transform(self.applause_ratio):
wav = self.apply_applause(wav)

# Reverb
if random.random() < self.reverb_ratio:
if self.check_apply_transform(self.reverb_ratio):
wav = self.apply_reverb(wav)

# EQ
if random.random() < self.bandpass_ratio:
if self.check_apply_transform(self.bandpass_ratio):
wav = self.apply_bandpass(wav)

return wav
Expand All @@ -487,15 +527,25 @@ def norm_mel(self, mel_spec: torch.Tensor):
return log_spec

def log_mel(
self, wav: torch.Tensor, shift: int | None = None, detune: bool = False
self,
wav: torch.Tensor,
shift: int | None = None,
detune: bool = False,
t_count: int = None,
):
if t_count is not None:
self.t_count = t_count

spec = self.spec_transform(wav)[..., :-1]

# check: are detune and shift mutually exclusive?
# should we also put a ratio on shift?
if shift is not None and shift != 0:
spec = self.shift_spec(spec, shift)
elif detune is True:
# Don't detune and spec shift at the same time
spec = self.detune_spec(spec)
if self.check_apply_transform(self.detune_ratio):
# Don't detune and spec shift at the same time
spec = self.detune_spec(spec)

mel_spec = self.mel_transform(spec)

Expand All @@ -504,15 +554,25 @@ def log_mel(

return log_spec

def inverse_log_mel(self, mel: torch.Tensor):
"""
Takes as input a log mel spectrogram and returns the corresponding audio.
"""
mel = (4 * mel) - 4
mel = torch.pow(10, mel)
mel = self.inverse_mel(mel)
return self.inverse_spec_transform(mel)

def forward(self, wav: torch.Tensor, shift: int = 0):
# Noise, and reverb
self.t_count = 0
wav = self.aug_wav(wav)

# Spec, detuning & pitch shift
log_mel = self.log_mel(wav, shift, detune=True)

# Spec aug
if random.random() < self.spec_aug_ratio:
if self.check_apply_transform(self.spec_aug_ratio):
log_mel = self.spec_aug(log_mel)

return log_mel
67 changes: 0 additions & 67 deletions baselines/giantmidi/transcribe_new_files.py

This file was deleted.

3 changes: 0 additions & 3 deletions baselines/requirements-baselines.txt

This file was deleted.

50 changes: 50 additions & 0 deletions experiments/baselines/giantmidi/transcribe_new_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import argparse
import time
import torch
import piano_transcription_inference
import glob
from more_itertools import unique_everseen
from tqdm.auto import tqdm
from random import shuffle
import sys
here = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(here, '../..'))
import loader_util


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser. ')
parser = loader_util.add_io_arguments(parser)
args = parser.parse_args()

files_to_transcribe = loader_util.get_files_to_transcribe(args)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
transcriptor = piano_transcription_inference.PianoTranscription(device=device)

# Transcriptor
for n, (input_fname, output_fname) in tqdm(enumerate(files_to_transcribe), total=len(files_to_transcribe)):
if os.path.exists(output_fname):
continue

now_start = time.time()
(audio, _) = (piano_transcription_inference
.load_audio(input_fname, sr=piano_transcription_inference.sample_rate, mono=True))
print(f'READING ELAPSED TIME: {time.time() - now_start}')
now_read = time.time()
try:
# Transcribe
transcribed_dict = transcriptor.transcribe(audio, output_fname)
except:
print('Failed for this audio!')
print(f'TRANSCRIPTION ELAPSED TIME: {time.time() - now_read}')
print(f'TOTAL ELAPSED TIME: {time.time() - now_start}')



"""
python transcribe_new_files.py \
--input_dir_to_transcribe /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \
--output_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data/kong-model
"""
Loading
Loading