Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

synthetic data and baselines #19

Merged
merged 5 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
*.xml
*.html
*.htm
*.sf2

.idea/

notebooks/scratch
baselines/hft_transformer/model_files/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
17 changes: 17 additions & 0 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
bandpass_ratio: float = 0.1,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
codecs_ratio: float = 0.01,
spec_aug_ratio: float = 0.5,
):
super().__init__()
Expand All @@ -219,6 +220,7 @@ def __init__(
self.distort_ratio = distort_ratio
self.reduce_ratio = reduce_ratio
self.spec_aug_ratio = spec_aug_ratio
self.codecs_ratio = codecs_ratio
self.reduction_resample_rate = 6000 # Hardcoded?

# Audio aug
Expand Down Expand Up @@ -397,6 +399,20 @@ def distortion_aug_cpu(self, wav: torch.Tensor):

return wav

def apply_codec(self, wav: torch.tensor):
"""
Apply different audio codecs to the audio.
"""
format_encoder_pairs = [
("wav", "pcm_mulaw"),
("g722", None),
("ogg", "vorbis")
]
for format, encoder in format_encoder_pairs:
encoder = torchaudio.io.AudioEffector(format=format, encoder=encoder)
if random.random() < self.codecs_ratio:
wav = encoder.apply(wav, self.sample_rate)

def shift_spec(self, specs: torch.Tensor, shift: int):
if shift == 0:
return specs
Expand Down Expand Up @@ -429,6 +445,7 @@ def aug_wav(self, wav: torch.Tensor):
# Noise
if random.random() < self.noise_ratio:
wav = self.apply_noise(wav)

if random.random() < self.applause_ratio:
wav = self.apply_applause(wav)

Expand Down
75 changes: 75 additions & 0 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,81 @@
from amt.tokenizer import AmtTokenizer
from amt.config import load_config
from amt.audio import pad_or_trim
from midi2audio import FluidSynth
import random


class SyntheticMidiHandler:
def __init__(self, soundfont_path: str, soundfont_prob_dict: dict = None, num_wavs_per_midi: int = 1):
"""
File to load MIDI files and convert them to audio.

Parameters
----------
soundfont_path : str
Path to the directory containing soundfont files.
soundfont_prob_dict : dict, optional
Dictionary containing the probability of using a soundfont file.
The keys are the soundfont file names and the values are the
probability of using the soundfont file. If none is given, then
a uniform distribution is used.
num_wavs_per_midi : int, optional
Number of audio files to generate per MIDI file.
"""

self.soundfont_path = soundfont_path
self.soundfont_prob_dict = soundfont_prob_dict
self.num_wavs_per_midi = num_wavs_per_midi

self.fs_objs = self._load_soundfonts()
self.soundfont_cumul_prob_dict = self._get_cumulative_prob_dict()

def _load_soundfonts(self):
"""Loads the soundfonts into fluidsynth objects."""
fs_files = os.listdir(self.soundfont_path)
fs_objs = {}
for fs_file in fs_files:
fs_objs[fs_file] = FluidSynth(fs_file)
return fs_objs

def _get_cumulative_prob_dict(self):
"""Returns a dictionary with the cumulative probabilities of the soundfonts.
Used for sampling the soundfonts.
"""
if self.soundfont_prob_dict is None:
self.soundfont_prob_dict = {k: 1 / len(self.fs_objs) for k in self.fs_objs.keys()}
self.soundfont_prob_dict = {k: v / sum(self.soundfont_prob_dict.values())
for k, v in self.soundfont_prob_dict.items()}
cumul_prob_dict = {}
cumul_prob = 0
for k, v in self.soundfont_prob_dict.items():
cumul_prob_dict[k] = (cumul_prob, cumul_prob + v)
cumul_prob += v
return cumul_prob_dict

def _sample_soundfont(self):
"""Samples a soundfont file."""
rand_num = random.random()
for k, (v_s, v_e) in self.soundfont_cumul_prob_dict.items():
if (rand_num >= v_s) and (rand_num < v_e):
return self.fs_objs[k]

def get_wav(self, midi_path: str, save_path: str):
"""
Converts a MIDI file to audio.

Parameters
----------
midi_path : str
Path to the MIDI file.
save_path : str
Path to save the audio file.
"""
for i in range(self.num_wavs_per_midi):
soundfont = self._sample_soundfont()
if self.num_wavs_per_midi > 1:
save_path = save_path[:-4] + f"_{i}.wav"
soundfont.midi_to_audio(midi_path, save_path)


def get_wav_mid_segments(
Expand Down
67 changes: 67 additions & 0 deletions baselines/giantmidi/transcribe_new_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import argparse
import time
import torch
import piano_transcription_inference
import glob


def transcribe_piano(mp3s_dir, midis_dir, begin_index=None, end_index=None):
"""Transcribe piano solo mp3s to midi files."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.makedirs(midis_dir, exist_ok=True)

# Transcriptor
transcriptor = piano_transcription_inference.PianoTranscription(device=device)

transcribe_time = time.time()
for n, mp3_path in enumerate(glob.glob(os.path.join(mp3s_dir, '*.mp3'))[begin_index:end_index]):
print(n, mp3_path)
midi_file = os.path.basename(mp3_path).replace('.mp3', '.midi')
midi_path = os.path.join(midis_dir, midi_file)
if os.path.exists(midi_path):
continue

(audio, _) = (
piano_transcription_inference
.load_audio(mp3_path, sr=piano_transcription_inference.sample_rate, mono=True)
)

try:
# Transcribe
transcribed_dict = transcriptor.transcribe(audio, midi_path)
print(transcribed_dict)
except:
print('Failed for this audio!')

print('Time: {:.3f} s'.format(time.time() - transcribe_time))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser. ')
parser.add_argument('--mp3s_dir', type=str, required=True, help='')
parser.add_argument('--midis_dir', type=str, required=True, help='')
parser.add_argument(
'--begin_index', type=int, required=False,
help='File num., of an ordered list of files, to start transcribing from.', default=None
)
parser.add_argument(
'--end_index', type=int, required=False, default=None,
help='File num., of an ordered list of files, to end transcription.'
)

# Parse arguments
args = parser.parse_args()
transcribe_piano(
mp3s_dir=args.mp3s_dir,
midis_dir=args.midis_dir,
begin_index=args.begin_index,
end_index=args.end_index
)

"""
python transcribe_new_files.py \
transcribe_piano \
--mp3s_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \
--midis_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data/kong-model
"""
Loading
Loading